| @@ -119,9 +119,6 @@ class ForgotPasswordResetApi(Resource): | |||
| if not reset_data: | |||
| raise InvalidTokenError() | |||
| # Must use token in reset phase | |||
| if reset_data.get("phase", "") != "reset": | |||
| raise InvalidTokenError() | |||
| # Must use token in reset phase | |||
| if reset_data.get("phase", "") != "reset": | |||
| raise InvalidTokenError() | |||
| @@ -59,7 +59,14 @@ class InstalledAppsListApi(Resource): | |||
| if FeatureService.get_system_features().webapp_auth.enabled: | |||
| user_id = current_user.id | |||
| res = [] | |||
| app_ids = [installed_app["app"].id for installed_app in installed_app_list] | |||
| webapp_settings = EnterpriseService.WebAppAuth.batch_get_app_access_mode_by_id(app_ids) | |||
| for installed_app in installed_app_list: | |||
| webapp_setting = webapp_settings.get(installed_app["app"].id) | |||
| if not webapp_setting: | |||
| continue | |||
| if webapp_setting.access_mode == "sso_verified": | |||
| continue | |||
| app_code = AppService.get_app_code_by_id(str(installed_app["app"].id)) | |||
| if EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp( | |||
| user_id=user_id, | |||
| @@ -44,6 +44,17 @@ def only_edition_cloud(view): | |||
| return decorated | |||
| def only_edition_enterprise(view): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| if not dify_config.ENTERPRISE_ENABLED: | |||
| abort(404) | |||
| return view(*args, **kwargs) | |||
| return decorated | |||
| def only_edition_self_hosted(view): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| @@ -15,4 +15,17 @@ api.add_resource(FileApi, "/files/upload") | |||
| api.add_resource(RemoteFileInfoApi, "/remote-files/<path:url>") | |||
| api.add_resource(RemoteFileUploadApi, "/remote-files/upload") | |||
| from . import app, audio, completion, conversation, feature, message, passport, saved_message, site, workflow | |||
| from . import ( | |||
| app, | |||
| audio, | |||
| completion, | |||
| conversation, | |||
| feature, | |||
| forgot_password, | |||
| login, | |||
| message, | |||
| passport, | |||
| saved_message, | |||
| site, | |||
| workflow, | |||
| ) | |||
| @@ -10,6 +10,8 @@ from libs.passport import PassportService | |||
| from models.model import App, AppMode | |||
| from services.app_service import AppService | |||
| from services.enterprise.enterprise_service import EnterpriseService | |||
| from services.feature_service import FeatureService | |||
| from services.webapp_auth_service import WebAppAuthService | |||
| class AppParameterApi(WebApiResource): | |||
| @@ -46,10 +48,22 @@ class AppMeta(WebApiResource): | |||
| class AppAccessMode(Resource): | |||
| def get(self): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("appId", type=str, required=True, location="args") | |||
| parser.add_argument("appId", type=str, required=False, location="args") | |||
| parser.add_argument("appCode", type=str, required=False, location="args") | |||
| args = parser.parse_args() | |||
| app_id = args["appId"] | |||
| features = FeatureService.get_system_features() | |||
| if not features.webapp_auth.enabled: | |||
| return {"accessMode": "public"} | |||
| app_id = args.get("appId") | |||
| if args.get("appCode"): | |||
| app_code = args["appCode"] | |||
| app_id = AppService.get_app_id_by_code(app_code) | |||
| if not app_id: | |||
| raise ValueError("appId or appCode must be provided") | |||
| res = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id) | |||
| return {"accessMode": res.access_mode} | |||
| @@ -75,6 +89,10 @@ class AppWebAuthPermission(Resource): | |||
| except Exception as e: | |||
| pass | |||
| features = FeatureService.get_system_features() | |||
| if not features.webapp_auth.enabled: | |||
| return {"result": True} | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("appId", type=str, required=True, location="args") | |||
| args = parser.parse_args() | |||
| @@ -82,7 +100,9 @@ class AppWebAuthPermission(Resource): | |||
| app_id = args["appId"] | |||
| app_code = AppService.get_app_code_by_id(app_id) | |||
| res = EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(str(user_id), app_code) | |||
| res = True | |||
| if WebAppAuthService.is_app_require_permission_check(app_id=app_id): | |||
| res = EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(str(user_id), app_code) | |||
| return {"result": res} | |||
| @@ -0,0 +1,147 @@ | |||
| import base64 | |||
| import secrets | |||
| from flask import request | |||
| from flask_restful import Resource, reqparse | |||
| from sqlalchemy import select | |||
| from sqlalchemy.orm import Session | |||
| from controllers.console.auth.error import ( | |||
| EmailCodeError, | |||
| EmailPasswordResetLimitError, | |||
| InvalidEmailError, | |||
| InvalidTokenError, | |||
| PasswordMismatchError, | |||
| ) | |||
| from controllers.console.error import AccountNotFound, EmailSendIpLimitError | |||
| from controllers.console.wraps import email_password_login_enabled, only_edition_enterprise, setup_required | |||
| from controllers.web import api | |||
| from extensions.ext_database import db | |||
| from libs.helper import email, extract_remote_ip | |||
| from libs.password import hash_password, valid_password | |||
| from models.account import Account | |||
| from services.account_service import AccountService | |||
| class ForgotPasswordSendEmailApi(Resource): | |||
| @only_edition_enterprise | |||
| @setup_required | |||
| @email_password_login_enabled | |||
| def post(self): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("email", type=email, required=True, location="json") | |||
| parser.add_argument("language", type=str, required=False, location="json") | |||
| args = parser.parse_args() | |||
| ip_address = extract_remote_ip(request) | |||
| if AccountService.is_email_send_ip_limit(ip_address): | |||
| raise EmailSendIpLimitError() | |||
| if args["language"] is not None and args["language"] == "zh-Hans": | |||
| language = "zh-Hans" | |||
| else: | |||
| language = "en-US" | |||
| with Session(db.engine) as session: | |||
| account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none() | |||
| token = None | |||
| if account is None: | |||
| raise AccountNotFound() | |||
| else: | |||
| token = AccountService.send_reset_password_email(account=account, email=args["email"], language=language) | |||
| return {"result": "success", "data": token} | |||
| class ForgotPasswordCheckApi(Resource): | |||
| @only_edition_enterprise | |||
| @setup_required | |||
| @email_password_login_enabled | |||
| def post(self): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("email", type=str, required=True, location="json") | |||
| parser.add_argument("code", type=str, required=True, location="json") | |||
| parser.add_argument("token", type=str, required=True, nullable=False, location="json") | |||
| args = parser.parse_args() | |||
| user_email = args["email"] | |||
| is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args["email"]) | |||
| if is_forgot_password_error_rate_limit: | |||
| raise EmailPasswordResetLimitError() | |||
| token_data = AccountService.get_reset_password_data(args["token"]) | |||
| if token_data is None: | |||
| raise InvalidTokenError() | |||
| if user_email != token_data.get("email"): | |||
| raise InvalidEmailError() | |||
| if args["code"] != token_data.get("code"): | |||
| AccountService.add_forgot_password_error_rate_limit(args["email"]) | |||
| raise EmailCodeError() | |||
| # Verified, revoke the first token | |||
| AccountService.revoke_reset_password_token(args["token"]) | |||
| # Refresh token data by generating a new token | |||
| _, new_token = AccountService.generate_reset_password_token( | |||
| user_email, code=args["code"], additional_data={"phase": "reset"} | |||
| ) | |||
| AccountService.reset_forgot_password_error_rate_limit(args["email"]) | |||
| return {"is_valid": True, "email": token_data.get("email"), "token": new_token} | |||
| class ForgotPasswordResetApi(Resource): | |||
| @only_edition_enterprise | |||
| @setup_required | |||
| @email_password_login_enabled | |||
| def post(self): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("token", type=str, required=True, nullable=False, location="json") | |||
| parser.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json") | |||
| parser.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json") | |||
| args = parser.parse_args() | |||
| # Validate passwords match | |||
| if args["new_password"] != args["password_confirm"]: | |||
| raise PasswordMismatchError() | |||
| # Validate token and get reset data | |||
| reset_data = AccountService.get_reset_password_data(args["token"]) | |||
| if not reset_data: | |||
| raise InvalidTokenError() | |||
| # Must use token in reset phase | |||
| if reset_data.get("phase", "") != "reset": | |||
| raise InvalidTokenError() | |||
| # Revoke token to prevent reuse | |||
| AccountService.revoke_reset_password_token(args["token"]) | |||
| # Generate secure salt and hash password | |||
| salt = secrets.token_bytes(16) | |||
| password_hashed = hash_password(args["new_password"], salt) | |||
| email = reset_data.get("email", "") | |||
| with Session(db.engine) as session: | |||
| account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none() | |||
| if account: | |||
| self._update_existing_account(account, password_hashed, salt, session) | |||
| else: | |||
| raise AccountNotFound() | |||
| return {"result": "success"} | |||
| def _update_existing_account(self, account, password_hashed, salt, session): | |||
| # Update existing account credentials | |||
| account.password = base64.b64encode(password_hashed).decode() | |||
| account.password_salt = base64.b64encode(salt).decode() | |||
| session.commit() | |||
| api.add_resource(ForgotPasswordSendEmailApi, "/forgot-password") | |||
| api.add_resource(ForgotPasswordCheckApi, "/forgot-password/validity") | |||
| api.add_resource(ForgotPasswordResetApi, "/forgot-password/resets") | |||
| @@ -1,12 +1,11 @@ | |||
| from flask import request | |||
| from flask_restful import Resource, reqparse | |||
| from jwt import InvalidTokenError # type: ignore | |||
| from werkzeug.exceptions import BadRequest | |||
| import services | |||
| from controllers.console.auth.error import EmailCodeError, EmailOrPasswordMismatchError, InvalidEmailError | |||
| from controllers.console.error import AccountBannedError, AccountNotFound | |||
| from controllers.console.wraps import setup_required | |||
| from controllers.console.wraps import only_edition_enterprise, setup_required | |||
| from controllers.web import api | |||
| from libs.helper import email | |||
| from libs.password import valid_password | |||
| from services.account_service import AccountService | |||
| @@ -16,6 +15,8 @@ from services.webapp_auth_service import WebAppAuthService | |||
| class LoginApi(Resource): | |||
| """Resource for web app email/password login.""" | |||
| @setup_required | |||
| @only_edition_enterprise | |||
| def post(self): | |||
| """Authenticate user and login.""" | |||
| parser = reqparse.RequestParser() | |||
| @@ -23,10 +24,6 @@ class LoginApi(Resource): | |||
| parser.add_argument("password", type=valid_password, required=True, location="json") | |||
| args = parser.parse_args() | |||
| app_code = request.headers.get("X-App-Code") | |||
| if app_code is None: | |||
| raise BadRequest("X-App-Code header is missing.") | |||
| try: | |||
| account = WebAppAuthService.authenticate(args["email"], args["password"]) | |||
| except services.errors.account.AccountLoginError: | |||
| @@ -36,12 +33,8 @@ class LoginApi(Resource): | |||
| except services.errors.account.AccountNotFoundError: | |||
| raise AccountNotFound() | |||
| WebAppAuthService._validate_user_accessibility(account=account, app_code=app_code) | |||
| end_user = WebAppAuthService.create_end_user(email=args["email"], app_code=app_code) | |||
| token = WebAppAuthService.login(account=account, app_code=app_code, end_user_id=end_user.id) | |||
| return {"result": "success", "token": token} | |||
| token = WebAppAuthService.login(account=account) | |||
| return {"result": "success", "data": {"access_token": token}} | |||
| # class LogoutApi(Resource): | |||
| @@ -56,6 +49,7 @@ class LoginApi(Resource): | |||
| class EmailCodeLoginSendEmailApi(Resource): | |||
| @setup_required | |||
| @only_edition_enterprise | |||
| def post(self): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("email", type=email, required=True, location="json") | |||
| @@ -78,6 +72,7 @@ class EmailCodeLoginSendEmailApi(Resource): | |||
| class EmailCodeLoginApi(Resource): | |||
| @setup_required | |||
| @only_edition_enterprise | |||
| def post(self): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("email", type=str, required=True, location="json") | |||
| @@ -86,9 +81,6 @@ class EmailCodeLoginApi(Resource): | |||
| args = parser.parse_args() | |||
| user_email = args["email"] | |||
| app_code = request.headers.get("X-App-Code") | |||
| if app_code is None: | |||
| raise BadRequest("X-App-Code header is missing.") | |||
| token_data = WebAppAuthService.get_email_code_login_data(args["token"]) | |||
| if token_data is None: | |||
| @@ -105,16 +97,12 @@ class EmailCodeLoginApi(Resource): | |||
| if not account: | |||
| raise AccountNotFound() | |||
| WebAppAuthService._validate_user_accessibility(account=account, app_code=app_code) | |||
| end_user = WebAppAuthService.create_end_user(email=user_email, app_code=app_code) | |||
| token = WebAppAuthService.login(account=account, app_code=app_code, end_user_id=end_user.id) | |||
| token = WebAppAuthService.login(account=account) | |||
| AccountService.reset_login_error_rate_limit(args["email"]) | |||
| return {"result": "success", "token": token} | |||
| return {"result": "success", "data": {"access_token": token}} | |||
| # api.add_resource(LoginApi, "/login") | |||
| api.add_resource(LoginApi, "/login") | |||
| # api.add_resource(LogoutApi, "/logout") | |||
| # api.add_resource(EmailCodeLoginSendEmailApi, "/email-code-login") | |||
| # api.add_resource(EmailCodeLoginApi, "/email-code-login/validity") | |||
| api.add_resource(EmailCodeLoginSendEmailApi, "/email-code-login") | |||
| api.add_resource(EmailCodeLoginApi, "/email-code-login/validity") | |||
| @@ -1,9 +1,11 @@ | |||
| import uuid | |||
| from datetime import UTC, datetime, timedelta | |||
| from flask import request | |||
| from flask_restful import Resource | |||
| from werkzeug.exceptions import NotFound, Unauthorized | |||
| from configs import dify_config | |||
| from controllers.web import api | |||
| from controllers.web.error import WebAppAuthRequiredError | |||
| from extensions.ext_database import db | |||
| @@ -11,6 +13,7 @@ from libs.passport import PassportService | |||
| from models.model import App, EndUser, Site | |||
| from services.enterprise.enterprise_service import EnterpriseService | |||
| from services.feature_service import FeatureService | |||
| from services.webapp_auth_service import WebAppAuthService, WebAppAuthType | |||
| class PassportResource(Resource): | |||
| @@ -20,10 +23,19 @@ class PassportResource(Resource): | |||
| system_features = FeatureService.get_system_features() | |||
| app_code = request.headers.get("X-App-Code") | |||
| user_id = request.args.get("user_id") | |||
| web_app_access_token = request.args.get("web_app_access_token") | |||
| if app_code is None: | |||
| raise Unauthorized("X-App-Code header is missing.") | |||
| # exchange token for enterprise logined web user | |||
| enterprise_user_decoded = decode_enterprise_webapp_user_id(web_app_access_token) | |||
| if enterprise_user_decoded: | |||
| # a web user has already logged in, exchange a token for this app without redirecting to the login page | |||
| return exchange_token_for_existing_web_user( | |||
| app_code=app_code, enterprise_user_decoded=enterprise_user_decoded | |||
| ) | |||
| if system_features.webapp_auth.enabled: | |||
| app_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code) | |||
| if not app_settings or not app_settings.access_mode == "public": | |||
| @@ -84,6 +96,128 @@ class PassportResource(Resource): | |||
| api.add_resource(PassportResource, "/passport") | |||
| def decode_enterprise_webapp_user_id(jwt_token: str | None): | |||
| """ | |||
| Decode the enterprise user session from the Authorization header. | |||
| """ | |||
| if not jwt_token: | |||
| return None | |||
| decoded = PassportService().verify(jwt_token) | |||
| source = decoded.get("token_source") | |||
| if not source or source != "webapp_login_token": | |||
| raise Unauthorized("Invalid token source. Expected 'webapp_login_token'.") | |||
| return decoded | |||
| def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded: dict): | |||
| """ | |||
| Exchange a token for an existing web user session. | |||
| """ | |||
| user_id = enterprise_user_decoded.get("user_id") | |||
| end_user_id = enterprise_user_decoded.get("end_user_id") | |||
| session_id = enterprise_user_decoded.get("session_id") | |||
| user_auth_type = enterprise_user_decoded.get("auth_type") | |||
| if not user_auth_type: | |||
| raise Unauthorized("Missing auth_type in the token.") | |||
| site = db.session.query(Site).filter(Site.code == app_code, Site.status == "normal").first() | |||
| if not site: | |||
| raise NotFound() | |||
| app_model = db.session.query(App).filter(App.id == site.app_id).first() | |||
| if not app_model or app_model.status != "normal" or not app_model.enable_site: | |||
| raise NotFound() | |||
| app_auth_type = WebAppAuthService.get_app_auth_type(app_code=app_code) | |||
| if app_auth_type == WebAppAuthType.PUBLIC: | |||
| return _exchange_for_public_app_token(app_model, site, enterprise_user_decoded) | |||
| elif app_auth_type == WebAppAuthType.EXTERNAL and user_auth_type != "external": | |||
| raise WebAppAuthRequiredError("Please login as external user.") | |||
| elif app_auth_type == WebAppAuthType.INTERNAL and user_auth_type != "internal": | |||
| raise WebAppAuthRequiredError("Please login as internal user.") | |||
| end_user = None | |||
| if end_user_id: | |||
| end_user = db.session.query(EndUser).filter(EndUser.id == end_user_id).first() | |||
| if session_id: | |||
| end_user = ( | |||
| db.session.query(EndUser) | |||
| .filter( | |||
| EndUser.session_id == session_id, | |||
| EndUser.tenant_id == app_model.tenant_id, | |||
| EndUser.app_id == app_model.id, | |||
| ) | |||
| .first() | |||
| ) | |||
| if not end_user: | |||
| if not session_id: | |||
| raise NotFound("Missing session_id for existing web user.") | |||
| end_user = EndUser( | |||
| tenant_id=app_model.tenant_id, | |||
| app_id=app_model.id, | |||
| type="browser", | |||
| is_anonymous=True, | |||
| session_id=session_id, | |||
| ) | |||
| db.session.add(end_user) | |||
| db.session.commit() | |||
| exp_dt = datetime.now(UTC) + timedelta(hours=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES * 24) | |||
| exp = int(exp_dt.timestamp()) | |||
| payload = { | |||
| "iss": site.id, | |||
| "sub": "Web API Passport", | |||
| "app_id": site.app_id, | |||
| "app_code": site.code, | |||
| "user_id": user_id, | |||
| "end_user_id": end_user.id, | |||
| "auth_type": user_auth_type, | |||
| "granted_at": int(datetime.now(UTC).timestamp()), | |||
| "token_source": "webapp", | |||
| "exp": exp, | |||
| } | |||
| token: str = PassportService().issue(payload) | |||
| return { | |||
| "access_token": token, | |||
| } | |||
| def _exchange_for_public_app_token(app_model, site, token_decoded): | |||
| user_id = token_decoded.get("user_id") | |||
| end_user = None | |||
| if user_id: | |||
| end_user = ( | |||
| db.session.query(EndUser).filter(EndUser.app_id == app_model.id, EndUser.session_id == user_id).first() | |||
| ) | |||
| if not end_user: | |||
| end_user = EndUser( | |||
| tenant_id=app_model.tenant_id, | |||
| app_id=app_model.id, | |||
| type="browser", | |||
| is_anonymous=True, | |||
| session_id=generate_session_id(), | |||
| ) | |||
| db.session.add(end_user) | |||
| db.session.commit() | |||
| payload = { | |||
| "iss": site.app_id, | |||
| "sub": "Web API Passport", | |||
| "app_id": site.app_id, | |||
| "app_code": site.code, | |||
| "end_user_id": end_user.id, | |||
| } | |||
| tk = PassportService().issue(payload) | |||
| return { | |||
| "access_token": tk, | |||
| } | |||
| def generate_session_id(): | |||
| """ | |||
| Generate a unique session ID. | |||
| @@ -1,3 +1,4 @@ | |||
| from datetime import UTC, datetime | |||
| from functools import wraps | |||
| from flask import request | |||
| @@ -8,8 +9,9 @@ from controllers.web.error import WebAppAuthAccessDeniedError, WebAppAuthRequire | |||
| from extensions.ext_database import db | |||
| from libs.passport import PassportService | |||
| from models.model import App, EndUser, Site | |||
| from services.enterprise.enterprise_service import EnterpriseService | |||
| from services.enterprise.enterprise_service import EnterpriseService, WebAppSettings | |||
| from services.feature_service import FeatureService | |||
| from services.webapp_auth_service import WebAppAuthService | |||
| def validate_jwt_token(view=None): | |||
| @@ -45,7 +47,8 @@ def decode_jwt_token(): | |||
| raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.") | |||
| decoded = PassportService().verify(tk) | |||
| app_code = decoded.get("app_code") | |||
| app_model = db.session.query(App).filter(App.id == decoded["app_id"]).first() | |||
| app_id = decoded.get("app_id") | |||
| app_model = db.session.query(App).filter(App.id == app_id).first() | |||
| site = db.session.query(Site).filter(Site.code == app_code).first() | |||
| if not app_model: | |||
| raise NotFound() | |||
| @@ -53,23 +56,30 @@ def decode_jwt_token(): | |||
| raise BadRequest("Site URL is no longer valid.") | |||
| if app_model.enable_site is False: | |||
| raise BadRequest("Site is disabled.") | |||
| end_user = db.session.query(EndUser).filter(EndUser.id == decoded["end_user_id"]).first() | |||
| end_user_id = decoded.get("end_user_id") | |||
| end_user = db.session.query(EndUser).filter(EndUser.id == end_user_id).first() | |||
| if not end_user: | |||
| raise NotFound() | |||
| # for enterprise webapp auth | |||
| app_web_auth_enabled = False | |||
| webapp_settings = None | |||
| if system_features.webapp_auth.enabled: | |||
| app_web_auth_enabled = ( | |||
| EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code).access_mode != "public" | |||
| ) | |||
| webapp_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code) | |||
| if not webapp_settings: | |||
| raise NotFound("Web app settings not found.") | |||
| app_web_auth_enabled = webapp_settings.access_mode != "public" | |||
| _validate_webapp_token(decoded, app_web_auth_enabled, system_features.webapp_auth.enabled) | |||
| _validate_user_accessibility(decoded, app_code, app_web_auth_enabled, system_features.webapp_auth.enabled) | |||
| _validate_user_accessibility( | |||
| decoded, app_code, app_web_auth_enabled, system_features.webapp_auth.enabled, webapp_settings | |||
| ) | |||
| return app_model, end_user | |||
| except Unauthorized as e: | |||
| if system_features.webapp_auth.enabled: | |||
| if not app_code: | |||
| raise Unauthorized("Please re-login to access the web app.") | |||
| app_web_auth_enabled = ( | |||
| EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=str(app_code)).access_mode != "public" | |||
| ) | |||
| @@ -95,15 +105,41 @@ def _validate_webapp_token(decoded, app_web_auth_enabled: bool, system_webapp_au | |||
| raise Unauthorized("webapp token expired.") | |||
| def _validate_user_accessibility(decoded, app_code, app_web_auth_enabled: bool, system_webapp_auth_enabled: bool): | |||
| def _validate_user_accessibility( | |||
| decoded, | |||
| app_code, | |||
| app_web_auth_enabled: bool, | |||
| system_webapp_auth_enabled: bool, | |||
| webapp_settings: WebAppSettings | None, | |||
| ): | |||
| if system_webapp_auth_enabled and app_web_auth_enabled: | |||
| # Check if the user is allowed to access the web app | |||
| user_id = decoded.get("user_id") | |||
| if not user_id: | |||
| raise WebAppAuthRequiredError() | |||
| if not EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(user_id, app_code=app_code): | |||
| raise WebAppAuthAccessDeniedError() | |||
| if not webapp_settings: | |||
| raise WebAppAuthRequiredError("Web app settings not found.") | |||
| if WebAppAuthService.is_app_require_permission_check(access_mode=webapp_settings.access_mode): | |||
| if not EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(user_id, app_code=app_code): | |||
| raise WebAppAuthAccessDeniedError() | |||
| auth_type = decoded.get("auth_type") | |||
| granted_at = decoded.get("granted_at") | |||
| if not auth_type: | |||
| raise WebAppAuthAccessDeniedError("Missing auth_type in the token.") | |||
| if not granted_at: | |||
| raise WebAppAuthAccessDeniedError("Missing granted_at in the token.") | |||
| # check if sso has been updated | |||
| if auth_type == "external": | |||
| last_update_time = EnterpriseService.get_app_sso_settings_last_update_time() | |||
| if granted_at and datetime.fromtimestamp(granted_at, tz=UTC) < last_update_time: | |||
| raise WebAppAuthAccessDeniedError("SSO settings have been updated. Please re-login.") | |||
| elif auth_type == "internal": | |||
| last_update_time = EnterpriseService.get_workspace_sso_settings_last_update_time() | |||
| if granted_at and datetime.fromtimestamp(granted_at, tz=UTC) < last_update_time: | |||
| raise WebAppAuthAccessDeniedError("SSO settings have been updated. Please re-login.") | |||
| class WebApiResource(Resource): | |||
| @@ -57,6 +57,9 @@ def load_user_from_request(request_from_flask_login): | |||
| raise Unauthorized("Invalid Authorization token.") | |||
| decoded = PassportService().verify(auth_token) | |||
| user_id = decoded.get("user_id") | |||
| source = decoded.get("token_source") | |||
| if source: | |||
| raise Unauthorized("Invalid Authorization token.") | |||
| if not user_id: | |||
| raise Unauthorized("Invalid Authorization token.") | |||
| @@ -395,3 +395,15 @@ class AppService: | |||
| if not site: | |||
| raise ValueError(f"App with id {app_id} not found") | |||
| return str(site.code) | |||
| @staticmethod | |||
| def get_app_id_by_code(app_code: str) -> str: | |||
| """ | |||
| Get app id by app code | |||
| :param app_code: app code | |||
| :return: app id | |||
| """ | |||
| site = db.session.query(Site).filter(Site.code == app_code).first() | |||
| if not site: | |||
| raise ValueError(f"App with code {app_code} not found") | |||
| return str(site.app_id) | |||
| @@ -1,3 +1,5 @@ | |||
| from datetime import datetime | |||
| from pydantic import BaseModel, Field | |||
| from services.enterprise.base import EnterpriseRequest | |||
| @@ -5,7 +7,7 @@ from services.enterprise.base import EnterpriseRequest | |||
| class WebAppSettings(BaseModel): | |||
| access_mode: str = Field( | |||
| description="Access mode for the web app. Can be 'public' or 'private'", | |||
| description="Access mode for the web app. Can be 'public', 'private', 'private_all', 'sso_verified'", | |||
| default="private", | |||
| alias="accessMode", | |||
| ) | |||
| @@ -20,6 +22,28 @@ class EnterpriseService: | |||
| def get_workspace_info(cls, tenant_id: str): | |||
| return EnterpriseRequest.send_request("GET", f"/workspace/{tenant_id}/info") | |||
| @classmethod | |||
| def get_app_sso_settings_last_update_time(cls) -> datetime: | |||
| data = EnterpriseRequest.send_request("GET", "/sso/app/last-update-time") | |||
| if not data: | |||
| raise ValueError("No data found.") | |||
| try: | |||
| # parse the UTC timestamp from the response | |||
| return datetime.fromisoformat(data.replace("Z", "+00:00")) | |||
| except ValueError as e: | |||
| raise ValueError(f"Invalid date format: {data}") from e | |||
| @classmethod | |||
| def get_workspace_sso_settings_last_update_time(cls) -> datetime: | |||
| data = EnterpriseRequest.send_request("GET", "/sso/workspace/last-update-time") | |||
| if not data: | |||
| raise ValueError("No data found.") | |||
| try: | |||
| # parse the UTC timestamp from the response | |||
| return datetime.fromisoformat(data.replace("Z", "+00:00")) | |||
| except ValueError as e: | |||
| raise ValueError(f"Invalid date format: {data}") from e | |||
| class WebAppAuth: | |||
| @classmethod | |||
| def is_user_allowed_to_access_webapp(cls, user_id: str, app_code: str): | |||
| @@ -1,3 +1,4 @@ | |||
| import enum | |||
| import secrets | |||
| from datetime import UTC, datetime, timedelta | |||
| from typing import Any, Optional, cast | |||
| @@ -5,27 +6,33 @@ from typing import Any, Optional, cast | |||
| from werkzeug.exceptions import NotFound, Unauthorized | |||
| from configs import dify_config | |||
| from controllers.web.error import WebAppAuthAccessDeniedError | |||
| from extensions.ext_database import db | |||
| from libs.helper import TokenManager | |||
| from libs.passport import PassportService | |||
| from libs.password import compare_password | |||
| from models.account import Account, AccountStatus | |||
| from models.model import App, EndUser, Site | |||
| from services.app_service import AppService | |||
| from services.enterprise.enterprise_service import EnterpriseService | |||
| from services.errors.account import AccountLoginError, AccountNotFoundError, AccountPasswordError | |||
| from services.feature_service import FeatureService | |||
| from tasks.mail_email_code_login import send_email_code_login_mail_task | |||
| class WebAppAuthType(enum.StrEnum): | |||
| """Enum for web app authentication types.""" | |||
| PUBLIC = "public" | |||
| INTERNAL = "internal" | |||
| EXTERNAL = "external" | |||
| class WebAppAuthService: | |||
| """Service for web app authentication.""" | |||
| @staticmethod | |||
| def authenticate(email: str, password: str) -> Account: | |||
| """authenticate account with email and password""" | |||
| account = Account.query.filter_by(email=email).first() | |||
| account = db.session.query(Account).filter_by(email=email).first() | |||
| if not account: | |||
| raise AccountNotFoundError() | |||
| @@ -38,12 +45,8 @@ class WebAppAuthService: | |||
| return cast(Account, account) | |||
| @classmethod | |||
| def login(cls, account: Account, app_code: str, end_user_id: str) -> str: | |||
| site = db.session.query(Site).filter(Site.code == app_code).first() | |||
| if not site: | |||
| raise NotFound("Site not found.") | |||
| access_token = cls._get_account_jwt_token(account=account, site=site, end_user_id=end_user_id) | |||
| def login(cls, account: Account) -> str: | |||
| access_token = cls._get_account_jwt_token(account=account) | |||
| return access_token | |||
| @@ -68,7 +71,7 @@ class WebAppAuthService: | |||
| code = "".join([str(secrets.randbelow(exclusive_upper_bound=10)) for _ in range(6)]) | |||
| token = TokenManager.generate_token( | |||
| account=account, email=email, token_type="webapp_email_code_login", additional_data={"code": code} | |||
| account=account, email=email, token_type="email_code_login", additional_data={"code": code} | |||
| ) | |||
| send_email_code_login_mail_task.delay( | |||
| language=language, | |||
| @@ -80,11 +83,11 @@ class WebAppAuthService: | |||
| @classmethod | |||
| def get_email_code_login_data(cls, token: str) -> Optional[dict[str, Any]]: | |||
| return TokenManager.get_token_data(token, "webapp_email_code_login") | |||
| return TokenManager.get_token_data(token, "email_code_login") | |||
| @classmethod | |||
| def revoke_email_code_login_token(cls, token: str): | |||
| TokenManager.revoke_token(token, "webapp_email_code_login") | |||
| TokenManager.revoke_token(token, "email_code_login") | |||
| @classmethod | |||
| def create_end_user(cls, app_code, email) -> EndUser: | |||
| @@ -109,33 +112,67 @@ class WebAppAuthService: | |||
| return end_user | |||
| @classmethod | |||
| def _validate_user_accessibility(cls, account: Account, app_code: str): | |||
| """Check if the user is allowed to access the app.""" | |||
| system_features = FeatureService.get_system_features() | |||
| if system_features.webapp_auth.enabled: | |||
| app_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code) | |||
| if ( | |||
| app_settings.access_mode != "public" | |||
| and not EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(account.id, app_code=app_code) | |||
| ): | |||
| raise WebAppAuthAccessDeniedError() | |||
| @classmethod | |||
| def _get_account_jwt_token(cls, account: Account, site: Site, end_user_id: str) -> str: | |||
| def _get_account_jwt_token(cls, account: Account) -> str: | |||
| exp_dt = datetime.now(UTC) + timedelta(hours=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES * 24) | |||
| exp = int(exp_dt.timestamp()) | |||
| payload = { | |||
| "iss": site.id, | |||
| "sub": "Web API Passport", | |||
| "app_id": site.app_id, | |||
| "app_code": site.code, | |||
| "user_id": account.id, | |||
| "end_user_id": end_user_id, | |||
| "token_source": "webapp", | |||
| "session_id": account.email, | |||
| "token_source": "webapp_login_token", | |||
| "auth_type": "internal", | |||
| "exp": exp, | |||
| } | |||
| token: str = PassportService().issue(payload) | |||
| return token | |||
| @classmethod | |||
| def is_app_require_permission_check( | |||
| cls, app_code: Optional[str] = None, app_id: Optional[str] = None, access_mode: Optional[str] = None | |||
| ) -> bool: | |||
| """ | |||
| Check if the app requires permission check based on its access mode. | |||
| """ | |||
| modes_requiring_permission_check = [ | |||
| "private", | |||
| "private_all", | |||
| ] | |||
| if access_mode: | |||
| return access_mode in modes_requiring_permission_check | |||
| if not app_code and not app_id: | |||
| raise ValueError("Either app_code or app_id must be provided.") | |||
| if app_code: | |||
| app_id = AppService.get_app_id_by_code(app_code) | |||
| if not app_id: | |||
| raise ValueError("App ID could not be determined from the provided app_code.") | |||
| webapp_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id) | |||
| if webapp_settings and webapp_settings.access_mode in modes_requiring_permission_check: | |||
| return True | |||
| return False | |||
| @classmethod | |||
| def get_app_auth_type(cls, app_code: str | None = None, access_mode: str | None = None) -> WebAppAuthType: | |||
| """ | |||
| Get the authentication type for the app based on its access mode. | |||
| """ | |||
| if not app_code and not access_mode: | |||
| raise ValueError("Either app_code or access_mode must be provided.") | |||
| if access_mode: | |||
| if access_mode == "public": | |||
| return WebAppAuthType.PUBLIC | |||
| elif access_mode in ["private", "private_all"]: | |||
| return WebAppAuthType.INTERNAL | |||
| elif access_mode == "sso_verified": | |||
| return WebAppAuthType.EXTERNAL | |||
| if app_code: | |||
| webapp_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code) | |||
| return cls.get_app_auth_type(access_mode=webapp_settings.access_mode) | |||
| raise ValueError("Could not determine app authentication type.") | |||