瀏覽代碼

Fix/plugin race condition (#14253)

tags/1.0.0
Yeuoly 8 月之前
父節點
當前提交
490b6d092e
沒有連結到貢獻者的電子郵件帳戶。

+ 7
- 0
api/app_factory.py 查看文件

import time import time


from configs import dify_config from configs import dify_config
from contexts.wrapper import RecyclableContextVar
from dify_app import DifyApp from dify_app import DifyApp




dify_app = DifyApp(__name__) dify_app = DifyApp(__name__)
dify_app.config.from_mapping(dify_config.model_dump()) dify_app.config.from_mapping(dify_config.model_dump())


# add before request hook
@dify_app.before_request
def before_request():
# add an unique identifier to each request
RecyclableContextVar.increment_thread_recycles()

return dify_app return dify_app





+ 15
- 4
api/contexts/__init__.py 查看文件

from threading import Lock from threading import Lock
from typing import TYPE_CHECKING from typing import TYPE_CHECKING


from contexts.wrapper import RecyclableContextVar

if TYPE_CHECKING: if TYPE_CHECKING:
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
from core.tools.plugin_tool.provider import PluginToolProviderController from core.tools.plugin_tool.provider import PluginToolProviderController


workflow_variable_pool: ContextVar["VariablePool"] = ContextVar("workflow_variable_pool") workflow_variable_pool: ContextVar["VariablePool"] = ContextVar("workflow_variable_pool")


plugin_tool_providers: ContextVar[dict[str, "PluginToolProviderController"]] = ContextVar("plugin_tool_providers")
plugin_tool_providers_lock: ContextVar[Lock] = ContextVar("plugin_tool_providers_lock")
"""
To avoid race-conditions caused by gunicorn thread recycling, using RecyclableContextVar to replace with
"""
plugin_tool_providers: RecyclableContextVar[dict[str, "PluginToolProviderController"]] = RecyclableContextVar(
ContextVar("plugin_tool_providers")
)
plugin_tool_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(ContextVar("plugin_tool_providers_lock"))


plugin_model_providers: ContextVar[list["PluginModelProviderEntity"] | None] = ContextVar("plugin_model_providers")
plugin_model_providers_lock: ContextVar[Lock] = ContextVar("plugin_model_providers_lock")
plugin_model_providers: RecyclableContextVar[list["PluginModelProviderEntity"] | None] = RecyclableContextVar(
ContextVar("plugin_model_providers")
)
plugin_model_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(
ContextVar("plugin_model_providers_lock")
)

+ 65
- 0
api/contexts/wrapper.py 查看文件

from contextvars import ContextVar
from typing import Generic, TypeVar

T = TypeVar("T")


class HiddenValue:
pass


_default = HiddenValue()


class RecyclableContextVar(Generic[T]):
"""
RecyclableContextVar is a wrapper around ContextVar
It's safe to use in gunicorn with thread recycling, but features like `reset` are not available for now

NOTE: you need to call `increment_thread_recycles` before requests
"""

_thread_recycles: ContextVar[int] = ContextVar("thread_recycles")

@classmethod
def increment_thread_recycles(cls):
try:
recycles = cls._thread_recycles.get()
cls._thread_recycles.set(recycles + 1)
except LookupError:
cls._thread_recycles.set(0)

def __init__(self, context_var: ContextVar[T]):
self._context_var = context_var
self._updates = ContextVar[int](context_var.name + "_updates", default=0)

def get(self, default: T | HiddenValue = _default) -> T:
thread_recycles = self._thread_recycles.get(0)
self_updates = self._updates.get()
if thread_recycles > self_updates:
self._updates.set(thread_recycles)

# check if thread is recycled and should be updated
if thread_recycles < self_updates:
return self._context_var.get()
else:
# thread_recycles >= self_updates, means current context is invalid
if isinstance(default, HiddenValue) or default is _default:
raise LookupError
else:
return default

def set(self, value: T):
# it leads to a situation that self.updates is less than cls.thread_recycles if `set` was never called before
# increase it manually
thread_recycles = self._thread_recycles.get(0)
self_updates = self._updates.get()
if thread_recycles > self_updates:
self._updates.set(thread_recycles)

if self._updates.get() == self._thread_recycles.get(0):
# after increment,
self._updates.set(self._updates.get() + 1)

# set the context
self._context_var.set(value)

+ 2
- 2
api/core/agent/entities.py 查看文件

from enum import StrEnum from enum import StrEnum
from typing import Any, Optional, Union from typing import Any, Optional, Union


from pydantic import BaseModel
from pydantic import BaseModel, Field


from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType


provider_type: ToolProviderType provider_type: ToolProviderType
provider_id: str provider_id: str
tool_name: str tool_name: str
tool_parameters: dict[str, Any] = {}
tool_parameters: dict[str, Any] = Field(default_factory=dict)
plugin_unique_identifier: str | None = None plugin_unique_identifier: str | None = None





+ 2
- 4
api/core/app/app_config/easy_ui_based_app/model_config/manager.py 查看文件

from typing import Any from typing import Any


from core.app.app_config.entities import ModelConfigEntity from core.app.app_config.entities import ModelConfigEntity
from core.entities import DEFAULT_PLUGIN_ID
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from core.plugin.entities.plugin import ModelProviderID
from core.provider_manager import ProviderManager from core.provider_manager import ProviderManager




raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}") raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}")


