Browse Source

refactor: use dify_config to replace legacy usage of flask app's config (#9089)

tags/0.10.1
Bowen Liang 1 year ago
parent
commit
4d9160ca9f
No account linked to committer's email address

+ 6
- 10
api/app.py View File

import os import os


from configs import dify_config

if os.environ.get("DEBUG", "false").lower() != "true": if os.environ.get("DEBUG", "false").lower() != "true":
from gevent import monkey from gevent import monkey


time.tzset() time.tzset()




# -------------
# Configuration
# -------------
config_type = os.getenv("EDITION", default="SELF_HOSTED") # ce edition first


# create app # create app
app = create_app() app = create_app()
celery = app.extensions["celery"] celery = app.extensions["celery"]


if app.config.get("TESTING"):
if dify_config.TESTING:
print("App is running in TESTING mode") print("App is running in TESTING mode")




def after_request(response): def after_request(response):
"""Add Version headers to the response.""" """Add Version headers to the response."""
response.set_cookie("remember_token", "", expires=0) response.set_cookie("remember_token", "", expires=0)
response.headers.add("X-Version", app.config["CURRENT_VERSION"])
response.headers.add("X-Env", app.config["DEPLOY_ENV"])
response.headers.add("X-Version", dify_config.CURRENT_VERSION)
response.headers.add("X-Env", dify_config.DEPLOY_ENV)
return response return response




@app.route("/health") @app.route("/health")
def health(): def health():
return Response( return Response(
json.dumps({"pid": os.getpid(), "status": "ok", "version": app.config["CURRENT_VERSION"]}),
json.dumps({"pid": os.getpid(), "status": "ok", "version": dify_config.CURRENT_VERSION}),
status=200, status=200,
content_type="application/json", content_type="application/json",
) )

+ 3
- 3
api/app_factory.py View File



def create_app() -> Flask: def create_app() -> Flask:
app = create_flask_app_with_configs() app = create_flask_app_with_configs()
app.secret_key = app.config["SECRET_KEY"]
app.secret_key = dify_config.SECRET_KEY
initialize_extensions(app) initialize_extensions(app)
register_blueprints(app) register_blueprints(app)
register_commands(app) register_commands(app)


CORS( CORS(
web_bp, web_bp,
resources={r"/*": {"origins": app.config["WEB_API_CORS_ALLOW_ORIGINS"]}},
resources={r"/*": {"origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS}},
supports_credentials=True, supports_credentials=True,
allow_headers=["Content-Type", "Authorization", "X-App-Code"], allow_headers=["Content-Type", "Authorization", "X-App-Code"],
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],


CORS( CORS(
console_app_bp, console_app_bp,
resources={r"/*": {"origins": app.config["CONSOLE_CORS_ALLOW_ORIGINS"]}},
resources={r"/*": {"origins": dify_config.CONSOLE_CORS_ALLOW_ORIGINS}},
supports_credentials=True, supports_credentials=True,
allow_headers=["Content-Type", "Authorization"], allow_headers=["Content-Type", "Authorization"],
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],

+ 15
- 0
api/configs/feature/__init__.py View File

default=5, default=5,
) )


LOGIN_DISABLED: bool = Field(
description="Whether to disable login checks",
default=False,
)

ADMIN_API_KEY_ENABLE: bool = Field(
description="Whether to enable admin api key for authentication",
default=False,
)

ADMIN_API_KEY: Optional[str] = Field(
description="admin api key for authentication",
default=None,
)



class AppExecutionConfig(BaseSettings): class AppExecutionConfig(BaseSettings):
""" """

+ 3
- 3
api/controllers/console/admin.py View File

import os
from functools import wraps from functools import wraps


from flask import request from flask import request
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse
from werkzeug.exceptions import NotFound, Unauthorized from werkzeug.exceptions import NotFound, Unauthorized


from configs import dify_config
from constants.languages import supported_language from constants.languages import supported_language
from controllers.console import api from controllers.console import api
from controllers.console.wraps import only_edition_cloud from controllers.console.wraps import only_edition_cloud
def admin_required(view): def admin_required(view):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args, **kwargs):
if not os.getenv("ADMIN_API_KEY"):
if not dify_config.ADMIN_API_KEY:
raise Unauthorized("API key is invalid.") raise Unauthorized("API key is invalid.")


auth_header = request.headers.get("Authorization") auth_header = request.headers.get("Authorization")
if auth_scheme != "bearer": if auth_scheme != "bearer":
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.") raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")


if os.getenv("ADMIN_API_KEY") != auth_token:
if dify_config.ADMIN_API_KEY != auth_token:
raise Unauthorized("API key is invalid.") raise Unauthorized("API key is invalid.")


return view(*args, **kwargs) return view(*args, **kwargs)

+ 44
- 47
api/core/hosting_configuration.py View File

from typing import Optional from typing import Optional


from flask import Config, Flask
from flask import Flask
from pydantic import BaseModel from pydantic import BaseModel


from configs import dify_config
from core.entities.provider_entities import QuotaUnit, RestrictModel from core.entities.provider_entities import QuotaUnit, RestrictModel
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from models.provider import ProviderQuotaType from models.provider import ProviderQuotaType
moderation_config: HostedModerationConfig = None moderation_config: HostedModerationConfig = None


def init_app(self, app: Flask) -> None: def init_app(self, app: Flask) -> None:
config = app.config

if config.get("EDITION") != "CLOUD":
if dify_config.EDITION != "CLOUD":
return return


self.provider_map["azure_openai"] = self.init_azure_openai(config)
self.provider_map["openai"] = self.init_openai(config)
self.provider_map["anthropic"] = self.init_anthropic(config)
self.provider_map["minimax"] = self.init_minimax(config)
self.provider_map["spark"] = self.init_spark(config)
self.provider_map["zhipuai"] = self.init_zhipuai(config)
self.provider_map["azure_openai"] = self.init_azure_openai()
self.provider_map["openai"] = self.init_openai()
self.provider_map["anthropic"] = self.init_anthropic()
self.provider_map["minimax"] = self.init_minimax()
self.provider_map["spark"] = self.init_spark()
self.provider_map["zhipuai"] = self.init_zhipuai()


