| @@ -2,6 +2,7 @@ import logging | |||
| import time | |||
| from configs import dify_config | |||
| from contexts.wrapper import RecyclableContextVar | |||
| from dify_app import DifyApp | |||
| @@ -16,6 +17,12 @@ def create_flask_app_with_configs() -> DifyApp: | |||
| dify_app = DifyApp(__name__) | |||
| dify_app.config.from_mapping(dify_config.model_dump()) | |||
| # add before request hook | |||
| @dify_app.before_request | |||
| def before_request(): | |||
| # add an unique identifier to each request | |||
| RecyclableContextVar.increment_thread_recycles() | |||
| return dify_app | |||
| @@ -2,6 +2,8 @@ from contextvars import ContextVar | |||
| from threading import Lock | |||
| from typing import TYPE_CHECKING | |||
| from contexts.wrapper import RecyclableContextVar | |||
| if TYPE_CHECKING: | |||
| from core.plugin.entities.plugin_daemon import PluginModelProviderEntity | |||
| from core.tools.plugin_tool.provider import PluginToolProviderController | |||
| @@ -12,8 +14,17 @@ tenant_id: ContextVar[str] = ContextVar("tenant_id") | |||
| workflow_variable_pool: ContextVar["VariablePool"] = ContextVar("workflow_variable_pool") | |||
| plugin_tool_providers: ContextVar[dict[str, "PluginToolProviderController"]] = ContextVar("plugin_tool_providers") | |||
| plugin_tool_providers_lock: ContextVar[Lock] = ContextVar("plugin_tool_providers_lock") | |||
| """ | |||
| To avoid race-conditions caused by gunicorn thread recycling, using RecyclableContextVar to replace with | |||
| """ | |||
| plugin_tool_providers: RecyclableContextVar[dict[str, "PluginToolProviderController"]] = RecyclableContextVar( | |||
| ContextVar("plugin_tool_providers") | |||
| ) | |||
| plugin_tool_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(ContextVar("plugin_tool_providers_lock")) | |||
| plugin_model_providers: ContextVar[list["PluginModelProviderEntity"] | None] = ContextVar("plugin_model_providers") | |||
| plugin_model_providers_lock: ContextVar[Lock] = ContextVar("plugin_model_providers_lock") | |||
| plugin_model_providers: RecyclableContextVar[list["PluginModelProviderEntity"] | None] = RecyclableContextVar( | |||
| ContextVar("plugin_model_providers") | |||
| ) | |||
| plugin_model_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar( | |||
| ContextVar("plugin_model_providers_lock") | |||
| ) | |||
| @@ -0,0 +1,65 @@ | |||
| from contextvars import ContextVar | |||
| from typing import Generic, TypeVar | |||
| T = TypeVar("T") | |||
| class HiddenValue: | |||
| pass | |||
| _default = HiddenValue() | |||
| class RecyclableContextVar(Generic[T]): | |||
| """ | |||
| RecyclableContextVar is a wrapper around ContextVar | |||
| It's safe to use in gunicorn with thread recycling, but features like `reset` are not available for now | |||
| NOTE: you need to call `increment_thread_recycles` before requests | |||
| """ | |||
| _thread_recycles: ContextVar[int] = ContextVar("thread_recycles") | |||
| @classmethod | |||
| def increment_thread_recycles(cls): | |||
| try: | |||
| recycles = cls._thread_recycles.get() | |||
| cls._thread_recycles.set(recycles + 1) | |||
| except LookupError: | |||
| cls._thread_recycles.set(0) | |||
| def __init__(self, context_var: ContextVar[T]): | |||
| self._context_var = context_var | |||
| self._updates = ContextVar[int](context_var.name + "_updates", default=0) | |||
| def get(self, default: T | HiddenValue = _default) -> T: | |||
| thread_recycles = self._thread_recycles.get(0) | |||
| self_updates = self._updates.get() | |||
| if thread_recycles > self_updates: | |||
| self._updates.set(thread_recycles) | |||
| # check if thread is recycled and should be updated | |||
| if thread_recycles < self_updates: | |||
| return self._context_var.get() | |||
| else: | |||
| # thread_recycles >= self_updates, means current context is invalid | |||
| if isinstance(default, HiddenValue) or default is _default: | |||
| raise LookupError | |||
| else: | |||
| return default | |||
| def set(self, value: T): | |||
| # it leads to a situation that self.updates is less than cls.thread_recycles if `set` was never called before | |||
| # increase it manually | |||
| thread_recycles = self._thread_recycles.get(0) | |||
| self_updates = self._updates.get() | |||
| if thread_recycles > self_updates: | |||
| self._updates.set(thread_recycles) | |||
| if self._updates.get() == self._thread_recycles.get(0): | |||
| # after increment, | |||
| self._updates.set(self._updates.get() + 1) | |||
| # set the context | |||
| self._context_var.set(value) | |||
| @@ -1,7 +1,7 @@ | |||
| from enum import StrEnum | |||
| from typing import Any, Optional, Union | |||
| from pydantic import BaseModel | |||
| from pydantic import BaseModel, Field | |||
| from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType | |||
| @@ -14,7 +14,7 @@ class AgentToolEntity(BaseModel): | |||
| provider_type: ToolProviderType | |||
| provider_id: str | |||
| tool_name: str | |||
| tool_parameters: dict[str, Any] = {} | |||
| tool_parameters: dict[str, Any] = Field(default_factory=dict) | |||
| plugin_unique_identifier: str | None = None | |||
| @@ -2,9 +2,9 @@ from collections.abc import Mapping | |||
| from typing import Any | |||
| from core.app.app_config.entities import ModelConfigEntity | |||
| from core.entities import DEFAULT_PLUGIN_ID | |||
| from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType | |||
| from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory | |||
| from core.plugin.entities.plugin import ModelProviderID | |||
| from core.provider_manager import ProviderManager | |||
| @@ -61,9 +61,7 @@ class ModelConfigManager: | |||
| raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}") | |||
| if "/" not in config["model"]["provider"]: | |||
| config["model"]["provider"] = ( | |||
| f"{DEFAULT_PLUGIN_ID}/{config['model']['provider']}/{config['model']['provider']}" | |||
| ) | |||
| config["model"]["provider"] = str(ModelProviderID(config["model"]["provider"])) | |||
| if config["model"]["provider"] not in model_provider_names: | |||
| raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}") | |||
| @@ -17,8 +17,8 @@ class ModelConfigEntity(BaseModel): | |||
| provider: str | |||
| model: str | |||
| mode: Optional[str] = None | |||
| parameters: dict[str, Any] = {} | |||
| stop: list[str] = [] | |||
| parameters: dict[str, Any] = Field(default_factory=dict) | |||
| stop: list[str] = Field(default_factory=list) | |||
| class AdvancedChatMessageEntity(BaseModel): | |||
| @@ -132,7 +132,7 @@ class ExternalDataVariableEntity(BaseModel): | |||
| variable: str | |||
| type: str | |||
| config: dict[str, Any] = {} | |||
| config: dict[str, Any] = Field(default_factory=dict) | |||
| class DatasetRetrieveConfigEntity(BaseModel): | |||
| @@ -188,7 +188,7 @@ class SensitiveWordAvoidanceEntity(BaseModel): | |||
| """ | |||
| type: str | |||
| config: dict[str, Any] = {} | |||
| config: dict[str, Any] = Field(default_factory=dict) | |||
| class TextToSpeechEntity(BaseModel): | |||
| @@ -63,9 +63,9 @@ class ModelConfigWithCredentialsEntity(BaseModel): | |||
| model_schema: AIModelEntity | |||
| mode: str | |||
| provider_model_bundle: ProviderModelBundle | |||
| credentials: dict[str, Any] = {} | |||
| parameters: dict[str, Any] = {} | |||
| stop: list[str] = [] | |||
| credentials: dict[str, Any] = Field(default_factory=dict) | |||
| parameters: dict[str, Any] = Field(default_factory=dict) | |||
| stop: list[str] = Field(default_factory=list) | |||
| # pydantic configs | |||
| model_config = ConfigDict(protected_namespaces=()) | |||
| @@ -94,7 +94,7 @@ class AppGenerateEntity(BaseModel): | |||
| call_depth: int = 0 | |||
| # extra parameters, like: auto_generate_conversation_name | |||
| extras: dict[str, Any] = {} | |||
| extras: dict[str, Any] = Field(default_factory=dict) | |||
| # tracing instance | |||
| trace_manager: Optional[TraceQueueManager] = None | |||
| @@ -6,11 +6,10 @@ from collections.abc import Iterator, Sequence | |||
| from json import JSONDecodeError | |||
| from typing import Optional | |||
| from pydantic import BaseModel, ConfigDict | |||
| from pydantic import BaseModel, ConfigDict, Field | |||
| from sqlalchemy import or_ | |||
| from constants import HIDDEN_VALUE | |||
| from core.entities import DEFAULT_PLUGIN_ID | |||
| from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity | |||
| from core.entities.provider_entities import ( | |||
| CustomConfiguration, | |||
| @@ -1004,7 +1003,7 @@ class ProviderConfigurations(BaseModel): | |||
| """ | |||
| tenant_id: str | |||
| configurations: dict[str, ProviderConfiguration] = {} | |||
| configurations: dict[str, ProviderConfiguration] = Field(default_factory=dict) | |||
| def __init__(self, tenant_id: str): | |||
| super().__init__(tenant_id=tenant_id) | |||
| @@ -1060,7 +1059,7 @@ class ProviderConfigurations(BaseModel): | |||
| def __getitem__(self, key): | |||
| if "/" not in key: | |||
| key = f"{DEFAULT_PLUGIN_ID}/{key}/{key}" | |||
| key = str(ModelProviderID(key)) | |||
| return self.configurations[key] | |||
| @@ -1075,7 +1074,7 @@ class ProviderConfigurations(BaseModel): | |||
| def get(self, key, default=None) -> ProviderConfiguration | None: | |||
| if "/" not in key: | |||
| key = f"{DEFAULT_PLUGIN_ID}/{key}/{key}" | |||
| key = str(ModelProviderID(key)) | |||
| return self.configurations.get(key, default) # type: ignore | |||
| @@ -41,9 +41,13 @@ class HostedModerationConfig(BaseModel): | |||
| class HostingConfiguration: | |||
| provider_map: dict[str, HostingProvider] = {} | |||
| provider_map: dict[str, HostingProvider] | |||
| moderation_config: Optional[HostedModerationConfig] = None | |||
| def __init__(self) -> None: | |||
| self.provider_map = {} | |||
| self.moderation_config = None | |||
| def init_app(self, app: Flask) -> None: | |||
| if dify_config.EDITION != "CLOUD": | |||
| return | |||
| @@ -7,7 +7,6 @@ from typing import Optional | |||
| from pydantic import BaseModel | |||
| import contexts | |||
| from core.entities import DEFAULT_PLUGIN_ID | |||
| from core.helper.position_helper import get_provider_position_map, sort_to_dict_by_position_map | |||
| from core.model_runtime.entities.model_entities import AIModelEntity, ModelType | |||
| from core.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity | |||
| @@ -34,9 +33,11 @@ class ModelProviderExtension(BaseModel): | |||
| class ModelProviderFactory: | |||
| provider_position_map: dict[str, int] = {} | |||
| provider_position_map: dict[str, int] | |||
| def __init__(self, tenant_id: str) -> None: | |||
| self.provider_position_map = {} | |||
| self.tenant_id = tenant_id | |||
| self.plugin_model_manager = PluginModelManager() | |||
| @@ -360,11 +361,5 @@ class ModelProviderFactory: | |||
| :param provider: provider name | |||
| :return: plugin id and provider name | |||
| """ | |||
| plugin_id = DEFAULT_PLUGIN_ID | |||
| provider_name = provider | |||
| if "/" in provider: | |||
| # get the plugin_id before provider | |||
| plugin_id = "/".join(provider.split("/")[:-1]) | |||
| provider_name = provider.split("/")[-1] | |||
| return str(plugin_id), provider_name | |||
| provider_id = ModelProviderID(provider) | |||
| return provider_id.plugin_id, provider_id.provider_name | |||
| @@ -13,10 +13,10 @@ from sqlalchemy.orm import Session | |||
| from werkzeug.exceptions import NotFound | |||
| from configs import dify_config | |||
| from core.entities import DEFAULT_PLUGIN_ID | |||
| from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError | |||
| from core.model_manager import ModelManager | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from core.plugin.entities.plugin import ModelProviderID | |||
| from core.rag.index_processor.constant.index_type import IndexType | |||
| from core.rag.retrieval.retrieval_methods import RetrievalMethod | |||
| from events.dataset_event import dataset_was_deleted | |||
| @@ -328,14 +328,10 @@ class DatasetService: | |||
| else: | |||
| # add default plugin id to both setting sets, to make sure the plugin model provider is consistent | |||
| plugin_model_provider = dataset.embedding_model_provider | |||
| if "/" not in plugin_model_provider: | |||
| plugin_model_provider = f"{DEFAULT_PLUGIN_ID}/{plugin_model_provider}/{plugin_model_provider}" | |||
| plugin_model_provider = str(ModelProviderID(plugin_model_provider)) | |||
| new_plugin_model_provider = data["embedding_model_provider"] | |||
| if "/" not in new_plugin_model_provider: | |||
| new_plugin_model_provider = ( | |||
| f"{DEFAULT_PLUGIN_ID}/{new_plugin_model_provider}/{new_plugin_model_provider}" | |||
| ) | |||
| new_plugin_model_provider = str(ModelProviderID(new_plugin_model_provider)) | |||
| if ( | |||
| new_plugin_model_provider != plugin_model_provider | |||