Du kan inte välja fler än 25 ämnen Ämnen måste starta med en bokstav eller siffra, kan innehålla bindestreck ('-') och vara max 35 tecken långa.

ext_login.py 2.8KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. import json
  2. import flask_login # type: ignore
  3. from flask import Response, request
  4. from flask_login import user_loaded_from_request, user_logged_in
  5. from werkzeug.exceptions import NotFound, Unauthorized
  6. import contexts
  7. from dify_app import DifyApp
  8. from extensions.ext_database import db
  9. from libs.passport import PassportService
  10. from models.account import Account
  11. from models.model import EndUser
  12. from services.account_service import AccountService
  13. login_manager = flask_login.LoginManager()
  14. # Flask-Login configuration
  15. @login_manager.request_loader
  16. def load_user_from_request(request_from_flask_login):
  17. """Load user based on the request."""
  18. auth_header = request.headers.get("Authorization", "")
  19. auth_token: str | None = None
  20. if auth_header:
  21. if " " not in auth_header:
  22. raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
  23. auth_scheme, auth_token = auth_header.split(maxsplit=1)
  24. auth_scheme = auth_scheme.lower()
  25. if auth_scheme != "bearer":
  26. raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
  27. else:
  28. auth_token = request.args.get("_token")
  29. if request.blueprint in {"console", "inner_api"}:
  30. if not auth_token:
  31. raise Unauthorized("Invalid Authorization token.")
  32. decoded = PassportService().verify(auth_token)
  33. user_id = decoded.get("user_id")
  34. if not user_id:
  35. raise Unauthorized("Invalid Authorization token.")
  36. logged_in_account = AccountService.load_logged_in_account(account_id=user_id)
  37. return logged_in_account
  38. elif request.blueprint == "web":
  39. decoded = PassportService().verify(auth_token)
  40. end_user_id = decoded.get("end_user_id")
  41. if not end_user_id:
  42. raise Unauthorized("Invalid Authorization token.")
  43. end_user = db.session.query(EndUser).filter(EndUser.id == decoded["end_user_id"]).first()
  44. if not end_user:
  45. raise NotFound("End user not found.")
  46. return end_user
  47. @user_logged_in.connect
  48. @user_loaded_from_request.connect
  49. def on_user_logged_in(_sender, user):
  50. """Called when a user logged in.
  51. Note: AccountService.load_logged_in_account will populate user.current_tenant_id
  52. through the load_user method, which calls account.set_tenant_id().
  53. """
  54. if user and isinstance(user, Account) and user.current_tenant_id:
  55. contexts.tenant_id.set(user.current_tenant_id)
  56. @login_manager.unauthorized_handler
  57. def unauthorized_handler():
  58. """Handle unauthorized requests."""
  59. return Response(
  60. json.dumps({"code": "unauthorized", "message": "Unauthorized."}),
  61. status=401,
  62. content_type="application/json",
  63. )
  64. def init_app(app: DifyApp):
  65. login_manager.init_app(app)