self.moderation_config = self.init_moderation_config(config)
self.moderation_config = self.init_moderation_config()


@staticmethod @staticmethod
def init_azure_openai(app_config: Config) -> HostingProvider:
def init_azure_openai() -> HostingProvider:
quota_unit = QuotaUnit.TIMES quota_unit = QuotaUnit.TIMES
if app_config.get("HOSTED_AZURE_OPENAI_ENABLED"):
if dify_config.HOSTED_AZURE_OPENAI_ENABLED:
credentials = { credentials = {
"openai_api_key": app_config.get("HOSTED_AZURE_OPENAI_API_KEY"),
"openai_api_base": app_config.get("HOSTED_AZURE_OPENAI_API_BASE"),
"openai_api_key": dify_config.HOSTED_AZURE_OPENAI_API_KEY,
"openai_api_base": dify_config.HOSTED_AZURE_OPENAI_API_BASE,
"base_model_name": "gpt-35-turbo", "base_model_name": "gpt-35-turbo",
} }


quotas = [] quotas = []
hosted_quota_limit = int(app_config.get("HOSTED_AZURE_OPENAI_QUOTA_LIMIT", "1000"))
hosted_quota_limit = dify_config.HOSTED_AZURE_OPENAI_QUOTA_LIMIT
trial_quota = TrialHostingQuota( trial_quota = TrialHostingQuota(
quota_limit=hosted_quota_limit, quota_limit=hosted_quota_limit,
restrict_models=[ restrict_models=[
quota_unit=quota_unit, quota_unit=quota_unit,
) )


def init_openai(self, app_config: Config) -> HostingProvider:
def init_openai(self) -> HostingProvider:
quota_unit = QuotaUnit.CREDITS quota_unit = QuotaUnit.CREDITS
quotas = [] quotas = []


if app_config.get("HOSTED_OPENAI_TRIAL_ENABLED"):
hosted_quota_limit = int(app_config.get("HOSTED_OPENAI_QUOTA_LIMIT", "200"))
trial_models = self.parse_restrict_models_from_env(app_config, "HOSTED_OPENAI_TRIAL_MODELS")
if dify_config.HOSTED_OPENAI_TRIAL_ENABLED:
hosted_quota_limit = dify_config.HOSTED_OPENAI_QUOTA_LIMIT
trial_models = self.parse_restrict_models_from_env("HOSTED_OPENAI_TRIAL_MODELS")
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trial_models) trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trial_models)
quotas.append(trial_quota) quotas.append(trial_quota)


if app_config.get("HOSTED_OPENAI_PAID_ENABLED"):
paid_models = self.parse_restrict_models_from_env(app_config, "HOSTED_OPENAI_PAID_MODELS")
if dify_config.HOSTED_OPENAI_PAID_ENABLED:
paid_models = self.parse_restrict_models_from_env("HOSTED_OPENAI_PAID_MODELS")
paid_quota = PaidHostingQuota(restrict_models=paid_models) paid_quota = PaidHostingQuota(restrict_models=paid_models)
quotas.append(paid_quota) quotas.append(paid_quota)


if len(quotas) > 0: if len(quotas) > 0:
credentials = { credentials = {
"openai_api_key": app_config.get("HOSTED_OPENAI_API_KEY"),
"openai_api_key": dify_config.HOSTED_OPENAI_API_KEY,
} }


if app_config.get("HOSTED_OPENAI_API_BASE"):
credentials["openai_api_base"] = app_config.get("HOSTED_OPENAI_API_BASE")
if dify_config.HOSTED_OPENAI_API_BASE:
credentials["openai_api_base"] = dify_config.HOSTED_OPENAI_API_BASE


if app_config.get("HOSTED_OPENAI_API_ORGANIZATION"):
credentials["openai_organization"] = app_config.get("HOSTED_OPENAI_API_ORGANIZATION")
if dify_config.HOSTED_OPENAI_API_ORGANIZATION:
credentials["openai_organization"] = dify_config.HOSTED_OPENAI_API_ORGANIZATION


return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas) return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas)


) )


@staticmethod @staticmethod
def init_anthropic(app_config: Config) -> HostingProvider:
def init_anthropic() -> HostingProvider:
quota_unit = QuotaUnit.TOKENS quota_unit = QuotaUnit.TOKENS
quotas = [] quotas = []


if app_config.get("HOSTED_ANTHROPIC_TRIAL_ENABLED"):
hosted_quota_limit = int(app_config.get("HOSTED_ANTHROPIC_QUOTA_LIMIT", "0"))
if dify_config.HOSTED_ANTHROPIC_TRIAL_ENABLED:
hosted_quota_limit = dify_config.HOSTED_ANTHROPIC_QUOTA_LIMIT
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit) trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit)
quotas.append(trial_quota) quotas.append(trial_quota)


if app_config.get("HOSTED_ANTHROPIC_PAID_ENABLED"):
if dify_config.HOSTED_ANTHROPIC_PAID_ENABLED:
paid_quota = PaidHostingQuota() paid_quota = PaidHostingQuota()
quotas.append(paid_quota) quotas.append(paid_quota)


if len(quotas) > 0: if len(quotas) > 0:
credentials = { credentials = {
"anthropic_api_key": app_config.get("HOSTED_ANTHROPIC_API_KEY"),
"anthropic_api_key": dify_config.HOSTED_ANTHROPIC_API_KEY,
} }


if app_config.get("HOSTED_ANTHROPIC_API_BASE"):
credentials["anthropic_api_url"] = app_config.get("HOSTED_ANTHROPIC_API_BASE")
if dify_config.HOSTED_ANTHROPIC_API_BASE:
credentials["anthropic_api_url"] = dify_config.HOSTED_ANTHROPIC_API_BASE


return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas) return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas)


) )


