You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import json
  17. import logging
  18. import re
  19. import secrets
  20. from datetime import datetime
  21. from flask import redirect, request, session
  22. from flask_login import current_user, login_required, login_user, logout_user
  23. from werkzeug.security import check_password_hash, generate_password_hash
  24. from api import settings
  25. from api.apps.auth import get_auth_client
  26. from api.db import FileType, UserTenantRole
  27. from api.db.db_models import TenantLLM
  28. from api.db.services.file_service import FileService
  29. from api.db.services.llm_service import LLMService, TenantLLMService
  30. from api.db.services.user_service import TenantService, UserService, UserTenantService
  31. from api.utils import (
  32. current_timestamp,
  33. datetime_format,
  34. decrypt,
  35. download_img,
  36. get_format_time,
  37. get_uuid,
  38. )
  39. from api.utils.api_utils import (
  40. construct_response,
  41. get_data_error_result,
  42. get_json_result,
  43. server_error_response,
  44. validate_request,
  45. )
  46. @manager.route("/login", methods=["POST", "GET"]) # noqa: F821
  47. def login():
  48. """
  49. User login endpoint.
  50. ---
  51. tags:
  52. - User
  53. parameters:
  54. - in: body
  55. name: body
  56. description: Login credentials.
  57. required: true
  58. schema:
  59. type: object
  60. properties:
  61. email:
  62. type: string
  63. description: User email.
  64. password:
  65. type: string
  66. description: User password.
  67. responses:
  68. 200:
  69. description: Login successful.
  70. schema:
  71. type: object
  72. 401:
  73. description: Authentication failed.
  74. schema:
  75. type: object
  76. """
  77. if not request.json:
  78. return get_json_result(data=False, code=settings.RetCode.AUTHENTICATION_ERROR, message="Unauthorized!")
  79. email = request.json.get("email", "")
  80. users = UserService.query(email=email)
  81. if not users:
  82. return get_json_result(
  83. data=False,
  84. code=settings.RetCode.AUTHENTICATION_ERROR,
  85. message=f"Email: {email} is not registered!",
  86. )
  87. password = request.json.get("password")
  88. try:
  89. password = decrypt(password)
  90. except BaseException:
  91. return get_json_result(data=False, code=settings.RetCode.SERVER_ERROR, message="Fail to crypt password")
  92. user = UserService.query_user(email, password)
  93. if user:
  94. response_data = user.to_json()
  95. user.access_token = get_uuid()
  96. login_user(user)
  97. user.update_time = (current_timestamp(),)
  98. user.update_date = (datetime_format(datetime.now()),)
  99. user.save()
  100. msg = "Welcome back!"
  101. return construct_response(data=response_data, auth=user.get_id(), message=msg)
  102. else:
  103. return get_json_result(
  104. data=False,
  105. code=settings.RetCode.AUTHENTICATION_ERROR,
  106. message="Email and password do not match!",
  107. )
  108. @manager.route("/login/channels", methods=["GET"]) # noqa: F821
  109. def get_login_channels():
  110. """
  111. Get all supported authentication channels.
  112. """
  113. try:
  114. channels = []
  115. for channel, config in settings.OAUTH_CONFIG.items():
  116. channels.append(
  117. {
  118. "channel": channel,
  119. "display_name": config.get("display_name", channel.title()),
  120. "icon": config.get("icon", "sso"),
  121. }
  122. )
  123. return get_json_result(data=channels)
  124. except Exception as e:
  125. logging.exception(e)
  126. return get_json_result(data=[], message=f"Load channels failure, error: {str(e)}", code=settings.RetCode.EXCEPTION_ERROR)
  127. @manager.route("/login/<channel>", methods=["GET"]) # noqa: F821
  128. def oauth_login(channel):
  129. channel_config = settings.OAUTH_CONFIG.get(channel)
  130. if not channel_config:
  131. raise ValueError(f"Invalid channel name: {channel}")
  132. auth_cli = get_auth_client(channel_config)
  133. state = get_uuid()
  134. session["oauth_state"] = state
  135. auth_url = auth_cli.get_authorization_url(state)
  136. return redirect(auth_url)
  137. @manager.route("/oauth/callback/<channel>", methods=["GET"]) # noqa: F821
  138. def oauth_callback(channel):
  139. """
  140. Handle the OAuth/OIDC callback for various channels dynamically.
  141. """
  142. try:
  143. channel_config = settings.OAUTH_CONFIG.get(channel)
  144. if not channel_config:
  145. raise ValueError(f"Invalid channel name: {channel}")
  146. auth_cli = get_auth_client(channel_config)
  147. # Check the state
  148. state = request.args.get("state")
  149. if not state or state != session.get("oauth_state"):
  150. return redirect("/?error=invalid_state")
  151. session.pop("oauth_state", None)
  152. # Obtain the authorization code
  153. code = request.args.get("code")
  154. if not code:
  155. return redirect("/?error=missing_code")
  156. # Exchange authorization code for access token
  157. token_info = auth_cli.exchange_code_for_token(code)
  158. access_token = token_info.get("access_token")
  159. if not access_token:
  160. return redirect("/?error=token_failed")
  161. id_token = token_info.get("id_token")
  162. # Fetch user info
  163. user_info = auth_cli.fetch_user_info(access_token, id_token=id_token)
  164. if not user_info.email:
  165. return redirect("/?error=email_missing")
  166. # Login or register
  167. users = UserService.query(email=user_info.email)
  168. user_id = get_uuid()
  169. if not users:
  170. try:
  171. try:
  172. avatar = download_img(user_info.avatar_url)
  173. except Exception as e:
  174. logging.exception(e)
  175. avatar = ""
  176. users = user_register(
  177. user_id,
  178. {
  179. "access_token": get_uuid(),
  180. "email": user_info.email,
  181. "avatar": avatar,
  182. "nickname": user_info.nickname,
  183. "login_channel": channel,
  184. "last_login_time": get_format_time(),
  185. "is_superuser": False,
  186. },
  187. )
  188. if not users:
  189. raise Exception(f"Failed to register {user_info.email}")
  190. if len(users) > 1:
  191. raise Exception(f"Same email: {user_info.email} exists!")
  192. # Try to log in
  193. user = users[0]
  194. login_user(user)
  195. return redirect(f"/?auth={user.get_id()}")
  196. except Exception as e:
  197. rollback_user_registration(user_id)
  198. logging.exception(e)
  199. return redirect(f"/?error={str(e)}")
  200. # User exists, try to log in
  201. user = users[0]
  202. user.access_token = get_uuid()
  203. login_user(user)
  204. user.save()
  205. return redirect(f"/?auth={user.get_id()}")
  206. except Exception as e:
  207. logging.exception(e)
  208. return redirect(f"/?error={str(e)}")
  209. @manager.route("/github_callback", methods=["GET"]) # noqa: F821
  210. def github_callback():
  211. """
  212. **Deprecated**, Use `/oauth/callback/<channel>` instead.
  213. GitHub OAuth callback endpoint.
  214. ---
  215. tags:
  216. - OAuth
  217. parameters:
  218. - in: query
  219. name: code
  220. type: string
  221. required: true
  222. description: Authorization code from GitHub.
  223. responses:
  224. 200:
  225. description: Authentication successful.
  226. schema:
  227. type: object
  228. """
  229. import requests
  230. res = requests.post(
  231. settings.GITHUB_OAUTH.get("url"),
  232. data={
  233. "client_id": settings.GITHUB_OAUTH.get("client_id"),
  234. "client_secret": settings.GITHUB_OAUTH.get("secret_key"),
  235. "code": request.args.get("code"),
  236. },
  237. headers={"Accept": "application/json"},
  238. )
  239. res = res.json()
  240. if "error" in res:
  241. return redirect("/?error=%s" % res["error_description"])
  242. if "user:email" not in res["scope"].split(","):
  243. return redirect("/?error=user:email not in scope")
  244. session["access_token"] = res["access_token"]
  245. session["access_token_from"] = "github"
  246. user_info = user_info_from_github(session["access_token"])
  247. email_address = user_info["email"]
  248. users = UserService.query(email=email_address)
  249. user_id = get_uuid()
  250. if not users:
  251. # User isn't try to register
  252. try:
  253. try:
  254. avatar = download_img(user_info["avatar_url"])
  255. except Exception as e:
  256. logging.exception(e)
  257. avatar = ""
  258. users = user_register(
  259. user_id,
  260. {
  261. "access_token": session["access_token"],
  262. "email": email_address,
  263. "avatar": avatar,
  264. "nickname": user_info["login"],
  265. "login_channel": "github",
  266. "last_login_time": get_format_time(),
  267. "is_superuser": False,
  268. },
  269. )
  270. if not users:
  271. raise Exception(f"Fail to register {email_address}.")
  272. if len(users) > 1:
  273. raise Exception(f"Same email: {email_address} exists!")
  274. # Try to log in
  275. user = users[0]
  276. login_user(user)
  277. return redirect("/?auth=%s" % user.get_id())
  278. except Exception as e:
  279. rollback_user_registration(user_id)
  280. logging.exception(e)
  281. return redirect("/?error=%s" % str(e))
  282. # User has already registered, try to log in
  283. user = users[0]
  284. user.access_token = get_uuid()
  285. login_user(user)
  286. user.save()
  287. return redirect("/?auth=%s" % user.get_id())
  288. @manager.route("/feishu_callback", methods=["GET"]) # noqa: F821
  289. def feishu_callback():
  290. """
  291. Feishu OAuth callback endpoint.
  292. ---
  293. tags:
  294. - OAuth
  295. parameters:
  296. - in: query
  297. name: code
  298. type: string
  299. required: true
  300. description: Authorization code from Feishu.
  301. responses:
  302. 200:
  303. description: Authentication successful.
  304. schema:
  305. type: object
  306. """
  307. import requests
  308. app_access_token_res = requests.post(
  309. settings.FEISHU_OAUTH.get("app_access_token_url"),
  310. data=json.dumps(
  311. {
  312. "app_id": settings.FEISHU_OAUTH.get("app_id"),
  313. "app_secret": settings.FEISHU_OAUTH.get("app_secret"),
  314. }
  315. ),
  316. headers={"Content-Type": "application/json; charset=utf-8"},
  317. )
  318. app_access_token_res = app_access_token_res.json()
  319. if app_access_token_res["code"] != 0:
  320. return redirect("/?error=%s" % app_access_token_res)
  321. res = requests.post(
  322. settings.FEISHU_OAUTH.get("user_access_token_url"),
  323. data=json.dumps(
  324. {
  325. "grant_type": settings.FEISHU_OAUTH.get("grant_type"),
  326. "code": request.args.get("code"),
  327. }
  328. ),
  329. headers={
  330. "Content-Type": "application/json; charset=utf-8",
  331. "Authorization": f"Bearer {app_access_token_res['app_access_token']}",
  332. },
  333. )
  334. res = res.json()
  335. if res["code"] != 0:
  336. return redirect("/?error=%s" % res["message"])
  337. if "contact:user.email:readonly" not in res["data"]["scope"].split():
  338. return redirect("/?error=contact:user.email:readonly not in scope")
  339. session["access_token"] = res["data"]["access_token"]
  340. session["access_token_from"] = "feishu"
  341. user_info = user_info_from_feishu(session["access_token"])
  342. email_address = user_info["email"]
  343. users = UserService.query(email=email_address)
  344. user_id = get_uuid()
  345. if not users:
  346. # User isn't try to register
  347. try:
  348. try:
  349. avatar = download_img(user_info["avatar_url"])
  350. except Exception as e:
  351. logging.exception(e)
  352. avatar = ""
  353. users = user_register(
  354. user_id,
  355. {
  356. "access_token": session["access_token"],
  357. "email": email_address,
  358. "avatar": avatar,
  359. "nickname": user_info["en_name"],
  360. "login_channel": "feishu",
  361. "last_login_time": get_format_time(),
  362. "is_superuser": False,
  363. },
  364. )
  365. if not users:
  366. raise Exception(f"Fail to register {email_address}.")
  367. if len(users) > 1:
  368. raise Exception(f"Same email: {email_address} exists!")
  369. # Try to log in
  370. user = users[0]
  371. login_user(user)
  372. return redirect("/?auth=%s" % user.get_id())
  373. except Exception as e:
  374. rollback_user_registration(user_id)
  375. logging.exception(e)
  376. return redirect("/?error=%s" % str(e))
  377. # User has already registered, try to log in
  378. user = users[0]
  379. user.access_token = get_uuid()
  380. login_user(user)
  381. user.save()
  382. return redirect("/?auth=%s" % user.get_id())
  383. def user_info_from_feishu(access_token):
  384. import requests
  385. headers = {
  386. "Content-Type": "application/json; charset=utf-8",
  387. "Authorization": f"Bearer {access_token}",
  388. }
  389. res = requests.get("https://open.feishu.cn/open-apis/authen/v1/user_info", headers=headers)
  390. user_info = res.json()["data"]
  391. user_info["email"] = None if user_info.get("email") == "" else user_info["email"]
  392. return user_info
  393. def user_info_from_github(access_token):
  394. import requests
  395. headers = {"Accept": "application/json", "Authorization": f"token {access_token}"}
  396. res = requests.get(f"https://api.github.com/user?access_token={access_token}", headers=headers)
  397. user_info = res.json()
  398. email_info = requests.get(
  399. f"https://api.github.com/user/emails?access_token={access_token}",
  400. headers=headers,
  401. ).json()
  402. user_info["email"] = next((email for email in email_info if email["primary"]), None)["email"]
  403. return user_info
  404. @manager.route("/logout", methods=["GET"]) # noqa: F821
  405. @login_required
  406. def log_out():
  407. """
  408. User logout endpoint.
  409. ---
  410. tags:
  411. - User
  412. security:
  413. - ApiKeyAuth: []
  414. responses:
  415. 200:
  416. description: Logout successful.
  417. schema:
  418. type: object
  419. """
  420. current_user.access_token = f"INVALID_{secrets.token_hex(16)}"
  421. current_user.save()
  422. logout_user()
  423. return get_json_result(data=True)
  424. @manager.route("/setting", methods=["POST"]) # noqa: F821
  425. @login_required
  426. def setting_user():
  427. """
  428. Update user settings.
  429. ---
  430. tags:
  431. - User
  432. security:
  433. - ApiKeyAuth: []
  434. parameters:
  435. - in: body
  436. name: body
  437. description: User settings to update.
  438. required: true
  439. schema:
  440. type: object
  441. properties:
  442. nickname:
  443. type: string
  444. description: New nickname.
  445. email:
  446. type: string
  447. description: New email.
  448. responses:
  449. 200:
  450. description: Settings updated successfully.
  451. schema:
  452. type: object
  453. """
  454. update_dict = {}
  455. request_data = request.json
  456. if request_data.get("password"):
  457. new_password = request_data.get("new_password")
  458. if not check_password_hash(current_user.password, decrypt(request_data["password"])):
  459. return get_json_result(
  460. data=False,
  461. code=settings.RetCode.AUTHENTICATION_ERROR,
  462. message="Password error!",
  463. )
  464. if new_password:
  465. update_dict["password"] = generate_password_hash(decrypt(new_password))
  466. for k in request_data.keys():
  467. if k in [
  468. "password",
  469. "new_password",
  470. "email",
  471. "status",
  472. "is_superuser",
  473. "login_channel",
  474. "is_anonymous",
  475. "is_active",
  476. "is_authenticated",
  477. "last_login_time",
  478. ]:
  479. continue
  480. update_dict[k] = request_data[k]
  481. try:
  482. UserService.update_by_id(current_user.id, update_dict)
  483. return get_json_result(data=True)
  484. except Exception as e:
  485. logging.exception(e)
  486. return get_json_result(data=False, message="Update failure!", code=settings.RetCode.EXCEPTION_ERROR)
  487. @manager.route("/info", methods=["GET"]) # noqa: F821
  488. @login_required
  489. def user_profile():
  490. """
  491. Get user profile information.
  492. ---
  493. tags:
  494. - User
  495. security:
  496. - ApiKeyAuth: []
  497. responses:
  498. 200:
  499. description: User profile retrieved successfully.
  500. schema:
  501. type: object
  502. properties:
  503. id:
  504. type: string
  505. description: User ID.
  506. nickname:
  507. type: string
  508. description: User nickname.
  509. email:
  510. type: string
  511. description: User email.
  512. """
  513. return get_json_result(data=current_user.to_dict())
  514. def rollback_user_registration(user_id):
  515. try:
  516. UserService.delete_by_id(user_id)
  517. except Exception:
  518. pass
  519. try:
  520. TenantService.delete_by_id(user_id)
  521. except Exception:
  522. pass
  523. try:
  524. u = UserTenantService.query(tenant_id=user_id)
  525. if u:
  526. UserTenantService.delete_by_id(u[0].id)
  527. except Exception:
  528. pass
  529. try:
  530. TenantLLM.delete().where(TenantLLM.tenant_id == user_id).execute()
  531. except Exception:
  532. pass
  533. def user_register(user_id, user):
  534. user["id"] = user_id
  535. tenant = {
  536. "id": user_id,
  537. "name": user["nickname"] + "‘s Kingdom",
  538. "llm_id": settings.CHAT_MDL,
  539. "embd_id": settings.EMBEDDING_MDL,
  540. "asr_id": settings.ASR_MDL,
  541. "parser_ids": settings.PARSERS,
  542. "img2txt_id": settings.IMAGE2TEXT_MDL,
  543. "rerank_id": settings.RERANK_MDL,
  544. }
  545. usr_tenant = {
  546. "tenant_id": user_id,
  547. "user_id": user_id,
  548. "invited_by": user_id,
  549. "role": UserTenantRole.OWNER,
  550. }
  551. file_id = get_uuid()
  552. file = {
  553. "id": file_id,
  554. "parent_id": file_id,
  555. "tenant_id": user_id,
  556. "created_by": user_id,
  557. "name": "/",
  558. "type": FileType.FOLDER.value,
  559. "size": 0,
  560. "location": "",
  561. }
  562. tenant_llm = []
  563. for llm in LLMService.query(fid=settings.LLM_FACTORY):
  564. tenant_llm.append(
  565. {
  566. "tenant_id": user_id,
  567. "llm_factory": settings.LLM_FACTORY,
  568. "llm_name": llm.llm_name,
  569. "model_type": llm.model_type,
  570. "api_key": settings.API_KEY,
  571. "api_base": settings.LLM_BASE_URL,
  572. "max_tokens": llm.max_tokens if llm.max_tokens else 8192,
  573. }
  574. )
  575. if settings.LIGHTEN != 1:
  576. for buildin_embedding_model in settings.BUILTIN_EMBEDDING_MODELS:
  577. mdlnm, fid = TenantLLMService.split_model_name_and_factory(buildin_embedding_model)
  578. tenant_llm.append(
  579. {
  580. "tenant_id": user_id,
  581. "llm_factory": fid,
  582. "llm_name": mdlnm,
  583. "model_type": "embedding",
  584. "api_key": "",
  585. "api_base": "",
  586. "max_tokens": 1024 if buildin_embedding_model == "BAAI/bge-large-zh-v1.5@BAAI" else 512,
  587. }
  588. )
  589. if not UserService.save(**user):
  590. return
  591. TenantService.insert(**tenant)
  592. UserTenantService.insert(**usr_tenant)
  593. TenantLLMService.insert_many(tenant_llm)
  594. FileService.insert(file)
  595. return UserService.query(email=user["email"])
  596. @manager.route("/register", methods=["POST"]) # noqa: F821
  597. @validate_request("nickname", "email", "password")
  598. def user_add():
  599. """
  600. Register a new user.
  601. ---
  602. tags:
  603. - User
  604. parameters:
  605. - in: body
  606. name: body
  607. description: Registration details.
  608. required: true
  609. schema:
  610. type: object
  611. properties:
  612. nickname:
  613. type: string
  614. description: User nickname.
  615. email:
  616. type: string
  617. description: User email.
  618. password:
  619. type: string
  620. description: User password.
  621. responses:
  622. 200:
  623. description: Registration successful.
  624. schema:
  625. type: object
  626. """
  627. if not settings.REGISTER_ENABLED:
  628. return get_json_result(
  629. data=False,
  630. message="User registration is disabled!",
  631. code=settings.RetCode.OPERATING_ERROR,
  632. )
  633. req = request.json
  634. email_address = req["email"]
  635. # Validate the email address
  636. if not re.match(r"^[\w\._-]+@([\w_-]+\.)+[\w-]{2,}$", email_address):
  637. return get_json_result(
  638. data=False,
  639. message=f"Invalid email address: {email_address}!",
  640. code=settings.RetCode.OPERATING_ERROR,
  641. )
  642. # Check if the email address is already used
  643. if UserService.query(email=email_address):
  644. return get_json_result(
  645. data=False,
  646. message=f"Email: {email_address} has already registered!",
  647. code=settings.RetCode.OPERATING_ERROR,
  648. )
  649. # Construct user info data
  650. nickname = req["nickname"]
  651. user_dict = {
  652. "access_token": get_uuid(),
  653. "email": email_address,
  654. "nickname": nickname,
  655. "password": decrypt(req["password"]),
  656. "login_channel": "password",
  657. "last_login_time": get_format_time(),
  658. "is_superuser": False,
  659. }
  660. user_id = get_uuid()
  661. try:
  662. users = user_register(user_id, user_dict)
  663. if not users:
  664. raise Exception(f"Fail to register {email_address}.")
  665. if len(users) > 1:
  666. raise Exception(f"Same email: {email_address} exists!")
  667. user = users[0]
  668. login_user(user)
  669. return construct_response(
  670. data=user.to_json(),
  671. auth=user.get_id(),
  672. message=f"{nickname}, welcome aboard!",
  673. )
  674. except Exception as e:
  675. rollback_user_registration(user_id)
  676. logging.exception(e)
  677. return get_json_result(
  678. data=False,
  679. message=f"User registration failure, error: {str(e)}",
  680. code=settings.RetCode.EXCEPTION_ERROR,
  681. )
  682. @manager.route("/tenant_info", methods=["GET"]) # noqa: F821
  683. @login_required
  684. def tenant_info():
  685. """
  686. Get tenant information.
  687. ---
  688. tags:
  689. - Tenant
  690. security:
  691. - ApiKeyAuth: []
  692. responses:
  693. 200:
  694. description: Tenant information retrieved successfully.
  695. schema:
  696. type: object
  697. properties:
  698. tenant_id:
  699. type: string
  700. description: Tenant ID.
  701. name:
  702. type: string
  703. description: Tenant name.
  704. llm_id:
  705. type: string
  706. description: LLM ID.
  707. embd_id:
  708. type: string
  709. description: Embedding model ID.
  710. """
  711. try:
  712. tenants = TenantService.get_info_by(current_user.id)
  713. if not tenants:
  714. return get_data_error_result(message="Tenant not found!")
  715. return get_json_result(data=tenants[0])
  716. except Exception as e:
  717. return server_error_response(e)
  718. @manager.route("/set_tenant_info", methods=["POST"]) # noqa: F821
  719. @login_required
  720. @validate_request("tenant_id", "asr_id", "embd_id", "img2txt_id", "llm_id")
  721. def set_tenant_info():
  722. """
  723. Update tenant information.
  724. ---
  725. tags:
  726. - Tenant
  727. security:
  728. - ApiKeyAuth: []
  729. parameters:
  730. - in: body
  731. name: body
  732. description: Tenant information to update.
  733. required: true
  734. schema:
  735. type: object
  736. properties:
  737. tenant_id:
  738. type: string
  739. description: Tenant ID.
  740. llm_id:
  741. type: string
  742. description: LLM ID.
  743. embd_id:
  744. type: string
  745. description: Embedding model ID.
  746. asr_id:
  747. type: string
  748. description: ASR model ID.
  749. img2txt_id:
  750. type: string
  751. description: Image to Text model ID.
  752. responses:
  753. 200:
  754. description: Tenant information updated successfully.
  755. schema:
  756. type: object
  757. """
  758. req = request.json
  759. try:
  760. tid = req.pop("tenant_id")
  761. TenantService.update_by_id(tid, req)
  762. return get_json_result(data=True)
  763. except Exception as e:
  764. return server_error_response(e)