You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

user_app.py 20KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import logging
  17. import json
  18. import re
  19. from datetime import datetime
  20. from flask import request, session, redirect
  21. from werkzeug.security import generate_password_hash, check_password_hash
  22. from flask_login import login_required, current_user, login_user, logout_user
  23. from api.db.db_models import TenantLLM
  24. from api.db.services.llm_service import TenantLLMService, LLMService
  25. from api.utils.api_utils import (
  26. server_error_response,
  27. validate_request,
  28. get_data_error_result,
  29. )
  30. from api.utils import (
  31. get_uuid,
  32. get_format_time,
  33. decrypt,
  34. download_img,
  35. current_timestamp,
  36. datetime_format,
  37. )
  38. from api.db import UserTenantRole, FileType
  39. from api.settings import (
  40. RetCode,
  41. GITHUB_OAUTH,
  42. FEISHU_OAUTH,
  43. CHAT_MDL,
  44. EMBEDDING_MDL,
  45. ASR_MDL,
  46. IMAGE2TEXT_MDL,
  47. PARSERS,
  48. API_KEY,
  49. LLM_FACTORY,
  50. LLM_BASE_URL,
  51. RERANK_MDL,
  52. )
  53. from api.db.services.user_service import UserService, TenantService, UserTenantService
  54. from api.db.services.file_service import FileService
  55. from api.utils.api_utils import get_json_result, construct_response
  56. @manager.route("/login", methods=["POST", "GET"])
  57. def login():
  58. """
  59. User login endpoint.
  60. ---
  61. tags:
  62. - User
  63. parameters:
  64. - in: body
  65. name: body
  66. description: Login credentials.
  67. required: true
  68. schema:
  69. type: object
  70. properties:
  71. email:
  72. type: string
  73. description: User email.
  74. password:
  75. type: string
  76. description: User password.
  77. responses:
  78. 200:
  79. description: Login successful.
  80. schema:
  81. type: object
  82. 401:
  83. description: Authentication failed.
  84. schema:
  85. type: object
  86. """
  87. if not request.json:
  88. return get_json_result(
  89. data=False, code=RetCode.AUTHENTICATION_ERROR, message="Unauthorized!"
  90. )
  91. email = request.json.get("email", "")
  92. users = UserService.query(email=email)
  93. if not users:
  94. return get_json_result(
  95. data=False,
  96. code=RetCode.AUTHENTICATION_ERROR,
  97. message=f"Email: {email} is not registered!",
  98. )
  99. password = request.json.get("password")
  100. try:
  101. password = decrypt(password)
  102. except BaseException:
  103. return get_json_result(
  104. data=False, code=RetCode.SERVER_ERROR, message="Fail to crypt password"
  105. )
  106. user = UserService.query_user(email, password)
  107. if user:
  108. response_data = user.to_json()
  109. user.access_token = get_uuid()
  110. login_user(user)
  111. user.update_time = (current_timestamp(),)
  112. user.update_date = (datetime_format(datetime.now()),)
  113. user.save()
  114. msg = "Welcome back!"
  115. return construct_response(data=response_data, auth=user.get_id(), message=msg)
  116. else:
  117. return get_json_result(
  118. data=False,
  119. code=RetCode.AUTHENTICATION_ERROR,
  120. message="Email and password do not match!",
  121. )
  122. @manager.route("/github_callback", methods=["GET"])
  123. def github_callback():
  124. """
  125. GitHub OAuth callback endpoint.
  126. ---
  127. tags:
  128. - OAuth
  129. parameters:
  130. - in: query
  131. name: code
  132. type: string
  133. required: true
  134. description: Authorization code from GitHub.
  135. responses:
  136. 200:
  137. description: Authentication successful.
  138. schema:
  139. type: object
  140. """
  141. import requests
  142. res = requests.post(
  143. GITHUB_OAUTH.get("url"),
  144. data={
  145. "client_id": GITHUB_OAUTH.get("client_id"),
  146. "client_secret": GITHUB_OAUTH.get("secret_key"),
  147. "code": request.args.get("code"),
  148. },
  149. headers={"Accept": "application/json"},
  150. )
  151. res = res.json()
  152. if "error" in res:
  153. return redirect("/?error=%s" % res["error_description"])
  154. if "user:email" not in res["scope"].split(","):
  155. return redirect("/?error=user:email not in scope")
  156. session["access_token"] = res["access_token"]
  157. session["access_token_from"] = "github"
  158. user_info = user_info_from_github(session["access_token"])
  159. email_address = user_info["email"]
  160. users = UserService.query(email=email_address)
  161. user_id = get_uuid()
  162. if not users:
  163. # User isn't try to register
  164. try:
  165. try:
  166. avatar = download_img(user_info["avatar_url"])
  167. except Exception as e:
  168. logging.exception(e)
  169. avatar = ""
  170. users = user_register(
  171. user_id,
  172. {
  173. "access_token": session["access_token"],
  174. "email": email_address,
  175. "avatar": avatar,
  176. "nickname": user_info["login"],
  177. "login_channel": "github",
  178. "last_login_time": get_format_time(),
  179. "is_superuser": False,
  180. },
  181. )
  182. if not users:
  183. raise Exception(f"Fail to register {email_address}.")
  184. if len(users) > 1:
  185. raise Exception(f"Same email: {email_address} exists!")
  186. # Try to log in
  187. user = users[0]
  188. login_user(user)
  189. return redirect("/?auth=%s" % user.get_id())
  190. except Exception as e:
  191. rollback_user_registration(user_id)
  192. logging.exception(e)
  193. return redirect("/?error=%s" % str(e))
  194. # User has already registered, try to log in
  195. user = users[0]
  196. user.access_token = get_uuid()
  197. login_user(user)
  198. user.save()
  199. return redirect("/?auth=%s" % user.get_id())
  200. @manager.route("/feishu_callback", methods=["GET"])
  201. def feishu_callback():
  202. """
  203. Feishu OAuth callback endpoint.
  204. ---
  205. tags:
  206. - OAuth
  207. parameters:
  208. - in: query
  209. name: code
  210. type: string
  211. required: true
  212. description: Authorization code from Feishu.
  213. responses:
  214. 200:
  215. description: Authentication successful.
  216. schema:
  217. type: object
  218. """
  219. import requests
  220. app_access_token_res = requests.post(
  221. FEISHU_OAUTH.get("app_access_token_url"),
  222. data=json.dumps(
  223. {
  224. "app_id": FEISHU_OAUTH.get("app_id"),
  225. "app_secret": FEISHU_OAUTH.get("app_secret"),
  226. }
  227. ),
  228. headers={"Content-Type": "application/json; charset=utf-8"},
  229. )
  230. app_access_token_res = app_access_token_res.json()
  231. if app_access_token_res["code"] != 0:
  232. return redirect("/?error=%s" % app_access_token_res)
  233. res = requests.post(
  234. FEISHU_OAUTH.get("user_access_token_url"),
  235. data=json.dumps(
  236. {
  237. "grant_type": FEISHU_OAUTH.get("grant_type"),
  238. "code": request.args.get("code"),
  239. }
  240. ),
  241. headers={
  242. "Content-Type": "application/json; charset=utf-8",
  243. "Authorization": f"Bearer {app_access_token_res['app_access_token']}",
  244. },
  245. )
  246. res = res.json()
  247. if res["code"] != 0:
  248. return redirect("/?error=%s" % res["message"])
  249. if "contact:user.email:readonly" not in res["data"]["scope"].split(" "):
  250. return redirect("/?error=contact:user.email:readonly not in scope")
  251. session["access_token"] = res["data"]["access_token"]
  252. session["access_token_from"] = "feishu"
  253. user_info = user_info_from_feishu(session["access_token"])
  254. email_address = user_info["email"]
  255. users = UserService.query(email=email_address)
  256. user_id = get_uuid()
  257. if not users:
  258. # User isn't try to register
  259. try:
  260. try:
  261. avatar = download_img(user_info["avatar_url"])
  262. except Exception as e:
  263. logging.exception(e)
  264. avatar = ""
  265. users = user_register(
  266. user_id,
  267. {
  268. "access_token": session["access_token"],
  269. "email": email_address,
  270. "avatar": avatar,
  271. "nickname": user_info["en_name"],
  272. "login_channel": "feishu",
  273. "last_login_time": get_format_time(),
  274. "is_superuser": False,
  275. },
  276. )
  277. if not users:
  278. raise Exception(f"Fail to register {email_address}.")
  279. if len(users) > 1:
  280. raise Exception(f"Same email: {email_address} exists!")
  281. # Try to log in
  282. user = users[0]
  283. login_user(user)
  284. return redirect("/?auth=%s" % user.get_id())
  285. except Exception as e:
  286. rollback_user_registration(user_id)
  287. logging.exception(e)
  288. return redirect("/?error=%s" % str(e))
  289. # User has already registered, try to log in
  290. user = users[0]
  291. user.access_token = get_uuid()
  292. login_user(user)
  293. user.save()
  294. return redirect("/?auth=%s" % user.get_id())
  295. def user_info_from_feishu(access_token):
  296. import requests
  297. headers = {
  298. "Content-Type": "application/json; charset=utf-8",
  299. "Authorization": f"Bearer {access_token}",
  300. }
  301. res = requests.get(
  302. "https://open.feishu.cn/open-apis/authen/v1/user_info", headers=headers
  303. )
  304. user_info = res.json()["data"]
  305. user_info["email"] = None if user_info.get("email") == "" else user_info["email"]
  306. return user_info
  307. def user_info_from_github(access_token):
  308. import requests
  309. headers = {"Accept": "application/json", "Authorization": f"token {access_token}"}
  310. res = requests.get(
  311. f"https://api.github.com/user?access_token={access_token}", headers=headers
  312. )
  313. user_info = res.json()
  314. email_info = requests.get(
  315. f"https://api.github.com/user/emails?access_token={access_token}",
  316. headers=headers,
  317. ).json()
  318. user_info["email"] = next(
  319. (email for email in email_info if email["primary"] == True), None
  320. )["email"]
  321. return user_info
  322. @manager.route("/logout", methods=["GET"])
  323. @login_required
  324. def log_out():
  325. """
  326. User logout endpoint.
  327. ---
  328. tags:
  329. - User
  330. security:
  331. - ApiKeyAuth: []
  332. responses:
  333. 200:
  334. description: Logout successful.
  335. schema:
  336. type: object
  337. """
  338. current_user.access_token = ""
  339. current_user.save()
  340. logout_user()
  341. return get_json_result(data=True)
  342. @manager.route("/setting", methods=["POST"])
  343. @login_required
  344. def setting_user():
  345. """
  346. Update user settings.
  347. ---
  348. tags:
  349. - User
  350. security:
  351. - ApiKeyAuth: []
  352. parameters:
  353. - in: body
  354. name: body
  355. description: User settings to update.
  356. required: true
  357. schema:
  358. type: object
  359. properties:
  360. nickname:
  361. type: string
  362. description: New nickname.
  363. email:
  364. type: string
  365. description: New email.
  366. responses:
  367. 200:
  368. description: Settings updated successfully.
  369. schema:
  370. type: object
  371. """
  372. update_dict = {}
  373. request_data = request.json
  374. if request_data.get("password"):
  375. new_password = request_data.get("new_password")
  376. if not check_password_hash(
  377. current_user.password, decrypt(request_data["password"])
  378. ):
  379. return get_json_result(
  380. data=False,
  381. code=RetCode.AUTHENTICATION_ERROR,
  382. message="Password error!",
  383. )
  384. if new_password:
  385. update_dict["password"] = generate_password_hash(decrypt(new_password))
  386. for k in request_data.keys():
  387. if k in [
  388. "password",
  389. "new_password",
  390. "email",
  391. "status",
  392. "is_superuser",
  393. "login_channel",
  394. "is_anonymous",
  395. "is_active",
  396. "is_authenticated",
  397. "last_login_time",
  398. ]:
  399. continue
  400. update_dict[k] = request_data[k]
  401. try:
  402. UserService.update_by_id(current_user.id, update_dict)
  403. return get_json_result(data=True)
  404. except Exception as e:
  405. logging.exception(e)
  406. return get_json_result(
  407. data=False, message="Update failure!", code=RetCode.EXCEPTION_ERROR
  408. )
  409. @manager.route("/info", methods=["GET"])
  410. @login_required
  411. def user_profile():
  412. """
  413. Get user profile information.
  414. ---
  415. tags:
  416. - User
  417. security:
  418. - ApiKeyAuth: []
  419. responses:
  420. 200:
  421. description: User profile retrieved successfully.
  422. schema:
  423. type: object
  424. properties:
  425. id:
  426. type: string
  427. description: User ID.
  428. nickname:
  429. type: string
  430. description: User nickname.
  431. email:
  432. type: string
  433. description: User email.
  434. """
  435. return get_json_result(data=current_user.to_dict())
  436. def rollback_user_registration(user_id):
  437. try:
  438. UserService.delete_by_id(user_id)
  439. except Exception:
  440. pass
  441. try:
  442. TenantService.delete_by_id(user_id)
  443. except Exception:
  444. pass
  445. try:
  446. u = UserTenantService.query(tenant_id=user_id)
  447. if u:
  448. UserTenantService.delete_by_id(u[0].id)
  449. except Exception:
  450. pass
  451. try:
  452. TenantLLM.delete().where(TenantLLM.tenant_id == user_id).execute()
  453. except Exception:
  454. pass
  455. def user_register(user_id, user):
  456. user["id"] = user_id
  457. tenant = {
  458. "id": user_id,
  459. "name": user["nickname"] + "‘s Kingdom",
  460. "llm_id": CHAT_MDL,
  461. "embd_id": EMBEDDING_MDL,
  462. "asr_id": ASR_MDL,
  463. "parser_ids": PARSERS,
  464. "img2txt_id": IMAGE2TEXT_MDL,
  465. "rerank_id": RERANK_MDL,
  466. }
  467. usr_tenant = {
  468. "tenant_id": user_id,
  469. "user_id": user_id,
  470. "invited_by": user_id,
  471. "role": UserTenantRole.OWNER,
  472. }
  473. file_id = get_uuid()
  474. file = {
  475. "id": file_id,
  476. "parent_id": file_id,
  477. "tenant_id": user_id,
  478. "created_by": user_id,
  479. "name": "/",
  480. "type": FileType.FOLDER.value,
  481. "size": 0,
  482. "location": "",
  483. }
  484. tenant_llm = []
  485. for llm in LLMService.query(fid=LLM_FACTORY):
  486. tenant_llm.append(
  487. {
  488. "tenant_id": user_id,
  489. "llm_factory": LLM_FACTORY,
  490. "llm_name": llm.llm_name,
  491. "model_type": llm.model_type,
  492. "api_key": API_KEY,
  493. "api_base": LLM_BASE_URL,
  494. }
  495. )
  496. if not UserService.save(**user):
  497. return
  498. TenantService.insert(**tenant)
  499. UserTenantService.insert(**usr_tenant)
  500. TenantLLMService.insert_many(tenant_llm)
  501. FileService.insert(file)
  502. return UserService.query(email=user["email"])
  503. @manager.route("/register", methods=["POST"])
  504. @validate_request("nickname", "email", "password")
  505. def user_add():
  506. """
  507. Register a new user.
  508. ---
  509. tags:
  510. - User
  511. parameters:
  512. - in: body
  513. name: body
  514. description: Registration details.
  515. required: true
  516. schema:
  517. type: object
  518. properties:
  519. nickname:
  520. type: string
  521. description: User nickname.
  522. email:
  523. type: string
  524. description: User email.
  525. password:
  526. type: string
  527. description: User password.
  528. responses:
  529. 200:
  530. description: Registration successful.
  531. schema:
  532. type: object
  533. """
  534. req = request.json
  535. email_address = req["email"]
  536. # Validate the email address
  537. if not re.match(r"^[\w\._-]+@([\w_-]+\.)+[\w-]{2,5}$", email_address):
  538. return get_json_result(
  539. data=False,
  540. message=f"Invalid email address: {email_address}!",
  541. code=RetCode.OPERATING_ERROR,
  542. )
  543. # Check if the email address is already used
  544. if UserService.query(email=email_address):
  545. return get_json_result(
  546. data=False,
  547. message=f"Email: {email_address} has already registered!",
  548. code=RetCode.OPERATING_ERROR,
  549. )
  550. # Construct user info data
  551. nickname = req["nickname"]
  552. user_dict = {
  553. "access_token": get_uuid(),
  554. "email": email_address,
  555. "nickname": nickname,
  556. "password": decrypt(req["password"]),
  557. "login_channel": "password",
  558. "last_login_time": get_format_time(),
  559. "is_superuser": False,
  560. }
  561. user_id = get_uuid()
  562. try:
  563. users = user_register(user_id, user_dict)
  564. if not users:
  565. raise Exception(f"Fail to register {email_address}.")
  566. if len(users) > 1:
  567. raise Exception(f"Same email: {email_address} exists!")
  568. user = users[0]
  569. login_user(user)
  570. return construct_response(
  571. data=user.to_json(),
  572. auth=user.get_id(),
  573. message=f"{nickname}, welcome aboard!",
  574. )
  575. except Exception as e:
  576. rollback_user_registration(user_id)
  577. logging.exception(e)
  578. return get_json_result(
  579. data=False,
  580. message=f"User registration failure, error: {str(e)}",
  581. code=RetCode.EXCEPTION_ERROR,
  582. )
  583. @manager.route("/tenant_info", methods=["GET"])
  584. @login_required
  585. def tenant_info():
  586. """
  587. Get tenant information.
  588. ---
  589. tags:
  590. - Tenant
  591. security:
  592. - ApiKeyAuth: []
  593. responses:
  594. 200:
  595. description: Tenant information retrieved successfully.
  596. schema:
  597. type: object
  598. properties:
  599. tenant_id:
  600. type: string
  601. description: Tenant ID.
  602. name:
  603. type: string
  604. description: Tenant name.
  605. llm_id:
  606. type: string
  607. description: LLM ID.
  608. embd_id:
  609. type: string
  610. description: Embedding model ID.
  611. """
  612. try:
  613. tenants = TenantService.get_info_by(current_user.id)
  614. if not tenants:
  615. return get_data_error_result(message="Tenant not found!")
  616. return get_json_result(data=tenants[0])
  617. except Exception as e:
  618. return server_error_response(e)
  619. @manager.route("/set_tenant_info", methods=["POST"])
  620. @login_required
  621. @validate_request("tenant_id", "asr_id", "embd_id", "img2txt_id", "llm_id")
  622. def set_tenant_info():
  623. """
  624. Update tenant information.
  625. ---
  626. tags:
  627. - Tenant
  628. security:
  629. - ApiKeyAuth: []
  630. parameters:
  631. - in: body
  632. name: body
  633. description: Tenant information to update.
  634. required: true
  635. schema:
  636. type: object
  637. properties:
  638. tenant_id:
  639. type: string
  640. description: Tenant ID.
  641. llm_id:
  642. type: string
  643. description: LLM ID.
  644. embd_id:
  645. type: string
  646. description: Embedding model ID.
  647. asr_id:
  648. type: string
  649. description: ASR model ID.
  650. img2txt_id:
  651. type: string
  652. description: Image to Text model ID.
  653. responses:
  654. 200:
  655. description: Tenant information updated successfully.
  656. schema:
  657. type: object
  658. """
  659. req = request.json
  660. try:
  661. tid = req["tenant_id"]
  662. del req["tenant_id"]
  663. TenantService.update_by_id(tid, req)
  664. return get_json_result(data=True)
  665. except Exception as e:
  666. return server_error_response(e)