@staticmethod @staticmethod
def init_minimax(app_config: Config) -> HostingProvider:
def init_minimax() -> HostingProvider:
quota_unit = QuotaUnit.TOKENS quota_unit = QuotaUnit.TOKENS
if app_config.get("HOSTED_MINIMAX_ENABLED"):
if dify_config.HOSTED_MINIMAX_ENABLED:
quotas = [FreeHostingQuota()] quotas = [FreeHostingQuota()]


return HostingProvider( return HostingProvider(
) )


@staticmethod @staticmethod
def init_spark(app_config: Config) -> HostingProvider:
def init_spark() -> HostingProvider:
quota_unit = QuotaUnit.TOKENS quota_unit = QuotaUnit.TOKENS
if app_config.get("HOSTED_SPARK_ENABLED"):
if dify_config.HOSTED_SPARK_ENABLED:
quotas = [FreeHostingQuota()] quotas = [FreeHostingQuota()]


return HostingProvider( return HostingProvider(
) )


@staticmethod @staticmethod
def init_zhipuai(app_config: Config) -> HostingProvider:
def init_zhipuai() -> HostingProvider:
quota_unit = QuotaUnit.TOKENS quota_unit = QuotaUnit.TOKENS
if app_config.get("HOSTED_ZHIPUAI_ENABLED"):
if dify_config.HOSTED_ZHIPUAI_ENABLED:
quotas = [FreeHostingQuota()] quotas = [FreeHostingQuota()]


return HostingProvider( return HostingProvider(
) )


@staticmethod @staticmethod
def init_moderation_config(app_config: Config) -> HostedModerationConfig:
if app_config.get("HOSTED_MODERATION_ENABLED") and app_config.get("HOSTED_MODERATION_PROVIDERS"):
return HostedModerationConfig(
enabled=True, providers=app_config.get("HOSTED_MODERATION_PROVIDERS").split(",")
)
def init_moderation_config() -> HostedModerationConfig:
if dify_config.HOSTED_MODERATION_ENABLED and dify_config.HOSTED_MODERATION_PROVIDERS:
return HostedModerationConfig(enabled=True, providers=dify_config.HOSTED_MODERATION_PROVIDERS.split(","))


return HostedModerationConfig(enabled=False) return HostedModerationConfig(enabled=False)


@staticmethod @staticmethod
def parse_restrict_models_from_env(app_config: Config, env_var: str) -> list[RestrictModel]:
models_str = app_config.get(env_var)
def parse_restrict_models_from_env(env_var: str) -> list[RestrictModel]:
models_str = dify_config.model_dump().get(env_var)
models_list = models_str.split(",") if models_str else [] models_list = models_str.split(",") if models_str else []
return [ return [
RestrictModel(model=model_name.strip(), model_type=ModelType.LLM) RestrictModel(model=model_name.strip(), model_type=ModelType.LLM)

+ 1
- 2
api/core/rag/datasource/vdb/qdrant/qdrant_vector.py View File

if not dataset.index_struct_dict: if not dataset.index_struct_dict:
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.QDRANT, collection_name)) dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.QDRANT, collection_name))