if "/" not in config["model"]["provider"]: if "/" not in config["model"]["provider"]:
config["model"]["provider"] = (
f"{DEFAULT_PLUGIN_ID}/{config['model']['provider']}/{config['model']['provider']}"
)
config["model"]["provider"] = str(ModelProviderID(config["model"]["provider"]))


if config["model"]["provider"] not in model_provider_names: if config["model"]["provider"] not in model_provider_names:
raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}") raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}")

+ 4
- 4
api/core/app/app_config/entities.py 查看文件

provider: str provider: str
model: str model: str
mode: Optional[str] = None mode: Optional[str] = None
parameters: dict[str, Any] = {}
stop: list[str] = []
parameters: dict[str, Any] = Field(default_factory=dict)
stop: list[str] = Field(default_factory=list)




class AdvancedChatMessageEntity(BaseModel): class AdvancedChatMessageEntity(BaseModel):


variable: str variable: str
type: str type: str
config: dict[str, Any] = {}
config: dict[str, Any] = Field(default_factory=dict)




class DatasetRetrieveConfigEntity(BaseModel): class DatasetRetrieveConfigEntity(BaseModel):
""" """


type: str type: str
config: dict[str, Any] = {}
config: dict[str, Any] = Field(default_factory=dict)




class TextToSpeechEntity(BaseModel): class TextToSpeechEntity(BaseModel):

+ 4
- 4
api/core/app/entities/app_invoke_entities.py 查看文件

model_schema: AIModelEntity model_schema: AIModelEntity
mode: str mode: str
provider_model_bundle: ProviderModelBundle provider_model_bundle: ProviderModelBundle
credentials: dict[str, Any] = {}
parameters: dict[str, Any] = {}
stop: list[str] = []
credentials: dict[str, Any] = Field(default_factory=dict)
parameters: dict[str, Any] = Field(default_factory=dict)
stop: list[str] = Field(default_factory=list)


# pydantic configs # pydantic configs
model_config = ConfigDict(protected_namespaces=()) model_config = ConfigDict(protected_namespaces=())
call_depth: int = 0 call_depth: int = 0


# extra parameters, like: auto_generate_conversation_name # extra parameters, like: auto_generate_conversation_name
extras: dict[str, Any] = {}
extras: dict[str, Any] = Field(default_factory=dict)


# tracing instance # tracing instance
trace_manager: Optional[TraceQueueManager] = None trace_manager: Optional[TraceQueueManager] = None

+ 4
- 5
api/core/entities/provider_configuration.py 查看文件

from json import JSONDecodeError from json import JSONDecodeError
from typing import Optional from typing import Optional


from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, Field
from sqlalchemy import or_ from sqlalchemy import or_


