| @@ -2,7 +2,7 @@ import os | |||
| from configs.app_configs import DifyConfigs | |||
| if not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true': | |||
| if not os.environ.get("DEBUG") or os.environ.get("DEBUG", "false").lower() != 'true': | |||
| from gevent import monkey | |||
| monkey.patch_all() | |||
| @@ -152,27 +152,26 @@ def initialize_extensions(app): | |||
| @login_manager.request_loader | |||
| def load_user_from_request(request_from_flask_login): | |||
| """Load user based on the request.""" | |||
| if request.blueprint in ['console', 'inner_api']: | |||
| # Check if the user_id contains a dot, indicating the old format | |||
| auth_header = request.headers.get('Authorization', '') | |||
| if not auth_header: | |||
| auth_token = request.args.get('_token') | |||
| if not auth_token: | |||
| raise Unauthorized('Invalid Authorization token.') | |||
| else: | |||
| if ' ' not in auth_header: | |||
| raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.') | |||
| auth_scheme, auth_token = auth_header.split(None, 1) | |||
| auth_scheme = auth_scheme.lower() | |||
| if auth_scheme != 'bearer': | |||
| raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.') | |||
| decoded = PassportService().verify(auth_token) | |||
| user_id = decoded.get('user_id') | |||
| return AccountService.load_user(user_id) | |||
| else: | |||
| if request.blueprint not in ['console', 'inner_api']: | |||
| return None | |||
| # Check if the user_id contains a dot, indicating the old format | |||
| auth_header = request.headers.get('Authorization', '') | |||
| if not auth_header: | |||
| auth_token = request.args.get('_token') | |||
| if not auth_token: | |||
| raise Unauthorized('Invalid Authorization token.') | |||
| else: | |||
| if ' ' not in auth_header: | |||
| raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.') | |||
| auth_scheme, auth_token = auth_header.split(None, 1) | |||
| auth_scheme = auth_scheme.lower() | |||
| if auth_scheme != 'bearer': | |||
| raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.') | |||
| decoded = PassportService().verify(auth_token) | |||
| user_id = decoded.get('user_id') | |||
| return AccountService.load_logged_in_account(account_id=user_id, token=auth_token) | |||
| @login_manager.unauthorized_handler | |||
| @@ -1,3 +1,5 @@ | |||
| from typing import cast | |||
| import flask_login | |||
| from flask import current_app, request | |||
| from flask_restful import Resource, reqparse | |||
| @@ -5,8 +7,9 @@ from flask_restful import Resource, reqparse | |||
| import services | |||
| from controllers.console import api | |||
| from controllers.console.setup import setup_required | |||
| from libs.helper import email | |||
| from libs.helper import email, get_remote_ip | |||
| from libs.password import valid_password | |||
| from models.account import Account | |||
| from services.account_service import AccountService, TenantService | |||
| @@ -34,10 +37,7 @@ class LoginApi(Resource): | |||
| if len(tenants) == 0: | |||
| return {'result': 'fail', 'data': 'workspace not found, please contact system admin to invite you to join in a workspace'} | |||
| AccountService.update_last_login(account, request) | |||
| # todo: return the user info | |||
| token = AccountService.get_account_jwt_token(account) | |||
| token = AccountService.login(account, ip_address=get_remote_ip(request)) | |||
| return {'result': 'success', 'data': token} | |||
| @@ -46,6 +46,9 @@ class LogoutApi(Resource): | |||
| @setup_required | |||
| def get(self): | |||
| account = cast(Account, flask_login.current_user) | |||
| token = request.headers.get('Authorization', '').split(' ')[1] | |||
| AccountService.logout(account=account, token=token) | |||
| flask_login.logout_user() | |||
| return {'result': 'success'} | |||
| @@ -8,6 +8,7 @@ from flask_restful import Resource | |||
| from constants.languages import languages | |||
| from extensions.ext_database import db | |||
| from libs.helper import get_remote_ip | |||
| from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo | |||
| from models.account import Account, AccountStatus | |||
| from services.account_service import AccountService, RegisterService, TenantService | |||
| @@ -78,9 +79,7 @@ class OAuthCallback(Resource): | |||
| TenantService.create_owner_tenant_if_not_exist(account) | |||
| AccountService.update_last_login(account, request) | |||
| token = AccountService.get_account_jwt_token(account) | |||
| token = AccountService.login(account, ip_address=get_remote_ip(request)) | |||
| return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?console_token={token}') | |||
| @@ -4,7 +4,7 @@ from flask import current_app, request | |||
| from flask_restful import Resource, reqparse | |||
| from extensions.ext_database import db | |||
| from libs.helper import email, str_len | |||
| from libs.helper import email, get_remote_ip, str_len | |||
| from libs.password import valid_password | |||
| from models.model import DifySetup | |||
| from services.account_service import AccountService, RegisterService, TenantService | |||
| @@ -61,7 +61,7 @@ class SetupApi(Resource): | |||
| TenantService.create_owner_tenant_if_not_exist(account) | |||
| setup() | |||
| AccountService.update_last_login(account, request) | |||
| AccountService.update_last_login(account, ip_address=get_remote_ip(request)) | |||
| return {'result': 'success'}, 201 | |||
| @@ -140,7 +140,7 @@ def generate_string(n): | |||
| return result | |||
| def get_remote_ip(request): | |||
| def get_remote_ip(request) -> str: | |||
| if request.headers.get('CF-Connecting-IP'): | |||
| return request.headers.get('Cf-Connecting-Ip') | |||
| elif request.headers.getlist("X-Forwarded-For"): | |||
| @@ -1 +1,3 @@ | |||
| import services.errors | |||
| from . import errors | |||
| __all__ = ['errors'] | |||
| @@ -13,7 +13,6 @@ from werkzeug.exceptions import Unauthorized | |||
| from constants.languages import language_timezone_mapping, languages | |||
| from events.tenant_event import tenant_was_created | |||
| from extensions.ext_redis import redis_client | |||
| from libs.helper import get_remote_ip | |||
| from libs.passport import PassportService | |||
| from libs.password import compare_password, hash_password, valid_password | |||
| from libs.rsa import generate_key_pair | |||
| @@ -67,10 +66,10 @@ class AccountService: | |||
| @staticmethod | |||
| def get_account_jwt_token(account): | |||
| def get_account_jwt_token(account, *, exp: timedelta = timedelta(days=30)): | |||
| payload = { | |||
| "user_id": account.id, | |||
| "exp": datetime.now(timezone.utc).replace(tzinfo=None) + timedelta(days=30), | |||
| "exp": datetime.now(timezone.utc).replace(tzinfo=None) + exp, | |||
| "iss": current_app.config['EDITION'], | |||
| "sub": 'Console API Passport', | |||
| } | |||
| @@ -195,14 +194,35 @@ class AccountService: | |||
| return account | |||
| @staticmethod | |||
| def update_last_login(account: Account, request) -> None: | |||
| def update_last_login(account: Account, *, ip_address: str) -> None: | |||
| """Update last login time and ip""" | |||
| account.last_login_at = datetime.now(timezone.utc).replace(tzinfo=None) | |||
| account.last_login_ip = get_remote_ip(request) | |||
| account.last_login_ip = ip_address | |||
| db.session.add(account) | |||
| db.session.commit() | |||
| logging.info(f'Account {account.id} logged in successfully.') | |||
| @staticmethod | |||
| def login(account: Account, *, ip_address: Optional[str] = None): | |||
| if ip_address: | |||
| AccountService.update_last_login(account, ip_address=ip_address) | |||
| exp = timedelta(days=30) | |||
| token = AccountService.get_account_jwt_token(account, exp=exp) | |||
| redis_client.set(_get_login_cache_key(account_id=account.id, token=token), '1', ex=int(exp.total_seconds())) | |||
| return token | |||
| @staticmethod | |||
| def logout(*, account: Account, token: str): | |||
| redis_client.delete(_get_login_cache_key(account_id=account.id, token=token)) | |||
| @staticmethod | |||
| def load_logged_in_account(*, account_id: str, token: str): | |||
| if not redis_client.get(_get_login_cache_key(account_id=account_id, token=token)): | |||
| return None | |||
| return AccountService.load_user(account_id) | |||
| def _get_login_cache_key(*, account_id: str, token: str): | |||
| return f"account_login:{account_id}:{token}" | |||
| class TenantService: | |||
| @@ -1,6 +1,29 @@ | |||
| from . import ( | |||
| account, | |||
| app, | |||
| app_model_config, | |||
| audio, | |||
| base, | |||
| completion, | |||
| conversation, | |||
| dataset, | |||
| document, | |||
| file, | |||
| index, | |||
| message, | |||
| ) | |||
| __all__ = [ | |||
| 'base', 'conversation', 'message', 'index', 'app_model_config', 'account', 'document', 'dataset', | |||
| 'app', 'completion', 'audio', 'file' | |||
| "base", | |||
| "conversation", | |||
| "message", | |||
| "index", | |||
| "app_model_config", | |||
| "account", | |||
| "document", | |||
| "dataset", | |||
| "app", | |||
| "completion", | |||
| "audio", | |||
| "file", | |||
| ] | |||
| from . import * | |||