| from contexts.wrapper import RecyclableContextVar | from contexts.wrapper import RecyclableContextVar | ||||
| if TYPE_CHECKING: | if TYPE_CHECKING: | ||||
| from core.model_runtime.entities.model_entities import AIModelEntity | |||||
| from core.plugin.entities.plugin_daemon import PluginModelProviderEntity | from core.plugin.entities.plugin_daemon import PluginModelProviderEntity | ||||
| from core.tools.plugin_tool.provider import PluginToolProviderController | from core.tools.plugin_tool.provider import PluginToolProviderController | ||||
| from core.workflow.entities.variable_pool import VariablePool | from core.workflow.entities.variable_pool import VariablePool | ||||
| plugin_tool_providers: RecyclableContextVar[dict[str, "PluginToolProviderController"]] = RecyclableContextVar( | plugin_tool_providers: RecyclableContextVar[dict[str, "PluginToolProviderController"]] = RecyclableContextVar( | ||||
| ContextVar("plugin_tool_providers") | ContextVar("plugin_tool_providers") | ||||
| ) | ) | ||||
| plugin_tool_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(ContextVar("plugin_tool_providers_lock")) | plugin_tool_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(ContextVar("plugin_tool_providers_lock")) | ||||
| plugin_model_providers: RecyclableContextVar[list["PluginModelProviderEntity"] | None] = RecyclableContextVar( | plugin_model_providers: RecyclableContextVar[list["PluginModelProviderEntity"] | None] = RecyclableContextVar( | ||||
| ContextVar("plugin_model_providers") | ContextVar("plugin_model_providers") | ||||
| ) | ) | ||||
| plugin_model_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar( | plugin_model_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar( | ||||
| ContextVar("plugin_model_providers_lock") | ContextVar("plugin_model_providers_lock") | ||||
| ) | ) | ||||
| plugin_model_schema_lock: RecyclableContextVar[Lock] = RecyclableContextVar(ContextVar("plugin_model_schema_lock")) | |||||
| plugin_model_schemas: RecyclableContextVar[dict[str, "AIModelEntity"]] = RecyclableContextVar( | |||||
| ContextVar("plugin_model_schemas") | |||||
| ) |
| import decimal | import decimal | ||||
| import hashlib | |||||
| from threading import Lock | |||||
| from typing import Optional | from typing import Optional | ||||
| from pydantic import BaseModel, ConfigDict, Field | from pydantic import BaseModel, ConfigDict, Field | ||||
| import contexts | |||||
| from core.model_runtime.entities.common_entities import I18nObject | from core.model_runtime.entities.common_entities import I18nObject | ||||
| from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE | from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE | ||||
| from core.model_runtime.entities.model_entities import ( | from core.model_runtime.entities.model_entities import ( | ||||
| :return: model schema | :return: model schema | ||||
| """ | """ | ||||
| plugin_model_manager = PluginModelManager() | plugin_model_manager = PluginModelManager() | ||||
| return plugin_model_manager.get_model_schema( | |||||
| tenant_id=self.tenant_id, | |||||
| user_id="unknown", | |||||
| plugin_id=self.plugin_id, | |||||
| provider=self.provider_name, | |||||
| model_type=self.model_type.value, | |||||
| model=model, | |||||
| credentials=credentials or {}, | |||||
| ) | |||||
| cache_key = f"{self.tenant_id}:{self.plugin_id}:{self.provider_name}:{self.model_type.value}:{model}" | |||||
| # sort credentials | |||||
| sorted_credentials = sorted(credentials.items()) if credentials else [] | |||||
| cache_key += ":".join([hashlib.md5(f"{k}:{v}".encode()).hexdigest() for k, v in sorted_credentials]) | |||||
| try: | |||||
| contexts.plugin_model_schemas.get() | |||||
| except LookupError: | |||||
| contexts.plugin_model_schemas.set({}) | |||||
| contexts.plugin_model_schema_lock.set(Lock()) | |||||
| with contexts.plugin_model_schema_lock.get(): | |||||
| if cache_key in contexts.plugin_model_schemas.get(): | |||||
| return contexts.plugin_model_schemas.get()[cache_key] | |||||
| schema = plugin_model_manager.get_model_schema( | |||||
| tenant_id=self.tenant_id, | |||||
| user_id="unknown", | |||||
| plugin_id=self.plugin_id, | |||||
| provider=self.provider_name, | |||||
| model_type=self.model_type.value, | |||||
| model=model, | |||||
| credentials=credentials or {}, | |||||
| ) | |||||
| if schema: | |||||
| contexts.plugin_model_schemas.get()[cache_key] = schema | |||||
| return schema | |||||
| def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> Optional[AIModelEntity]: | def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> Optional[AIModelEntity]: | ||||
| """ | """ |
| import hashlib | |||||
| import logging | import logging | ||||
| import os | import os | ||||
| from collections.abc import Sequence | from collections.abc import Sequence | ||||
| Get model schema | Get model schema | ||||
| """ | """ | ||||
| plugin_id, provider_name = self.get_plugin_id_and_provider_name_from_provider(provider) | plugin_id, provider_name = self.get_plugin_id_and_provider_name_from_provider(provider) | ||||
| model_schema = self.plugin_model_manager.get_model_schema( | |||||
| tenant_id=self.tenant_id, | |||||
| user_id="unknown", | |||||
| plugin_id=plugin_id, | |||||
| provider=provider_name, | |||||
| model_type=model_type.value, | |||||
| model=model, | |||||
| credentials=credentials, | |||||
| ) | |||||
| cache_key = f"{self.tenant_id}:{plugin_id}:{provider_name}:{model_type.value}:{model}" | |||||
| # sort credentials | |||||
| sorted_credentials = sorted(credentials.items()) if credentials else [] | |||||
| cache_key += ":".join([hashlib.md5(f"{k}:{v}".encode()).hexdigest() for k, v in sorted_credentials]) | |||||
| return model_schema | |||||
| try: | |||||
| contexts.plugin_model_schemas.get() | |||||
| except LookupError: | |||||
| contexts.plugin_model_schemas.set({}) | |||||
| contexts.plugin_model_schema_lock.set(Lock()) | |||||
| with contexts.plugin_model_schema_lock.get(): | |||||
| if cache_key in contexts.plugin_model_schemas.get(): | |||||
| return contexts.plugin_model_schemas.get()[cache_key] | |||||
| schema = self.plugin_model_manager.get_model_schema( | |||||
| tenant_id=self.tenant_id, | |||||
| user_id="unknown", | |||||
| plugin_id=plugin_id, | |||||
| provider=provider_name, | |||||
| model_type=model_type.value, | |||||
| model=model, | |||||
| credentials=credentials or {}, | |||||
| ) | |||||
| if schema: | |||||
| contexts.plugin_model_schemas.get()[cache_key] = schema | |||||
| return schema | |||||
| def get_models( | def get_models( | ||||
| self, | self, |