| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879 |
- import json
-
- import flask_login # type: ignore
- from flask import Response, request
- from flask_login import user_loaded_from_request, user_logged_in
- from werkzeug.exceptions import NotFound, Unauthorized
-
- import contexts
- from dify_app import DifyApp
- from extensions.ext_database import db
- from libs.passport import PassportService
- from models.account import Account
- from models.model import EndUser
- from services.account_service import AccountService
-
- login_manager = flask_login.LoginManager()
-
-
- # Flask-Login configuration
- @login_manager.request_loader
- def load_user_from_request(request_from_flask_login):
- """Load user based on the request."""
- auth_header = request.headers.get("Authorization", "")
- auth_token: str | None = None
- if auth_header:
- if " " not in auth_header:
- raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
- auth_scheme, auth_token = auth_header.split(maxsplit=1)
- auth_scheme = auth_scheme.lower()
- if auth_scheme != "bearer":
- raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
- else:
- auth_token = request.args.get("_token")
-
- if request.blueprint in {"console", "inner_api"}:
- if not auth_token:
- raise Unauthorized("Invalid Authorization token.")
- decoded = PassportService().verify(auth_token)
- user_id = decoded.get("user_id")
- if not user_id:
- raise Unauthorized("Invalid Authorization token.")
-
- logged_in_account = AccountService.load_logged_in_account(account_id=user_id)
- return logged_in_account
- elif request.blueprint == "web":
- decoded = PassportService().verify(auth_token)
- end_user_id = decoded.get("end_user_id")
- if not end_user_id:
- raise Unauthorized("Invalid Authorization token.")
- end_user = db.session.query(EndUser).filter(EndUser.id == decoded["end_user_id"]).first()
- if not end_user:
- raise NotFound("End user not found.")
- return end_user
-
-
- @user_logged_in.connect
- @user_loaded_from_request.connect
- def on_user_logged_in(_sender, user):
- """Called when a user logged in.
-
- Note: AccountService.load_logged_in_account will populate user.current_tenant_id
- through the load_user method, which calls account.set_tenant_id().
- """
- if user and isinstance(user, Account) and user.current_tenant_id:
- contexts.tenant_id.set(user.current_tenant_id)
-
-
- @login_manager.unauthorized_handler
- def unauthorized_handler():
- """Handle unauthorized requests."""
- return Response(
- json.dumps({"code": "unauthorized", "message": "Unauthorized."}),
- status=401,
- content_type="application/json",
- )
-
-
- def init_app(app: DifyApp):
- login_manager.init_app(app)
|