| WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,* | WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,* | ||||
| CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,* | CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,* | ||||
| # Cookie configuration | |||||
| COOKIE_HTTPONLY=true | |||||
| COOKIE_SAMESITE=None | |||||
| COOKIE_SECURE=true | |||||
| # Session configuration | |||||
| SESSION_PERMANENT=true | |||||
| SESSION_USE_SIGNER=true | |||||
| ## support redis, sqlalchemy | |||||
| SESSION_TYPE=redis | |||||
| # session redis configuration | |||||
| SESSION_REDIS_HOST=localhost | |||||
| SESSION_REDIS_PORT=6379 | |||||
| SESSION_REDIS_PASSWORD=difyai123456 | |||||
| SESSION_REDIS_DB=2 | |||||
| # Vector database configuration, support: weaviate, qdrant | # Vector database configuration, support: weaviate, qdrant | ||||
| VECTOR_STORE=weaviate | VECTOR_STORE=weaviate | ||||
| # -*- coding:utf-8 -*- | # -*- coding:utf-8 -*- | ||||
| import os | import os | ||||
| from datetime import datetime, timedelta | |||||
| from werkzeug.exceptions import Forbidden | |||||
| from werkzeug.exceptions import Unauthorized | |||||
| if not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true': | if not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true': | ||||
| from gevent import monkey | from gevent import monkey | ||||
| import json | import json | ||||
| import threading | import threading | ||||
| from flask import Flask, request, Response, session | |||||
| import flask_login | |||||
| from flask import Flask, request, Response | |||||
| from flask_cors import CORS | from flask_cors import CORS | ||||
| from core.model_providers.providers import hosted | from core.model_providers.providers import hosted | ||||
| from extensions import ext_session, ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \ | |||||
| from extensions import ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \ | |||||
| ext_database, ext_storage, ext_mail, ext_stripe | ext_database, ext_storage, ext_mail, ext_stripe | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from extensions.ext_login import login_manager | from extensions.ext_login import login_manager | ||||
| from events import event_handlers | from events import event_handlers | ||||
| # DO NOT REMOVE ABOVE | # DO NOT REMOVE ABOVE | ||||
| import core | |||||
| from config import Config, CloudEditionConfig | from config import Config, CloudEditionConfig | ||||
| from commands import register_commands | from commands import register_commands | ||||
| from models.account import TenantAccountJoin, AccountStatus | |||||
| from models.model import Account, EndUser, App | |||||
| from services.account_service import TenantService | |||||
| from services.account_service import AccountService | |||||
| from libs.passport import PassportService | |||||
| import warnings | import warnings | ||||
| warnings.simplefilter("ignore", ResourceWarning) | warnings.simplefilter("ignore", ResourceWarning) | ||||
| ext_redis.init_app(app) | ext_redis.init_app(app) | ||||
| ext_storage.init_app(app) | ext_storage.init_app(app) | ||||
| ext_celery.init_app(app) | ext_celery.init_app(app) | ||||
| ext_session.init_app(app) | |||||
| ext_login.init_app(app) | ext_login.init_app(app) | ||||
| ext_mail.init_app(app) | ext_mail.init_app(app) | ||||
| ext_sentry.init_app(app) | ext_sentry.init_app(app) | ||||
| ext_stripe.init_app(app) | ext_stripe.init_app(app) | ||||
| def _create_tenant_for_account(account): | |||||
| tenant = TenantService.create_tenant(f"{account.name}'s Workspace") | |||||
| TenantService.create_tenant_member(tenant, account, role='owner') | |||||
| account.current_tenant = tenant | |||||
| return tenant | |||||
| # Flask-Login configuration | # Flask-Login configuration | ||||
| @login_manager.user_loader | |||||
| def load_user(user_id): | |||||
| """Load user based on the user_id.""" | |||||
| @login_manager.request_loader | |||||
| def load_user_from_request(request_from_flask_login): | |||||
| """Load user based on the request.""" | |||||
| if request.blueprint == 'console': | if request.blueprint == 'console': | ||||
| # Check if the user_id contains a dot, indicating the old format | # Check if the user_id contains a dot, indicating the old format | ||||
| if '.' in user_id: | |||||
| tenant_id, account_id = user_id.split('.') | |||||
| else: | |||||
| account_id = user_id | |||||
| account = db.session.query(Account).filter(Account.id == account_id).first() | |||||
| if account: | |||||
| if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value: | |||||
| raise Forbidden('Account is banned or closed.') | |||||
| workspace_id = session.get('workspace_id') | |||||
| if workspace_id: | |||||
| tenant_account_join = db.session.query(TenantAccountJoin).filter( | |||||
| TenantAccountJoin.account_id == account.id, | |||||
| TenantAccountJoin.tenant_id == workspace_id | |||||
| ).first() | |||||
| if not tenant_account_join: | |||||
| tenant_account_join = db.session.query(TenantAccountJoin).filter( | |||||
| TenantAccountJoin.account_id == account.id).first() | |||||
| if tenant_account_join: | |||||
| account.current_tenant_id = tenant_account_join.tenant_id | |||||
| else: | |||||
| _create_tenant_for_account(account) | |||||
| session['workspace_id'] = account.current_tenant_id | |||||
| else: | |||||
| account.current_tenant_id = workspace_id | |||||
| else: | |||||
| tenant_account_join = db.session.query(TenantAccountJoin).filter( | |||||
| TenantAccountJoin.account_id == account.id).first() | |||||
| if tenant_account_join: | |||||
| account.current_tenant_id = tenant_account_join.tenant_id | |||||
| else: | |||||
| _create_tenant_for_account(account) | |||||
| session['workspace_id'] = account.current_tenant_id | |||||
| current_time = datetime.utcnow() | |||||
| # update last_active_at when last_active_at is more than 10 minutes ago | |||||
| if current_time - account.last_active_at > timedelta(minutes=10): | |||||
| account.last_active_at = current_time | |||||
| db.session.commit() | |||||
| # Log in the user with the updated user_id | |||||
| flask_login.login_user(account, remember=True) | |||||
| return account | |||||
| auth_header = request.headers.get('Authorization', '') | |||||
| if ' ' not in auth_header: | |||||
| raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.') | |||||
| auth_scheme, auth_token = auth_header.split(None, 1) | |||||
| auth_scheme = auth_scheme.lower() | |||||
| if auth_scheme != 'bearer': | |||||
| raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.') | |||||
| decoded = PassportService().verify(auth_token) | |||||
| user_id = decoded.get('user_id') | |||||
| return AccountService.load_user(user_id) | |||||
| else: | else: | ||||
| return None | return None | ||||
| @login_manager.unauthorized_handler | @login_manager.unauthorized_handler | ||||
| def unauthorized_handler(): | def unauthorized_handler(): | ||||
| """Handle unauthorized requests.""" | """Handle unauthorized requests.""" | ||||
| @app.after_request | @app.after_request | ||||
| def after_request(response): | def after_request(response): | ||||
| """Add Version headers to the response.""" | """Add Version headers to the response.""" | ||||
| response.set_cookie('remember_token', '', expires=0) | |||||
| response.headers.add('X-Version', app.config['CURRENT_VERSION']) | response.headers.add('X-Version', app.config['CURRENT_VERSION']) | ||||
| response.headers.add('X-Env', app.config['DEPLOY_ENV']) | response.headers.add('X-Env', app.config['DEPLOY_ENV']) | ||||
| return response | return response |
| dotenv.load_dotenv() | dotenv.load_dotenv() | ||||
| DEFAULTS = { | DEFAULTS = { | ||||
| 'COOKIE_HTTPONLY': 'True', | |||||
| 'COOKIE_SECURE': 'True', | |||||
| 'COOKIE_SAMESITE': 'None', | |||||
| 'DB_USERNAME': 'postgres', | 'DB_USERNAME': 'postgres', | ||||
| 'DB_PASSWORD': '', | 'DB_PASSWORD': '', | ||||
| 'DB_HOST': 'localhost', | 'DB_HOST': 'localhost', | ||||
| 'REDIS_PORT': '6379', | 'REDIS_PORT': '6379', | ||||
| 'REDIS_DB': '0', | 'REDIS_DB': '0', | ||||
| 'REDIS_USE_SSL': 'False', | 'REDIS_USE_SSL': 'False', | ||||
| 'SESSION_REDIS_HOST': 'localhost', | |||||
| 'SESSION_REDIS_PORT': '6379', | |||||
| 'SESSION_REDIS_DB': '2', | |||||
| 'SESSION_REDIS_USE_SSL': 'False', | |||||
| 'OAUTH_REDIRECT_PATH': '/console/api/oauth/authorize', | 'OAUTH_REDIRECT_PATH': '/console/api/oauth/authorize', | ||||
| 'OAUTH_REDIRECT_INDEX_PATH': '/', | 'OAUTH_REDIRECT_INDEX_PATH': '/', | ||||
| 'CONSOLE_WEB_URL': 'https://cloud.dify.ai', | 'CONSOLE_WEB_URL': 'https://cloud.dify.ai', | ||||
| 'STORAGE_TYPE': 'local', | 'STORAGE_TYPE': 'local', | ||||
| 'STORAGE_LOCAL_PATH': 'storage', | 'STORAGE_LOCAL_PATH': 'storage', | ||||
| 'CHECK_UPDATE_URL': 'https://updates.dify.ai', | 'CHECK_UPDATE_URL': 'https://updates.dify.ai', | ||||
| 'SESSION_TYPE': 'sqlalchemy', | |||||
| 'SESSION_PERMANENT': 'True', | |||||
| 'SESSION_USE_SIGNER': 'True', | |||||
| 'DEPLOY_ENV': 'PRODUCTION', | 'DEPLOY_ENV': 'PRODUCTION', | ||||
| 'SQLALCHEMY_POOL_SIZE': 30, | 'SQLALCHEMY_POOL_SIZE': 30, | ||||
| 'SQLALCHEMY_POOL_RECYCLE': 3600, | 'SQLALCHEMY_POOL_RECYCLE': 3600, | ||||
| # Alternatively you can set it with `SECRET_KEY` environment variable. | # Alternatively you can set it with `SECRET_KEY` environment variable. | ||||
| self.SECRET_KEY = get_env('SECRET_KEY') | self.SECRET_KEY = get_env('SECRET_KEY') | ||||
| # cookie settings | |||||
| self.REMEMBER_COOKIE_HTTPONLY = get_bool_env('COOKIE_HTTPONLY') | |||||
| self.SESSION_COOKIE_HTTPONLY = get_bool_env('COOKIE_HTTPONLY') | |||||
| self.REMEMBER_COOKIE_SAMESITE = get_env('COOKIE_SAMESITE') | |||||
| self.SESSION_COOKIE_SAMESITE = get_env('COOKIE_SAMESITE') | |||||
| self.REMEMBER_COOKIE_SECURE = get_bool_env('COOKIE_SECURE') | |||||
| self.SESSION_COOKIE_SECURE = get_bool_env('COOKIE_SECURE') | |||||
| self.PERMANENT_SESSION_LIFETIME = timedelta(days=7) | |||||
| # session settings, only support sqlalchemy, redis | |||||
| self.SESSION_TYPE = get_env('SESSION_TYPE') | |||||
| self.SESSION_PERMANENT = get_bool_env('SESSION_PERMANENT') | |||||
| self.SESSION_USE_SIGNER = get_bool_env('SESSION_USE_SIGNER') | |||||
| # redis settings | # redis settings | ||||
| self.REDIS_HOST = get_env('REDIS_HOST') | self.REDIS_HOST = get_env('REDIS_HOST') | ||||
| self.REDIS_PORT = get_env('REDIS_PORT') | self.REDIS_PORT = get_env('REDIS_PORT') | ||||
| self.REDIS_DB = get_env('REDIS_DB') | self.REDIS_DB = get_env('REDIS_DB') | ||||
| self.REDIS_USE_SSL = get_bool_env('REDIS_USE_SSL') | self.REDIS_USE_SSL = get_bool_env('REDIS_USE_SSL') | ||||
| # session redis settings | |||||
| self.SESSION_REDIS_HOST = get_env('SESSION_REDIS_HOST') | |||||
| self.SESSION_REDIS_PORT = get_env('SESSION_REDIS_PORT') | |||||
| self.SESSION_REDIS_USERNAME = get_env('SESSION_REDIS_USERNAME') | |||||
| self.SESSION_REDIS_PASSWORD = get_env('SESSION_REDIS_PASSWORD') | |||||
| self.SESSION_REDIS_DB = get_env('SESSION_REDIS_DB') | |||||
| self.SESSION_REDIS_USE_SSL = get_bool_env('SESSION_REDIS_USE_SSL') | |||||
| # storage settings | # storage settings | ||||
| self.STORAGE_TYPE = get_env('STORAGE_TYPE') | self.STORAGE_TYPE = get_env('STORAGE_TYPE') | ||||
| self.STORAGE_LOCAL_PATH = get_env('STORAGE_LOCAL_PATH') | self.STORAGE_LOCAL_PATH = get_env('STORAGE_LOCAL_PATH') |
| import services | import services | ||||
| from controllers.console import api | from controllers.console import api | ||||
| from controllers.console.error import AccountNotLinkTenantError | |||||
| from controllers.console.setup import setup_required | from controllers.console.setup import setup_required | ||||
| from libs.helper import email | from libs.helper import email | ||||
| from libs.password import valid_password | from libs.password import valid_password | ||||
| except Exception: | except Exception: | ||||
| pass | pass | ||||
| flask_login.login_user(account, remember=args['remember_me']) | |||||
| AccountService.update_last_login(account, request) | AccountService.update_last_login(account, request) | ||||
| # todo: return the user info | # todo: return the user info | ||||
| token = AccountService.get_account_jwt_token(account) | |||||
| return {'result': 'success'} | |||||
| return {'result': 'success', 'data': token} | |||||
| class LogoutApi(Resource): | class LogoutApi(Resource): |
| from datetime import datetime | from datetime import datetime | ||||
| from typing import Optional | from typing import Optional | ||||
| import flask_login | |||||
| import requests | import requests | ||||
| from flask import request, redirect, current_app, session | |||||
| from flask import request, redirect, current_app | |||||
| from flask_restful import Resource | from flask_restful import Resource | ||||
| from libs.oauth import OAuthUserInfo, GitHubOAuth, GoogleOAuth | from libs.oauth import OAuthUserInfo, GitHubOAuth, GoogleOAuth | ||||
| account.initialized_at = datetime.utcnow() | account.initialized_at = datetime.utcnow() | ||||
| db.session.commit() | db.session.commit() | ||||
| # login user | |||||
| session.clear() | |||||
| flask_login.login_user(account, remember=True) | |||||
| AccountService.update_last_login(account, request) | AccountService.update_last_login(account, request) | ||||
| return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?oauth_login=success') | |||||
| token = AccountService.get_account_jwt_token(account) | |||||
| return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?console_token={token}') | |||||
| def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]: | def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]: |
| # -*- coding:utf-8 -*- | # -*- coding:utf-8 -*- | ||||
| from functools import wraps | from functools import wraps | ||||
| import flask_login | |||||
| from flask import request, current_app | from flask import request, current_app | ||||
| from flask_restful import Resource, reqparse | from flask_restful import Resource, reqparse | ||||
| ) | ) | ||||
| setup() | setup() | ||||
| # Login | |||||
| flask_login.login_user(account) | |||||
| AccountService.update_last_login(account, request) | AccountService.update_last_login(account, request) | ||||
| return {'result': 'success'}, 201 | return {'result': 'success'}, 201 |
| import os | import os | ||||
| from functools import wraps | from functools import wraps | ||||
| import flask_login | |||||
| from flask import current_app | from flask import current_app | ||||
| from flask import g | from flask import g | ||||
| from flask import has_request_context | from flask import has_request_context | ||||
| from flask import request | |||||
| from flask import request, session | |||||
| from flask_login import user_logged_in | from flask_login import user_logged_in | ||||
| from flask_login.config import EXEMPT_METHODS | from flask_login.config import EXEMPT_METHODS | ||||
| from werkzeug.exceptions import Unauthorized | from werkzeug.exceptions import Unauthorized |
| import redis | |||||
| from redis.connection import SSLConnection, Connection | |||||
| from flask import request | |||||
| from flask_session import Session, SqlAlchemySessionInterface, RedisSessionInterface | |||||
| from flask_session.sessions import total_seconds | |||||
| from itsdangerous import want_bytes | |||||
| from extensions.ext_database import db | |||||
| sess = Session() | |||||
| def init_app(app): | |||||
| sqlalchemy_session_interface = CustomSqlAlchemySessionInterface( | |||||
| app, | |||||
| db, | |||||
| app.config.get('SESSION_SQLALCHEMY_TABLE', 'sessions'), | |||||
| app.config.get('SESSION_KEY_PREFIX', 'session:'), | |||||
| app.config.get('SESSION_USE_SIGNER', False), | |||||
| app.config.get('SESSION_PERMANENT', True) | |||||
| ) | |||||
| session_type = app.config.get('SESSION_TYPE') | |||||
| if session_type == 'sqlalchemy': | |||||
| app.session_interface = sqlalchemy_session_interface | |||||
| elif session_type == 'redis': | |||||
| connection_class = Connection | |||||
| if app.config.get('SESSION_REDIS_USE_SSL', False): | |||||
| connection_class = SSLConnection | |||||
| sess_redis_client = redis.Redis() | |||||
| sess_redis_client.connection_pool = redis.ConnectionPool(**{ | |||||
| 'host': app.config.get('SESSION_REDIS_HOST', 'localhost'), | |||||
| 'port': app.config.get('SESSION_REDIS_PORT', 6379), | |||||
| 'username': app.config.get('SESSION_REDIS_USERNAME', None), | |||||
| 'password': app.config.get('SESSION_REDIS_PASSWORD', None), | |||||
| 'db': app.config.get('SESSION_REDIS_DB', 2), | |||||
| 'encoding': 'utf-8', | |||||
| 'encoding_errors': 'strict', | |||||
| 'decode_responses': False | |||||
| }, connection_class=connection_class) | |||||
| app.extensions['session_redis'] = sess_redis_client | |||||
| app.session_interface = CustomRedisSessionInterface( | |||||
| sess_redis_client, | |||||
| app.config.get('SESSION_KEY_PREFIX', 'session:'), | |||||
| app.config.get('SESSION_USE_SIGNER', False), | |||||
| app.config.get('SESSION_PERMANENT', True) | |||||
| ) | |||||
| class CustomSqlAlchemySessionInterface(SqlAlchemySessionInterface): | |||||
| def __init__( | |||||
| self, | |||||
| app, | |||||
| db, | |||||
| table, | |||||
| key_prefix, | |||||
| use_signer=False, | |||||
| permanent=True, | |||||
| sequence=None, | |||||
| autodelete=False, | |||||
| ): | |||||
| if db is None: | |||||
| from flask_sqlalchemy import SQLAlchemy | |||||
| db = SQLAlchemy(app) | |||||
| self.db = db | |||||
| self.key_prefix = key_prefix | |||||
| self.use_signer = use_signer | |||||
| self.permanent = permanent | |||||
| self.autodelete = autodelete | |||||
| self.sequence = sequence | |||||
| self.has_same_site_capability = hasattr(self, "get_cookie_samesite") | |||||
| class Session(self.db.Model): | |||||
| __tablename__ = table | |||||
| if sequence: | |||||
| id = self.db.Column( # noqa: A003, VNE003, A001 | |||||
| self.db.Integer, self.db.Sequence(sequence), primary_key=True | |||||
| ) | |||||
| else: | |||||
| id = self.db.Column( # noqa: A003, VNE003, A001 | |||||
| self.db.Integer, primary_key=True | |||||
| ) | |||||
| session_id = self.db.Column(self.db.String(255), unique=True) | |||||
| data = self.db.Column(self.db.LargeBinary) | |||||
| expiry = self.db.Column(self.db.DateTime) | |||||
| def __init__(self, session_id, data, expiry): | |||||
| self.session_id = session_id | |||||
| self.data = data | |||||
| self.expiry = expiry | |||||
| def __repr__(self): | |||||
| return f"<Session data {self.data}>" | |||||
| self.sql_session_model = Session | |||||
| def save_session(self, *args, **kwargs): | |||||
| if request.blueprint == 'service_api': | |||||
| return | |||||
| elif request.method == 'OPTIONS': | |||||
| return | |||||
| elif request.endpoint and request.endpoint == 'health': | |||||
| return | |||||
| return super().save_session(*args, **kwargs) | |||||
| class CustomRedisSessionInterface(RedisSessionInterface): | |||||
| def save_session(self, app, session, response): | |||||
| if request.blueprint == 'service_api': | |||||
| return | |||||
| elif request.method == 'OPTIONS': | |||||
| return | |||||
| elif request.endpoint and request.endpoint == 'health': | |||||
| return | |||||
| if not self.should_set_cookie(app, session): | |||||
| return | |||||
| domain = self.get_cookie_domain(app) | |||||
| path = self.get_cookie_path(app) | |||||
| if not session: | |||||
| if session.modified: | |||||
| self.redis.delete(self.key_prefix + session.sid) | |||||
| response.delete_cookie( | |||||
| app.config["SESSION_COOKIE_NAME"], domain=domain, path=path | |||||
| ) | |||||
| return | |||||
| # Modification case. There are upsides and downsides to | |||||
| # emitting a set-cookie header each request. The behavior | |||||
| # is controlled by the :meth:`should_set_cookie` method | |||||
| # which performs a quick check to figure out if the cookie | |||||
| # should be set or not. This is controlled by the | |||||
| # SESSION_REFRESH_EACH_REQUEST config flag as well as | |||||
| # the permanent flag on the session itself. | |||||
| # if not self.should_set_cookie(app, session): | |||||
| # return | |||||
| conditional_cookie_kwargs = {} | |||||
| httponly = self.get_cookie_httponly(app) | |||||
| secure = self.get_cookie_secure(app) | |||||
| if self.has_same_site_capability: | |||||
| conditional_cookie_kwargs["samesite"] = self.get_cookie_samesite(app) | |||||
| expires = self.get_expiration_time(app, session) | |||||
| if session.permanent: | |||||
| value = self.serializer.dumps(dict(session)) | |||||
| if value is not None: | |||||
| self.redis.setex( | |||||
| name=self.key_prefix + session.sid, | |||||
| value=value, | |||||
| time=total_seconds(app.permanent_session_lifetime), | |||||
| ) | |||||
| if self.use_signer: | |||||
| session_id = self._get_signer(app).sign(want_bytes(session.sid)).decode("utf-8") | |||||
| else: | |||||
| session_id = session.sid | |||||
| response.set_cookie( | |||||
| app.config["SESSION_COOKIE_NAME"], | |||||
| session_id, | |||||
| expires=expires, | |||||
| httponly=httponly, | |||||
| domain=domain, | |||||
| path=path, | |||||
| secure=secure, | |||||
| **conditional_cookie_kwargs, | |||||
| ) |
| import logging | import logging | ||||
| import secrets | import secrets | ||||
| import uuid | import uuid | ||||
| from datetime import datetime | |||||
| from datetime import datetime, timedelta | |||||
| from hashlib import sha256 | from hashlib import sha256 | ||||
| from typing import Optional | from typing import Optional | ||||
| from flask import session | |||||
| from werkzeug.exceptions import Forbidden, Unauthorized | |||||
| from flask import session, current_app | |||||
| from sqlalchemy import func | from sqlalchemy import func | ||||
| from events.tenant_event import tenant_was_created | from events.tenant_event import tenant_was_created | ||||
| from libs.helper import get_remote_ip | from libs.helper import get_remote_ip | ||||
| from libs.password import compare_password, hash_password | from libs.password import compare_password, hash_password | ||||
| from libs.rsa import generate_key_pair | from libs.rsa import generate_key_pair | ||||
| from libs.passport import PassportService | |||||
| from models.account import * | from models.account import * | ||||
| from tasks.mail_invite_member_task import send_invite_member_mail_task | from tasks.mail_invite_member_task import send_invite_member_mail_task | ||||
| def _create_tenant_for_account(account): | |||||
| tenant = TenantService.create_tenant(f"{account.name}'s Workspace") | |||||
| TenantService.create_tenant_member(tenant, account, role='owner') | |||||
| account.current_tenant = tenant | |||||
| return tenant | |||||
| class AccountService: | class AccountService: | ||||
| @staticmethod | @staticmethod | ||||
| def load_user(account_id: int) -> Account: | |||||
| def load_user(user_id: str) -> Account: | |||||
| # todo: used by flask_login | # todo: used by flask_login | ||||
| pass | |||||
| if '.' in user_id: | |||||
| tenant_id, account_id = user_id.split('.') | |||||
| else: | |||||
| account_id = user_id | |||||
| account = db.session.query(Account).filter(Account.id == account_id).first() | |||||
| if account: | |||||
| if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value: | |||||
| raise Forbidden('Account is banned or closed.') | |||||
| workspace_id = session.get('workspace_id') | |||||
| if workspace_id: | |||||
| tenant_account_join = db.session.query(TenantAccountJoin).filter( | |||||
| TenantAccountJoin.account_id == account.id, | |||||
| TenantAccountJoin.tenant_id == workspace_id | |||||
| ).first() | |||||
| if not tenant_account_join: | |||||
| tenant_account_join = db.session.query(TenantAccountJoin).filter( | |||||
| TenantAccountJoin.account_id == account.id).first() | |||||
| if tenant_account_join: | |||||
| account.current_tenant_id = tenant_account_join.tenant_id | |||||
| else: | |||||
| _create_tenant_for_account(account) | |||||
| session['workspace_id'] = account.current_tenant_id | |||||
| else: | |||||
| account.current_tenant_id = workspace_id | |||||
| else: | |||||
| tenant_account_join = db.session.query(TenantAccountJoin).filter( | |||||
| TenantAccountJoin.account_id == account.id).first() | |||||
| if tenant_account_join: | |||||
| account.current_tenant_id = tenant_account_join.tenant_id | |||||
| else: | |||||
| _create_tenant_for_account(account) | |||||
| session['workspace_id'] = account.current_tenant_id | |||||
| current_time = datetime.utcnow() | |||||
| # update last_active_at when last_active_at is more than 10 minutes ago | |||||
| if current_time - account.last_active_at > timedelta(minutes=10): | |||||
| account.last_active_at = current_time | |||||
| db.session.commit() | |||||
| return account | |||||
| @staticmethod | |||||
| def get_account_jwt_token(account): | |||||
| payload = { | |||||
| "user_id": account.id, | |||||
| "exp": datetime.utcnow() + timedelta(days=30), | |||||
| "iss": current_app.config['EDITION'], | |||||
| "sub": 'Console API Passport', | |||||
| } | |||||
| token = PassportService().issue(payload) | |||||
| return token | |||||
| @staticmethod | @staticmethod | ||||
| def authenticate(email: str, password: str) -> Account: | def authenticate(email: str, password: str) -> Account: |
| REDIS_USE_SSL: 'false' | REDIS_USE_SSL: 'false' | ||||
| # use redis db 0 for redis cache | # use redis db 0 for redis cache | ||||
| REDIS_DB: 0 | REDIS_DB: 0 | ||||
| # The configurations of session, Supported values are `sqlalchemy`. `redis` | |||||
| SESSION_TYPE: redis | |||||
| SESSION_REDIS_HOST: redis | |||||
| SESSION_REDIS_PORT: 6379 | |||||
| SESSION_REDIS_USERNAME: '' | |||||
| SESSION_REDIS_PASSWORD: difyai123456 | |||||
| SESSION_REDIS_USE_SSL: 'false' | |||||
| # use redis db 2 for session store | |||||
| SESSION_REDIS_DB: 2 | |||||
| # The configurations of celery broker. | # The configurations of celery broker. | ||||
| # Use redis as the broker, and redis db 1 for celery broker. | # Use redis as the broker, and redis db 1 for celery broker. | ||||
| CELERY_BROKER_URL: redis://:difyai123456@redis:6379/1 | CELERY_BROKER_URL: redis://:difyai123456@redis:6379/1 | ||||
| # If you want to enable cross-origin support, | # If you want to enable cross-origin support, | ||||
| # you must use the HTTPS protocol and set the configuration to `SameSite=None, Secure=true, HttpOnly=true`. | # you must use the HTTPS protocol and set the configuration to `SameSite=None, Secure=true, HttpOnly=true`. | ||||
| # | # | ||||
| # For **production** purposes, please set `SameSite=Lax, Secure=true, HttpOnly=true`. | |||||
| COOKIE_HTTPONLY: 'true' | |||||
| COOKIE_SAMESITE: 'Lax' | |||||
| COOKIE_SECURE: 'false' | |||||
| # The type of storage to use for storing user files. Supported values are `local` and `s3`, Default: `local` | # The type of storage to use for storing user files. Supported values are `local` and `s3`, Default: `local` | ||||
| STORAGE_TYPE: local | STORAGE_TYPE: local | ||||
| # The path to the local storage directory, the directory relative the root path of API service codes or absolute path. Default: `storage` or `/home/john/storage`. | # The path to the local storage directory, the directory relative the root path of API service codes or absolute path. Default: `storage` or `/home/john/storage`. |
| 'use client' | 'use client' | ||||
| import { SWRConfig } from 'swr' | import { SWRConfig } from 'swr' | ||||
| import { useEffect, useState } from 'react' | |||||
| import type { ReactNode } from 'react' | import type { ReactNode } from 'react' | ||||
| import { useRouter, useSearchParams } from 'next/navigation' | |||||
| type SwrInitorProps = { | type SwrInitorProps = { | ||||
| children: ReactNode | children: ReactNode | ||||
| const SwrInitor = ({ | const SwrInitor = ({ | ||||
| children, | children, | ||||
| }: SwrInitorProps) => { | }: SwrInitorProps) => { | ||||
| return ( | |||||
| <SWRConfig value={{ | |||||
| shouldRetryOnError: false, | |||||
| }}> | |||||
| {children} | |||||
| </SWRConfig> | |||||
| ) | |||||
| const router = useRouter() | |||||
| const searchParams = useSearchParams() | |||||
| const consoleToken = searchParams.get('console_token') | |||||
| const consoleTokenFromLocalStorage = localStorage?.getItem('console_token') | |||||
| const [init, setInit] = useState(false) | |||||
| useEffect(() => { | |||||
| if (!(consoleToken || consoleTokenFromLocalStorage)) | |||||
| router.replace('/signin') | |||||
| if (consoleToken) { | |||||
| localStorage?.setItem('console_token', consoleToken!) | |||||
| router.replace('/apps', { forceOptimisticNavigation: false }) | |||||
| } | |||||
| setInit(true) | |||||
| }, []) | |||||
| return init | |||||
| ? ( | |||||
| <SWRConfig value={{ | |||||
| shouldRetryOnError: false, | |||||
| }}> | |||||
| {children} | |||||
| </SWRConfig> | |||||
| ) | |||||
| : null | |||||
| } | } | ||||
| export default SwrInitor | export default SwrInitor |
| const Header = () => { | const Header = () => { | ||||
| const { locale, setLocaleOnClient } = useContext(I18n) | const { locale, setLocaleOnClient } = useContext(I18n) | ||||
| if (localStorage?.getItem('console_token')) | |||||
| localStorage.removeItem('console_token') | |||||
| return <div className='flex items-center justify-between p-6 w-full'> | return <div className='flex items-center justify-between p-6 w-full'> | ||||
| <div className={style.logo}></div> | <div className={style.logo}></div> | ||||
| <Select | <Select |
| } | } | ||||
| try { | try { | ||||
| setIsLoading(true) | setIsLoading(true) | ||||
| await login({ | |||||
| const res = await login({ | |||||
| url: '/login', | url: '/login', | ||||
| body: { | body: { | ||||
| email, | email, | ||||
| remember_me: true, | remember_me: true, | ||||
| }, | }, | ||||
| }) | }) | ||||
| router.push('/apps') | |||||
| localStorage.setItem('console_token', res.data) | |||||
| router.replace('/apps') | |||||
| } | } | ||||
| finally { | finally { | ||||
| setIsLoading(false) | setIsLoading(false) |
| } | } | ||||
| options.headers.set('Authorization', `Bearer ${accessTokenJson[sharedToken]}`) | options.headers.set('Authorization', `Bearer ${accessTokenJson[sharedToken]}`) | ||||
| } | } | ||||
| else { | |||||
| const accessToken = localStorage.getItem('console_token') || '' | |||||
| options.headers.set('Authorization', `Bearer ${accessToken}`) | |||||
| } | |||||
| if (deleteContentType) { | if (deleteContentType) { | ||||
| options.headers.delete('Content-Type') | options.headers.delete('Content-Type') | ||||
| const defaultOptions = { | const defaultOptions = { | ||||
| method: 'POST', | method: 'POST', | ||||
| url: `${API_PREFIX}/files/upload`, | url: `${API_PREFIX}/files/upload`, | ||||
| headers: {}, | |||||
| headers: { | |||||
| Authorization: `Bearer ${localStorage.getItem('console_token') || ''}`, | |||||
| }, | |||||
| data: {}, | data: {}, | ||||
| } | } | ||||
| options = { | options = { |
| } from '@/models/app' | } from '@/models/app' | ||||
| import type { BackendModel, ProviderMap } from '@/app/components/header/account-setting/model-page/declarations' | import type { BackendModel, ProviderMap } from '@/app/components/header/account-setting/model-page/declarations' | ||||
| export const login: Fetcher<CommonResponse, { url: string; body: Record<string, any> }> = ({ url, body }) => { | |||||
| return post<CommonResponse>(url, { body }) | |||||
| export const login: Fetcher<CommonResponse & { data: string }, { url: string; body: Record<string, any> }> = ({ url, body }) => { | |||||
| return post(url, { body }) as Promise<CommonResponse & { data: string }> | |||||
| } | } | ||||
| export const setup: Fetcher<CommonResponse, { body: Record<string, any> }> = ({ body }) => { | export const setup: Fetcher<CommonResponse, { body: Record<string, any> }> = ({ body }) => { |