| import os | import os | ||||
| from configs import dify_config | |||||
| if os.environ.get("DEBUG", "false").lower() != "true": | if os.environ.get("DEBUG", "false").lower() != "true": | ||||
| from gevent import monkey | from gevent import monkey | ||||
| time.tzset() | time.tzset() | ||||
| # ------------- | |||||
| # Configuration | |||||
| # ------------- | |||||
| config_type = os.getenv("EDITION", default="SELF_HOSTED") # ce edition first | |||||
| # create app | # create app | ||||
| app = create_app() | app = create_app() | ||||
| celery = app.extensions["celery"] | celery = app.extensions["celery"] | ||||
| if app.config.get("TESTING"): | |||||
| if dify_config.TESTING: | |||||
| print("App is running in TESTING mode") | print("App is running in TESTING mode") | ||||
| def after_request(response): | def after_request(response): | ||||
| """Add Version headers to the response.""" | """Add Version headers to the response.""" | ||||
| response.set_cookie("remember_token", "", expires=0) | response.set_cookie("remember_token", "", expires=0) | ||||
| response.headers.add("X-Version", app.config["CURRENT_VERSION"]) | |||||
| response.headers.add("X-Env", app.config["DEPLOY_ENV"]) | |||||
| response.headers.add("X-Version", dify_config.CURRENT_VERSION) | |||||
| response.headers.add("X-Env", dify_config.DEPLOY_ENV) | |||||
| return response | return response | ||||
| @app.route("/health") | @app.route("/health") | ||||
| def health(): | def health(): | ||||
| return Response( | return Response( | ||||
| json.dumps({"pid": os.getpid(), "status": "ok", "version": app.config["CURRENT_VERSION"]}), | |||||
| json.dumps({"pid": os.getpid(), "status": "ok", "version": dify_config.CURRENT_VERSION}), | |||||
| status=200, | status=200, | ||||
| content_type="application/json", | content_type="application/json", | ||||
| ) | ) |
| def create_app() -> Flask: | def create_app() -> Flask: | ||||
| app = create_flask_app_with_configs() | app = create_flask_app_with_configs() | ||||
| app.secret_key = app.config["SECRET_KEY"] | |||||
| app.secret_key = dify_config.SECRET_KEY | |||||
| initialize_extensions(app) | initialize_extensions(app) | ||||
| register_blueprints(app) | register_blueprints(app) | ||||
| register_commands(app) | register_commands(app) | ||||
| CORS( | CORS( | ||||
| web_bp, | web_bp, | ||||
| resources={r"/*": {"origins": app.config["WEB_API_CORS_ALLOW_ORIGINS"]}}, | |||||
| resources={r"/*": {"origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS}}, | |||||
| supports_credentials=True, | supports_credentials=True, | ||||
| allow_headers=["Content-Type", "Authorization", "X-App-Code"], | allow_headers=["Content-Type", "Authorization", "X-App-Code"], | ||||
| methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], | methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], | ||||
| CORS( | CORS( | ||||
| console_app_bp, | console_app_bp, | ||||
| resources={r"/*": {"origins": app.config["CONSOLE_CORS_ALLOW_ORIGINS"]}}, | |||||
| resources={r"/*": {"origins": dify_config.CONSOLE_CORS_ALLOW_ORIGINS}}, | |||||
| supports_credentials=True, | supports_credentials=True, | ||||
| allow_headers=["Content-Type", "Authorization"], | allow_headers=["Content-Type", "Authorization"], | ||||
| methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], | methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], |
| default=5, | default=5, | ||||
| ) | ) | ||||
| LOGIN_DISABLED: bool = Field( | |||||
| description="Whether to disable login checks", | |||||
| default=False, | |||||
| ) | |||||
| ADMIN_API_KEY_ENABLE: bool = Field( | |||||
| description="Whether to enable admin api key for authentication", | |||||
| default=False, | |||||
| ) | |||||
| ADMIN_API_KEY: Optional[str] = Field( | |||||
| description="admin api key for authentication", | |||||
| default=None, | |||||
| ) | |||||
| class AppExecutionConfig(BaseSettings): | class AppExecutionConfig(BaseSettings): | ||||
| """ | """ |
| import os | |||||
| from functools import wraps | from functools import wraps | ||||
| from flask import request | from flask import request | ||||
| from flask_restful import Resource, reqparse | from flask_restful import Resource, reqparse | ||||
| from werkzeug.exceptions import NotFound, Unauthorized | from werkzeug.exceptions import NotFound, Unauthorized | ||||
| from configs import dify_config | |||||
| from constants.languages import supported_language | from constants.languages import supported_language | ||||
| from controllers.console import api | from controllers.console import api | ||||
| from controllers.console.wraps import only_edition_cloud | from controllers.console.wraps import only_edition_cloud | ||||
| def admin_required(view): | def admin_required(view): | ||||
| @wraps(view) | @wraps(view) | ||||
| def decorated(*args, **kwargs): | def decorated(*args, **kwargs): | ||||
| if not os.getenv("ADMIN_API_KEY"): | |||||
| if not dify_config.ADMIN_API_KEY: | |||||
| raise Unauthorized("API key is invalid.") | raise Unauthorized("API key is invalid.") | ||||
| auth_header = request.headers.get("Authorization") | auth_header = request.headers.get("Authorization") | ||||
| if auth_scheme != "bearer": | if auth_scheme != "bearer": | ||||
| raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.") | raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.") | ||||
| if os.getenv("ADMIN_API_KEY") != auth_token: | |||||
| if dify_config.ADMIN_API_KEY != auth_token: | |||||
| raise Unauthorized("API key is invalid.") | raise Unauthorized("API key is invalid.") | ||||
| return view(*args, **kwargs) | return view(*args, **kwargs) |
| from typing import Optional | from typing import Optional | ||||
| from flask import Config, Flask | |||||
| from flask import Flask | |||||
| from pydantic import BaseModel | from pydantic import BaseModel | ||||
| from configs import dify_config | |||||
| from core.entities.provider_entities import QuotaUnit, RestrictModel | from core.entities.provider_entities import QuotaUnit, RestrictModel | ||||
| from core.model_runtime.entities.model_entities import ModelType | from core.model_runtime.entities.model_entities import ModelType | ||||
| from models.provider import ProviderQuotaType | from models.provider import ProviderQuotaType | ||||
| moderation_config: HostedModerationConfig = None | moderation_config: HostedModerationConfig = None | ||||
| def init_app(self, app: Flask) -> None: | def init_app(self, app: Flask) -> None: | ||||
| config = app.config | |||||
| if config.get("EDITION") != "CLOUD": | |||||
| if dify_config.EDITION != "CLOUD": | |||||
| return | return | ||||
| self.provider_map["azure_openai"] = self.init_azure_openai(config) | |||||
| self.provider_map["openai"] = self.init_openai(config) | |||||
| self.provider_map["anthropic"] = self.init_anthropic(config) | |||||
| self.provider_map["minimax"] = self.init_minimax(config) | |||||
| self.provider_map["spark"] = self.init_spark(config) | |||||
| self.provider_map["zhipuai"] = self.init_zhipuai(config) | |||||
| self.provider_map["azure_openai"] = self.init_azure_openai() | |||||
| self.provider_map["openai"] = self.init_openai() | |||||
| self.provider_map["anthropic"] = self.init_anthropic() | |||||
| self.provider_map["minimax"] = self.init_minimax() | |||||
| self.provider_map["spark"] = self.init_spark() | |||||
| self.provider_map["zhipuai"] = self.init_zhipuai() | |||||
| self.moderation_config = self.init_moderation_config(config) | |||||
| self.moderation_config = self.init_moderation_config() | |||||
| @staticmethod | @staticmethod | ||||
| def init_azure_openai(app_config: Config) -> HostingProvider: | |||||
| def init_azure_openai() -> HostingProvider: | |||||
| quota_unit = QuotaUnit.TIMES | quota_unit = QuotaUnit.TIMES | ||||
| if app_config.get("HOSTED_AZURE_OPENAI_ENABLED"): | |||||
| if dify_config.HOSTED_AZURE_OPENAI_ENABLED: | |||||
| credentials = { | credentials = { | ||||
| "openai_api_key": app_config.get("HOSTED_AZURE_OPENAI_API_KEY"), | |||||
| "openai_api_base": app_config.get("HOSTED_AZURE_OPENAI_API_BASE"), | |||||
| "openai_api_key": dify_config.HOSTED_AZURE_OPENAI_API_KEY, | |||||
| "openai_api_base": dify_config.HOSTED_AZURE_OPENAI_API_BASE, | |||||
| "base_model_name": "gpt-35-turbo", | "base_model_name": "gpt-35-turbo", | ||||
| } | } | ||||
| quotas = [] | quotas = [] | ||||
| hosted_quota_limit = int(app_config.get("HOSTED_AZURE_OPENAI_QUOTA_LIMIT", "1000")) | |||||
| hosted_quota_limit = dify_config.HOSTED_AZURE_OPENAI_QUOTA_LIMIT | |||||
| trial_quota = TrialHostingQuota( | trial_quota = TrialHostingQuota( | ||||
| quota_limit=hosted_quota_limit, | quota_limit=hosted_quota_limit, | ||||
| restrict_models=[ | restrict_models=[ | ||||
| quota_unit=quota_unit, | quota_unit=quota_unit, | ||||
| ) | ) | ||||
| def init_openai(self, app_config: Config) -> HostingProvider: | |||||
| def init_openai(self) -> HostingProvider: | |||||
| quota_unit = QuotaUnit.CREDITS | quota_unit = QuotaUnit.CREDITS | ||||
| quotas = [] | quotas = [] | ||||
| if app_config.get("HOSTED_OPENAI_TRIAL_ENABLED"): | |||||
| hosted_quota_limit = int(app_config.get("HOSTED_OPENAI_QUOTA_LIMIT", "200")) | |||||
| trial_models = self.parse_restrict_models_from_env(app_config, "HOSTED_OPENAI_TRIAL_MODELS") | |||||
| if dify_config.HOSTED_OPENAI_TRIAL_ENABLED: | |||||
| hosted_quota_limit = dify_config.HOSTED_OPENAI_QUOTA_LIMIT | |||||
| trial_models = self.parse_restrict_models_from_env("HOSTED_OPENAI_TRIAL_MODELS") | |||||
| trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trial_models) | trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trial_models) | ||||
| quotas.append(trial_quota) | quotas.append(trial_quota) | ||||
| if app_config.get("HOSTED_OPENAI_PAID_ENABLED"): | |||||
| paid_models = self.parse_restrict_models_from_env(app_config, "HOSTED_OPENAI_PAID_MODELS") | |||||
| if dify_config.HOSTED_OPENAI_PAID_ENABLED: | |||||
| paid_models = self.parse_restrict_models_from_env("HOSTED_OPENAI_PAID_MODELS") | |||||
| paid_quota = PaidHostingQuota(restrict_models=paid_models) | paid_quota = PaidHostingQuota(restrict_models=paid_models) | ||||
| quotas.append(paid_quota) | quotas.append(paid_quota) | ||||
| if len(quotas) > 0: | if len(quotas) > 0: | ||||
| credentials = { | credentials = { | ||||
| "openai_api_key": app_config.get("HOSTED_OPENAI_API_KEY"), | |||||
| "openai_api_key": dify_config.HOSTED_OPENAI_API_KEY, | |||||
| } | } | ||||
| if app_config.get("HOSTED_OPENAI_API_BASE"): | |||||
| credentials["openai_api_base"] = app_config.get("HOSTED_OPENAI_API_BASE") | |||||
| if dify_config.HOSTED_OPENAI_API_BASE: | |||||
| credentials["openai_api_base"] = dify_config.HOSTED_OPENAI_API_BASE | |||||
| if app_config.get("HOSTED_OPENAI_API_ORGANIZATION"): | |||||
| credentials["openai_organization"] = app_config.get("HOSTED_OPENAI_API_ORGANIZATION") | |||||
| if dify_config.HOSTED_OPENAI_API_ORGANIZATION: | |||||
| credentials["openai_organization"] = dify_config.HOSTED_OPENAI_API_ORGANIZATION | |||||
| return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas) | return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas) | ||||
| ) | ) | ||||
| @staticmethod | @staticmethod | ||||
| def init_anthropic(app_config: Config) -> HostingProvider: | |||||
| def init_anthropic() -> HostingProvider: | |||||
| quota_unit = QuotaUnit.TOKENS | quota_unit = QuotaUnit.TOKENS | ||||
| quotas = [] | quotas = [] | ||||
| if app_config.get("HOSTED_ANTHROPIC_TRIAL_ENABLED"): | |||||
| hosted_quota_limit = int(app_config.get("HOSTED_ANTHROPIC_QUOTA_LIMIT", "0")) | |||||
| if dify_config.HOSTED_ANTHROPIC_TRIAL_ENABLED: | |||||
| hosted_quota_limit = dify_config.HOSTED_ANTHROPIC_QUOTA_LIMIT | |||||
| trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit) | trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit) | ||||
| quotas.append(trial_quota) | quotas.append(trial_quota) | ||||
| if app_config.get("HOSTED_ANTHROPIC_PAID_ENABLED"): | |||||
| if dify_config.HOSTED_ANTHROPIC_PAID_ENABLED: | |||||
| paid_quota = PaidHostingQuota() | paid_quota = PaidHostingQuota() | ||||
| quotas.append(paid_quota) | quotas.append(paid_quota) | ||||
| if len(quotas) > 0: | if len(quotas) > 0: | ||||
| credentials = { | credentials = { | ||||
| "anthropic_api_key": app_config.get("HOSTED_ANTHROPIC_API_KEY"), | |||||
| "anthropic_api_key": dify_config.HOSTED_ANTHROPIC_API_KEY, | |||||
| } | } | ||||
| if app_config.get("HOSTED_ANTHROPIC_API_BASE"): | |||||
| credentials["anthropic_api_url"] = app_config.get("HOSTED_ANTHROPIC_API_BASE") | |||||
| if dify_config.HOSTED_ANTHROPIC_API_BASE: | |||||
| credentials["anthropic_api_url"] = dify_config.HOSTED_ANTHROPIC_API_BASE | |||||
| return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas) | return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas) | ||||
| ) | ) | ||||
| @staticmethod | @staticmethod | ||||
| def init_minimax(app_config: Config) -> HostingProvider: | |||||
| def init_minimax() -> HostingProvider: | |||||
| quota_unit = QuotaUnit.TOKENS | quota_unit = QuotaUnit.TOKENS | ||||
| if app_config.get("HOSTED_MINIMAX_ENABLED"): | |||||
| if dify_config.HOSTED_MINIMAX_ENABLED: | |||||
| quotas = [FreeHostingQuota()] | quotas = [FreeHostingQuota()] | ||||
| return HostingProvider( | return HostingProvider( | ||||
| ) | ) | ||||
| @staticmethod | @staticmethod | ||||
| def init_spark(app_config: Config) -> HostingProvider: | |||||
| def init_spark() -> HostingProvider: | |||||
| quota_unit = QuotaUnit.TOKENS | quota_unit = QuotaUnit.TOKENS | ||||
| if app_config.get("HOSTED_SPARK_ENABLED"): | |||||
| if dify_config.HOSTED_SPARK_ENABLED: | |||||
| quotas = [FreeHostingQuota()] | quotas = [FreeHostingQuota()] | ||||
| return HostingProvider( | return HostingProvider( | ||||
| ) | ) | ||||
| @staticmethod | @staticmethod | ||||
| def init_zhipuai(app_config: Config) -> HostingProvider: | |||||
| def init_zhipuai() -> HostingProvider: | |||||
| quota_unit = QuotaUnit.TOKENS | quota_unit = QuotaUnit.TOKENS | ||||
| if app_config.get("HOSTED_ZHIPUAI_ENABLED"): | |||||
| if dify_config.HOSTED_ZHIPUAI_ENABLED: | |||||
| quotas = [FreeHostingQuota()] | quotas = [FreeHostingQuota()] | ||||
| return HostingProvider( | return HostingProvider( | ||||
| ) | ) | ||||
| @staticmethod | @staticmethod | ||||
| def init_moderation_config(app_config: Config) -> HostedModerationConfig: | |||||
| if app_config.get("HOSTED_MODERATION_ENABLED") and app_config.get("HOSTED_MODERATION_PROVIDERS"): | |||||
| return HostedModerationConfig( | |||||
| enabled=True, providers=app_config.get("HOSTED_MODERATION_PROVIDERS").split(",") | |||||
| ) | |||||
| def init_moderation_config() -> HostedModerationConfig: | |||||
| if dify_config.HOSTED_MODERATION_ENABLED and dify_config.HOSTED_MODERATION_PROVIDERS: | |||||
| return HostedModerationConfig(enabled=True, providers=dify_config.HOSTED_MODERATION_PROVIDERS.split(",")) | |||||
| return HostedModerationConfig(enabled=False) | return HostedModerationConfig(enabled=False) | ||||
| @staticmethod | @staticmethod | ||||
| def parse_restrict_models_from_env(app_config: Config, env_var: str) -> list[RestrictModel]: | |||||
| models_str = app_config.get(env_var) | |||||
| def parse_restrict_models_from_env(env_var: str) -> list[RestrictModel]: | |||||
| models_str = dify_config.model_dump().get(env_var) | |||||
| models_list = models_str.split(",") if models_str else [] | models_list = models_str.split(",") if models_str else [] | ||||
| return [ | return [ | ||||
| RestrictModel(model=model_name.strip(), model_type=ModelType.LLM) | RestrictModel(model=model_name.strip(), model_type=ModelType.LLM) |
| if not dataset.index_struct_dict: | if not dataset.index_struct_dict: | ||||
| dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.QDRANT, collection_name)) | dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.QDRANT, collection_name)) | ||||
| config = current_app.config | |||||
| return QdrantVector( | return QdrantVector( | ||||
| collection_name=collection_name, | collection_name=collection_name, | ||||
| group_id=dataset.id, | group_id=dataset.id, | ||||
| config=QdrantConfig( | config=QdrantConfig( | ||||
| endpoint=dify_config.QDRANT_URL, | endpoint=dify_config.QDRANT_URL, | ||||
| api_key=dify_config.QDRANT_API_KEY, | api_key=dify_config.QDRANT_API_KEY, | ||||
| root_path=config.root_path, | |||||
| root_path=current_app.config.root_path, | |||||
| timeout=dify_config.QDRANT_CLIENT_TIMEOUT, | timeout=dify_config.QDRANT_CLIENT_TIMEOUT, | ||||
| grpc_port=dify_config.QDRANT_GRPC_PORT, | grpc_port=dify_config.QDRANT_GRPC_PORT, | ||||
| prefer_grpc=dify_config.QDRANT_GRPC_ENABLED, | prefer_grpc=dify_config.QDRANT_GRPC_ENABLED, |
| from celery import Celery, Task | from celery import Celery, Task | ||||
| from flask import Flask | from flask import Flask | ||||
| from configs import dify_config | |||||
| def init_app(app: Flask) -> Celery: | def init_app(app: Flask) -> Celery: | ||||
| class FlaskTask(Task): | class FlaskTask(Task): | ||||
| broker_transport_options = {} | broker_transport_options = {} | ||||
| if app.config.get("CELERY_USE_SENTINEL"): | |||||
| if dify_config.CELERY_USE_SENTINEL: | |||||
| broker_transport_options = { | broker_transport_options = { | ||||
| "master_name": app.config.get("CELERY_SENTINEL_MASTER_NAME"), | |||||
| "master_name": dify_config.CELERY_SENTINEL_MASTER_NAME, | |||||
| "sentinel_kwargs": { | "sentinel_kwargs": { | ||||
| "socket_timeout": app.config.get("CELERY_SENTINEL_SOCKET_TIMEOUT", 0.1), | |||||
| "socket_timeout": dify_config.CELERY_SENTINEL_SOCKET_TIMEOUT, | |||||
| }, | }, | ||||
| } | } | ||||
| celery_app = Celery( | celery_app = Celery( | ||||
| app.name, | app.name, | ||||
| task_cls=FlaskTask, | task_cls=FlaskTask, | ||||
| broker=app.config.get("CELERY_BROKER_URL"), | |||||
| backend=app.config.get("CELERY_BACKEND"), | |||||
| broker=dify_config.CELERY_BROKER_URL, | |||||
| backend=dify_config.CELERY_BACKEND, | |||||
| task_ignore_result=True, | task_ignore_result=True, | ||||
| ) | ) | ||||
| } | } | ||||
| celery_app.conf.update( | celery_app.conf.update( | ||||
| result_backend=app.config.get("CELERY_RESULT_BACKEND"), | |||||
| result_backend=dify_config.CELERY_RESULT_BACKEND, | |||||
| broker_transport_options=broker_transport_options, | broker_transport_options=broker_transport_options, | ||||
| broker_connection_retry_on_startup=True, | broker_connection_retry_on_startup=True, | ||||
| ) | ) | ||||
| if app.config.get("BROKER_USE_SSL"): | |||||
| if dify_config.BROKER_USE_SSL: | |||||
| celery_app.conf.update( | celery_app.conf.update( | ||||
| broker_use_ssl=ssl_options, # Add the SSL options to the broker configuration | broker_use_ssl=ssl_options, # Add the SSL options to the broker configuration | ||||
| ) | ) | ||||
| "schedule.clean_embedding_cache_task", | "schedule.clean_embedding_cache_task", | ||||
| "schedule.clean_unused_datasets_task", | "schedule.clean_unused_datasets_task", | ||||
| ] | ] | ||||
| day = app.config.get("CELERY_BEAT_SCHEDULER_TIME") | |||||
| day = dify_config.CELERY_BEAT_SCHEDULER_TIME | |||||
| beat_schedule = { | beat_schedule = { | ||||
| "clean_embedding_cache_task": { | "clean_embedding_cache_task": { | ||||
| "task": "schedule.clean_embedding_cache_task.clean_embedding_cache_task", | "task": "schedule.clean_embedding_cache_task.clean_embedding_cache_task", |
| from flask import Flask | from flask import Flask | ||||
| from configs import dify_config | |||||
| def init_app(app: Flask): | def init_app(app: Flask): | ||||
| if app.config.get("API_COMPRESSION_ENABLED"): | |||||
| if dify_config.API_COMPRESSION_ENABLED: | |||||
| from flask_compress import Compress | from flask_compress import Compress | ||||
| app.config["COMPRESS_MIMETYPES"] = [ | app.config["COMPRESS_MIMETYPES"] = [ |
| from flask import Flask | from flask import Flask | ||||
| from configs import dify_config | |||||
| def init_app(app: Flask): | def init_app(app: Flask): | ||||
| log_handlers = None | log_handlers = None | ||||
| log_file = app.config.get("LOG_FILE") | |||||
| log_file = dify_config.LOG_FILE | |||||
| if log_file: | if log_file: | ||||
| log_dir = os.path.dirname(log_file) | log_dir = os.path.dirname(log_file) | ||||
| os.makedirs(log_dir, exist_ok=True) | os.makedirs(log_dir, exist_ok=True) | ||||
| ] | ] | ||||
| logging.basicConfig( | logging.basicConfig( | ||||
| level=app.config.get("LOG_LEVEL"), | |||||
| format=app.config.get("LOG_FORMAT"), | |||||
| datefmt=app.config.get("LOG_DATEFORMAT"), | |||||
| level=dify_config.LOG_LEVEL, | |||||
| format=dify_config.LOG_FORMAT, | |||||
| datefmt=dify_config.LOG_DATEFORMAT, | |||||
| handlers=log_handlers, | handlers=log_handlers, | ||||
| force=True, | force=True, | ||||
| ) | ) | ||||
| log_tz = app.config.get("LOG_TZ") | |||||
| log_tz = dify_config.LOG_TZ | |||||
| if log_tz: | if log_tz: | ||||
| from datetime import datetime | from datetime import datetime | ||||
| import resend | import resend | ||||
| from flask import Flask | from flask import Flask | ||||
| from configs import dify_config | |||||
| class Mail: | class Mail: | ||||
| def __init__(self): | def __init__(self): | ||||
| return self._client is not None | return self._client is not None | ||||
| def init_app(self, app: Flask): | def init_app(self, app: Flask): | ||||
| if app.config.get("MAIL_TYPE"): | |||||
| if app.config.get("MAIL_DEFAULT_SEND_FROM"): | |||||
| self._default_send_from = app.config.get("MAIL_DEFAULT_SEND_FROM") | |||||
| mail_type = dify_config.MAIL_TYPE | |||||
| if not mail_type: | |||||
| logging.warning("MAIL_TYPE is not set") | |||||
| return | |||||
| if app.config.get("MAIL_TYPE") == "resend": | |||||
| api_key = app.config.get("RESEND_API_KEY") | |||||
| if dify_config.MAIL_DEFAULT_SEND_FROM: | |||||
| self._default_send_from = dify_config.MAIL_DEFAULT_SEND_FROM | |||||
| match mail_type: | |||||
| case "resend": | |||||
| api_key = dify_config.RESEND_API_KEY | |||||
| if not api_key: | if not api_key: | ||||
| raise ValueError("RESEND_API_KEY is not set") | raise ValueError("RESEND_API_KEY is not set") | ||||
| api_url = app.config.get("RESEND_API_URL") | |||||
| api_url = dify_config.RESEND_API_URL | |||||
| if api_url: | if api_url: | ||||
| resend.api_url = api_url | resend.api_url = api_url | ||||
| resend.api_key = api_key | resend.api_key = api_key | ||||
| self._client = resend.Emails | self._client = resend.Emails | ||||
| elif app.config.get("MAIL_TYPE") == "smtp": | |||||
| case "smtp": | |||||
| from libs.smtp import SMTPClient | from libs.smtp import SMTPClient | ||||
| if not app.config.get("SMTP_SERVER") or not app.config.get("SMTP_PORT"): | |||||
| if not dify_config.SMTP_SERVER or not dify_config.SMTP_PORT: | |||||
| raise ValueError("SMTP_SERVER and SMTP_PORT are required for smtp mail type") | raise ValueError("SMTP_SERVER and SMTP_PORT are required for smtp mail type") | ||||
| if not app.config.get("SMTP_USE_TLS") and app.config.get("SMTP_OPPORTUNISTIC_TLS"): | |||||
| if not dify_config.SMTP_USE_TLS and dify_config.SMTP_OPPORTUNISTIC_TLS: | |||||
| raise ValueError("SMTP_OPPORTUNISTIC_TLS is not supported without enabling SMTP_USE_TLS") | raise ValueError("SMTP_OPPORTUNISTIC_TLS is not supported without enabling SMTP_USE_TLS") | ||||
| self._client = SMTPClient( | self._client = SMTPClient( | ||||
| server=app.config.get("SMTP_SERVER"), | |||||
| port=app.config.get("SMTP_PORT"), | |||||
| username=app.config.get("SMTP_USERNAME"), | |||||
| password=app.config.get("SMTP_PASSWORD"), | |||||
| _from=app.config.get("MAIL_DEFAULT_SEND_FROM"), | |||||
| use_tls=app.config.get("SMTP_USE_TLS"), | |||||
| opportunistic_tls=app.config.get("SMTP_OPPORTUNISTIC_TLS"), | |||||
| server=dify_config.SMTP_SERVER, | |||||
| port=dify_config.SMTP_PORT, | |||||
| username=dify_config.SMTP_USERNAME, | |||||
| password=dify_config.SMTP_PASSWORD, | |||||
| _from=dify_config.MAIL_DEFAULT_SEND_FROM, | |||||
| use_tls=dify_config.SMTP_USE_TLS, | |||||
| opportunistic_tls=dify_config.SMTP_OPPORTUNISTIC_TLS, | |||||
| ) | ) | ||||
| else: | |||||
| raise ValueError("Unsupported mail type {}".format(app.config.get("MAIL_TYPE"))) | |||||
| else: | |||||
| logging.warning("MAIL_TYPE is not set") | |||||
| case _: | |||||
| raise ValueError("Unsupported mail type {}".format(mail_type)) | |||||
| def send(self, to: str, subject: str, html: str, from_: Optional[str] = None): | def send(self, to: str, subject: str, html: str, from_: Optional[str] = None): | ||||
| if not self._client: | if not self._client: |
| from redis.connection import Connection, SSLConnection | from redis.connection import Connection, SSLConnection | ||||
| from redis.sentinel import Sentinel | from redis.sentinel import Sentinel | ||||
| from configs import dify_config | |||||
| class RedisClientWrapper(redis.Redis): | class RedisClientWrapper(redis.Redis): | ||||
| """ | """ | ||||
| def init_app(app): | def init_app(app): | ||||
| global redis_client | global redis_client | ||||
| connection_class = Connection | connection_class = Connection | ||||
| if app.config.get("REDIS_USE_SSL"): | |||||
| if dify_config.REDIS_USE_SSL: | |||||
| connection_class = SSLConnection | connection_class = SSLConnection | ||||
| redis_params = { | redis_params = { | ||||
| "username": app.config.get("REDIS_USERNAME"), | |||||
| "password": app.config.get("REDIS_PASSWORD"), | |||||
| "db": app.config.get("REDIS_DB"), | |||||
| "username": dify_config.REDIS_USERNAME, | |||||
| "password": dify_config.REDIS_PASSWORD, | |||||
| "db": dify_config.REDIS_DB, | |||||
| "encoding": "utf-8", | "encoding": "utf-8", | ||||
| "encoding_errors": "strict", | "encoding_errors": "strict", | ||||
| "decode_responses": False, | "decode_responses": False, | ||||
| } | } | ||||
| if app.config.get("REDIS_USE_SENTINEL"): | |||||
| if dify_config.REDIS_USE_SENTINEL: | |||||
| sentinel_hosts = [ | sentinel_hosts = [ | ||||
| (node.split(":")[0], int(node.split(":")[1])) for node in app.config.get("REDIS_SENTINELS").split(",") | |||||
| (node.split(":")[0], int(node.split(":")[1])) for node in dify_config.REDIS_SENTINELS.split(",") | |||||
| ] | ] | ||||
| sentinel = Sentinel( | sentinel = Sentinel( | ||||
| sentinel_hosts, | sentinel_hosts, | ||||
| sentinel_kwargs={ | sentinel_kwargs={ | ||||
| "socket_timeout": app.config.get("REDIS_SENTINEL_SOCKET_TIMEOUT", 0.1), | |||||
| "username": app.config.get("REDIS_SENTINEL_USERNAME"), | |||||
| "password": app.config.get("REDIS_SENTINEL_PASSWORD"), | |||||
| "socket_timeout": dify_config.REDIS_SENTINEL_SOCKET_TIMEOUT, | |||||
| "username": dify_config.REDIS_SENTINEL_USERNAME, | |||||
| "password": dify_config.REDIS_SENTINEL_PASSWORD, | |||||
| }, | }, | ||||
| ) | ) | ||||
| master = sentinel.master_for(app.config.get("REDIS_SENTINEL_SERVICE_NAME"), **redis_params) | |||||
| master = sentinel.master_for(dify_config.REDIS_SENTINEL_SERVICE_NAME, **redis_params) | |||||
| redis_client.initialize(master) | redis_client.initialize(master) | ||||
| else: | else: | ||||
| redis_params.update( | redis_params.update( | ||||
| { | { | ||||
| "host": app.config.get("REDIS_HOST"), | |||||
| "port": app.config.get("REDIS_PORT"), | |||||
| "host": dify_config.REDIS_HOST, | |||||
| "port": dify_config.REDIS_PORT, | |||||
| "connection_class": connection_class, | "connection_class": connection_class, | ||||
| } | } | ||||
| ) | ) |
| from sentry_sdk.integrations.flask import FlaskIntegration | from sentry_sdk.integrations.flask import FlaskIntegration | ||||
| from werkzeug.exceptions import HTTPException | from werkzeug.exceptions import HTTPException | ||||
| from configs import dify_config | |||||
| from core.model_runtime.errors.invoke import InvokeRateLimitError | from core.model_runtime.errors.invoke import InvokeRateLimitError | ||||
| def init_app(app): | def init_app(app): | ||||
| if app.config.get("SENTRY_DSN"): | |||||
| if dify_config.SENTRY_DSN: | |||||
| sentry_sdk.init( | sentry_sdk.init( | ||||
| dsn=app.config.get("SENTRY_DSN"), | |||||
| dsn=dify_config.SENTRY_DSN, | |||||
| integrations=[FlaskIntegration(), CeleryIntegration()], | integrations=[FlaskIntegration(), CeleryIntegration()], | ||||
| ignore_errors=[ | ignore_errors=[ | ||||
| HTTPException, | HTTPException, | ||||
| InvokeRateLimitError, | InvokeRateLimitError, | ||||
| parse_error.defaultErrorResponse, | parse_error.defaultErrorResponse, | ||||
| ], | ], | ||||
| traces_sample_rate=app.config.get("SENTRY_TRACES_SAMPLE_RATE", 1.0), | |||||
| profiles_sample_rate=app.config.get("SENTRY_PROFILES_SAMPLE_RATE", 1.0), | |||||
| environment=app.config.get("DEPLOY_ENV"), | |||||
| release=f"dify-{app.config.get('CURRENT_VERSION')}-{app.config.get('COMMIT_SHA')}", | |||||
| traces_sample_rate=dify_config.SENTRY_TRACES_SAMPLE_RATE, | |||||
| profiles_sample_rate=dify_config.SENTRY_PROFILES_SAMPLE_RATE, | |||||
| environment=dify_config.DEPLOY_ENV, | |||||
| release=f"dify-{dify_config.CURRENT_VERSION}-{dify_config.COMMIT_SHA}", | |||||
| before_send=before_send, | before_send=before_send, | ||||
| ) | ) |
| def init_app(self, app: Flask): | def init_app(self, app: Flask): | ||||
| storage_factory = self.get_storage_factory(dify_config.STORAGE_TYPE) | storage_factory = self.get_storage_factory(dify_config.STORAGE_TYPE) | ||||
| self.storage_runner = storage_factory(app=app) | |||||
| self.storage_runner = storage_factory() | |||||
| @staticmethod | @staticmethod | ||||
| def get_storage_factory(storage_type: str) -> type[BaseStorage]: | def get_storage_factory(storage_type: str) -> type[BaseStorage]: |
| from collections.abc import Generator | from collections.abc import Generator | ||||
| import oss2 as aliyun_s3 | import oss2 as aliyun_s3 | ||||
| from flask import Flask | |||||
| from configs import dify_config | |||||
| from extensions.storage.base_storage import BaseStorage | from extensions.storage.base_storage import BaseStorage | ||||
| class AliyunOssStorage(BaseStorage): | class AliyunOssStorage(BaseStorage): | ||||
| """Implementation for Aliyun OSS storage.""" | """Implementation for Aliyun OSS storage.""" | ||||
| def __init__(self, app: Flask): | |||||
| super().__init__(app) | |||||
| app_config = self.app.config | |||||
| self.bucket_name = app_config.get("ALIYUN_OSS_BUCKET_NAME") | |||||
| self.folder = app.config.get("ALIYUN_OSS_PATH") | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.bucket_name = dify_config.ALIYUN_OSS_BUCKET_NAME | |||||
| self.folder = dify_config.ALIYUN_OSS_PATH | |||||
| oss_auth_method = aliyun_s3.Auth | oss_auth_method = aliyun_s3.Auth | ||||
| region = None | region = None | ||||
| if app_config.get("ALIYUN_OSS_AUTH_VERSION") == "v4": | |||||
| if dify_config.ALIYUN_OSS_AUTH_VERSION == "v4": | |||||
| oss_auth_method = aliyun_s3.AuthV4 | oss_auth_method = aliyun_s3.AuthV4 | ||||
| region = app_config.get("ALIYUN_OSS_REGION") | |||||
| oss_auth = oss_auth_method(app_config.get("ALIYUN_OSS_ACCESS_KEY"), app_config.get("ALIYUN_OSS_SECRET_KEY")) | |||||
| region = dify_config.ALIYUN_OSS_REGION | |||||
| oss_auth = oss_auth_method(dify_config.ALIYUN_OSS_ACCESS_KEY, dify_config.ALIYUN_OSS_SECRET_KEY) | |||||
| self.client = aliyun_s3.Bucket( | self.client = aliyun_s3.Bucket( | ||||
| oss_auth, | oss_auth, | ||||
| app_config.get("ALIYUN_OSS_ENDPOINT"), | |||||
| dify_config.ALIYUN_OSS_ENDPOINT, | |||||
| self.bucket_name, | self.bucket_name, | ||||
| connect_timeout=30, | connect_timeout=30, | ||||
| region=region, | region=region, |
| from datetime import datetime, timedelta, timezone | from datetime import datetime, timedelta, timezone | ||||
| from azure.storage.blob import AccountSasPermissions, BlobServiceClient, ResourceTypes, generate_account_sas | from azure.storage.blob import AccountSasPermissions, BlobServiceClient, ResourceTypes, generate_account_sas | ||||
| from flask import Flask | |||||
| from configs import dify_config | |||||
| from extensions.ext_redis import redis_client | from extensions.ext_redis import redis_client | ||||
| from extensions.storage.base_storage import BaseStorage | from extensions.storage.base_storage import BaseStorage | ||||
| class AzureBlobStorage(BaseStorage): | class AzureBlobStorage(BaseStorage): | ||||
| """Implementation for Azure Blob storage.""" | """Implementation for Azure Blob storage.""" | ||||
| def __init__(self, app: Flask): | |||||
| super().__init__(app) | |||||
| app_config = self.app.config | |||||
| self.bucket_name = app_config.get("AZURE_BLOB_CONTAINER_NAME") | |||||
| self.account_url = app_config.get("AZURE_BLOB_ACCOUNT_URL") | |||||
| self.account_name = app_config.get("AZURE_BLOB_ACCOUNT_NAME") | |||||
| self.account_key = app_config.get("AZURE_BLOB_ACCOUNT_KEY") | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.bucket_name = dify_config.AZURE_BLOB_CONTAINER_NAME | |||||
| self.account_url = dify_config.AZURE_BLOB_ACCOUNT_URL | |||||
| self.account_name = dify_config.AZURE_BLOB_ACCOUNT_NAME | |||||
| self.account_key = dify_config.AZURE_BLOB_ACCOUNT_KEY | |||||
| def save(self, filename, data): | def save(self, filename, data): | ||||
| client = self._sync_client() | client = self._sync_client() |
| from baidubce.auth.bce_credentials import BceCredentials | from baidubce.auth.bce_credentials import BceCredentials | ||||
| from baidubce.bce_client_configuration import BceClientConfiguration | from baidubce.bce_client_configuration import BceClientConfiguration | ||||
| from baidubce.services.bos.bos_client import BosClient | from baidubce.services.bos.bos_client import BosClient | ||||
| from flask import Flask | |||||
| from configs import dify_config | |||||
| from extensions.storage.base_storage import BaseStorage | from extensions.storage.base_storage import BaseStorage | ||||
| class BaiduObsStorage(BaseStorage): | class BaiduObsStorage(BaseStorage): | ||||
| """Implementation for Baidu OBS storage.""" | """Implementation for Baidu OBS storage.""" | ||||
| def __init__(self, app: Flask): | |||||
| super().__init__(app) | |||||
| app_config = self.app.config | |||||
| self.bucket_name = app_config.get("BAIDU_OBS_BUCKET_NAME") | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.bucket_name = dify_config.BAIDU_OBS_BUCKET_NAME | |||||
| client_config = BceClientConfiguration( | client_config = BceClientConfiguration( | ||||
| credentials=BceCredentials( | credentials=BceCredentials( | ||||
| access_key_id=app_config.get("BAIDU_OBS_ACCESS_KEY"), | |||||
| secret_access_key=app_config.get("BAIDU_OBS_SECRET_KEY"), | |||||
| access_key_id=dify_config.BAIDU_OBS_ACCESS_KEY, | |||||
| secret_access_key=dify_config.BAIDU_OBS_SECRET_KEY, | |||||
| ), | ), | ||||
| endpoint=app_config.get("BAIDU_OBS_ENDPOINT"), | |||||
| endpoint=dify_config.BAIDU_OBS_ENDPOINT, | |||||
| ) | ) | ||||
| self.client = BosClient(config=client_config) | self.client = BosClient(config=client_config) |
| from abc import ABC, abstractmethod | from abc import ABC, abstractmethod | ||||
| from collections.abc import Generator | from collections.abc import Generator | ||||
| from flask import Flask | |||||
| class BaseStorage(ABC): | class BaseStorage(ABC): | ||||
| """Interface for file storage.""" | """Interface for file storage.""" | ||||
| app = None | |||||
| def __init__(self, app: Flask): | |||||
| self.app = app | |||||
| def __init__(self): # noqa: B027 | |||||
| pass | |||||
| @abstractmethod | @abstractmethod | ||||
| def save(self, filename, data): | def save(self, filename, data): |
| import json | import json | ||||
| from collections.abc import Generator | from collections.abc import Generator | ||||
| from flask import Flask | |||||
| from google.cloud import storage as google_cloud_storage | from google.cloud import storage as google_cloud_storage | ||||
| from configs import dify_config | |||||
| from extensions.storage.base_storage import BaseStorage | from extensions.storage.base_storage import BaseStorage | ||||
| class GoogleCloudStorage(BaseStorage): | class GoogleCloudStorage(BaseStorage): | ||||
| """Implementation for Google Cloud storage.""" | """Implementation for Google Cloud storage.""" | ||||
| def __init__(self, app: Flask): | |||||
| super().__init__(app) | |||||
| app_config = self.app.config | |||||
| self.bucket_name = app_config.get("GOOGLE_STORAGE_BUCKET_NAME") | |||||
| service_account_json_str = app_config.get("GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64") | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.bucket_name = dify_config.GOOGLE_STORAGE_BUCKET_NAME | |||||
| service_account_json_str = dify_config.GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64 | |||||
| # if service_account_json_str is empty, use Application Default Credentials | # if service_account_json_str is empty, use Application Default Credentials | ||||
| if service_account_json_str: | if service_account_json_str: | ||||
| service_account_json = base64.b64decode(service_account_json_str).decode("utf-8") | service_account_json = base64.b64decode(service_account_json_str).decode("utf-8") |
| from collections.abc import Generator | from collections.abc import Generator | ||||
| from flask import Flask | |||||
| from obs import ObsClient | from obs import ObsClient | ||||
| from configs import dify_config | |||||
| from extensions.storage.base_storage import BaseStorage | from extensions.storage.base_storage import BaseStorage | ||||
| class HuaweiObsStorage(BaseStorage): | class HuaweiObsStorage(BaseStorage): | ||||
| """Implementation for Huawei OBS storage.""" | """Implementation for Huawei OBS storage.""" | ||||
| def __init__(self, app: Flask): | |||||
| super().__init__(app) | |||||
| app_config = self.app.config | |||||
| self.bucket_name = app_config.get("HUAWEI_OBS_BUCKET_NAME") | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.bucket_name = dify_config.HUAWEI_OBS_BUCKET_NAME | |||||
| self.client = ObsClient( | self.client = ObsClient( | ||||
| access_key_id=app_config.get("HUAWEI_OBS_ACCESS_KEY"), | |||||
| secret_access_key=app_config.get("HUAWEI_OBS_SECRET_KEY"), | |||||
| server=app_config.get("HUAWEI_OBS_SERVER"), | |||||
| access_key_id=dify_config.HUAWEI_OBS_ACCESS_KEY, | |||||
| secret_access_key=dify_config.HUAWEI_OBS_SECRET_KEY, | |||||
| server=dify_config.HUAWEI_OBS_SERVER, | |||||
| ) | ) | ||||
| def save(self, filename, data): | def save(self, filename, data): |
| from collections.abc import Generator | from collections.abc import Generator | ||||
| from pathlib import Path | from pathlib import Path | ||||
| from flask import Flask | |||||
| from flask import current_app | |||||
| from configs import dify_config | |||||
| from extensions.storage.base_storage import BaseStorage | from extensions.storage.base_storage import BaseStorage | ||||
| class LocalFsStorage(BaseStorage): | class LocalFsStorage(BaseStorage): | ||||
| """Implementation for local filesystem storage.""" | """Implementation for local filesystem storage.""" | ||||
| def __init__(self, app: Flask): | |||||
| super().__init__(app) | |||||
| folder = self.app.config.get("STORAGE_LOCAL_PATH") | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| folder = dify_config.STORAGE_LOCAL_PATH | |||||
| if not os.path.isabs(folder): | if not os.path.isabs(folder): | ||||
| folder = os.path.join(app.root_path, folder) | |||||
| folder = os.path.join(current_app.root_path, folder) | |||||
| self.folder = folder | self.folder = folder | ||||
| def save(self, filename, data): | def save(self, filename, data): |
| import boto3 | import boto3 | ||||
| from botocore.exceptions import ClientError | from botocore.exceptions import ClientError | ||||
| from flask import Flask | |||||
| from configs import dify_config | |||||
| from extensions.storage.base_storage import BaseStorage | from extensions.storage.base_storage import BaseStorage | ||||
| class OracleOCIStorage(BaseStorage): | class OracleOCIStorage(BaseStorage): | ||||
| """Implementation for Oracle OCI storage.""" | """Implementation for Oracle OCI storage.""" | ||||
| def __init__(self, app: Flask): | |||||
| super().__init__(app) | |||||
| app_config = self.app.config | |||||
| self.bucket_name = app_config.get("OCI_BUCKET_NAME") | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.bucket_name = dify_config.OCI_BUCKET_NAME | |||||
| self.client = boto3.client( | self.client = boto3.client( | ||||
| "s3", | "s3", | ||||
| aws_secret_access_key=app_config.get("OCI_SECRET_KEY"), | |||||
| aws_access_key_id=app_config.get("OCI_ACCESS_KEY"), | |||||
| endpoint_url=app_config.get("OCI_ENDPOINT"), | |||||
| region_name=app_config.get("OCI_REGION"), | |||||
| aws_secret_access_key=dify_config.OCI_SECRET_KEY, | |||||
| aws_access_key_id=dify_config.OCI_ACCESS_KEY, | |||||
| endpoint_url=dify_config.OCI_ENDPOINT, | |||||
| region_name=dify_config.OCI_REGION, | |||||
| ) | ) | ||||
| def save(self, filename, data): | def save(self, filename, data): |
| from collections.abc import Generator | from collections.abc import Generator | ||||
| from flask import Flask | |||||
| from qcloud_cos import CosConfig, CosS3Client | from qcloud_cos import CosConfig, CosS3Client | ||||
| from configs import dify_config | |||||
| from extensions.storage.base_storage import BaseStorage | from extensions.storage.base_storage import BaseStorage | ||||
| class TencentCosStorage(BaseStorage): | class TencentCosStorage(BaseStorage): | ||||
| """Implementation for Tencent Cloud COS storage.""" | """Implementation for Tencent Cloud COS storage.""" | ||||
| def __init__(self, app: Flask): | |||||
| super().__init__(app) | |||||
| app_config = self.app.config | |||||
| self.bucket_name = app_config.get("TENCENT_COS_BUCKET_NAME") | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.bucket_name = dify_config.TENCENT_COS_BUCKET_NAME | |||||
| config = CosConfig( | config = CosConfig( | ||||
| Region=app_config.get("TENCENT_COS_REGION"), | |||||
| SecretId=app_config.get("TENCENT_COS_SECRET_ID"), | |||||
| SecretKey=app_config.get("TENCENT_COS_SECRET_KEY"), | |||||
| Scheme=app_config.get("TENCENT_COS_SCHEME"), | |||||
| Region=dify_config.TENCENT_COS_REGION, | |||||
| SecretId=dify_config.TENCENT_COS_SECRET_ID, | |||||
| SecretKey=dify_config.TENCENT_COS_SECRET_KEY, | |||||
| Scheme=dify_config.TENCENT_COS_SCHEME, | |||||
| ) | ) | ||||
| self.client = CosS3Client(config) | self.client = CosS3Client(config) | ||||
| from collections.abc import Generator | from collections.abc import Generator | ||||
| import tos | import tos | ||||
| from flask import Flask | |||||
| from configs import dify_config | |||||
| from extensions.storage.base_storage import BaseStorage | from extensions.storage.base_storage import BaseStorage | ||||
| class VolcengineTosStorage(BaseStorage): | class VolcengineTosStorage(BaseStorage): | ||||
| """Implementation for Volcengine TOS storage.""" | """Implementation for Volcengine TOS storage.""" | ||||
| def __init__(self, app: Flask): | |||||
| super().__init__(app) | |||||
| app_config = self.app.config | |||||
| self.bucket_name = app_config.get("VOLCENGINE_TOS_BUCKET_NAME") | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.bucket_name = dify_config.VOLCENGINE_TOS_BUCKET_NAME | |||||
| self.client = tos.TosClientV2( | self.client = tos.TosClientV2( | ||||
| ak=app_config.get("VOLCENGINE_TOS_ACCESS_KEY"), | |||||
| sk=app_config.get("VOLCENGINE_TOS_SECRET_KEY"), | |||||
| endpoint=app_config.get("VOLCENGINE_TOS_ENDPOINT"), | |||||
| region=app_config.get("VOLCENGINE_TOS_REGION"), | |||||
| ak=dify_config.VOLCENGINE_TOS_ACCESS_KEY, | |||||
| sk=dify_config.VOLCENGINE_TOS_SECRET_KEY, | |||||
| endpoint=dify_config.VOLCENGINE_TOS_ENDPOINT, | |||||
| region=dify_config.VOLCENGINE_TOS_REGION, | |||||
| ) | ) | ||||
| def save(self, filename, data): | def save(self, filename, data): |
| from typing import Any, Optional, Union | from typing import Any, Optional, Union | ||||
| from zoneinfo import available_timezones | from zoneinfo import available_timezones | ||||
| from flask import Response, current_app, stream_with_context | |||||
| from flask import Response, stream_with_context | |||||
| from flask_restful import fields | from flask_restful import fields | ||||
| from configs import dify_config | |||||
| from core.app.features.rate_limiting.rate_limit import RateLimitGenerator | from core.app.features.rate_limiting.rate_limit import RateLimitGenerator | ||||
| from core.file import helpers as file_helpers | from core.file import helpers as file_helpers | ||||
| from extensions.ext_redis import redis_client | from extensions.ext_redis import redis_client | ||||
| if additional_data: | if additional_data: | ||||
| token_data.update(additional_data) | token_data.update(additional_data) | ||||
| expiry_minutes = current_app.config[f"{token_type.upper()}_TOKEN_EXPIRY_MINUTES"] | |||||
| expiry_minutes = dify_config.model_dump().get(f"{token_type.upper()}_TOKEN_EXPIRY_MINUTES") | |||||
| token_key = cls._get_token_key(token, token_type) | token_key = cls._get_token_key(token, token_type) | ||||
| expiry_time = int(expiry_minutes * 60) | expiry_time = int(expiry_minutes * 60) | ||||
| redis_client.setex(token_key, expiry_time, json.dumps(token_data)) | redis_client.setex(token_key, expiry_time, json.dumps(token_data)) |
| import os | |||||
| from functools import wraps | from functools import wraps | ||||
| from flask import current_app, g, has_request_context, request | from flask import current_app, g, has_request_context, request | ||||
| from werkzeug.exceptions import Unauthorized | from werkzeug.exceptions import Unauthorized | ||||
| from werkzeug.local import LocalProxy | from werkzeug.local import LocalProxy | ||||
| from configs import dify_config | |||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from models.account import Account, Tenant, TenantAccountJoin | from models.account import Account, Tenant, TenantAccountJoin | ||||
| @wraps(func) | @wraps(func) | ||||
| def decorated_view(*args, **kwargs): | def decorated_view(*args, **kwargs): | ||||
| auth_header = request.headers.get("Authorization") | auth_header = request.headers.get("Authorization") | ||||
| admin_api_key_enable = os.getenv("ADMIN_API_KEY_ENABLE", default="False") | |||||
| if admin_api_key_enable.lower() == "true": | |||||
| if dify_config.ADMIN_API_KEY_ENABLE: | |||||
| if auth_header: | if auth_header: | ||||
| if " " not in auth_header: | if " " not in auth_header: | ||||
| raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.") | raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.") | ||||
| auth_scheme = auth_scheme.lower() | auth_scheme = auth_scheme.lower() | ||||
| if auth_scheme != "bearer": | if auth_scheme != "bearer": | ||||
| raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.") | raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.") | ||||
| admin_api_key = os.getenv("ADMIN_API_KEY") | |||||
| admin_api_key = dify_config.ADMIN_API_KEY | |||||
| if admin_api_key: | if admin_api_key: | ||||
| if os.getenv("ADMIN_API_KEY") == auth_token: | |||||
| if admin_api_key == auth_token: | |||||
| workspace_id = request.headers.get("X-WORKSPACE-ID") | workspace_id = request.headers.get("X-WORKSPACE-ID") | ||||
| if workspace_id: | if workspace_id: | ||||
| tenant_account_join = ( | tenant_account_join = ( | ||||
| account.current_tenant = tenant | account.current_tenant = tenant | ||||
| current_app.login_manager._update_request_context_with_user(account) | current_app.login_manager._update_request_context_with_user(account) | ||||
| user_logged_in.send(current_app._get_current_object(), user=_get_user()) | user_logged_in.send(current_app._get_current_object(), user=_get_user()) | ||||
| if request.method in EXEMPT_METHODS or current_app.config.get("LOGIN_DISABLED"): | |||||
| if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED: | |||||
| pass | pass | ||||
| elif not current_user.is_authenticated: | elif not current_user.is_authenticated: | ||||
| return current_app.login_manager.unauthorized() | return current_app.login_manager.unauthorized() |
| import pytest | import pytest | ||||
| from app_factory import create_app | from app_factory import create_app | ||||
| from configs import dify_config | |||||
| mock_user = type( | mock_user = type( | ||||
| "MockUser", | "MockUser", | ||||
| @pytest.fixture | @pytest.fixture | ||||
| def app(): | def app(): | ||||
| app = create_app() | app = create_app() | ||||
| app.config["LOGIN_DISABLED"] = True | |||||
| dify_config.LOGIN_DISABLED = True | |||||
| return app | return app |
| return cls._instance | return cls._instance | ||||
| def __init__(self): | def __init__(self): | ||||
| self.storage = VolcengineTosStorage(app=Flask(__name__)) | |||||
| self.storage = VolcengineTosStorage() | |||||
| self.storage.bucket_name = get_example_bucket() | self.storage.bucket_name = get_example_bucket() | ||||
| self.storage.client = TosClientV2( | self.storage.client = TosClientV2( | ||||
| ak="dify", | ak="dify", |