| @@ -20,6 +20,9 @@ FILES_URL=http://127.0.0.1:5001 | |||
| # The time in seconds after the signature is rejected | |||
| FILES_ACCESS_TIMEOUT=300 | |||
| # Access token expiration time in minutes | |||
| ACCESS_TOKEN_EXPIRE_MINUTES=60 | |||
| # celery configuration | |||
| CELERY_BROKER_URL=redis://:difyai123456@localhost:6379/1 | |||
| @@ -183,7 +183,7 @@ def load_user_from_request(request_from_flask_login): | |||
| decoded = PassportService().verify(auth_token) | |||
| user_id = decoded.get("user_id") | |||
| logged_in_account = AccountService.load_logged_in_account(account_id=user_id, token=auth_token) | |||
| logged_in_account = AccountService.load_logged_in_account(account_id=user_id) | |||
| if logged_in_account: | |||
| contexts.tenant_id.set(logged_in_account.current_tenant_id) | |||
| return logged_in_account | |||
| @@ -360,9 +360,9 @@ class WorkflowConfig(BaseSettings): | |||
| ) | |||
| class OAuthConfig(BaseSettings): | |||
| class AuthConfig(BaseSettings): | |||
| """ | |||
| Configuration for OAuth authentication | |||
| Configuration for authentication and OAuth | |||
| """ | |||
| OAUTH_REDIRECT_PATH: str = Field( | |||
| @@ -371,7 +371,7 @@ class OAuthConfig(BaseSettings): | |||
| ) | |||
| GITHUB_CLIENT_ID: Optional[str] = Field( | |||
| description="GitHub OAuth client secret", | |||
| description="GitHub OAuth client ID", | |||
| default=None, | |||
| ) | |||
| @@ -390,6 +390,11 @@ class OAuthConfig(BaseSettings): | |||
| default=None, | |||
| ) | |||
| ACCESS_TOKEN_EXPIRE_MINUTES: PositiveInt = Field( | |||
| description="Expiration time for access tokens in minutes", | |||
| default=60, | |||
| ) | |||
| class ModerationConfig(BaseSettings): | |||
| """ | |||
| @@ -607,6 +612,7 @@ class PositionConfig(BaseSettings): | |||
| class FeatureConfig( | |||
| # place the configs in alphabet order | |||
| AppExecutionConfig, | |||
| AuthConfig, # Changed from OAuthConfig to AuthConfig | |||
| BillingConfig, | |||
| CodeExecutionSandboxConfig, | |||
| DataSetConfig, | |||
| @@ -621,14 +627,13 @@ class FeatureConfig( | |||
| MailConfig, | |||
| ModelLoadBalanceConfig, | |||
| ModerationConfig, | |||
| OAuthConfig, | |||
| PositionConfig, | |||
| RagEtlConfig, | |||
| SecurityConfig, | |||
| ToolConfig, | |||
| UpdateConfig, | |||
| WorkflowConfig, | |||
| WorkspaceConfig, | |||
| PositionConfig, | |||
| # hosted services config | |||
| HostedServiceConfig, | |||
| CeleryBeatConfig, | |||
| @@ -7,7 +7,7 @@ from flask_restful import Resource, reqparse | |||
| import services | |||
| from controllers.console import api | |||
| from controllers.console.setup import setup_required | |||
| from libs.helper import email, get_remote_ip | |||
| from libs.helper import email, extract_remote_ip | |||
| from libs.password import valid_password | |||
| from models.account import Account | |||
| from services.account_service import AccountService, TenantService | |||
| @@ -40,17 +40,16 @@ class LoginApi(Resource): | |||
| "data": "workspace not found, please contact system admin to invite you to join in a workspace", | |||
| } | |||
| token = AccountService.login(account, ip_address=get_remote_ip(request)) | |||
| token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request)) | |||
| return {"result": "success", "data": token} | |||
| return {"result": "success", "data": token_pair.model_dump()} | |||
| class LogoutApi(Resource): | |||
| @setup_required | |||
| def get(self): | |||
| account = cast(Account, flask_login.current_user) | |||
| token = request.headers.get("Authorization", "").split(" ")[1] | |||
| AccountService.logout(account=account, token=token) | |||
| AccountService.logout(account=account) | |||
| flask_login.logout_user() | |||
| return {"result": "success"} | |||
| @@ -106,5 +105,19 @@ class ResetPasswordApi(Resource): | |||
| return {"result": "success"} | |||
| class RefreshTokenApi(Resource): | |||
| def post(self): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("refresh_token", type=str, required=True, location="json") | |||
| args = parser.parse_args() | |||
| try: | |||
| new_token_pair = AccountService.refresh_token(args["refresh_token"]) | |||
| return {"result": "success", "data": new_token_pair.model_dump()} | |||
| except Exception as e: | |||
| return {"result": "fail", "data": str(e)}, 401 | |||
| api.add_resource(LoginApi, "/login") | |||
| api.add_resource(LogoutApi, "/logout") | |||
| api.add_resource(RefreshTokenApi, "/refresh-token") | |||
| @@ -9,7 +9,7 @@ from flask_restful import Resource | |||
| from configs import dify_config | |||
| from constants.languages import languages | |||
| from extensions.ext_database import db | |||
| from libs.helper import get_remote_ip | |||
| from libs.helper import extract_remote_ip | |||
| from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo | |||
| from models.account import Account, AccountStatus | |||
| from services.account_service import AccountService, RegisterService, TenantService | |||
| @@ -81,9 +81,14 @@ class OAuthCallback(Resource): | |||
| TenantService.create_owner_tenant_if_not_exist(account) | |||
| token = AccountService.login(account, ip_address=get_remote_ip(request)) | |||
| token_pair = AccountService.login( | |||
| account=account, | |||
| ip_address=extract_remote_ip(request), | |||
| ) | |||
| return redirect(f"{dify_config.CONSOLE_WEB_URL}?console_token={token}") | |||
| return redirect( | |||
| f"{dify_config.CONSOLE_WEB_URL}?access_token={token_pair.access_token}&refresh_token={token_pair.refresh_token}" | |||
| ) | |||
| def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]: | |||
| @@ -4,7 +4,7 @@ from flask import request | |||
| from flask_restful import Resource, reqparse | |||
| from configs import dify_config | |||
| from libs.helper import StrLen, email, get_remote_ip | |||
| from libs.helper import StrLen, email, extract_remote_ip | |||
| from libs.password import valid_password | |||
| from models.model import DifySetup | |||
| from services.account_service import RegisterService, TenantService | |||
| @@ -46,7 +46,7 @@ class SetupApi(Resource): | |||
| # setup | |||
| RegisterService.setup( | |||
| email=args["email"], name=args["name"], password=args["password"], ip_address=get_remote_ip(request) | |||
| email=args["email"], name=args["name"], password=args["password"], ip_address=extract_remote_ip(request) | |||
| ) | |||
| return {"result": "success"}, 201 | |||
| @@ -162,7 +162,7 @@ def generate_string(n): | |||
| return result | |||
| def get_remote_ip(request) -> str: | |||
| def extract_remote_ip(request) -> str: | |||
| if request.headers.get("CF-Connecting-IP"): | |||
| return request.headers.get("Cf-Connecting-Ip") | |||
| elif request.headers.getlist("X-Forwarded-For"): | |||
| @@ -7,6 +7,7 @@ from datetime import datetime, timedelta, timezone | |||
| from hashlib import sha256 | |||
| from typing import Any, Optional | |||
| from pydantic import BaseModel | |||
| from sqlalchemy import func | |||
| from werkzeug.exceptions import Unauthorized | |||
| @@ -49,9 +50,39 @@ from tasks.mail_invite_member_task import send_invite_member_mail_task | |||
| from tasks.mail_reset_password_task import send_reset_password_mail_task | |||
| class TokenPair(BaseModel): | |||
| access_token: str | |||
| refresh_token: str | |||
| REFRESH_TOKEN_PREFIX = "refresh_token:" | |||
| ACCOUNT_REFRESH_TOKEN_PREFIX = "account_refresh_token:" | |||
| REFRESH_TOKEN_EXPIRY = timedelta(days=30) | |||
| class AccountService: | |||
| reset_password_rate_limiter = RateLimiter(prefix="reset_password_rate_limit", max_attempts=5, time_window=60 * 60) | |||
| @staticmethod | |||
| def _get_refresh_token_key(refresh_token: str) -> str: | |||
| return f"{REFRESH_TOKEN_PREFIX}{refresh_token}" | |||
| @staticmethod | |||
| def _get_account_refresh_token_key(account_id: str) -> str: | |||
| return f"{ACCOUNT_REFRESH_TOKEN_PREFIX}{account_id}" | |||
| @staticmethod | |||
| def _store_refresh_token(refresh_token: str, account_id: str) -> None: | |||
| redis_client.setex(AccountService._get_refresh_token_key(refresh_token), REFRESH_TOKEN_EXPIRY, account_id) | |||
| redis_client.setex( | |||
| AccountService._get_account_refresh_token_key(account_id), REFRESH_TOKEN_EXPIRY, refresh_token | |||
| ) | |||
| @staticmethod | |||
| def _delete_refresh_token(refresh_token: str, account_id: str) -> None: | |||
| redis_client.delete(AccountService._get_refresh_token_key(refresh_token)) | |||
| redis_client.delete(AccountService._get_account_refresh_token_key(account_id)) | |||
| @staticmethod | |||
| def load_user(user_id: str) -> None | Account: | |||
| account = Account.query.filter_by(id=user_id).first() | |||
| @@ -61,9 +92,7 @@ class AccountService: | |||
| if account.status in {AccountStatus.BANNED.value, AccountStatus.CLOSED.value}: | |||
| raise Unauthorized("Account is banned or closed.") | |||
| current_tenant: TenantAccountJoin = TenantAccountJoin.query.filter_by( | |||
| account_id=account.id, current=True | |||
| ).first() | |||
| current_tenant = TenantAccountJoin.query.filter_by(account_id=account.id, current=True).first() | |||
| if current_tenant: | |||
| account.current_tenant_id = current_tenant.tenant_id | |||
| else: | |||
| @@ -84,10 +113,12 @@ class AccountService: | |||
| return account | |||
| @staticmethod | |||
| def get_account_jwt_token(account, *, exp: timedelta = timedelta(days=30)): | |||
| def get_account_jwt_token(account: Account) -> str: | |||
| exp_dt = datetime.now(timezone.utc) + timedelta(minutes=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES) | |||
| exp = int(exp_dt.timestamp()) | |||
| payload = { | |||
| "user_id": account.id, | |||
| "exp": datetime.now(timezone.utc).replace(tzinfo=None) + exp, | |||
| "exp": exp, | |||
| "iss": dify_config.EDITION, | |||
| "sub": "Console API Passport", | |||
| } | |||
| @@ -213,7 +244,7 @@ class AccountService: | |||
| return account | |||
| @staticmethod | |||
| def update_last_login(account: Account, *, ip_address: str) -> None: | |||
| def update_login_info(account: Account, *, ip_address: str) -> None: | |||
| """Update last login time and ip""" | |||
| account.last_login_at = datetime.now(timezone.utc).replace(tzinfo=None) | |||
| account.last_login_ip = ip_address | |||
| @@ -221,22 +252,45 @@ class AccountService: | |||
| db.session.commit() | |||
| @staticmethod | |||
| def login(account: Account, *, ip_address: Optional[str] = None): | |||
| def login(account: Account, *, ip_address: Optional[str] = None) -> TokenPair: | |||
| if ip_address: | |||
| AccountService.update_last_login(account, ip_address=ip_address) | |||
| exp = timedelta(days=30) | |||
| token = AccountService.get_account_jwt_token(account, exp=exp) | |||
| redis_client.set(_get_login_cache_key(account_id=account.id, token=token), "1", ex=int(exp.total_seconds())) | |||
| return token | |||
| AccountService.update_login_info(account=account, ip_address=ip_address) | |||
| access_token = AccountService.get_account_jwt_token(account=account) | |||
| refresh_token = _generate_refresh_token() | |||
| AccountService._store_refresh_token(refresh_token, account.id) | |||
| return TokenPair(access_token=access_token, refresh_token=refresh_token) | |||
| @staticmethod | |||
| def logout(*, account: Account, token: str): | |||
| redis_client.delete(_get_login_cache_key(account_id=account.id, token=token)) | |||
| def logout(*, account: Account) -> None: | |||
| refresh_token = redis_client.get(AccountService._get_account_refresh_token_key(account.id)) | |||
| if refresh_token: | |||
| AccountService._delete_refresh_token(refresh_token.decode("utf-8"), account.id) | |||
| @staticmethod | |||
| def load_logged_in_account(*, account_id: str, token: str): | |||
| if not redis_client.get(_get_login_cache_key(account_id=account_id, token=token)): | |||
| return None | |||
| def refresh_token(refresh_token: str) -> TokenPair: | |||
| # Verify the refresh token | |||
| account_id = redis_client.get(AccountService._get_refresh_token_key(refresh_token)) | |||
| if not account_id: | |||
| raise ValueError("Invalid refresh token") | |||
| account = AccountService.load_user(account_id.decode("utf-8")) | |||
| if not account: | |||
| raise ValueError("Invalid account") | |||
| # Generate new access token and refresh token | |||
| new_access_token = AccountService.get_account_jwt_token(account) | |||
| new_refresh_token = _generate_refresh_token() | |||
| AccountService._delete_refresh_token(refresh_token, account.id) | |||
| AccountService._store_refresh_token(new_refresh_token, account.id) | |||
| return TokenPair(access_token=new_access_token, refresh_token=new_refresh_token) | |||
| @staticmethod | |||
| def load_logged_in_account(*, account_id: str): | |||
| return AccountService.load_user(account_id) | |||
| @classmethod | |||
| @@ -258,10 +312,6 @@ class AccountService: | |||
| return TokenManager.get_token_data(token, "reset_password") | |||
| def _get_login_cache_key(*, account_id: str, token: str): | |||
| return f"account_login:{account_id}:{token}" | |||
| class TenantService: | |||
| @staticmethod | |||
| def create_tenant(name: str) -> Tenant: | |||
| @@ -698,3 +748,8 @@ class RegisterService: | |||
| invitation = json.loads(data) | |||
| return invitation | |||
| def _generate_refresh_token(length: int = 64): | |||
| token = secrets.token_hex(length) | |||
| return token | |||
| @@ -91,6 +91,9 @@ MIGRATION_ENABLED=true | |||
| # The default value is 300 seconds. | |||
| FILES_ACCESS_TIMEOUT=300 | |||
| # Access token expiration time in minutes | |||
| ACCESS_TOKEN_EXPIRE_MINUTES=60 | |||
| # The maximum number of active requests for the application, where 0 means unlimited, should be a non-negative integer. | |||
| APP_MAX_ACTIVE_REQUESTS=0 | |||
| @@ -47,6 +47,7 @@ x-shared-env: &shared-api-worker-env | |||
| REDIS_SENTINEL_SERVICE_NAME: ${REDIS_SENTINEL_SERVICE_NAME:-} | |||
| REDIS_SENTINEL_USERNAME: ${REDIS_SENTINEL_USERNAME:-} | |||
| REDIS_SENTINEL_PASSWORD: ${REDIS_SENTINEL_PASSWORD:-} | |||
| ACCESS_TOKEN_EXPIRE_MINUTES: ${ACCESS_TOKEN_EXPIRE_MINUTES:-60} | |||
| REDIS_SENTINEL_SOCKET_TIMEOUT: ${REDIS_SENTINEL_SOCKET_TIMEOUT:-0.1} | |||
| CELERY_BROKER_URL: ${CELERY_BROKER_URL:-redis://:difyai123456@redis:6379/1} | |||
| BROKER_USE_SSL: ${BROKER_USE_SSL:-false} | |||