瀏覽代碼

feat: Add caching mechanism for plugin model schemas (#14898)

tags/1.0.1
Yeuoly 8 月之前
父節點
當前提交
4668c4996a
沒有連結到貢獻者的電子郵件帳戶。

+ 9
- 0
api/contexts/__init__.py 查看文件

@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING
from contexts.wrapper import RecyclableContextVar

if TYPE_CHECKING:
from core.model_runtime.entities.model_entities import AIModelEntity
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
from core.tools.plugin_tool.provider import PluginToolProviderController
from core.workflow.entities.variable_pool import VariablePool
@@ -20,11 +21,19 @@ To avoid race-conditions caused by gunicorn thread recycling, using RecyclableCo
plugin_tool_providers: RecyclableContextVar[dict[str, "PluginToolProviderController"]] = RecyclableContextVar(
ContextVar("plugin_tool_providers")
)

plugin_tool_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(ContextVar("plugin_tool_providers_lock"))

plugin_model_providers: RecyclableContextVar[list["PluginModelProviderEntity"] | None] = RecyclableContextVar(
ContextVar("plugin_model_providers")
)

plugin_model_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(
ContextVar("plugin_model_providers_lock")
)

plugin_model_schema_lock: RecyclableContextVar[Lock] = RecyclableContextVar(ContextVar("plugin_model_schema_lock"))

plugin_model_schemas: RecyclableContextVar[dict[str, "AIModelEntity"]] = RecyclableContextVar(
ContextVar("plugin_model_schemas")
)

+ 32
- 9
api/core/model_runtime/model_providers/__base/ai_model.py 查看文件

@@ -1,8 +1,11 @@
import decimal
import hashlib
from threading import Lock
from typing import Optional

from pydantic import BaseModel, ConfigDict, Field

import contexts
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE
from core.model_runtime.entities.model_entities import (
@@ -139,15 +142,35 @@ class AIModel(BaseModel):
:return: model schema
"""
plugin_model_manager = PluginModelManager()
return plugin_model_manager.get_model_schema(
tenant_id=self.tenant_id,
user_id="unknown",
plugin_id=self.plugin_id,
provider=self.provider_name,
model_type=self.model_type.value,
model=model,
credentials=credentials or {},
)
cache_key = f"{self.tenant_id}:{self.plugin_id}:{self.provider_name}:{self.model_type.value}:{model}"
# sort credentials
sorted_credentials = sorted(credentials.items()) if credentials else []
cache_key += ":".join([hashlib.md5(f"{k}:{v}".encode()).hexdigest() for k, v in sorted_credentials])

try:
contexts.plugin_model_schemas.get()
except LookupError:
contexts.plugin_model_schemas.set({})
contexts.plugin_model_schema_lock.set(Lock())

with contexts.plugin_model_schema_lock.get():
if cache_key in contexts.plugin_model_schemas.get():
return contexts.plugin_model_schemas.get()[cache_key]

schema = plugin_model_manager.get_model_schema(
tenant_id=self.tenant_id,
user_id="unknown",
plugin_id=self.plugin_id,
provider=self.provider_name,
model_type=self.model_type.value,
model=model,
credentials=credentials or {},
)

if schema:
contexts.plugin_model_schemas.get()[cache_key] = schema

return schema

def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
"""

+ 29
- 10
api/core/model_runtime/model_providers/model_provider_factory.py 查看文件

@@ -1,3 +1,4 @@
import hashlib
import logging
import os
from collections.abc import Sequence
@@ -206,17 +207,35 @@ class ModelProviderFactory:
Get model schema
"""
plugin_id, provider_name = self.get_plugin_id_and_provider_name_from_provider(provider)
model_schema = self.plugin_model_manager.get_model_schema(
tenant_id=self.tenant_id,
user_id="unknown",
plugin_id=plugin_id,
provider=provider_name,
model_type=model_type.value,
model=model,
credentials=credentials,
)
cache_key = f"{self.tenant_id}:{plugin_id}:{provider_name}:{model_type.value}:{model}"
# sort credentials
sorted_credentials = sorted(credentials.items()) if credentials else []
cache_key += ":".join([hashlib.md5(f"{k}:{v}".encode()).hexdigest() for k, v in sorted_credentials])

return model_schema
try:
contexts.plugin_model_schemas.get()
except LookupError:
contexts.plugin_model_schemas.set({})
contexts.plugin_model_schema_lock.set(Lock())

with contexts.plugin_model_schema_lock.get():
if cache_key in contexts.plugin_model_schemas.get():
return contexts.plugin_model_schemas.get()[cache_key]

schema = self.plugin_model_manager.get_model_schema(
tenant_id=self.tenant_id,
user_id="unknown",
plugin_id=plugin_id,
provider=provider_name,
model_type=model_type.value,
model=model,
credentials=credentials or {},
)

if schema:
contexts.plugin_model_schemas.get()[cache_key] = schema

return schema

def get_models(
self,

Loading…
取消
儲存