You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

user_app.py 24KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import logging
  17. import json
  18. import re
  19. from datetime import datetime
  20. from flask import request, session, redirect
  21. from werkzeug.security import generate_password_hash, check_password_hash
  22. from flask_login import login_required, current_user, login_user, logout_user
  23. from api.db.db_models import TenantLLM
  24. from api.db.services.llm_service import TenantLLMService, LLMService
  25. from api.utils.api_utils import (
  26. server_error_response,
  27. validate_request,
  28. get_data_error_result,
  29. )
  30. from api.utils import (
  31. get_uuid,
  32. get_format_time,
  33. decrypt,
  34. download_img,
  35. current_timestamp,
  36. datetime_format,
  37. )
  38. from api.db import UserTenantRole, FileType
  39. from api import settings
  40. from api.db.services.user_service import UserService, TenantService, UserTenantService
  41. from api.db.services.file_service import FileService
  42. from api.utils.api_utils import get_json_result, construct_response
  43. from api.apps.auth import get_auth_client
  44. @manager.route("/login", methods=["POST", "GET"]) # noqa: F821
  45. def login():
  46. """
  47. User login endpoint.
  48. ---
  49. tags:
  50. - User
  51. parameters:
  52. - in: body
  53. name: body
  54. description: Login credentials.
  55. required: true
  56. schema:
  57. type: object
  58. properties:
  59. email:
  60. type: string
  61. description: User email.
  62. password:
  63. type: string
  64. description: User password.
  65. responses:
  66. 200:
  67. description: Login successful.
  68. schema:
  69. type: object
  70. 401:
  71. description: Authentication failed.
  72. schema:
  73. type: object
  74. """
  75. if not request.json:
  76. return get_json_result(
  77. data=False, code=settings.RetCode.AUTHENTICATION_ERROR, message="Unauthorized!"
  78. )
  79. email = request.json.get("email", "")
  80. users = UserService.query(email=email)
  81. if not users:
  82. return get_json_result(
  83. data=False,
  84. code=settings.RetCode.AUTHENTICATION_ERROR,
  85. message=f"Email: {email} is not registered!",
  86. )
  87. password = request.json.get("password")
  88. try:
  89. password = decrypt(password)
  90. except BaseException:
  91. return get_json_result(
  92. data=False, code=settings.RetCode.SERVER_ERROR, message="Fail to crypt password"
  93. )
  94. user = UserService.query_user(email, password)
  95. if user:
  96. response_data = user.to_json()
  97. user.access_token = get_uuid()
  98. login_user(user)
  99. user.update_time = (current_timestamp(),)
  100. user.update_date = (datetime_format(datetime.now()),)
  101. user.save()
  102. msg = "Welcome back!"
  103. return construct_response(data=response_data, auth=user.get_id(), message=msg)
  104. else:
  105. return get_json_result(
  106. data=False,
  107. code=settings.RetCode.AUTHENTICATION_ERROR,
  108. message="Email and password do not match!",
  109. )
  110. @manager.route("/login/channels", methods=["GET"]) # noqa: F821
  111. def get_login_channels():
  112. """
  113. Get all supported authentication channels.
  114. """
  115. try:
  116. channels = []
  117. for channel, config in settings.OAUTH_CONFIG.items():
  118. channels.append({
  119. "channel": channel,
  120. "display_name": config.get("display_name", channel.title()),
  121. "icon": config.get("icon", "sso"),
  122. })
  123. return get_json_result(data=channels)
  124. except Exception as e:
  125. logging.exception(e)
  126. return get_json_result(
  127. data=[],
  128. message=f"Load channels failure, error: {str(e)}",
  129. code=settings.RetCode.EXCEPTION_ERROR
  130. )
  131. @manager.route("/login/<channel>", methods=["GET"]) # noqa: F821
  132. def oauth_login(channel):
  133. channel_config = settings.OAUTH_CONFIG.get(channel)
  134. if not channel_config:
  135. raise ValueError(f"Invalid channel name: {channel}")
  136. auth_cli = get_auth_client(channel_config)
  137. state = get_uuid()
  138. session["oauth_state"] = state
  139. auth_url = auth_cli.get_authorization_url(state)
  140. return redirect(auth_url)
  141. @manager.route("/oauth/callback/<channel>", methods=["GET"]) # noqa: F821
  142. def oauth_callback(channel):
  143. """
  144. Handle the OAuth/OIDC callback for various channels dynamically.
  145. """
  146. try:
  147. channel_config = settings.OAUTH_CONFIG.get(channel)
  148. if not channel_config:
  149. raise ValueError(f"Invalid channel name: {channel}")
  150. auth_cli = get_auth_client(channel_config)
  151. # Check the state
  152. state = request.args.get("state")
  153. if not state or state != session.get("oauth_state"):
  154. return redirect("/?error=invalid_state")
  155. session.pop("oauth_state", None)
  156. # Obtain the authorization code
  157. code = request.args.get("code")
  158. if not code:
  159. return redirect("/?error=missing_code")
  160. # Exchange authorization code for access token
  161. token_info = auth_cli.exchange_code_for_token(code)
  162. access_token = token_info.get("access_token")
  163. if not access_token:
  164. return redirect("/?error=token_failed")
  165. id_token = token_info.get("id_token")
  166. # Fetch user info
  167. user_info = auth_cli.fetch_user_info(access_token, id_token=id_token)
  168. if not user_info.email:
  169. return redirect("/?error=email_missing")
  170. # Login or register
  171. users = UserService.query(email=user_info.email)
  172. user_id = get_uuid()
  173. if not users:
  174. try:
  175. try:
  176. avatar = download_img(user_info.avatar_url)
  177. except Exception as e:
  178. logging.exception(e)
  179. avatar = ""
  180. users = user_register(
  181. user_id,
  182. {
  183. "access_token": get_uuid(),
  184. "email": user_info.email,
  185. "avatar": avatar,
  186. "nickname": user_info.nickname,
  187. "login_channel": channel,
  188. "last_login_time": get_format_time(),
  189. "is_superuser": False,
  190. },
  191. )
  192. if not users:
  193. raise Exception(f"Failed to register {user_info.email}")
  194. if len(users) > 1:
  195. raise Exception(f"Same email: {user_info.email} exists!")
  196. # Try to log in
  197. user = users[0]
  198. login_user(user)
  199. return redirect(f"/?auth={user.get_id()}")
  200. except Exception as e:
  201. rollback_user_registration(user_id)
  202. logging.exception(e)
  203. return redirect(f"/?error={str(e)}")
  204. # User exists, try to log in
  205. user = users[0]
  206. user.access_token = get_uuid()
  207. login_user(user)
  208. user.save()
  209. return redirect(f"/?auth={user.get_id()}")
  210. except Exception as e:
  211. logging.exception(e)
  212. return redirect(f"/?error={str(e)}")
  213. @manager.route("/github_callback", methods=["GET"]) # noqa: F821
  214. def github_callback():
  215. """
  216. **Deprecated**, Use `/oauth/callback/<channel>` instead.
  217. GitHub OAuth callback endpoint.
  218. ---
  219. tags:
  220. - OAuth
  221. parameters:
  222. - in: query
  223. name: code
  224. type: string
  225. required: true
  226. description: Authorization code from GitHub.
  227. responses:
  228. 200:
  229. description: Authentication successful.
  230. schema:
  231. type: object
  232. """
  233. import requests
  234. res = requests.post(
  235. settings.GITHUB_OAUTH.get("url"),
  236. data={
  237. "client_id": settings.GITHUB_OAUTH.get("client_id"),
  238. "client_secret": settings.GITHUB_OAUTH.get("secret_key"),
  239. "code": request.args.get("code"),
  240. },
  241. headers={"Accept": "application/json"},
  242. )
  243. res = res.json()
  244. if "error" in res:
  245. return redirect("/?error=%s" % res["error_description"])
  246. if "user:email" not in res["scope"].split(","):
  247. return redirect("/?error=user:email not in scope")
  248. session["access_token"] = res["access_token"]
  249. session["access_token_from"] = "github"
  250. user_info = user_info_from_github(session["access_token"])
  251. email_address = user_info["email"]
  252. users = UserService.query(email=email_address)
  253. user_id = get_uuid()
  254. if not users:
  255. # User isn't try to register
  256. try:
  257. try:
  258. avatar = download_img(user_info["avatar_url"])
  259. except Exception as e:
  260. logging.exception(e)
  261. avatar = ""
  262. users = user_register(
  263. user_id,
  264. {
  265. "access_token": session["access_token"],
  266. "email": email_address,
  267. "avatar": avatar,
  268. "nickname": user_info["login"],
  269. "login_channel": "github",
  270. "last_login_time": get_format_time(),
  271. "is_superuser": False,
  272. },
  273. )
  274. if not users:
  275. raise Exception(f"Fail to register {email_address}.")
  276. if len(users) > 1:
  277. raise Exception(f"Same email: {email_address} exists!")
  278. # Try to log in
  279. user = users[0]
  280. login_user(user)
  281. return redirect("/?auth=%s" % user.get_id())
  282. except Exception as e:
  283. rollback_user_registration(user_id)
  284. logging.exception(e)
  285. return redirect("/?error=%s" % str(e))
  286. # User has already registered, try to log in
  287. user = users[0]
  288. user.access_token = get_uuid()
  289. login_user(user)
  290. user.save()
  291. return redirect("/?auth=%s" % user.get_id())
  292. @manager.route("/feishu_callback", methods=["GET"]) # noqa: F821
  293. def feishu_callback():
  294. """
  295. Feishu OAuth callback endpoint.
  296. ---
  297. tags:
  298. - OAuth
  299. parameters:
  300. - in: query
  301. name: code
  302. type: string
  303. required: true
  304. description: Authorization code from Feishu.
  305. responses:
  306. 200:
  307. description: Authentication successful.
  308. schema:
  309. type: object
  310. """
  311. import requests
  312. app_access_token_res = requests.post(
  313. settings.FEISHU_OAUTH.get("app_access_token_url"),
  314. data=json.dumps(
  315. {
  316. "app_id": settings.FEISHU_OAUTH.get("app_id"),
  317. "app_secret": settings.FEISHU_OAUTH.get("app_secret"),
  318. }
  319. ),
  320. headers={"Content-Type": "application/json; charset=utf-8"},
  321. )
  322. app_access_token_res = app_access_token_res.json()
  323. if app_access_token_res["code"] != 0:
  324. return redirect("/?error=%s" % app_access_token_res)
  325. res = requests.post(
  326. settings.FEISHU_OAUTH.get("user_access_token_url"),
  327. data=json.dumps(
  328. {
  329. "grant_type": settings.FEISHU_OAUTH.get("grant_type"),
  330. "code": request.args.get("code"),
  331. }
  332. ),
  333. headers={
  334. "Content-Type": "application/json; charset=utf-8",
  335. "Authorization": f"Bearer {app_access_token_res['app_access_token']}",
  336. },
  337. )
  338. res = res.json()
  339. if res["code"] != 0:
  340. return redirect("/?error=%s" % res["message"])
  341. if "contact:user.email:readonly" not in res["data"]["scope"].split():
  342. return redirect("/?error=contact:user.email:readonly not in scope")
  343. session["access_token"] = res["data"]["access_token"]
  344. session["access_token_from"] = "feishu"
  345. user_info = user_info_from_feishu(session["access_token"])
  346. email_address = user_info["email"]
  347. users = UserService.query(email=email_address)
  348. user_id = get_uuid()
  349. if not users:
  350. # User isn't try to register
  351. try:
  352. try:
  353. avatar = download_img(user_info["avatar_url"])
  354. except Exception as e:
  355. logging.exception(e)
  356. avatar = ""
  357. users = user_register(
  358. user_id,
  359. {
  360. "access_token": session["access_token"],
  361. "email": email_address,
  362. "avatar": avatar,
  363. "nickname": user_info["en_name"],
  364. "login_channel": "feishu",
  365. "last_login_time": get_format_time(),
  366. "is_superuser": False,
  367. },
  368. )
  369. if not users:
  370. raise Exception(f"Fail to register {email_address}.")
  371. if len(users) > 1:
  372. raise Exception(f"Same email: {email_address} exists!")
  373. # Try to log in
  374. user = users[0]
  375. login_user(user)
  376. return redirect("/?auth=%s" % user.get_id())
  377. except Exception as e:
  378. rollback_user_registration(user_id)
  379. logging.exception(e)
  380. return redirect("/?error=%s" % str(e))
  381. # User has already registered, try to log in
  382. user = users[0]
  383. user.access_token = get_uuid()
  384. login_user(user)
  385. user.save()
  386. return redirect("/?auth=%s" % user.get_id())
  387. def user_info_from_feishu(access_token):
  388. import requests
  389. headers = {
  390. "Content-Type": "application/json; charset=utf-8",
  391. "Authorization": f"Bearer {access_token}",
  392. }
  393. res = requests.get(
  394. "https://open.feishu.cn/open-apis/authen/v1/user_info", headers=headers
  395. )
  396. user_info = res.json()["data"]
  397. user_info["email"] = None if user_info.get("email") == "" else user_info["email"]
  398. return user_info
  399. def user_info_from_github(access_token):
  400. import requests
  401. headers = {"Accept": "application/json", "Authorization": f"token {access_token}"}
  402. res = requests.get(
  403. f"https://api.github.com/user?access_token={access_token}", headers=headers
  404. )
  405. user_info = res.json()
  406. email_info = requests.get(
  407. f"https://api.github.com/user/emails?access_token={access_token}",
  408. headers=headers,
  409. ).json()
  410. user_info["email"] = next(
  411. (email for email in email_info if email["primary"]), None
  412. )["email"]
  413. return user_info
  414. @manager.route("/logout", methods=["GET"]) # noqa: F821
  415. @login_required
  416. def log_out():
  417. """
  418. User logout endpoint.
  419. ---
  420. tags:
  421. - User
  422. security:
  423. - ApiKeyAuth: []
  424. responses:
  425. 200:
  426. description: Logout successful.
  427. schema:
  428. type: object
  429. """
  430. current_user.access_token = ""
  431. current_user.save()
  432. logout_user()
  433. return get_json_result(data=True)
  434. @manager.route("/setting", methods=["POST"]) # noqa: F821
  435. @login_required
  436. def setting_user():
  437. """
  438. Update user settings.
  439. ---
  440. tags:
  441. - User
  442. security:
  443. - ApiKeyAuth: []
  444. parameters:
  445. - in: body
  446. name: body
  447. description: User settings to update.
  448. required: true
  449. schema:
  450. type: object
  451. properties:
  452. nickname:
  453. type: string
  454. description: New nickname.
  455. email:
  456. type: string
  457. description: New email.
  458. responses:
  459. 200:
  460. description: Settings updated successfully.
  461. schema:
  462. type: object
  463. """
  464. update_dict = {}
  465. request_data = request.json
  466. if request_data.get("password"):
  467. new_password = request_data.get("new_password")
  468. if not check_password_hash(
  469. current_user.password, decrypt(request_data["password"])
  470. ):
  471. return get_json_result(
  472. data=False,
  473. code=settings.RetCode.AUTHENTICATION_ERROR,
  474. message="Password error!",
  475. )
  476. if new_password:
  477. update_dict["password"] = generate_password_hash(decrypt(new_password))
  478. for k in request_data.keys():
  479. if k in [
  480. "password",
  481. "new_password",
  482. "email",
  483. "status",
  484. "is_superuser",
  485. "login_channel",
  486. "is_anonymous",
  487. "is_active",
  488. "is_authenticated",
  489. "last_login_time",
  490. ]:
  491. continue
  492. update_dict[k] = request_data[k]
  493. try:
  494. UserService.update_by_id(current_user.id, update_dict)
  495. return get_json_result(data=True)
  496. except Exception as e:
  497. logging.exception(e)
  498. return get_json_result(
  499. data=False, message="Update failure!", code=settings.RetCode.EXCEPTION_ERROR
  500. )
  501. @manager.route("/info", methods=["GET"]) # noqa: F821
  502. @login_required
  503. def user_profile():
  504. """
  505. Get user profile information.
  506. ---
  507. tags:
  508. - User
  509. security:
  510. - ApiKeyAuth: []
  511. responses:
  512. 200:
  513. description: User profile retrieved successfully.
  514. schema:
  515. type: object
  516. properties:
  517. id:
  518. type: string
  519. description: User ID.
  520. nickname:
  521. type: string
  522. description: User nickname.
  523. email:
  524. type: string
  525. description: User email.
  526. """
  527. return get_json_result(data=current_user.to_dict())
  528. def rollback_user_registration(user_id):
  529. try:
  530. UserService.delete_by_id(user_id)
  531. except Exception:
  532. pass
  533. try:
  534. TenantService.delete_by_id(user_id)
  535. except Exception:
  536. pass
  537. try:
  538. u = UserTenantService.query(tenant_id=user_id)
  539. if u:
  540. UserTenantService.delete_by_id(u[0].id)
  541. except Exception:
  542. pass
  543. try:
  544. TenantLLM.delete().where(TenantLLM.tenant_id == user_id).execute()
  545. except Exception:
  546. pass
  547. def user_register(user_id, user):
  548. user["id"] = user_id
  549. tenant = {
  550. "id": user_id,
  551. "name": user["nickname"] + "‘s Kingdom",
  552. "llm_id": settings.CHAT_MDL,
  553. "embd_id": settings.EMBEDDING_MDL,
  554. "asr_id": settings.ASR_MDL,
  555. "parser_ids": settings.PARSERS,
  556. "img2txt_id": settings.IMAGE2TEXT_MDL,
  557. "rerank_id": settings.RERANK_MDL,
  558. }
  559. usr_tenant = {
  560. "tenant_id": user_id,
  561. "user_id": user_id,
  562. "invited_by": user_id,
  563. "role": UserTenantRole.OWNER,
  564. }
  565. file_id = get_uuid()
  566. file = {
  567. "id": file_id,
  568. "parent_id": file_id,
  569. "tenant_id": user_id,
  570. "created_by": user_id,
  571. "name": "/",
  572. "type": FileType.FOLDER.value,
  573. "size": 0,
  574. "location": "",
  575. }
  576. tenant_llm = []
  577. for llm in LLMService.query(fid=settings.LLM_FACTORY):
  578. tenant_llm.append(
  579. {
  580. "tenant_id": user_id,
  581. "llm_factory": settings.LLM_FACTORY,
  582. "llm_name": llm.llm_name,
  583. "model_type": llm.model_type,
  584. "api_key": settings.API_KEY,
  585. "api_base": settings.LLM_BASE_URL,
  586. "max_tokens": llm.max_tokens if llm.max_tokens else 8192
  587. }
  588. )
  589. if not UserService.save(**user):
  590. return
  591. TenantService.insert(**tenant)
  592. UserTenantService.insert(**usr_tenant)
  593. TenantLLMService.insert_many(tenant_llm)
  594. FileService.insert(file)
  595. return UserService.query(email=user["email"])
  596. @manager.route("/register", methods=["POST"]) # noqa: F821
  597. @validate_request("nickname", "email", "password")
  598. def user_add():
  599. """
  600. Register a new user.
  601. ---
  602. tags:
  603. - User
  604. parameters:
  605. - in: body
  606. name: body
  607. description: Registration details.
  608. required: true
  609. schema:
  610. type: object
  611. properties:
  612. nickname:
  613. type: string
  614. description: User nickname.
  615. email:
  616. type: string
  617. description: User email.
  618. password:
  619. type: string
  620. description: User password.
  621. responses:
  622. 200:
  623. description: Registration successful.
  624. schema:
  625. type: object
  626. """
  627. if not settings.REGISTER_ENABLED:
  628. return get_json_result(
  629. data=False,
  630. message="User registration is disabled!",
  631. code=settings.RetCode.OPERATING_ERROR,
  632. )
  633. req = request.json
  634. email_address = req["email"]
  635. # Validate the email address
  636. if not re.match(r"^[\w\._-]+@([\w_-]+\.)+[\w-]{2,}$", email_address):
  637. return get_json_result(
  638. data=False,
  639. message=f"Invalid email address: {email_address}!",
  640. code=settings.RetCode.OPERATING_ERROR,
  641. )
  642. # Check if the email address is already used
  643. if UserService.query(email=email_address):
  644. return get_json_result(
  645. data=False,
  646. message=f"Email: {email_address} has already registered!",
  647. code=settings.RetCode.OPERATING_ERROR,
  648. )
  649. # Construct user info data
  650. nickname = req["nickname"]
  651. user_dict = {
  652. "access_token": get_uuid(),
  653. "email": email_address,
  654. "nickname": nickname,
  655. "password": decrypt(req["password"]),
  656. "login_channel": "password",
  657. "last_login_time": get_format_time(),
  658. "is_superuser": False,
  659. }
  660. user_id = get_uuid()
  661. try:
  662. users = user_register(user_id, user_dict)
  663. if not users:
  664. raise Exception(f"Fail to register {email_address}.")
  665. if len(users) > 1:
  666. raise Exception(f"Same email: {email_address} exists!")
  667. user = users[0]
  668. login_user(user)
  669. return construct_response(
  670. data=user.to_json(),
  671. auth=user.get_id(),
  672. message=f"{nickname}, welcome aboard!",
  673. )
  674. except Exception as e:
  675. rollback_user_registration(user_id)
  676. logging.exception(e)
  677. return get_json_result(
  678. data=False,
  679. message=f"User registration failure, error: {str(e)}",
  680. code=settings.RetCode.EXCEPTION_ERROR,
  681. )
  682. @manager.route("/tenant_info", methods=["GET"]) # noqa: F821
  683. @login_required
  684. def tenant_info():
  685. """
  686. Get tenant information.
  687. ---
  688. tags:
  689. - Tenant
  690. security:
  691. - ApiKeyAuth: []
  692. responses:
  693. 200:
  694. description: Tenant information retrieved successfully.
  695. schema:
  696. type: object
  697. properties:
  698. tenant_id:
  699. type: string
  700. description: Tenant ID.
  701. name:
  702. type: string
  703. description: Tenant name.
  704. llm_id:
  705. type: string
  706. description: LLM ID.
  707. embd_id:
  708. type: string
  709. description: Embedding model ID.
  710. """
  711. try:
  712. tenants = TenantService.get_info_by(current_user.id)
  713. if not tenants:
  714. return get_data_error_result(message="Tenant not found!")
  715. return get_json_result(data=tenants[0])
  716. except Exception as e:
  717. return server_error_response(e)
  718. @manager.route("/set_tenant_info", methods=["POST"]) # noqa: F821
  719. @login_required
  720. @validate_request("tenant_id", "asr_id", "embd_id", "img2txt_id", "llm_id")
  721. def set_tenant_info():
  722. """
  723. Update tenant information.
  724. ---
  725. tags:
  726. - Tenant
  727. security:
  728. - ApiKeyAuth: []
  729. parameters:
  730. - in: body
  731. name: body
  732. description: Tenant information to update.
  733. required: true
  734. schema:
  735. type: object
  736. properties:
  737. tenant_id:
  738. type: string
  739. description: Tenant ID.
  740. llm_id:
  741. type: string
  742. description: LLM ID.
  743. embd_id:
  744. type: string
  745. description: Embedding model ID.
  746. asr_id:
  747. type: string
  748. description: ASR model ID.
  749. img2txt_id:
  750. type: string
  751. description: Image to Text model ID.
  752. responses:
  753. 200:
  754. description: Tenant information updated successfully.
  755. schema:
  756. type: object
  757. """
  758. req = request.json
  759. try:
  760. tid = req.pop("tenant_id")
  761. TenantService.update_by_id(tid, req)
  762. return get_json_result(data=True)
  763. except Exception as e:
  764. return server_error_response(e)