浏览代码

refactor: Simplify plugin and provider ID generation logic and deduplicate plugin_ids (#14041)

tags/1.0.0
Yeuoly 8 个月前
父节点
当前提交
23888398d1
没有帐户链接到提交者的电子邮件
共有 2 个文件被更改,包括 8 次插入43 次删除
  1. 3
    10
      api/services/plugin/dependencies_analysis.py
  2. 5
    33
      api/services/plugin/plugin_migration.py

+ 3
- 10
api/services/plugin/dependencies_analysis.py 查看文件

from core.helper import marketplace from core.helper import marketplace
from core.plugin.entities.plugin import GenericProviderID, PluginDependency, PluginInstallationSource
from core.plugin.entities.plugin import ModelProviderID, PluginDependency, PluginInstallationSource, ToolProviderID
from core.plugin.manager.plugin import PluginInstallationManager from core.plugin.manager.plugin import PluginInstallationManager




Convert the tool id to the plugin_id Convert the tool id to the plugin_id
""" """
try: try:
tool_provider_id = GenericProviderID(tool_id)
if tool_id in ["jina", "siliconflow"]:
tool_provider_id.plugin_name = tool_provider_id.plugin_name + "_tool"
return tool_provider_id.plugin_id
return ToolProviderID(tool_id).plugin_id
except Exception as e: except Exception as e:
raise e raise e


Convert the model provider id to the plugin_id Convert the model provider id to the plugin_id
""" """
try: try:
generic_provider_id = GenericProviderID(model_provider_id)
if model_provider_id == "google":
generic_provider_id.plugin_name = "gemini"

return generic_provider_id.plugin_id
return ModelProviderID(model_provider_id).plugin_id
except Exception as e: except Exception as e:
raise e raise e



+ 5
- 33
api/services/plugin/plugin_migration.py 查看文件

from sqlalchemy.orm import Session from sqlalchemy.orm import Session


from core.agent.entities import AgentToolEntity from core.agent.entities import AgentToolEntity
from core.entities import DEFAULT_PLUGIN_ID
from core.helper import marketplace from core.helper import marketplace
from core.plugin.entities.plugin import PluginInstallationSource
from core.plugin.entities.plugin import ModelProviderID, PluginInstallationSource, ToolProviderID
from core.plugin.entities.plugin_daemon import PluginInstallTaskStatus from core.plugin.entities.plugin_daemon import PluginInstallTaskStatus
from core.plugin.manager.plugin import PluginInstallationManager from core.plugin.manager.plugin import PluginInstallationManager
from core.tools.entities.tool_entities import ToolProviderType from core.tools.entities.tool_entities import ToolProviderType
result = [] result = []
for row in rs: for row in rs:
provider_name = str(row[0]) provider_name = str(row[0])
if provider_name and "/" not in provider_name:
if provider_name == "google":
provider_name = "gemini"

result.append(DEFAULT_PLUGIN_ID + "/" + provider_name)
elif provider_name:
result.append(provider_name)
result.append(ModelProviderID(provider_name).plugin_id)


return result return result


rs = session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all() rs = session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all()
result = [] result = []
for row in rs: for row in rs:
if "/" not in row.provider:
result.append(DEFAULT_PLUGIN_ID + "/" + row.provider)
else:
result.append(row.provider)
result.append(ToolProviderID(row.provider).plugin_id)


return result return result


@classmethod
def _handle_builtin_tool_provider(cls, provider_name: str) -> str:
"""
Handle builtin tool provider.
"""
if provider_name == "jina":
provider_name = "jina_tool"
elif provider_name == "siliconflow":
provider_name = "siliconflow_tool"
elif provider_name == "stepfun":
provider_name = "stepfun_tool"

if "/" not in provider_name:
return DEFAULT_PLUGIN_ID + "/" + provider_name
else:
return provider_name

@classmethod @classmethod
def extract_workflow_tables(cls, tenant_id: str) -> Sequence[str]: def extract_workflow_tables(cls, tenant_id: str) -> Sequence[str]:
""" """
provider_name = data.get("provider_name") provider_name = data.get("provider_name")
provider_type = data.get("provider_type") provider_type = data.get("provider_type")
if provider_name not in excluded_providers and provider_type == ToolProviderType.BUILT_IN.value: if provider_name not in excluded_providers and provider_type == ToolProviderType.BUILT_IN.value:
provider_name = cls._handle_builtin_tool_provider(provider_name)
result.append(provider_name)
result.append(ToolProviderID(provider_name).plugin_id)


return result return result


tool_entity.provider_type == ToolProviderType.BUILT_IN.value tool_entity.provider_type == ToolProviderType.BUILT_IN.value
and tool_entity.provider_id not in excluded_providers and tool_entity.provider_id not in excluded_providers
): ):
result.append(cls._handle_builtin_tool_provider(tool_entity.provider_id))
result.append(ToolProviderID(tool_entity.provider_id).plugin_id)


except Exception: except Exception:
logger.exception(f"Failed to process tool {tool}") logger.exception(f"Failed to process tool {tool}")

正在加载...
取消
保存