from constants import HIDDEN_VALUE from constants import HIDDEN_VALUE
from core.entities import DEFAULT_PLUGIN_ID
from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
from core.entities.provider_entities import ( from core.entities.provider_entities import (
CustomConfiguration, CustomConfiguration,
""" """


tenant_id: str tenant_id: str
configurations: dict[str, ProviderConfiguration] = {}
configurations: dict[str, ProviderConfiguration] = Field(default_factory=dict)


def __init__(self, tenant_id: str): def __init__(self, tenant_id: str):
super().__init__(tenant_id=tenant_id) super().__init__(tenant_id=tenant_id)


def __getitem__(self, key): def __getitem__(self, key):
if "/" not in key: if "/" not in key:
key = f"{DEFAULT_PLUGIN_ID}/{key}/{key}"
key = str(ModelProviderID(key))


return self.configurations[key] return self.configurations[key]




def get(self, key, default=None) -> ProviderConfiguration | None: def get(self, key, default=None) -> ProviderConfiguration | None:
if "/" not in key: if "/" not in key:
key = f"{DEFAULT_PLUGIN_ID}/{key}/{key}"
key = str(ModelProviderID(key))


return self.configurations.get(key, default) # type: ignore return self.configurations.get(key, default) # type: ignore



+ 5
- 1
api/core/hosting_configuration.py 查看文件





class HostingConfiguration: class HostingConfiguration:
provider_map: dict[str, HostingProvider] = {}
provider_map: dict[str, HostingProvider]
moderation_config: Optional[HostedModerationConfig] = None moderation_config: Optional[HostedModerationConfig] = None


def __init__(self) -> None:
self.provider_map = {}
self.moderation_config = None

def init_app(self, app: Flask) -> None: def init_app(self, app: Flask) -> None:
if dify_config.EDITION != "CLOUD": if dify_config.EDITION != "CLOUD":
return return

+ 5
- 10
api/core/model_runtime/model_providers/model_provider_factory.py 查看文件

from pydantic import BaseModel from pydantic import BaseModel


import contexts import contexts
from core.entities import DEFAULT_PLUGIN_ID
from core.helper.position_helper import get_provider_position_map, sort_to_dict_by_position_map from core.helper.position_helper import get_provider_position_map, sort_to_dict_by_position_map
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
from core.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity from core.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity




class ModelProviderFactory: class ModelProviderFactory:
provider_position_map: dict[str, int] = {}
provider_position_map: dict[str, int]


def __init__(self, tenant_id: str) -> None: def __init__(self, tenant_id: str) -> None:
self.provider_position_map = {}

self.tenant_id = tenant_id self.tenant_id = tenant_id
self.plugin_model_manager = PluginModelManager() self.plugin_model_manager = PluginModelManager()


:param provider: provider name :param provider: provider name
:return: plugin id and provider name :return: plugin id and provider name
""" """
plugin_id = DEFAULT_PLUGIN_ID
provider_name = provider
if "/" in provider:
# get the plugin_id before provider
plugin_id = "/".join(provider.split("/")[:-1])
provider_name = provider.split("/")[-1]

return str(plugin_id), provider_name
provider_id = ModelProviderID(provider)
return provider_id.plugin_id, provider_id.provider_name

+ 3
- 7
api/services/dataset_service.py 查看文件

from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound


from configs import dify_config from configs import dify_config
from core.entities import DEFAULT_PLUGIN_ID
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.model_manager import ModelManager from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.plugin.entities.plugin import ModelProviderID
from core.rag.index_processor.constant.index_type import IndexType from core.rag.index_processor.constant.index_type import IndexType
from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.rag.retrieval.retrieval_methods import RetrievalMethod
from events.dataset_event import dataset_was_deleted from events.dataset_event import dataset_was_deleted
else: else:
# add default plugin id to both setting sets, to make sure the plugin model provider is consistent # add default plugin id to both setting sets, to make sure the plugin model provider is consistent
plugin_model_provider = dataset.embedding_model_provider plugin_model_provider = dataset.embedding_model_provider
if "/" not in plugin_model_provider:
plugin_model_provider = f"{DEFAULT_PLUGIN_ID}/{plugin_model_provider}/{plugin_model_provider}"
plugin_model_provider = str(ModelProviderID(plugin_model_provider))


new_plugin_model_provider = data["embedding_model_provider"] new_plugin_model_provider = data["embedding_model_provider"]
if "/" not in new_plugin_model_provider:
new_plugin_model_provider = (
f"{DEFAULT_PLUGIN_ID}/{new_plugin_model_provider}/{new_plugin_model_provider}"
)
new_plugin_model_provider = str(ModelProviderID(new_plugin_model_provider))


if ( if (
new_plugin_model_provider != plugin_model_provider new_plugin_model_provider != plugin_model_provider

Loading…
取消
儲存