config = current_app.config
return QdrantVector( return QdrantVector(
collection_name=collection_name, collection_name=collection_name,
group_id=dataset.id, group_id=dataset.id,
config=QdrantConfig( config=QdrantConfig(
endpoint=dify_config.QDRANT_URL, endpoint=dify_config.QDRANT_URL,
api_key=dify_config.QDRANT_API_KEY, api_key=dify_config.QDRANT_API_KEY,
root_path=config.root_path,
root_path=current_app.config.root_path,
timeout=dify_config.QDRANT_CLIENT_TIMEOUT, timeout=dify_config.QDRANT_CLIENT_TIMEOUT,
grpc_port=dify_config.QDRANT_GRPC_PORT, grpc_port=dify_config.QDRANT_GRPC_PORT,
prefer_grpc=dify_config.QDRANT_GRPC_ENABLED, prefer_grpc=dify_config.QDRANT_GRPC_ENABLED,

+ 10
- 8
api/extensions/ext_celery.py View File

from celery import Celery, Task from celery import Celery, Task
from flask import Flask from flask import Flask


from configs import dify_config



def init_app(app: Flask) -> Celery: def init_app(app: Flask) -> Celery:
class FlaskTask(Task): class FlaskTask(Task):


broker_transport_options = {} broker_transport_options = {}


if app.config.get("CELERY_USE_SENTINEL"):
if dify_config.CELERY_USE_SENTINEL:
broker_transport_options = { broker_transport_options = {
"master_name": app.config.get("CELERY_SENTINEL_MASTER_NAME"),
"master_name": dify_config.CELERY_SENTINEL_MASTER_NAME,
"sentinel_kwargs": { "sentinel_kwargs": {
"socket_timeout": app.config.get("CELERY_SENTINEL_SOCKET_TIMEOUT", 0.1),
"socket_timeout": dify_config.CELERY_SENTINEL_SOCKET_TIMEOUT,
}, },
} }


celery_app = Celery( celery_app = Celery(
app.name, app.name,
task_cls=FlaskTask, task_cls=FlaskTask,
broker=app.config.get("CELERY_BROKER_URL"),
backend=app.config.get("CELERY_BACKEND"),
broker=dify_config.CELERY_BROKER_URL,
backend=dify_config.CELERY_BACKEND,
task_ignore_result=True, task_ignore_result=True,
) )


} }


celery_app.conf.update( celery_app.conf.update(
result_backend=app.config.get("CELERY_RESULT_BACKEND"),
result_backend=dify_config.CELERY_RESULT_BACKEND,
broker_transport_options=broker_transport_options, broker_transport_options=broker_transport_options,
broker_connection_retry_on_startup=True, broker_connection_retry_on_startup=True,
) )


if app.config.get("BROKER_USE_SSL"):
if dify_config.BROKER_USE_SSL:
celery_app.conf.update( celery_app.conf.update(
broker_use_ssl=ssl_options, # Add the SSL options to the broker configuration broker_use_ssl=ssl_options, # Add the SSL options to the broker configuration
) )
"schedule.clean_embedding_cache_task", "schedule.clean_embedding_cache_task",
"schedule.clean_unused_datasets_task", "schedule.clean_unused_datasets_task",
] ]
day = app.config.get("CELERY_BEAT_SCHEDULER_TIME")
day = dify_config.CELERY_BEAT_SCHEDULER_TIME
beat_schedule = { beat_schedule = {
"clean_embedding_cache_task": { "clean_embedding_cache_task": {
"task": "schedule.clean_embedding_cache_task.clean_embedding_cache_task", "task": "schedule.clean_embedding_cache_task.clean_embedding_cache_task",

+ 3
- 1
api/extensions/ext_compress.py View File

from flask import Flask from flask import Flask


from configs import dify_config



def init_app(app: Flask): def init_app(app: Flask):
if app.config.get("API_COMPRESSION_ENABLED"):
if dify_config.API_COMPRESSION_ENABLED:
from flask_compress import Compress from flask_compress import Compress


app.config["COMPRESS_MIMETYPES"] = [ app.config["COMPRESS_MIMETYPES"] = [

+ 7
- 5
api/extensions/ext_logging.py View File



from flask import Flask from flask import Flask


from configs import dify_config



def init_app(app: Flask): def init_app(app: Flask):
log_handlers = None log_handlers = None
log_file = app.config.get("LOG_FILE")
log_file = dify_config.LOG_FILE
if log_file: if log_file:
log_dir = os.path.dirname(log_file) log_dir = os.path.dirname(log_file)
os.makedirs(log_dir, exist_ok=True) os.makedirs(log_dir, exist_ok=True)
] ]


logging.basicConfig( logging.basicConfig(
level=app.config.get("LOG_LEVEL"),
format=app.config.get("LOG_FORMAT"),
datefmt=app.config.get("LOG_DATEFORMAT"),
level=dify_config.LOG_LEVEL,
format=dify_config.LOG_FORMAT,
datefmt=dify_config.LOG_DATEFORMAT,
handlers=log_handlers, handlers=log_handlers,
force=True, force=True,
) )
log_tz = app.config.get("LOG_TZ")
log_tz = dify_config.LOG_TZ
if log_tz: if log_tz:
from datetime import datetime from datetime import datetime



+ 25
- 20
api/extensions/ext_mail.py View File

import resend import resend
from flask import Flask from flask import Flask


from configs import dify_config



class Mail: class Mail:
def __init__(self): def __init__(self):
return self._client is not None return self._client is not None


def init_app(self, app: Flask): def init_app(self, app: Flask):
if app.config.get("MAIL_TYPE"):
if app.config.get("MAIL_DEFAULT_SEND_FROM"):
self._default_send_from = app.config.get("MAIL_DEFAULT_SEND_FROM")
mail_type = dify_config.MAIL_TYPE
if not mail_type:
logging.warning("MAIL_TYPE is not set")
return


if app.config.get("MAIL_TYPE") == "resend":
api_key = app.config.get("RESEND_API_KEY")
if dify_config.MAIL_DEFAULT_SEND_FROM:
self._default_send_from = dify_config.MAIL_DEFAULT_SEND_FROM

match mail_type:
case "resend":
api_key = dify_config.RESEND_API_KEY
if not api_key: if not api_key:
raise ValueError("RESEND_API_KEY is not set") raise ValueError("RESEND_API_KEY is not set")


api_url = app.config.get("RESEND_API_URL")
api_url = dify_config.RESEND_API_URL
if api_url: if api_url:
resend.api_url = api_url resend.api_url = api_url


resend.api_key = api_key resend.api_key = api_key
self._client = resend.Emails self._client = resend.Emails
elif app.config.get("MAIL_TYPE") == "smtp":
case "smtp":
from libs.smtp import SMTPClient from libs.smtp import SMTPClient


if not app.config.get("SMTP_SERVER") or not app.config.get("SMTP_PORT"):
if not dify_config.SMTP_SERVER or not dify_config.SMTP_PORT:
raise ValueError("SMTP_SERVER and SMTP_PORT are required for smtp mail type") raise ValueError("SMTP_SERVER and SMTP_PORT are required for smtp mail type")
if not app.config.get("SMTP_USE_TLS") and app.config.get("SMTP_OPPORTUNISTIC_TLS"):
if not dify_config.SMTP_USE_TLS and dify_config.SMTP_OPPORTUNISTIC_TLS:
raise ValueError("SMTP_OPPORTUNISTIC_TLS is not supported without enabling SMTP_USE_TLS") raise ValueError("SMTP_OPPORTUNISTIC_TLS is not supported without enabling SMTP_USE_TLS")
self._client = SMTPClient( self._client = SMTPClient(
server=app.config.get("SMTP_SERVER"),
port=app.config.get("SMTP_PORT"),
username=app.config.get("SMTP_USERNAME"),
password=app.config.get("SMTP_PASSWORD"),
_from=app.config.get("MAIL_DEFAULT_SEND_FROM"),
use_tls=app.config.get("SMTP_USE_TLS"),
opportunistic_tls=app.config.get("SMTP_OPPORTUNISTIC_TLS"),
server=dify_config.SMTP_SERVER,
port=dify_config.SMTP_PORT,
username=dify_config.SMTP_USERNAME,
password=dify_config.SMTP_PASSWORD,
_from=dify_config.MAIL_DEFAULT_SEND_FROM,
use_tls=dify_config.SMTP_USE_TLS,
opportunistic_tls=dify_config.SMTP_OPPORTUNISTIC_TLS,
) )
else:
raise ValueError("Unsupported mail type {}".format(app.config.get("MAIL_TYPE")))
else:
logging.warning("MAIL_TYPE is not set")
case _:
raise ValueError("Unsupported mail type {}".format(mail_type))


def send(self, to: str, subject: str, html: str, from_: Optional[str] = None): def send(self, to: str, subject: str, html: str, from_: Optional[str] = None):
if not self._client: if not self._client:

+ 14
- 12
api/extensions/ext_redis.py View File

from redis.connection import Connection, SSLConnection from redis.connection import Connection, SSLConnection
from redis.sentinel import Sentinel from redis.sentinel import Sentinel


from configs import dify_config



class RedisClientWrapper(redis.Redis): class RedisClientWrapper(redis.Redis):
""" """
def init_app(app): def init_app(app):
global redis_client global redis_client
connection_class = Connection connection_class = Connection
if app.config.get("REDIS_USE_SSL"):
if dify_config.REDIS_USE_SSL:
connection_class = SSLConnection connection_class = SSLConnection


redis_params = { redis_params = {
"username": app.config.get("REDIS_USERNAME"),
"password": app.config.get("REDIS_PASSWORD"),
"db": app.config.get("REDIS_DB"),
"username": dify_config.REDIS_USERNAME,
"password": dify_config.REDIS_PASSWORD,
"db": dify_config.REDIS_DB,
"encoding": "utf-8", "encoding": "utf-8",
"encoding_errors": "strict", "encoding_errors": "strict",
"decode_responses": False, "decode_responses": False,
} }


if app.config.get("REDIS_USE_SENTINEL"):
if dify_config.REDIS_USE_SENTINEL:
sentinel_hosts = [ sentinel_hosts = [
(node.split(":")[0], int(node.split(":")[1])) for node in app.config.get("REDIS_SENTINELS").split(",")
(node.split(":")[0], int(node.split(":")[1])) for node in dify_config.REDIS_SENTINELS.split(",")
] ]
sentinel = Sentinel( sentinel = Sentinel(
sentinel_hosts, sentinel_hosts,
sentinel_kwargs={ sentinel_kwargs={
"socket_timeout": app.config.get("REDIS_SENTINEL_SOCKET_TIMEOUT", 0.1),
"username": app.config.get("REDIS_SENTINEL_USERNAME"),
"password": app.config.get("REDIS_SENTINEL_PASSWORD"),
"socket_timeout": dify_config.REDIS_SENTINEL_SOCKET_TIMEOUT,
"username": dify_config.REDIS_SENTINEL_USERNAME,
"password": dify_config.REDIS_SENTINEL_PASSWORD,
}, },
) )
master = sentinel.master_for(app.config.get("REDIS_SENTINEL_SERVICE_NAME"), **redis_params)
master = sentinel.master_for(dify_config.REDIS_SENTINEL_SERVICE_NAME, **redis_params)
redis_client.initialize(master) redis_client.initialize(master)
else: else:
redis_params.update( redis_params.update(
{ {
"host": app.config.get("REDIS_HOST"),
"port": app.config.get("REDIS_PORT"),
"host": dify_config.REDIS_HOST,
"port": dify_config.REDIS_PORT,
"connection_class": connection_class, "connection_class": connection_class,
} }
) )

+ 7
- 6
api/extensions/ext_sentry.py View File

from sentry_sdk.integrations.flask import FlaskIntegration from sentry_sdk.integrations.flask import FlaskIntegration
from werkzeug.exceptions import HTTPException from werkzeug.exceptions import HTTPException


from configs import dify_config
from core.model_runtime.errors.invoke import InvokeRateLimitError from core.model_runtime.errors.invoke import InvokeRateLimitError








def init_app(app): def init_app(app):
if app.config.get("SENTRY_DSN"):
if dify_config.SENTRY_DSN:
sentry_sdk.init( sentry_sdk.init(
dsn=app.config.get("SENTRY_DSN"),
dsn=dify_config.SENTRY_DSN,
integrations=[FlaskIntegration(), CeleryIntegration()], integrations=[FlaskIntegration(), CeleryIntegration()],
ignore_errors=[ ignore_errors=[
HTTPException, HTTPException,
InvokeRateLimitError, InvokeRateLimitError,
parse_error.defaultErrorResponse, parse_error.defaultErrorResponse,
], ],
traces_sample_rate=app.config.get("SENTRY_TRACES_SAMPLE_RATE", 1.0),
profiles_sample_rate=app.config.get("SENTRY_PROFILES_SAMPLE_RATE", 1.0),
environment=app.config.get("DEPLOY_ENV"),
release=f"dify-{app.config.get('CURRENT_VERSION')}-{app.config.get('COMMIT_SHA')}",
traces_sample_rate=dify_config.SENTRY_TRACES_SAMPLE_RATE,
profiles_sample_rate=dify_config.SENTRY_PROFILES_SAMPLE_RATE,
environment=dify_config.DEPLOY_ENV,
release=f"dify-{dify_config.CURRENT_VERSION}-{dify_config.COMMIT_SHA}",
before_send=before_send, before_send=before_send,
) )

+ 1
- 1
api/extensions/ext_storage.py View File



def init_app(self, app: Flask): def init_app(self, app: Flask):
storage_factory = self.get_storage_factory(dify_config.STORAGE_TYPE) storage_factory = self.get_storage_factory(dify_config.STORAGE_TYPE)
self.storage_runner = storage_factory(app=app)
self.storage_runner = storage_factory()


@staticmethod @staticmethod
def get_storage_factory(storage_type: str) -> type[BaseStorage]: def get_storage_factory(storage_type: str) -> type[BaseStorage]:

+ 9
- 11
api/extensions/storage/aliyun_oss_storage.py View File

from collections.abc import Generator from collections.abc import Generator


import oss2 as aliyun_s3 import oss2 as aliyun_s3
from flask import Flask


from configs import dify_config
from extensions.storage.base_storage import BaseStorage from extensions.storage.base_storage import BaseStorage




class AliyunOssStorage(BaseStorage): class AliyunOssStorage(BaseStorage):
"""Implementation for Aliyun OSS storage.""" """Implementation for Aliyun OSS storage."""


def __init__(self, app: Flask):
super().__init__(app)

app_config = self.app.config
self.bucket_name = app_config.get("ALIYUN_OSS_BUCKET_NAME")
self.folder = app.config.get("ALIYUN_OSS_PATH")
def __init__(self):
super().__init__()
self.bucket_name = dify_config.ALIYUN_OSS_BUCKET_NAME
self.folder = dify_config.ALIYUN_OSS_PATH
oss_auth_method = aliyun_s3.Auth oss_auth_method = aliyun_s3.Auth
region = None region = None
if app_config.get("ALIYUN_OSS_AUTH_VERSION") == "v4":
if dify_config.ALIYUN_OSS_AUTH_VERSION == "v4":
oss_auth_method = aliyun_s3.AuthV4 oss_auth_method = aliyun_s3.AuthV4
region = app_config.get("ALIYUN_OSS_REGION")
oss_auth = oss_auth_method(app_config.get("ALIYUN_OSS_ACCESS_KEY"), app_config.get("ALIYUN_OSS_SECRET_KEY"))
region = dify_config.ALIYUN_OSS_REGION
oss_auth = oss_auth_method(dify_config.ALIYUN_OSS_ACCESS_KEY, dify_config.ALIYUN_OSS_SECRET_KEY)
self.client = aliyun_s3.Bucket( self.client = aliyun_s3.Bucket(
oss_auth, oss_auth,
app_config.get("ALIYUN_OSS_ENDPOINT"),
dify_config.ALIYUN_OSS_ENDPOINT,
self.bucket_name, self.bucket_name,
connect_timeout=30, connect_timeout=30,
region=region, region=region,

+ 7
- 8
api/extensions/storage/azure_blob_storage.py View File

from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone


from azure.storage.blob import AccountSasPermissions, BlobServiceClient, ResourceTypes, generate_account_sas from azure.storage.blob import AccountSasPermissions, BlobServiceClient, ResourceTypes, generate_account_sas
from flask import Flask


from configs import dify_config
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from extensions.storage.base_storage import BaseStorage from extensions.storage.base_storage import BaseStorage


class AzureBlobStorage(BaseStorage): class AzureBlobStorage(BaseStorage):
"""Implementation for Azure Blob storage.""" """Implementation for Azure Blob storage."""


def __init__(self, app: Flask):
super().__init__(app)
app_config = self.app.config
self.bucket_name = app_config.get("AZURE_BLOB_CONTAINER_NAME")
self.account_url = app_config.get("AZURE_BLOB_ACCOUNT_URL")
self.account_name = app_config.get("AZURE_BLOB_ACCOUNT_NAME")
self.account_key = app_config.get("AZURE_BLOB_ACCOUNT_KEY")
def __init__(self):
super().__init__()
self.bucket_name = dify_config.AZURE_BLOB_CONTAINER_NAME
self.account_url = dify_config.AZURE_BLOB_ACCOUNT_URL
self.account_name = dify_config.AZURE_BLOB_ACCOUNT_NAME
self.account_key = dify_config.AZURE_BLOB_ACCOUNT_KEY


def save(self, filename, data): def save(self, filename, data):
client = self._sync_client() client = self._sync_client()

+ 7
- 8
api/extensions/storage/baidu_obs_storage.py View File

from baidubce.auth.bce_credentials import BceCredentials from baidubce.auth.bce_credentials import BceCredentials
from baidubce.bce_client_configuration import BceClientConfiguration from baidubce.bce_client_configuration import BceClientConfiguration
from baidubce.services.bos.bos_client import BosClient from baidubce.services.bos.bos_client import BosClient
from flask import Flask


from configs import dify_config
from extensions.storage.base_storage import BaseStorage from extensions.storage.base_storage import BaseStorage




class BaiduObsStorage(BaseStorage): class BaiduObsStorage(BaseStorage):
"""Implementation for Baidu OBS storage.""" """Implementation for Baidu OBS storage."""


def __init__(self, app: Flask):
super().__init__(app)
app_config = self.app.config
self.bucket_name = app_config.get("BAIDU_OBS_BUCKET_NAME")
def __init__(self):
super().__init__()
self.bucket_name = dify_config.BAIDU_OBS_BUCKET_NAME
client_config = BceClientConfiguration( client_config = BceClientConfiguration(
credentials=BceCredentials( credentials=BceCredentials(
access_key_id=app_config.get("BAIDU_OBS_ACCESS_KEY"),
secret_access_key=app_config.get("BAIDU_OBS_SECRET_KEY"),
access_key_id=dify_config.BAIDU_OBS_ACCESS_KEY,
secret_access_key=dify_config.BAIDU_OBS_SECRET_KEY,
), ),
endpoint=app_config.get("BAIDU_OBS_ENDPOINT"),
endpoint=dify_config.BAIDU_OBS_ENDPOINT,
) )


self.client = BosClient(config=client_config) self.client = BosClient(config=client_config)

+ 2
- 6
api/extensions/storage/base_storage.py View File

from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Generator from collections.abc import Generator


from flask import Flask



class BaseStorage(ABC): class BaseStorage(ABC):
"""Interface for file storage.""" """Interface for file storage."""


app = None

def __init__(self, app: Flask):
self.app = app
def __init__(self): # noqa: B027
pass


@abstractmethod @abstractmethod
def save(self, filename, data): def save(self, filename, data):

+ 6
- 6
api/extensions/storage/google_cloud_storage.py View File

import json import json
from collections.abc import Generator from collections.abc import Generator


from flask import Flask
from google.cloud import storage as google_cloud_storage from google.cloud import storage as google_cloud_storage


from configs import dify_config
from extensions.storage.base_storage import BaseStorage from extensions.storage.base_storage import BaseStorage




class GoogleCloudStorage(BaseStorage): class GoogleCloudStorage(BaseStorage):
"""Implementation for Google Cloud storage.""" """Implementation for Google Cloud storage."""


def __init__(self, app: Flask):
super().__init__(app)
app_config = self.app.config
self.bucket_name = app_config.get("GOOGLE_STORAGE_BUCKET_NAME")
service_account_json_str = app_config.get("GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64")
def __init__(self):
super().__init__()
self.bucket_name = dify_config.GOOGLE_STORAGE_BUCKET_NAME
service_account_json_str = dify_config.GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64
# if service_account_json_str is empty, use Application Default Credentials # if service_account_json_str is empty, use Application Default Credentials
if service_account_json_str: if service_account_json_str:
service_account_json = base64.b64decode(service_account_json_str).decode("utf-8") service_account_json = base64.b64decode(service_account_json_str).decode("utf-8")

+ 8
- 8
api/extensions/storage/huawei_obs_storage.py View File

from collections.abc import Generator from collections.abc import Generator


from flask import Flask
from obs import ObsClient from obs import ObsClient


from configs import dify_config
from extensions.storage.base_storage import BaseStorage from extensions.storage.base_storage import BaseStorage




class HuaweiObsStorage(BaseStorage): class HuaweiObsStorage(BaseStorage):
"""Implementation for Huawei OBS storage.""" """Implementation for Huawei OBS storage."""


def __init__(self, app: Flask):
super().__init__(app)
app_config = self.app.config
self.bucket_name = app_config.get("HUAWEI_OBS_BUCKET_NAME")
def __init__(self):
super().__init__()
self.bucket_name = dify_config.HUAWEI_OBS_BUCKET_NAME
self.client = ObsClient( self.client = ObsClient(
access_key_id=app_config.get("HUAWEI_OBS_ACCESS_KEY"),
secret_access_key=app_config.get("HUAWEI_OBS_SECRET_KEY"),
server=app_config.get("HUAWEI_OBS_SERVER"),
access_key_id=dify_config.HUAWEI_OBS_ACCESS_KEY,
secret_access_key=dify_config.HUAWEI_OBS_SECRET_KEY,
server=dify_config.HUAWEI_OBS_SERVER,
) )


def save(self, filename, data): def save(self, filename, data):

+ 6
- 5
api/extensions/storage/local_fs_storage.py View File

from collections.abc import Generator from collections.abc import Generator
from pathlib import Path from pathlib import Path


from flask import Flask
from flask import current_app


from configs import dify_config
from extensions.storage.base_storage import BaseStorage from extensions.storage.base_storage import BaseStorage




class LocalFsStorage(BaseStorage): class LocalFsStorage(BaseStorage):
"""Implementation for local filesystem storage.""" """Implementation for local filesystem storage."""


def __init__(self, app: Flask):
super().__init__(app)
folder = self.app.config.get("STORAGE_LOCAL_PATH")
def __init__(self):
super().__init__()
folder = dify_config.STORAGE_LOCAL_PATH
if not os.path.isabs(folder): if not os.path.isabs(folder):
folder = os.path.join(app.root_path, folder)
folder = os.path.join(current_app.root_path, folder)
self.folder = folder self.folder = folder


def save(self, filename, data): def save(self, filename, data):

+ 9
- 9
api/extensions/storage/oracle_oci_storage.py View File



import boto3 import boto3
from botocore.exceptions import ClientError from botocore.exceptions import ClientError
from flask import Flask


from configs import dify_config
from extensions.storage.base_storage import BaseStorage from extensions.storage.base_storage import BaseStorage




class OracleOCIStorage(BaseStorage): class OracleOCIStorage(BaseStorage):
"""Implementation for Oracle OCI storage.""" """Implementation for Oracle OCI storage."""


def __init__(self, app: Flask):
super().__init__(app)
app_config = self.app.config
self.bucket_name = app_config.get("OCI_BUCKET_NAME")
def __init__(self):
super().__init__()
self.bucket_name = dify_config.OCI_BUCKET_NAME
self.client = boto3.client( self.client = boto3.client(
"s3", "s3",
aws_secret_access_key=app_config.get("OCI_SECRET_KEY"),
aws_access_key_id=app_config.get("OCI_ACCESS_KEY"),
endpoint_url=app_config.get("OCI_ENDPOINT"),
region_name=app_config.get("OCI_REGION"),
aws_secret_access_key=dify_config.OCI_SECRET_KEY,
aws_access_key_id=dify_config.OCI_ACCESS_KEY,
endpoint_url=dify_config.OCI_ENDPOINT,
region_name=dify_config.OCI_REGION,
) )


def save(self, filename, data): def save(self, filename, data):

+ 9
- 9
api/extensions/storage/tencent_cos_storage.py View File

from collections.abc import Generator from collections.abc import Generator


from flask import Flask
from qcloud_cos import CosConfig, CosS3Client from qcloud_cos import CosConfig, CosS3Client


from configs import dify_config
from extensions.storage.base_storage import BaseStorage from extensions.storage.base_storage import BaseStorage




class TencentCosStorage(BaseStorage): class TencentCosStorage(BaseStorage):
"""Implementation for Tencent Cloud COS storage.""" """Implementation for Tencent Cloud COS storage."""


def __init__(self, app: Flask):
super().__init__(app)
app_config = self.app.config
self.bucket_name = app_config.get("TENCENT_COS_BUCKET_NAME")
def __init__(self):
super().__init__()
self.bucket_name = dify_config.TENCENT_COS_BUCKET_NAME
config = CosConfig( config = CosConfig(
Region=app_config.get("TENCENT_COS_REGION"),
SecretId=app_config.get("TENCENT_COS_SECRET_ID"),
SecretKey=app_config.get("TENCENT_COS_SECRET_KEY"),
Scheme=app_config.get("TENCENT_COS_SCHEME"),
Region=dify_config.TENCENT_COS_REGION,
SecretId=dify_config.TENCENT_COS_SECRET_ID,
SecretKey=dify_config.TENCENT_COS_SECRET_KEY,
Scheme=dify_config.TENCENT_COS_SCHEME,
) )
self.client = CosS3Client(config) self.client = CosS3Client(config)



+ 8
- 9
api/extensions/storage/volcengine_tos_storage.py View File

from collections.abc import Generator from collections.abc import Generator


import tos import tos
from flask import Flask


from configs import dify_config
from extensions.storage.base_storage import BaseStorage from extensions.storage.base_storage import BaseStorage




class VolcengineTosStorage(BaseStorage): class VolcengineTosStorage(BaseStorage):
"""Implementation for Volcengine TOS storage.""" """Implementation for Volcengine TOS storage."""


def __init__(self, app: Flask):
super().__init__(app)
app_config = self.app.config
self.bucket_name = app_config.get("VOLCENGINE_TOS_BUCKET_NAME")
def __init__(self):
super().__init__()
self.bucket_name = dify_config.VOLCENGINE_TOS_BUCKET_NAME
self.client = tos.TosClientV2( self.client = tos.TosClientV2(
ak=app_config.get("VOLCENGINE_TOS_ACCESS_KEY"),
sk=app_config.get("VOLCENGINE_TOS_SECRET_KEY"),
endpoint=app_config.get("VOLCENGINE_TOS_ENDPOINT"),
region=app_config.get("VOLCENGINE_TOS_REGION"),
ak=dify_config.VOLCENGINE_TOS_ACCESS_KEY,
sk=dify_config.VOLCENGINE_TOS_SECRET_KEY,
endpoint=dify_config.VOLCENGINE_TOS_ENDPOINT,
region=dify_config.VOLCENGINE_TOS_REGION,
) )


def save(self, filename, data): def save(self, filename, data):

+ 3
- 2
api/libs/helper.py View File

from typing import Any, Optional, Union from typing import Any, Optional, Union
from zoneinfo import available_timezones from zoneinfo import available_timezones


from flask import Response, current_app, stream_with_context
from flask import Response, stream_with_context
from flask_restful import fields from flask_restful import fields


from configs import dify_config
from core.app.features.rate_limiting.rate_limit import RateLimitGenerator from core.app.features.rate_limiting.rate_limit import RateLimitGenerator
from core.file import helpers as file_helpers from core.file import helpers as file_helpers
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
if additional_data: if additional_data:
token_data.update(additional_data) token_data.update(additional_data)


expiry_minutes = current_app.config[f"{token_type.upper()}_TOKEN_EXPIRY_MINUTES"]
expiry_minutes = dify_config.model_dump().get(f"{token_type.upper()}_TOKEN_EXPIRY_MINUTES")
token_key = cls._get_token_key(token, token_type) token_key = cls._get_token_key(token, token_type)
expiry_time = int(expiry_minutes * 60) expiry_time = int(expiry_minutes * 60)
redis_client.setex(token_key, expiry_time, json.dumps(token_data)) redis_client.setex(token_key, expiry_time, json.dumps(token_data))

+ 5
- 6
api/libs/login.py View File

import os
from functools import wraps from functools import wraps


from flask import current_app, g, has_request_context, request from flask import current_app, g, has_request_context, request
from werkzeug.exceptions import Unauthorized from werkzeug.exceptions import Unauthorized
from werkzeug.local import LocalProxy from werkzeug.local import LocalProxy


from configs import dify_config
from extensions.ext_database import db from extensions.ext_database import db
from models.account import Account, Tenant, TenantAccountJoin from models.account import Account, Tenant, TenantAccountJoin


@wraps(func) @wraps(func)
def decorated_view(*args, **kwargs): def decorated_view(*args, **kwargs):
auth_header = request.headers.get("Authorization") auth_header = request.headers.get("Authorization")
admin_api_key_enable = os.getenv("ADMIN_API_KEY_ENABLE", default="False")
if admin_api_key_enable.lower() == "true":
if dify_config.ADMIN_API_KEY_ENABLE:
if auth_header: if auth_header:
if " " not in auth_header: if " " not in auth_header:
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.") raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
auth_scheme = auth_scheme.lower() auth_scheme = auth_scheme.lower()
if auth_scheme != "bearer": if auth_scheme != "bearer":
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.") raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
admin_api_key = os.getenv("ADMIN_API_KEY")


admin_api_key = dify_config.ADMIN_API_KEY
if admin_api_key: if admin_api_key:
if os.getenv("ADMIN_API_KEY") == auth_token:
if admin_api_key == auth_token:
workspace_id = request.headers.get("X-WORKSPACE-ID") workspace_id = request.headers.get("X-WORKSPACE-ID")
if workspace_id: if workspace_id:
tenant_account_join = ( tenant_account_join = (
account.current_tenant = tenant account.current_tenant = tenant
current_app.login_manager._update_request_context_with_user(account) current_app.login_manager._update_request_context_with_user(account)
user_logged_in.send(current_app._get_current_object(), user=_get_user()) user_logged_in.send(current_app._get_current_object(), user=_get_user())
if request.method in EXEMPT_METHODS or current_app.config.get("LOGIN_DISABLED"):
if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED:
pass pass
elif not current_user.is_authenticated: elif not current_user.is_authenticated:
return current_app.login_manager.unauthorized() return current_app.login_manager.unauthorized()

+ 2
- 1
api/tests/integration_tests/controllers/app_fixture.py View File

import pytest import pytest


from app_factory import create_app from app_factory import create_app
from configs import dify_config


mock_user = type( mock_user = type(
"MockUser", "MockUser",
@pytest.fixture @pytest.fixture
def app(): def app():
app = create_app() app = create_app()
app.config["LOGIN_DISABLED"] = True
dify_config.LOGIN_DISABLED = True
return app return app

+ 1
- 1
api/tests/unit_tests/oss/volcengine_tos/test_volcengine_tos.py View File

return cls._instance return cls._instance


def __init__(self): def __init__(self):
self.storage = VolcengineTosStorage(app=Flask(__name__))
self.storage = VolcengineTosStorage()
self.storage.bucket_name = get_example_bucket() self.storage.bucket_name = get_example_bucket()
self.storage.client = TosClientV2( self.storage.client = TosClientV2(
ak="dify", ak="dify",

Loading…
Cancel
Save