| import time | import time | ||||
| from configs import dify_config | from configs import dify_config | ||||
| from contexts.wrapper import RecyclableContextVar | |||||
| from dify_app import DifyApp | from dify_app import DifyApp | ||||
| dify_app = DifyApp(__name__) | dify_app = DifyApp(__name__) | ||||
| dify_app.config.from_mapping(dify_config.model_dump()) | dify_app.config.from_mapping(dify_config.model_dump()) | ||||
| # add before request hook | |||||
| @dify_app.before_request | |||||
| def before_request(): | |||||
| # add an unique identifier to each request | |||||
| RecyclableContextVar.increment_thread_recycles() | |||||
| return dify_app | return dify_app | ||||
| from threading import Lock | from threading import Lock | ||||
| from typing import TYPE_CHECKING | from typing import TYPE_CHECKING | ||||
| from contexts.wrapper import RecyclableContextVar | |||||
| if TYPE_CHECKING: | if TYPE_CHECKING: | ||||
| from core.plugin.entities.plugin_daemon import PluginModelProviderEntity | from core.plugin.entities.plugin_daemon import PluginModelProviderEntity | ||||
| from core.tools.plugin_tool.provider import PluginToolProviderController | from core.tools.plugin_tool.provider import PluginToolProviderController | ||||
| workflow_variable_pool: ContextVar["VariablePool"] = ContextVar("workflow_variable_pool") | workflow_variable_pool: ContextVar["VariablePool"] = ContextVar("workflow_variable_pool") | ||||
| plugin_tool_providers: ContextVar[dict[str, "PluginToolProviderController"]] = ContextVar("plugin_tool_providers") | |||||
| plugin_tool_providers_lock: ContextVar[Lock] = ContextVar("plugin_tool_providers_lock") | |||||
| """ | |||||
| To avoid race-conditions caused by gunicorn thread recycling, using RecyclableContextVar to replace with | |||||
| """ | |||||
| plugin_tool_providers: RecyclableContextVar[dict[str, "PluginToolProviderController"]] = RecyclableContextVar( | |||||
| ContextVar("plugin_tool_providers") | |||||
| ) | |||||
| plugin_tool_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(ContextVar("plugin_tool_providers_lock")) | |||||
| plugin_model_providers: ContextVar[list["PluginModelProviderEntity"] | None] = ContextVar("plugin_model_providers") | |||||
| plugin_model_providers_lock: ContextVar[Lock] = ContextVar("plugin_model_providers_lock") | |||||
| plugin_model_providers: RecyclableContextVar[list["PluginModelProviderEntity"] | None] = RecyclableContextVar( | |||||
| ContextVar("plugin_model_providers") | |||||
| ) | |||||
| plugin_model_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar( | |||||
| ContextVar("plugin_model_providers_lock") | |||||
| ) |
| from contextvars import ContextVar | |||||
| from typing import Generic, TypeVar | |||||
| T = TypeVar("T") | |||||
| class HiddenValue: | |||||
| pass | |||||
| _default = HiddenValue() | |||||
| class RecyclableContextVar(Generic[T]): | |||||
| """ | |||||
| RecyclableContextVar is a wrapper around ContextVar | |||||
| It's safe to use in gunicorn with thread recycling, but features like `reset` are not available for now | |||||
| NOTE: you need to call `increment_thread_recycles` before requests | |||||
| """ | |||||
| _thread_recycles: ContextVar[int] = ContextVar("thread_recycles") | |||||
| @classmethod | |||||
| def increment_thread_recycles(cls): | |||||
| try: | |||||
| recycles = cls._thread_recycles.get() | |||||
| cls._thread_recycles.set(recycles + 1) | |||||
| except LookupError: | |||||
| cls._thread_recycles.set(0) | |||||
| def __init__(self, context_var: ContextVar[T]): | |||||
| self._context_var = context_var | |||||
| self._updates = ContextVar[int](context_var.name + "_updates", default=0) | |||||
| def get(self, default: T | HiddenValue = _default) -> T: | |||||
| thread_recycles = self._thread_recycles.get(0) | |||||
| self_updates = self._updates.get() | |||||
| if thread_recycles > self_updates: | |||||
| self._updates.set(thread_recycles) | |||||
| # check if thread is recycled and should be updated | |||||
| if thread_recycles < self_updates: | |||||
| return self._context_var.get() | |||||
| else: | |||||
| # thread_recycles >= self_updates, means current context is invalid | |||||
| if isinstance(default, HiddenValue) or default is _default: | |||||
| raise LookupError | |||||
| else: | |||||
| return default | |||||
| def set(self, value: T): | |||||
| # it leads to a situation that self.updates is less than cls.thread_recycles if `set` was never called before | |||||
| # increase it manually | |||||
| thread_recycles = self._thread_recycles.get(0) | |||||
| self_updates = self._updates.get() | |||||
| if thread_recycles > self_updates: | |||||
| self._updates.set(thread_recycles) | |||||
| if self._updates.get() == self._thread_recycles.get(0): | |||||
| # after increment, | |||||
| self._updates.set(self._updates.get() + 1) | |||||
| # set the context | |||||
| self._context_var.set(value) |
| from enum import StrEnum | from enum import StrEnum | ||||
| from typing import Any, Optional, Union | from typing import Any, Optional, Union | ||||
| from pydantic import BaseModel | |||||
| from pydantic import BaseModel, Field | |||||
| from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType | from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType | ||||
| provider_type: ToolProviderType | provider_type: ToolProviderType | ||||
| provider_id: str | provider_id: str | ||||
| tool_name: str | tool_name: str | ||||
| tool_parameters: dict[str, Any] = {} | |||||
| tool_parameters: dict[str, Any] = Field(default_factory=dict) | |||||
| plugin_unique_identifier: str | None = None | plugin_unique_identifier: str | None = None | ||||
| from typing import Any | from typing import Any | ||||
| from core.app.app_config.entities import ModelConfigEntity | from core.app.app_config.entities import ModelConfigEntity | ||||
| from core.entities import DEFAULT_PLUGIN_ID | |||||
| from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType | from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType | ||||
| from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory | from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory | ||||
| from core.plugin.entities.plugin import ModelProviderID | |||||
| from core.provider_manager import ProviderManager | from core.provider_manager import ProviderManager | ||||
| raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}") | raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}") | ||||
| if "/" not in config["model"]["provider"]: | if "/" not in config["model"]["provider"]: | ||||
| config["model"]["provider"] = ( | |||||
| f"{DEFAULT_PLUGIN_ID}/{config['model']['provider']}/{config['model']['provider']}" | |||||
| ) | |||||
| config["model"]["provider"] = str(ModelProviderID(config["model"]["provider"])) | |||||
| if config["model"]["provider"] not in model_provider_names: | if config["model"]["provider"] not in model_provider_names: | ||||
| raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}") | raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}") |
| provider: str | provider: str | ||||
| model: str | model: str | ||||
| mode: Optional[str] = None | mode: Optional[str] = None | ||||
| parameters: dict[str, Any] = {} | |||||
| stop: list[str] = [] | |||||
| parameters: dict[str, Any] = Field(default_factory=dict) | |||||
| stop: list[str] = Field(default_factory=list) | |||||
| class AdvancedChatMessageEntity(BaseModel): | class AdvancedChatMessageEntity(BaseModel): | ||||
| variable: str | variable: str | ||||
| type: str | type: str | ||||
| config: dict[str, Any] = {} | |||||
| config: dict[str, Any] = Field(default_factory=dict) | |||||
| class DatasetRetrieveConfigEntity(BaseModel): | class DatasetRetrieveConfigEntity(BaseModel): | ||||
| """ | """ | ||||
| type: str | type: str | ||||
| config: dict[str, Any] = {} | |||||
| config: dict[str, Any] = Field(default_factory=dict) | |||||
| class TextToSpeechEntity(BaseModel): | class TextToSpeechEntity(BaseModel): |
| model_schema: AIModelEntity | model_schema: AIModelEntity | ||||
| mode: str | mode: str | ||||
| provider_model_bundle: ProviderModelBundle | provider_model_bundle: ProviderModelBundle | ||||
| credentials: dict[str, Any] = {} | |||||
| parameters: dict[str, Any] = {} | |||||
| stop: list[str] = [] | |||||
| credentials: dict[str, Any] = Field(default_factory=dict) | |||||
| parameters: dict[str, Any] = Field(default_factory=dict) | |||||
| stop: list[str] = Field(default_factory=list) | |||||
| # pydantic configs | # pydantic configs | ||||
| model_config = ConfigDict(protected_namespaces=()) | model_config = ConfigDict(protected_namespaces=()) | ||||
| call_depth: int = 0 | call_depth: int = 0 | ||||
| # extra parameters, like: auto_generate_conversation_name | # extra parameters, like: auto_generate_conversation_name | ||||
| extras: dict[str, Any] = {} | |||||
| extras: dict[str, Any] = Field(default_factory=dict) | |||||
| # tracing instance | # tracing instance | ||||
| trace_manager: Optional[TraceQueueManager] = None | trace_manager: Optional[TraceQueueManager] = None |
| from json import JSONDecodeError | from json import JSONDecodeError | ||||
| from typing import Optional | from typing import Optional | ||||
| from pydantic import BaseModel, ConfigDict | |||||
| from pydantic import BaseModel, ConfigDict, Field | |||||
| from sqlalchemy import or_ | from sqlalchemy import or_ | ||||
| from constants import HIDDEN_VALUE | from constants import HIDDEN_VALUE | ||||
| from core.entities import DEFAULT_PLUGIN_ID | |||||
| from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity | from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity | ||||
| from core.entities.provider_entities import ( | from core.entities.provider_entities import ( | ||||
| CustomConfiguration, | CustomConfiguration, | ||||
| """ | """ | ||||
| tenant_id: str | tenant_id: str | ||||
| configurations: dict[str, ProviderConfiguration] = {} | |||||
| configurations: dict[str, ProviderConfiguration] = Field(default_factory=dict) | |||||
| def __init__(self, tenant_id: str): | def __init__(self, tenant_id: str): | ||||
| super().__init__(tenant_id=tenant_id) | super().__init__(tenant_id=tenant_id) | ||||
| def __getitem__(self, key): | def __getitem__(self, key): | ||||
| if "/" not in key: | if "/" not in key: | ||||
| key = f"{DEFAULT_PLUGIN_ID}/{key}/{key}" | |||||
| key = str(ModelProviderID(key)) | |||||
| return self.configurations[key] | return self.configurations[key] | ||||
| def get(self, key, default=None) -> ProviderConfiguration | None: | def get(self, key, default=None) -> ProviderConfiguration | None: | ||||
| if "/" not in key: | if "/" not in key: | ||||
| key = f"{DEFAULT_PLUGIN_ID}/{key}/{key}" | |||||
| key = str(ModelProviderID(key)) | |||||
| return self.configurations.get(key, default) # type: ignore | return self.configurations.get(key, default) # type: ignore | ||||
| class HostingConfiguration: | class HostingConfiguration: | ||||
| provider_map: dict[str, HostingProvider] = {} | |||||
| provider_map: dict[str, HostingProvider] | |||||
| moderation_config: Optional[HostedModerationConfig] = None | moderation_config: Optional[HostedModerationConfig] = None | ||||
| def __init__(self) -> None: | |||||
| self.provider_map = {} | |||||
| self.moderation_config = None | |||||
| def init_app(self, app: Flask) -> None: | def init_app(self, app: Flask) -> None: | ||||
| if dify_config.EDITION != "CLOUD": | if dify_config.EDITION != "CLOUD": | ||||
| return | return |
| from pydantic import BaseModel | from pydantic import BaseModel | ||||
| import contexts | import contexts | ||||
| from core.entities import DEFAULT_PLUGIN_ID | |||||
| from core.helper.position_helper import get_provider_position_map, sort_to_dict_by_position_map | from core.helper.position_helper import get_provider_position_map, sort_to_dict_by_position_map | ||||
| from core.model_runtime.entities.model_entities import AIModelEntity, ModelType | from core.model_runtime.entities.model_entities import AIModelEntity, ModelType | ||||
| from core.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity | from core.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity | ||||
| class ModelProviderFactory: | class ModelProviderFactory: | ||||
| provider_position_map: dict[str, int] = {} | |||||
| provider_position_map: dict[str, int] | |||||
| def __init__(self, tenant_id: str) -> None: | def __init__(self, tenant_id: str) -> None: | ||||
| self.provider_position_map = {} | |||||
| self.tenant_id = tenant_id | self.tenant_id = tenant_id | ||||
| self.plugin_model_manager = PluginModelManager() | self.plugin_model_manager = PluginModelManager() | ||||
| :param provider: provider name | :param provider: provider name | ||||
| :return: plugin id and provider name | :return: plugin id and provider name | ||||
| """ | """ | ||||
| plugin_id = DEFAULT_PLUGIN_ID | |||||
| provider_name = provider | |||||
| if "/" in provider: | |||||
| # get the plugin_id before provider | |||||
| plugin_id = "/".join(provider.split("/")[:-1]) | |||||
| provider_name = provider.split("/")[-1] | |||||
| return str(plugin_id), provider_name | |||||
| provider_id = ModelProviderID(provider) | |||||
| return provider_id.plugin_id, provider_id.provider_name |
| from werkzeug.exceptions import NotFound | from werkzeug.exceptions import NotFound | ||||
| from configs import dify_config | from configs import dify_config | ||||
| from core.entities import DEFAULT_PLUGIN_ID | |||||
| from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError | from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError | ||||
| from core.model_manager import ModelManager | from core.model_manager import ModelManager | ||||
| from core.model_runtime.entities.model_entities import ModelType | from core.model_runtime.entities.model_entities import ModelType | ||||
| from core.plugin.entities.plugin import ModelProviderID | |||||
| from core.rag.index_processor.constant.index_type import IndexType | from core.rag.index_processor.constant.index_type import IndexType | ||||
| from core.rag.retrieval.retrieval_methods import RetrievalMethod | from core.rag.retrieval.retrieval_methods import RetrievalMethod | ||||
| from events.dataset_event import dataset_was_deleted | from events.dataset_event import dataset_was_deleted | ||||
| else: | else: | ||||
| # add default plugin id to both setting sets, to make sure the plugin model provider is consistent | # add default plugin id to both setting sets, to make sure the plugin model provider is consistent | ||||
| plugin_model_provider = dataset.embedding_model_provider | plugin_model_provider = dataset.embedding_model_provider | ||||
| if "/" not in plugin_model_provider: | |||||
| plugin_model_provider = f"{DEFAULT_PLUGIN_ID}/{plugin_model_provider}/{plugin_model_provider}" | |||||
| plugin_model_provider = str(ModelProviderID(plugin_model_provider)) | |||||
| new_plugin_model_provider = data["embedding_model_provider"] | new_plugin_model_provider = data["embedding_model_provider"] | ||||
| if "/" not in new_plugin_model_provider: | |||||
| new_plugin_model_provider = ( | |||||
| f"{DEFAULT_PLUGIN_ID}/{new_plugin_model_provider}/{new_plugin_model_provider}" | |||||
| ) | |||||
| new_plugin_model_provider = str(ModelProviderID(new_plugin_model_provider)) | |||||
| if ( | if ( | ||||
| new_plugin_model_provider != plugin_model_provider | new_plugin_model_provider != plugin_model_provider |