| @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING | |||
| from contexts.wrapper import RecyclableContextVar | |||
| if TYPE_CHECKING: | |||
| from core.model_runtime.entities.model_entities import AIModelEntity | |||
| from core.plugin.entities.plugin_daemon import PluginModelProviderEntity | |||
| from core.tools.plugin_tool.provider import PluginToolProviderController | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| @@ -20,11 +21,19 @@ To avoid race-conditions caused by gunicorn thread recycling, using RecyclableCo | |||
| plugin_tool_providers: RecyclableContextVar[dict[str, "PluginToolProviderController"]] = RecyclableContextVar( | |||
| ContextVar("plugin_tool_providers") | |||
| ) | |||
| plugin_tool_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(ContextVar("plugin_tool_providers_lock")) | |||
| plugin_model_providers: RecyclableContextVar[list["PluginModelProviderEntity"] | None] = RecyclableContextVar( | |||
| ContextVar("plugin_model_providers") | |||
| ) | |||
| plugin_model_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar( | |||
| ContextVar("plugin_model_providers_lock") | |||
| ) | |||
| plugin_model_schema_lock: RecyclableContextVar[Lock] = RecyclableContextVar(ContextVar("plugin_model_schema_lock")) | |||
| plugin_model_schemas: RecyclableContextVar[dict[str, "AIModelEntity"]] = RecyclableContextVar( | |||
| ContextVar("plugin_model_schemas") | |||
| ) | |||
| @@ -1,8 +1,11 @@ | |||
| import decimal | |||
| import hashlib | |||
| from threading import Lock | |||
| from typing import Optional | |||
| from pydantic import BaseModel, ConfigDict, Field | |||
| import contexts | |||
| from core.model_runtime.entities.common_entities import I18nObject | |||
| from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE | |||
| from core.model_runtime.entities.model_entities import ( | |||
| @@ -139,15 +142,35 @@ class AIModel(BaseModel): | |||
| :return: model schema | |||
| """ | |||
| plugin_model_manager = PluginModelManager() | |||
| return plugin_model_manager.get_model_schema( | |||
| tenant_id=self.tenant_id, | |||
| user_id="unknown", | |||
| plugin_id=self.plugin_id, | |||
| provider=self.provider_name, | |||
| model_type=self.model_type.value, | |||
| model=model, | |||
| credentials=credentials or {}, | |||
| ) | |||
| cache_key = f"{self.tenant_id}:{self.plugin_id}:{self.provider_name}:{self.model_type.value}:{model}" | |||
| # sort credentials | |||
| sorted_credentials = sorted(credentials.items()) if credentials else [] | |||
| cache_key += ":".join([hashlib.md5(f"{k}:{v}".encode()).hexdigest() for k, v in sorted_credentials]) | |||
| try: | |||
| contexts.plugin_model_schemas.get() | |||
| except LookupError: | |||
| contexts.plugin_model_schemas.set({}) | |||
| contexts.plugin_model_schema_lock.set(Lock()) | |||
| with contexts.plugin_model_schema_lock.get(): | |||
| if cache_key in contexts.plugin_model_schemas.get(): | |||
| return contexts.plugin_model_schemas.get()[cache_key] | |||
| schema = plugin_model_manager.get_model_schema( | |||
| tenant_id=self.tenant_id, | |||
| user_id="unknown", | |||
| plugin_id=self.plugin_id, | |||
| provider=self.provider_name, | |||
| model_type=self.model_type.value, | |||
| model=model, | |||
| credentials=credentials or {}, | |||
| ) | |||
| if schema: | |||
| contexts.plugin_model_schemas.get()[cache_key] = schema | |||
| return schema | |||
| def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> Optional[AIModelEntity]: | |||
| """ | |||
| @@ -1,3 +1,4 @@ | |||
| import hashlib | |||
| import logging | |||
| import os | |||
| from collections.abc import Sequence | |||
| @@ -206,17 +207,35 @@ class ModelProviderFactory: | |||
| Get model schema | |||
| """ | |||
| plugin_id, provider_name = self.get_plugin_id_and_provider_name_from_provider(provider) | |||
| model_schema = self.plugin_model_manager.get_model_schema( | |||
| tenant_id=self.tenant_id, | |||
| user_id="unknown", | |||
| plugin_id=plugin_id, | |||
| provider=provider_name, | |||
| model_type=model_type.value, | |||
| model=model, | |||
| credentials=credentials, | |||
| ) | |||
| cache_key = f"{self.tenant_id}:{plugin_id}:{provider_name}:{model_type.value}:{model}" | |||
| # sort credentials | |||
| sorted_credentials = sorted(credentials.items()) if credentials else [] | |||
| cache_key += ":".join([hashlib.md5(f"{k}:{v}".encode()).hexdigest() for k, v in sorted_credentials]) | |||
| return model_schema | |||
| try: | |||
| contexts.plugin_model_schemas.get() | |||
| except LookupError: | |||
| contexts.plugin_model_schemas.set({}) | |||
| contexts.plugin_model_schema_lock.set(Lock()) | |||
| with contexts.plugin_model_schema_lock.get(): | |||
| if cache_key in contexts.plugin_model_schemas.get(): | |||
| return contexts.plugin_model_schemas.get()[cache_key] | |||
| schema = self.plugin_model_manager.get_model_schema( | |||
| tenant_id=self.tenant_id, | |||
| user_id="unknown", | |||
| plugin_id=plugin_id, | |||
| provider=provider_name, | |||
| model_type=model_type.value, | |||
| model=model, | |||
| credentials=credentials or {}, | |||
| ) | |||
| if schema: | |||
| contexts.plugin_model_schemas.get()[cache_key] = schema | |||
| return schema | |||
| def get_models( | |||
| self, | |||