| @@ -1,5 +1,7 @@ | |||
| import os | |||
| from configs import dify_config | |||
| if os.environ.get("DEBUG", "false").lower() != "true": | |||
| from gevent import monkey | |||
| @@ -36,17 +38,11 @@ if hasattr(time, "tzset"): | |||
| time.tzset() | |||
| # ------------- | |||
| # Configuration | |||
| # ------------- | |||
| config_type = os.getenv("EDITION", default="SELF_HOSTED") # ce edition first | |||
| # create app | |||
| app = create_app() | |||
| celery = app.extensions["celery"] | |||
| if app.config.get("TESTING"): | |||
| if dify_config.TESTING: | |||
| print("App is running in TESTING mode") | |||
| @@ -54,15 +50,15 @@ if app.config.get("TESTING"): | |||
| def after_request(response): | |||
| """Add Version headers to the response.""" | |||
| response.set_cookie("remember_token", "", expires=0) | |||
| response.headers.add("X-Version", app.config["CURRENT_VERSION"]) | |||
| response.headers.add("X-Env", app.config["DEPLOY_ENV"]) | |||
| response.headers.add("X-Version", dify_config.CURRENT_VERSION) | |||
| response.headers.add("X-Env", dify_config.DEPLOY_ENV) | |||
| return response | |||
| @app.route("/health") | |||
| def health(): | |||
| return Response( | |||
| json.dumps({"pid": os.getpid(), "status": "ok", "version": app.config["CURRENT_VERSION"]}), | |||
| json.dumps({"pid": os.getpid(), "status": "ok", "version": dify_config.CURRENT_VERSION}), | |||
| status=200, | |||
| content_type="application/json", | |||
| ) | |||
| @@ -68,7 +68,7 @@ def create_flask_app_with_configs() -> Flask: | |||
| def create_app() -> Flask: | |||
| app = create_flask_app_with_configs() | |||
| app.secret_key = app.config["SECRET_KEY"] | |||
| app.secret_key = dify_config.SECRET_KEY | |||
| initialize_extensions(app) | |||
| register_blueprints(app) | |||
| register_commands(app) | |||
| @@ -150,7 +150,7 @@ def register_blueprints(app): | |||
| CORS( | |||
| web_bp, | |||
| resources={r"/*": {"origins": app.config["WEB_API_CORS_ALLOW_ORIGINS"]}}, | |||
| resources={r"/*": {"origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS}}, | |||
| supports_credentials=True, | |||
| allow_headers=["Content-Type", "Authorization", "X-App-Code"], | |||
| methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], | |||
| @@ -161,7 +161,7 @@ def register_blueprints(app): | |||
| CORS( | |||
| console_app_bp, | |||
| resources={r"/*": {"origins": app.config["CONSOLE_CORS_ALLOW_ORIGINS"]}}, | |||
| resources={r"/*": {"origins": dify_config.CONSOLE_CORS_ALLOW_ORIGINS}}, | |||
| supports_credentials=True, | |||
| allow_headers=["Content-Type", "Authorization"], | |||
| methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], | |||
| @@ -32,6 +32,21 @@ class SecurityConfig(BaseSettings): | |||
| default=5, | |||
| ) | |||
| LOGIN_DISABLED: bool = Field( | |||
| description="Whether to disable login checks", | |||
| default=False, | |||
| ) | |||
| ADMIN_API_KEY_ENABLE: bool = Field( | |||
| description="Whether to enable admin api key for authentication", | |||
| default=False, | |||
| ) | |||
| ADMIN_API_KEY: Optional[str] = Field( | |||
| description="admin api key for authentication", | |||
| default=None, | |||
| ) | |||
| class AppExecutionConfig(BaseSettings): | |||
| """ | |||
| @@ -1,10 +1,10 @@ | |||
| import os | |||
| from functools import wraps | |||
| from flask import request | |||
| from flask_restful import Resource, reqparse | |||
| from werkzeug.exceptions import NotFound, Unauthorized | |||
| from configs import dify_config | |||
| from constants.languages import supported_language | |||
| from controllers.console import api | |||
| from controllers.console.wraps import only_edition_cloud | |||
| @@ -15,7 +15,7 @@ from models.model import App, InstalledApp, RecommendedApp | |||
| def admin_required(view): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| if not os.getenv("ADMIN_API_KEY"): | |||
| if not dify_config.ADMIN_API_KEY: | |||
| raise Unauthorized("API key is invalid.") | |||
| auth_header = request.headers.get("Authorization") | |||
| @@ -31,7 +31,7 @@ def admin_required(view): | |||
| if auth_scheme != "bearer": | |||
| raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.") | |||
| if os.getenv("ADMIN_API_KEY") != auth_token: | |||
| if dify_config.ADMIN_API_KEY != auth_token: | |||
| raise Unauthorized("API key is invalid.") | |||
| return view(*args, **kwargs) | |||
| @@ -1,8 +1,9 @@ | |||
| from typing import Optional | |||
| from flask import Config, Flask | |||
| from flask import Flask | |||
| from pydantic import BaseModel | |||
| from configs import dify_config | |||
| from core.entities.provider_entities import QuotaUnit, RestrictModel | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from models.provider import ProviderQuotaType | |||
| @@ -44,32 +45,30 @@ class HostingConfiguration: | |||
| moderation_config: HostedModerationConfig = None | |||
| def init_app(self, app: Flask) -> None: | |||
| config = app.config | |||
| if config.get("EDITION") != "CLOUD": | |||
| if dify_config.EDITION != "CLOUD": | |||
| return | |||
| self.provider_map["azure_openai"] = self.init_azure_openai(config) | |||
| self.provider_map["openai"] = self.init_openai(config) | |||
| self.provider_map["anthropic"] = self.init_anthropic(config) | |||
| self.provider_map["minimax"] = self.init_minimax(config) | |||
| self.provider_map["spark"] = self.init_spark(config) | |||
| self.provider_map["zhipuai"] = self.init_zhipuai(config) | |||
| self.provider_map["azure_openai"] = self.init_azure_openai() | |||
| self.provider_map["openai"] = self.init_openai() | |||
| self.provider_map["anthropic"] = self.init_anthropic() | |||
| self.provider_map["minimax"] = self.init_minimax() | |||
| self.provider_map["spark"] = self.init_spark() | |||
| self.provider_map["zhipuai"] = self.init_zhipuai() | |||
| self.moderation_config = self.init_moderation_config(config) | |||
| self.moderation_config = self.init_moderation_config() | |||
| @staticmethod | |||
| def init_azure_openai(app_config: Config) -> HostingProvider: | |||
| def init_azure_openai() -> HostingProvider: | |||
| quota_unit = QuotaUnit.TIMES | |||
| if app_config.get("HOSTED_AZURE_OPENAI_ENABLED"): | |||
| if dify_config.HOSTED_AZURE_OPENAI_ENABLED: | |||
| credentials = { | |||
| "openai_api_key": app_config.get("HOSTED_AZURE_OPENAI_API_KEY"), | |||
| "openai_api_base": app_config.get("HOSTED_AZURE_OPENAI_API_BASE"), | |||
| "openai_api_key": dify_config.HOSTED_AZURE_OPENAI_API_KEY, | |||
| "openai_api_base": dify_config.HOSTED_AZURE_OPENAI_API_BASE, | |||
| "base_model_name": "gpt-35-turbo", | |||
| } | |||
| quotas = [] | |||
| hosted_quota_limit = int(app_config.get("HOSTED_AZURE_OPENAI_QUOTA_LIMIT", "1000")) | |||
| hosted_quota_limit = dify_config.HOSTED_AZURE_OPENAI_QUOTA_LIMIT | |||
| trial_quota = TrialHostingQuota( | |||
| quota_limit=hosted_quota_limit, | |||
| restrict_models=[ | |||
| @@ -122,31 +121,31 @@ class HostingConfiguration: | |||
| quota_unit=quota_unit, | |||
| ) | |||
| def init_openai(self, app_config: Config) -> HostingProvider: | |||
| def init_openai(self) -> HostingProvider: | |||
| quota_unit = QuotaUnit.CREDITS | |||
| quotas = [] | |||
| if app_config.get("HOSTED_OPENAI_TRIAL_ENABLED"): | |||
| hosted_quota_limit = int(app_config.get("HOSTED_OPENAI_QUOTA_LIMIT", "200")) | |||
| trial_models = self.parse_restrict_models_from_env(app_config, "HOSTED_OPENAI_TRIAL_MODELS") | |||
| if dify_config.HOSTED_OPENAI_TRIAL_ENABLED: | |||
| hosted_quota_limit = dify_config.HOSTED_OPENAI_QUOTA_LIMIT | |||
| trial_models = self.parse_restrict_models_from_env("HOSTED_OPENAI_TRIAL_MODELS") | |||
| trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trial_models) | |||
| quotas.append(trial_quota) | |||
| if app_config.get("HOSTED_OPENAI_PAID_ENABLED"): | |||
| paid_models = self.parse_restrict_models_from_env(app_config, "HOSTED_OPENAI_PAID_MODELS") | |||
| if dify_config.HOSTED_OPENAI_PAID_ENABLED: | |||
| paid_models = self.parse_restrict_models_from_env("HOSTED_OPENAI_PAID_MODELS") | |||
| paid_quota = PaidHostingQuota(restrict_models=paid_models) | |||
| quotas.append(paid_quota) | |||
| if len(quotas) > 0: | |||
| credentials = { | |||
| "openai_api_key": app_config.get("HOSTED_OPENAI_API_KEY"), | |||
| "openai_api_key": dify_config.HOSTED_OPENAI_API_KEY, | |||
| } | |||
| if app_config.get("HOSTED_OPENAI_API_BASE"): | |||
| credentials["openai_api_base"] = app_config.get("HOSTED_OPENAI_API_BASE") | |||
| if dify_config.HOSTED_OPENAI_API_BASE: | |||
| credentials["openai_api_base"] = dify_config.HOSTED_OPENAI_API_BASE | |||
| if app_config.get("HOSTED_OPENAI_API_ORGANIZATION"): | |||
| credentials["openai_organization"] = app_config.get("HOSTED_OPENAI_API_ORGANIZATION") | |||
| if dify_config.HOSTED_OPENAI_API_ORGANIZATION: | |||
| credentials["openai_organization"] = dify_config.HOSTED_OPENAI_API_ORGANIZATION | |||
| return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas) | |||
| @@ -156,26 +155,26 @@ class HostingConfiguration: | |||
| ) | |||
| @staticmethod | |||
| def init_anthropic(app_config: Config) -> HostingProvider: | |||
| def init_anthropic() -> HostingProvider: | |||
| quota_unit = QuotaUnit.TOKENS | |||
| quotas = [] | |||
| if app_config.get("HOSTED_ANTHROPIC_TRIAL_ENABLED"): | |||
| hosted_quota_limit = int(app_config.get("HOSTED_ANTHROPIC_QUOTA_LIMIT", "0")) | |||
| if dify_config.HOSTED_ANTHROPIC_TRIAL_ENABLED: | |||
| hosted_quota_limit = dify_config.HOSTED_ANTHROPIC_QUOTA_LIMIT | |||
| trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit) | |||
| quotas.append(trial_quota) | |||
| if app_config.get("HOSTED_ANTHROPIC_PAID_ENABLED"): | |||
| if dify_config.HOSTED_ANTHROPIC_PAID_ENABLED: | |||
| paid_quota = PaidHostingQuota() | |||
| quotas.append(paid_quota) | |||
| if len(quotas) > 0: | |||
| credentials = { | |||
| "anthropic_api_key": app_config.get("HOSTED_ANTHROPIC_API_KEY"), | |||
| "anthropic_api_key": dify_config.HOSTED_ANTHROPIC_API_KEY, | |||
| } | |||
| if app_config.get("HOSTED_ANTHROPIC_API_BASE"): | |||
| credentials["anthropic_api_url"] = app_config.get("HOSTED_ANTHROPIC_API_BASE") | |||
| if dify_config.HOSTED_ANTHROPIC_API_BASE: | |||
| credentials["anthropic_api_url"] = dify_config.HOSTED_ANTHROPIC_API_BASE | |||
| return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas) | |||
| @@ -185,9 +184,9 @@ class HostingConfiguration: | |||
| ) | |||
| @staticmethod | |||
| def init_minimax(app_config: Config) -> HostingProvider: | |||
| def init_minimax() -> HostingProvider: | |||
| quota_unit = QuotaUnit.TOKENS | |||
| if app_config.get("HOSTED_MINIMAX_ENABLED"): | |||
| if dify_config.HOSTED_MINIMAX_ENABLED: | |||
| quotas = [FreeHostingQuota()] | |||
| return HostingProvider( | |||
| @@ -203,9 +202,9 @@ class HostingConfiguration: | |||
| ) | |||
| @staticmethod | |||
| def init_spark(app_config: Config) -> HostingProvider: | |||
| def init_spark() -> HostingProvider: | |||
| quota_unit = QuotaUnit.TOKENS | |||
| if app_config.get("HOSTED_SPARK_ENABLED"): | |||
| if dify_config.HOSTED_SPARK_ENABLED: | |||
| quotas = [FreeHostingQuota()] | |||
| return HostingProvider( | |||
| @@ -221,9 +220,9 @@ class HostingConfiguration: | |||
| ) | |||
| @staticmethod | |||
| def init_zhipuai(app_config: Config) -> HostingProvider: | |||
| def init_zhipuai() -> HostingProvider: | |||
| quota_unit = QuotaUnit.TOKENS | |||
| if app_config.get("HOSTED_ZHIPUAI_ENABLED"): | |||
| if dify_config.HOSTED_ZHIPUAI_ENABLED: | |||
| quotas = [FreeHostingQuota()] | |||
| return HostingProvider( | |||
| @@ -239,17 +238,15 @@ class HostingConfiguration: | |||
| ) | |||
| @staticmethod | |||
| def init_moderation_config(app_config: Config) -> HostedModerationConfig: | |||
| if app_config.get("HOSTED_MODERATION_ENABLED") and app_config.get("HOSTED_MODERATION_PROVIDERS"): | |||
| return HostedModerationConfig( | |||
| enabled=True, providers=app_config.get("HOSTED_MODERATION_PROVIDERS").split(",") | |||
| ) | |||
| def init_moderation_config() -> HostedModerationConfig: | |||
| if dify_config.HOSTED_MODERATION_ENABLED and dify_config.HOSTED_MODERATION_PROVIDERS: | |||
| return HostedModerationConfig(enabled=True, providers=dify_config.HOSTED_MODERATION_PROVIDERS.split(",")) | |||
| return HostedModerationConfig(enabled=False) | |||
| @staticmethod | |||
| def parse_restrict_models_from_env(app_config: Config, env_var: str) -> list[RestrictModel]: | |||
| models_str = app_config.get(env_var) | |||
| def parse_restrict_models_from_env(env_var: str) -> list[RestrictModel]: | |||
| models_str = dify_config.model_dump().get(env_var) | |||
| models_list = models_str.split(",") if models_str else [] | |||
| return [ | |||
| RestrictModel(model=model_name.strip(), model_type=ModelType.LLM) | |||
| @@ -428,14 +428,13 @@ class QdrantVectorFactory(AbstractVectorFactory): | |||
| if not dataset.index_struct_dict: | |||
| dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.QDRANT, collection_name)) | |||
| config = current_app.config | |||
| return QdrantVector( | |||
| collection_name=collection_name, | |||
| group_id=dataset.id, | |||
| config=QdrantConfig( | |||
| endpoint=dify_config.QDRANT_URL, | |||
| api_key=dify_config.QDRANT_API_KEY, | |||
| root_path=config.root_path, | |||
| root_path=current_app.config.root_path, | |||
| timeout=dify_config.QDRANT_CLIENT_TIMEOUT, | |||
| grpc_port=dify_config.QDRANT_GRPC_PORT, | |||
| prefer_grpc=dify_config.QDRANT_GRPC_ENABLED, | |||
| @@ -3,6 +3,8 @@ from datetime import timedelta | |||
| from celery import Celery, Task | |||
| from flask import Flask | |||
| from configs import dify_config | |||
| def init_app(app: Flask) -> Celery: | |||
| class FlaskTask(Task): | |||
| @@ -12,19 +14,19 @@ def init_app(app: Flask) -> Celery: | |||
| broker_transport_options = {} | |||
| if app.config.get("CELERY_USE_SENTINEL"): | |||
| if dify_config.CELERY_USE_SENTINEL: | |||
| broker_transport_options = { | |||
| "master_name": app.config.get("CELERY_SENTINEL_MASTER_NAME"), | |||
| "master_name": dify_config.CELERY_SENTINEL_MASTER_NAME, | |||
| "sentinel_kwargs": { | |||
| "socket_timeout": app.config.get("CELERY_SENTINEL_SOCKET_TIMEOUT", 0.1), | |||
| "socket_timeout": dify_config.CELERY_SENTINEL_SOCKET_TIMEOUT, | |||
| }, | |||
| } | |||
| celery_app = Celery( | |||
| app.name, | |||
| task_cls=FlaskTask, | |||
| broker=app.config.get("CELERY_BROKER_URL"), | |||
| backend=app.config.get("CELERY_BACKEND"), | |||
| broker=dify_config.CELERY_BROKER_URL, | |||
| backend=dify_config.CELERY_BACKEND, | |||
| task_ignore_result=True, | |||
| ) | |||
| @@ -37,12 +39,12 @@ def init_app(app: Flask) -> Celery: | |||
| } | |||
| celery_app.conf.update( | |||
| result_backend=app.config.get("CELERY_RESULT_BACKEND"), | |||
| result_backend=dify_config.CELERY_RESULT_BACKEND, | |||
| broker_transport_options=broker_transport_options, | |||
| broker_connection_retry_on_startup=True, | |||
| ) | |||
| if app.config.get("BROKER_USE_SSL"): | |||
| if dify_config.BROKER_USE_SSL: | |||
| celery_app.conf.update( | |||
| broker_use_ssl=ssl_options, # Add the SSL options to the broker configuration | |||
| ) | |||
| @@ -54,7 +56,7 @@ def init_app(app: Flask) -> Celery: | |||
| "schedule.clean_embedding_cache_task", | |||
| "schedule.clean_unused_datasets_task", | |||
| ] | |||
| day = app.config.get("CELERY_BEAT_SCHEDULER_TIME") | |||
| day = dify_config.CELERY_BEAT_SCHEDULER_TIME | |||
| beat_schedule = { | |||
| "clean_embedding_cache_task": { | |||
| "task": "schedule.clean_embedding_cache_task.clean_embedding_cache_task", | |||
| @@ -1,8 +1,10 @@ | |||
| from flask import Flask | |||
| from configs import dify_config | |||
| def init_app(app: Flask): | |||
| if app.config.get("API_COMPRESSION_ENABLED"): | |||
| if dify_config.API_COMPRESSION_ENABLED: | |||
| from flask_compress import Compress | |||
| app.config["COMPRESS_MIMETYPES"] = [ | |||
| @@ -5,10 +5,12 @@ from logging.handlers import RotatingFileHandler | |||
| from flask import Flask | |||
| from configs import dify_config | |||
| def init_app(app: Flask): | |||
| log_handlers = None | |||
| log_file = app.config.get("LOG_FILE") | |||
| log_file = dify_config.LOG_FILE | |||
| if log_file: | |||
| log_dir = os.path.dirname(log_file) | |||
| os.makedirs(log_dir, exist_ok=True) | |||
| @@ -22,13 +24,13 @@ def init_app(app: Flask): | |||
| ] | |||
| logging.basicConfig( | |||
| level=app.config.get("LOG_LEVEL"), | |||
| format=app.config.get("LOG_FORMAT"), | |||
| datefmt=app.config.get("LOG_DATEFORMAT"), | |||
| level=dify_config.LOG_LEVEL, | |||
| format=dify_config.LOG_FORMAT, | |||
| datefmt=dify_config.LOG_DATEFORMAT, | |||
| handlers=log_handlers, | |||
| force=True, | |||
| ) | |||
| log_tz = app.config.get("LOG_TZ") | |||
| log_tz = dify_config.LOG_TZ | |||
| if log_tz: | |||
| from datetime import datetime | |||
| @@ -4,6 +4,8 @@ from typing import Optional | |||
| import resend | |||
| from flask import Flask | |||
| from configs import dify_config | |||
| class Mail: | |||
| def __init__(self): | |||
| @@ -14,41 +16,44 @@ class Mail: | |||
| return self._client is not None | |||
| def init_app(self, app: Flask): | |||
| if app.config.get("MAIL_TYPE"): | |||
| if app.config.get("MAIL_DEFAULT_SEND_FROM"): | |||
| self._default_send_from = app.config.get("MAIL_DEFAULT_SEND_FROM") | |||
| mail_type = dify_config.MAIL_TYPE | |||
| if not mail_type: | |||
| logging.warning("MAIL_TYPE is not set") | |||
| return | |||
| if app.config.get("MAIL_TYPE") == "resend": | |||
| api_key = app.config.get("RESEND_API_KEY") | |||
| if dify_config.MAIL_DEFAULT_SEND_FROM: | |||
| self._default_send_from = dify_config.MAIL_DEFAULT_SEND_FROM | |||
| match mail_type: | |||
| case "resend": | |||
| api_key = dify_config.RESEND_API_KEY | |||
| if not api_key: | |||
| raise ValueError("RESEND_API_KEY is not set") | |||
| api_url = app.config.get("RESEND_API_URL") | |||
| api_url = dify_config.RESEND_API_URL | |||
| if api_url: | |||
| resend.api_url = api_url | |||
| resend.api_key = api_key | |||
| self._client = resend.Emails | |||
| elif app.config.get("MAIL_TYPE") == "smtp": | |||
| case "smtp": | |||
| from libs.smtp import SMTPClient | |||
| if not app.config.get("SMTP_SERVER") or not app.config.get("SMTP_PORT"): | |||
| if not dify_config.SMTP_SERVER or not dify_config.SMTP_PORT: | |||
| raise ValueError("SMTP_SERVER and SMTP_PORT are required for smtp mail type") | |||
| if not app.config.get("SMTP_USE_TLS") and app.config.get("SMTP_OPPORTUNISTIC_TLS"): | |||
| if not dify_config.SMTP_USE_TLS and dify_config.SMTP_OPPORTUNISTIC_TLS: | |||
| raise ValueError("SMTP_OPPORTUNISTIC_TLS is not supported without enabling SMTP_USE_TLS") | |||
| self._client = SMTPClient( | |||
| server=app.config.get("SMTP_SERVER"), | |||
| port=app.config.get("SMTP_PORT"), | |||
| username=app.config.get("SMTP_USERNAME"), | |||
| password=app.config.get("SMTP_PASSWORD"), | |||
| _from=app.config.get("MAIL_DEFAULT_SEND_FROM"), | |||
| use_tls=app.config.get("SMTP_USE_TLS"), | |||
| opportunistic_tls=app.config.get("SMTP_OPPORTUNISTIC_TLS"), | |||
| server=dify_config.SMTP_SERVER, | |||
| port=dify_config.SMTP_PORT, | |||
| username=dify_config.SMTP_USERNAME, | |||
| password=dify_config.SMTP_PASSWORD, | |||
| _from=dify_config.MAIL_DEFAULT_SEND_FROM, | |||
| use_tls=dify_config.SMTP_USE_TLS, | |||
| opportunistic_tls=dify_config.SMTP_OPPORTUNISTIC_TLS, | |||
| ) | |||
| else: | |||
| raise ValueError("Unsupported mail type {}".format(app.config.get("MAIL_TYPE"))) | |||
| else: | |||
| logging.warning("MAIL_TYPE is not set") | |||
| case _: | |||
| raise ValueError("Unsupported mail type {}".format(mail_type)) | |||
| def send(self, to: str, subject: str, html: str, from_: Optional[str] = None): | |||
| if not self._client: | |||
| @@ -2,6 +2,8 @@ import redis | |||
| from redis.connection import Connection, SSLConnection | |||
| from redis.sentinel import Sentinel | |||
| from configs import dify_config | |||
| class RedisClientWrapper(redis.Redis): | |||
| """ | |||
| @@ -43,37 +45,37 @@ redis_client = RedisClientWrapper() | |||
| def init_app(app): | |||
| global redis_client | |||
| connection_class = Connection | |||
| if app.config.get("REDIS_USE_SSL"): | |||
| if dify_config.REDIS_USE_SSL: | |||
| connection_class = SSLConnection | |||
| redis_params = { | |||
| "username": app.config.get("REDIS_USERNAME"), | |||
| "password": app.config.get("REDIS_PASSWORD"), | |||
| "db": app.config.get("REDIS_DB"), | |||
| "username": dify_config.REDIS_USERNAME, | |||
| "password": dify_config.REDIS_PASSWORD, | |||
| "db": dify_config.REDIS_DB, | |||
| "encoding": "utf-8", | |||
| "encoding_errors": "strict", | |||
| "decode_responses": False, | |||
| } | |||
| if app.config.get("REDIS_USE_SENTINEL"): | |||
| if dify_config.REDIS_USE_SENTINEL: | |||
| sentinel_hosts = [ | |||
| (node.split(":")[0], int(node.split(":")[1])) for node in app.config.get("REDIS_SENTINELS").split(",") | |||
| (node.split(":")[0], int(node.split(":")[1])) for node in dify_config.REDIS_SENTINELS.split(",") | |||
| ] | |||
| sentinel = Sentinel( | |||
| sentinel_hosts, | |||
| sentinel_kwargs={ | |||
| "socket_timeout": app.config.get("REDIS_SENTINEL_SOCKET_TIMEOUT", 0.1), | |||
| "username": app.config.get("REDIS_SENTINEL_USERNAME"), | |||
| "password": app.config.get("REDIS_SENTINEL_PASSWORD"), | |||
| "socket_timeout": dify_config.REDIS_SENTINEL_SOCKET_TIMEOUT, | |||
| "username": dify_config.REDIS_SENTINEL_USERNAME, | |||
| "password": dify_config.REDIS_SENTINEL_PASSWORD, | |||
| }, | |||
| ) | |||
| master = sentinel.master_for(app.config.get("REDIS_SENTINEL_SERVICE_NAME"), **redis_params) | |||
| master = sentinel.master_for(dify_config.REDIS_SENTINEL_SERVICE_NAME, **redis_params) | |||
| redis_client.initialize(master) | |||
| else: | |||
| redis_params.update( | |||
| { | |||
| "host": app.config.get("REDIS_HOST"), | |||
| "port": app.config.get("REDIS_PORT"), | |||
| "host": dify_config.REDIS_HOST, | |||
| "port": dify_config.REDIS_PORT, | |||
| "connection_class": connection_class, | |||
| } | |||
| ) | |||
| @@ -5,6 +5,7 @@ from sentry_sdk.integrations.celery import CeleryIntegration | |||
| from sentry_sdk.integrations.flask import FlaskIntegration | |||
| from werkzeug.exceptions import HTTPException | |||
| from configs import dify_config | |||
| from core.model_runtime.errors.invoke import InvokeRateLimitError | |||
| @@ -18,9 +19,9 @@ def before_send(event, hint): | |||
| def init_app(app): | |||
| if app.config.get("SENTRY_DSN"): | |||
| if dify_config.SENTRY_DSN: | |||
| sentry_sdk.init( | |||
| dsn=app.config.get("SENTRY_DSN"), | |||
| dsn=dify_config.SENTRY_DSN, | |||
| integrations=[FlaskIntegration(), CeleryIntegration()], | |||
| ignore_errors=[ | |||
| HTTPException, | |||
| @@ -29,9 +30,9 @@ def init_app(app): | |||
| InvokeRateLimitError, | |||
| parse_error.defaultErrorResponse, | |||
| ], | |||
| traces_sample_rate=app.config.get("SENTRY_TRACES_SAMPLE_RATE", 1.0), | |||
| profiles_sample_rate=app.config.get("SENTRY_PROFILES_SAMPLE_RATE", 1.0), | |||
| environment=app.config.get("DEPLOY_ENV"), | |||
| release=f"dify-{app.config.get('CURRENT_VERSION')}-{app.config.get('COMMIT_SHA')}", | |||
| traces_sample_rate=dify_config.SENTRY_TRACES_SAMPLE_RATE, | |||
| profiles_sample_rate=dify_config.SENTRY_PROFILES_SAMPLE_RATE, | |||
| environment=dify_config.DEPLOY_ENV, | |||
| release=f"dify-{dify_config.CURRENT_VERSION}-{dify_config.COMMIT_SHA}", | |||
| before_send=before_send, | |||
| ) | |||
| @@ -15,7 +15,7 @@ class Storage: | |||
| def init_app(self, app: Flask): | |||
| storage_factory = self.get_storage_factory(dify_config.STORAGE_TYPE) | |||
| self.storage_runner = storage_factory(app=app) | |||
| self.storage_runner = storage_factory() | |||
| @staticmethod | |||
| def get_storage_factory(storage_type: str) -> type[BaseStorage]: | |||
| @@ -1,29 +1,27 @@ | |||
| from collections.abc import Generator | |||
| import oss2 as aliyun_s3 | |||
| from flask import Flask | |||
| from configs import dify_config | |||
| from extensions.storage.base_storage import BaseStorage | |||
| class AliyunOssStorage(BaseStorage): | |||
| """Implementation for Aliyun OSS storage.""" | |||
| def __init__(self, app: Flask): | |||
| super().__init__(app) | |||
| app_config = self.app.config | |||
| self.bucket_name = app_config.get("ALIYUN_OSS_BUCKET_NAME") | |||
| self.folder = app.config.get("ALIYUN_OSS_PATH") | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.bucket_name = dify_config.ALIYUN_OSS_BUCKET_NAME | |||
| self.folder = dify_config.ALIYUN_OSS_PATH | |||
| oss_auth_method = aliyun_s3.Auth | |||
| region = None | |||
| if app_config.get("ALIYUN_OSS_AUTH_VERSION") == "v4": | |||
| if dify_config.ALIYUN_OSS_AUTH_VERSION == "v4": | |||
| oss_auth_method = aliyun_s3.AuthV4 | |||
| region = app_config.get("ALIYUN_OSS_REGION") | |||
| oss_auth = oss_auth_method(app_config.get("ALIYUN_OSS_ACCESS_KEY"), app_config.get("ALIYUN_OSS_SECRET_KEY")) | |||
| region = dify_config.ALIYUN_OSS_REGION | |||
| oss_auth = oss_auth_method(dify_config.ALIYUN_OSS_ACCESS_KEY, dify_config.ALIYUN_OSS_SECRET_KEY) | |||
| self.client = aliyun_s3.Bucket( | |||
| oss_auth, | |||
| app_config.get("ALIYUN_OSS_ENDPOINT"), | |||
| dify_config.ALIYUN_OSS_ENDPOINT, | |||
| self.bucket_name, | |||
| connect_timeout=30, | |||
| region=region, | |||
| @@ -2,8 +2,8 @@ from collections.abc import Generator | |||
| from datetime import datetime, timedelta, timezone | |||
| from azure.storage.blob import AccountSasPermissions, BlobServiceClient, ResourceTypes, generate_account_sas | |||
| from flask import Flask | |||
| from configs import dify_config | |||
| from extensions.ext_redis import redis_client | |||
| from extensions.storage.base_storage import BaseStorage | |||
| @@ -11,13 +11,12 @@ from extensions.storage.base_storage import BaseStorage | |||
| class AzureBlobStorage(BaseStorage): | |||
| """Implementation for Azure Blob storage.""" | |||
| def __init__(self, app: Flask): | |||
| super().__init__(app) | |||
| app_config = self.app.config | |||
| self.bucket_name = app_config.get("AZURE_BLOB_CONTAINER_NAME") | |||
| self.account_url = app_config.get("AZURE_BLOB_ACCOUNT_URL") | |||
| self.account_name = app_config.get("AZURE_BLOB_ACCOUNT_NAME") | |||
| self.account_key = app_config.get("AZURE_BLOB_ACCOUNT_KEY") | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.bucket_name = dify_config.AZURE_BLOB_CONTAINER_NAME | |||
| self.account_url = dify_config.AZURE_BLOB_ACCOUNT_URL | |||
| self.account_name = dify_config.AZURE_BLOB_ACCOUNT_NAME | |||
| self.account_key = dify_config.AZURE_BLOB_ACCOUNT_KEY | |||
| def save(self, filename, data): | |||
| client = self._sync_client() | |||
| @@ -5,24 +5,23 @@ from collections.abc import Generator | |||
| from baidubce.auth.bce_credentials import BceCredentials | |||
| from baidubce.bce_client_configuration import BceClientConfiguration | |||
| from baidubce.services.bos.bos_client import BosClient | |||
| from flask import Flask | |||
| from configs import dify_config | |||
| from extensions.storage.base_storage import BaseStorage | |||
| class BaiduObsStorage(BaseStorage): | |||
| """Implementation for Baidu OBS storage.""" | |||
| def __init__(self, app: Flask): | |||
| super().__init__(app) | |||
| app_config = self.app.config | |||
| self.bucket_name = app_config.get("BAIDU_OBS_BUCKET_NAME") | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.bucket_name = dify_config.BAIDU_OBS_BUCKET_NAME | |||
| client_config = BceClientConfiguration( | |||
| credentials=BceCredentials( | |||
| access_key_id=app_config.get("BAIDU_OBS_ACCESS_KEY"), | |||
| secret_access_key=app_config.get("BAIDU_OBS_SECRET_KEY"), | |||
| access_key_id=dify_config.BAIDU_OBS_ACCESS_KEY, | |||
| secret_access_key=dify_config.BAIDU_OBS_SECRET_KEY, | |||
| ), | |||
| endpoint=app_config.get("BAIDU_OBS_ENDPOINT"), | |||
| endpoint=dify_config.BAIDU_OBS_ENDPOINT, | |||
| ) | |||
| self.client = BosClient(config=client_config) | |||
| @@ -3,16 +3,12 @@ | |||
| from abc import ABC, abstractmethod | |||
| from collections.abc import Generator | |||
| from flask import Flask | |||
| class BaseStorage(ABC): | |||
| """Interface for file storage.""" | |||
| app = None | |||
| def __init__(self, app: Flask): | |||
| self.app = app | |||
| def __init__(self): # noqa: B027 | |||
| pass | |||
| @abstractmethod | |||
| def save(self, filename, data): | |||
| @@ -3,20 +3,20 @@ import io | |||
| import json | |||
| from collections.abc import Generator | |||
| from flask import Flask | |||
| from google.cloud import storage as google_cloud_storage | |||
| from configs import dify_config | |||
| from extensions.storage.base_storage import BaseStorage | |||
| class GoogleCloudStorage(BaseStorage): | |||
| """Implementation for Google Cloud storage.""" | |||
| def __init__(self, app: Flask): | |||
| super().__init__(app) | |||
| app_config = self.app.config | |||
| self.bucket_name = app_config.get("GOOGLE_STORAGE_BUCKET_NAME") | |||
| service_account_json_str = app_config.get("GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64") | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.bucket_name = dify_config.GOOGLE_STORAGE_BUCKET_NAME | |||
| service_account_json_str = dify_config.GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64 | |||
| # if service_account_json_str is empty, use Application Default Credentials | |||
| if service_account_json_str: | |||
| service_account_json = base64.b64decode(service_account_json_str).decode("utf-8") | |||
| @@ -1,22 +1,22 @@ | |||
| from collections.abc import Generator | |||
| from flask import Flask | |||
| from obs import ObsClient | |||
| from configs import dify_config | |||
| from extensions.storage.base_storage import BaseStorage | |||
| class HuaweiObsStorage(BaseStorage): | |||
| """Implementation for Huawei OBS storage.""" | |||
| def __init__(self, app: Flask): | |||
| super().__init__(app) | |||
| app_config = self.app.config | |||
| self.bucket_name = app_config.get("HUAWEI_OBS_BUCKET_NAME") | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.bucket_name = dify_config.HUAWEI_OBS_BUCKET_NAME | |||
| self.client = ObsClient( | |||
| access_key_id=app_config.get("HUAWEI_OBS_ACCESS_KEY"), | |||
| secret_access_key=app_config.get("HUAWEI_OBS_SECRET_KEY"), | |||
| server=app_config.get("HUAWEI_OBS_SERVER"), | |||
| access_key_id=dify_config.HUAWEI_OBS_ACCESS_KEY, | |||
| secret_access_key=dify_config.HUAWEI_OBS_SECRET_KEY, | |||
| server=dify_config.HUAWEI_OBS_SERVER, | |||
| ) | |||
| def save(self, filename, data): | |||
| @@ -3,19 +3,20 @@ import shutil | |||
| from collections.abc import Generator | |||
| from pathlib import Path | |||
| from flask import Flask | |||
| from flask import current_app | |||
| from configs import dify_config | |||
| from extensions.storage.base_storage import BaseStorage | |||
| class LocalFsStorage(BaseStorage): | |||
| """Implementation for local filesystem storage.""" | |||
| def __init__(self, app: Flask): | |||
| super().__init__(app) | |||
| folder = self.app.config.get("STORAGE_LOCAL_PATH") | |||
| def __init__(self): | |||
| super().__init__() | |||
| folder = dify_config.STORAGE_LOCAL_PATH | |||
| if not os.path.isabs(folder): | |||
| folder = os.path.join(app.root_path, folder) | |||
| folder = os.path.join(current_app.root_path, folder) | |||
| self.folder = folder | |||
| def save(self, filename, data): | |||
| @@ -2,24 +2,24 @@ from collections.abc import Generator | |||
| import boto3 | |||
| from botocore.exceptions import ClientError | |||
| from flask import Flask | |||
| from configs import dify_config | |||
| from extensions.storage.base_storage import BaseStorage | |||
| class OracleOCIStorage(BaseStorage): | |||
| """Implementation for Oracle OCI storage.""" | |||
| def __init__(self, app: Flask): | |||
| super().__init__(app) | |||
| app_config = self.app.config | |||
| self.bucket_name = app_config.get("OCI_BUCKET_NAME") | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.bucket_name = dify_config.OCI_BUCKET_NAME | |||
| self.client = boto3.client( | |||
| "s3", | |||
| aws_secret_access_key=app_config.get("OCI_SECRET_KEY"), | |||
| aws_access_key_id=app_config.get("OCI_ACCESS_KEY"), | |||
| endpoint_url=app_config.get("OCI_ENDPOINT"), | |||
| region_name=app_config.get("OCI_REGION"), | |||
| aws_secret_access_key=dify_config.OCI_SECRET_KEY, | |||
| aws_access_key_id=dify_config.OCI_ACCESS_KEY, | |||
| endpoint_url=dify_config.OCI_ENDPOINT, | |||
| region_name=dify_config.OCI_REGION, | |||
| ) | |||
| def save(self, filename, data): | |||
| @@ -1,23 +1,23 @@ | |||
| from collections.abc import Generator | |||
| from flask import Flask | |||
| from qcloud_cos import CosConfig, CosS3Client | |||
| from configs import dify_config | |||
| from extensions.storage.base_storage import BaseStorage | |||
| class TencentCosStorage(BaseStorage): | |||
| """Implementation for Tencent Cloud COS storage.""" | |||
| def __init__(self, app: Flask): | |||
| super().__init__(app) | |||
| app_config = self.app.config | |||
| self.bucket_name = app_config.get("TENCENT_COS_BUCKET_NAME") | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.bucket_name = dify_config.TENCENT_COS_BUCKET_NAME | |||
| config = CosConfig( | |||
| Region=app_config.get("TENCENT_COS_REGION"), | |||
| SecretId=app_config.get("TENCENT_COS_SECRET_ID"), | |||
| SecretKey=app_config.get("TENCENT_COS_SECRET_KEY"), | |||
| Scheme=app_config.get("TENCENT_COS_SCHEME"), | |||
| Region=dify_config.TENCENT_COS_REGION, | |||
| SecretId=dify_config.TENCENT_COS_SECRET_ID, | |||
| SecretKey=dify_config.TENCENT_COS_SECRET_KEY, | |||
| Scheme=dify_config.TENCENT_COS_SCHEME, | |||
| ) | |||
| self.client = CosS3Client(config) | |||
| @@ -1,23 +1,22 @@ | |||
| from collections.abc import Generator | |||
| import tos | |||
| from flask import Flask | |||
| from configs import dify_config | |||
| from extensions.storage.base_storage import BaseStorage | |||
| class VolcengineTosStorage(BaseStorage): | |||
| """Implementation for Volcengine TOS storage.""" | |||
| def __init__(self, app: Flask): | |||
| super().__init__(app) | |||
| app_config = self.app.config | |||
| self.bucket_name = app_config.get("VOLCENGINE_TOS_BUCKET_NAME") | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.bucket_name = dify_config.VOLCENGINE_TOS_BUCKET_NAME | |||
| self.client = tos.TosClientV2( | |||
| ak=app_config.get("VOLCENGINE_TOS_ACCESS_KEY"), | |||
| sk=app_config.get("VOLCENGINE_TOS_SECRET_KEY"), | |||
| endpoint=app_config.get("VOLCENGINE_TOS_ENDPOINT"), | |||
| region=app_config.get("VOLCENGINE_TOS_REGION"), | |||
| ak=dify_config.VOLCENGINE_TOS_ACCESS_KEY, | |||
| sk=dify_config.VOLCENGINE_TOS_SECRET_KEY, | |||
| endpoint=dify_config.VOLCENGINE_TOS_ENDPOINT, | |||
| region=dify_config.VOLCENGINE_TOS_REGION, | |||
| ) | |||
| def save(self, filename, data): | |||
| @@ -12,9 +12,10 @@ from hashlib import sha256 | |||
| from typing import Any, Optional, Union | |||
| from zoneinfo import available_timezones | |||
| from flask import Response, current_app, stream_with_context | |||
| from flask import Response, stream_with_context | |||
| from flask_restful import fields | |||
| from configs import dify_config | |||
| from core.app.features.rate_limiting.rate_limit import RateLimitGenerator | |||
| from core.file import helpers as file_helpers | |||
| from extensions.ext_redis import redis_client | |||
| @@ -214,7 +215,7 @@ class TokenManager: | |||
| if additional_data: | |||
| token_data.update(additional_data) | |||
| expiry_minutes = current_app.config[f"{token_type.upper()}_TOKEN_EXPIRY_MINUTES"] | |||
| expiry_minutes = dify_config.model_dump().get(f"{token_type.upper()}_TOKEN_EXPIRY_MINUTES") | |||
| token_key = cls._get_token_key(token, token_type) | |||
| expiry_time = int(expiry_minutes * 60) | |||
| redis_client.setex(token_key, expiry_time, json.dumps(token_data)) | |||
| @@ -1,4 +1,3 @@ | |||
| import os | |||
| from functools import wraps | |||
| from flask import current_app, g, has_request_context, request | |||
| @@ -7,6 +6,7 @@ from flask_login.config import EXEMPT_METHODS | |||
| from werkzeug.exceptions import Unauthorized | |||
| from werkzeug.local import LocalProxy | |||
| from configs import dify_config | |||
| from extensions.ext_database import db | |||
| from models.account import Account, Tenant, TenantAccountJoin | |||
| @@ -52,8 +52,7 @@ def login_required(func): | |||
| @wraps(func) | |||
| def decorated_view(*args, **kwargs): | |||
| auth_header = request.headers.get("Authorization") | |||
| admin_api_key_enable = os.getenv("ADMIN_API_KEY_ENABLE", default="False") | |||
| if admin_api_key_enable.lower() == "true": | |||
| if dify_config.ADMIN_API_KEY_ENABLE: | |||
| if auth_header: | |||
| if " " not in auth_header: | |||
| raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.") | |||
| @@ -61,10 +60,10 @@ def login_required(func): | |||
| auth_scheme = auth_scheme.lower() | |||
| if auth_scheme != "bearer": | |||
| raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.") | |||
| admin_api_key = os.getenv("ADMIN_API_KEY") | |||
| admin_api_key = dify_config.ADMIN_API_KEY | |||
| if admin_api_key: | |||
| if os.getenv("ADMIN_API_KEY") == auth_token: | |||
| if admin_api_key == auth_token: | |||
| workspace_id = request.headers.get("X-WORKSPACE-ID") | |||
| if workspace_id: | |||
| tenant_account_join = ( | |||
| @@ -82,7 +81,7 @@ def login_required(func): | |||
| account.current_tenant = tenant | |||
| current_app.login_manager._update_request_context_with_user(account) | |||
| user_logged_in.send(current_app._get_current_object(), user=_get_user()) | |||
| if request.method in EXEMPT_METHODS or current_app.config.get("LOGIN_DISABLED"): | |||
| if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED: | |||
| pass | |||
| elif not current_user.is_authenticated: | |||
| return current_app.login_manager.unauthorized() | |||
| @@ -1,6 +1,7 @@ | |||
| import pytest | |||
| from app_factory import create_app | |||
| from configs import dify_config | |||
| mock_user = type( | |||
| "MockUser", | |||
| @@ -20,5 +21,5 @@ mock_user = type( | |||
| @pytest.fixture | |||
| def app(): | |||
| app = create_app() | |||
| app.config["LOGIN_DISABLED"] = True | |||
| dify_config.LOGIN_DISABLED = True | |||
| return app | |||
| @@ -25,7 +25,7 @@ class VolcengineTosTest: | |||
| return cls._instance | |||
| def __init__(self): | |||
| self.storage = VolcengineTosStorage(app=Flask(__name__)) | |||
| self.storage = VolcengineTosStorage() | |||
| self.storage.bucket_name = get_example_bucket() | |||
| self.storage.client = TosClientV2( | |||
| ak="dify", | |||