您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

provider.py 5.2KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. import json
  2. from typing import Any, Optional, Self
  3. from core.mcp.types import Tool as RemoteMCPTool
  4. from core.tools.__base.tool_provider import ToolProviderController
  5. from core.tools.__base.tool_runtime import ToolRuntime
  6. from core.tools.entities.common_entities import I18nObject
  7. from core.tools.entities.tool_entities import (
  8. ToolDescription,
  9. ToolEntity,
  10. ToolIdentity,
  11. ToolProviderEntityWithPlugin,
  12. ToolProviderIdentity,
  13. ToolProviderType,
  14. )
  15. from core.tools.mcp_tool.tool import MCPTool
  16. from models.tools import MCPToolProvider
  17. from services.tools.tools_transform_service import ToolTransformService
  18. class MCPToolProviderController(ToolProviderController):
  19. def __init__(
  20. self,
  21. entity: ToolProviderEntityWithPlugin,
  22. provider_id: str,
  23. tenant_id: str,
  24. server_url: str,
  25. headers: Optional[dict[str, str]] = None,
  26. timeout: Optional[float] = None,
  27. sse_read_timeout: Optional[float] = None,
  28. ):
  29. super().__init__(entity)
  30. self.entity: ToolProviderEntityWithPlugin = entity
  31. self.tenant_id = tenant_id
  32. self.provider_id = provider_id
  33. self.server_url = server_url
  34. self.headers = headers or {}
  35. self.timeout = timeout
  36. self.sse_read_timeout = sse_read_timeout
  37. @property
  38. def provider_type(self) -> ToolProviderType:
  39. """
  40. returns the type of the provider
  41. :return: type of the provider
  42. """
  43. return ToolProviderType.MCP
  44. @classmethod
  45. def from_db(cls, db_provider: MCPToolProvider) -> Self:
  46. """
  47. from db provider
  48. """
  49. tools = []
  50. tools_data = json.loads(db_provider.tools)
  51. remote_mcp_tools = [RemoteMCPTool(**tool) for tool in tools_data]
  52. user = db_provider.load_user()
  53. tools = [
  54. ToolEntity(
  55. identity=ToolIdentity(
  56. author=user.name if user else "Anonymous",
  57. name=remote_mcp_tool.name,
  58. label=I18nObject(en_US=remote_mcp_tool.name, zh_Hans=remote_mcp_tool.name),
  59. provider=db_provider.server_identifier,
  60. icon=db_provider.icon,
  61. ),
  62. parameters=ToolTransformService.convert_mcp_schema_to_parameter(remote_mcp_tool.inputSchema),
  63. description=ToolDescription(
  64. human=I18nObject(
  65. en_US=remote_mcp_tool.description or "", zh_Hans=remote_mcp_tool.description or ""
  66. ),
  67. llm=remote_mcp_tool.description or "",
  68. ),
  69. output_schema=None,
  70. has_runtime_parameters=len(remote_mcp_tool.inputSchema) > 0,
  71. )
  72. for remote_mcp_tool in remote_mcp_tools
  73. ]
  74. return cls(
  75. entity=ToolProviderEntityWithPlugin(
  76. identity=ToolProviderIdentity(
  77. author=user.name if user else "Anonymous",
  78. name=db_provider.name,
  79. label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name),
  80. description=I18nObject(en_US="", zh_Hans=""),
  81. icon=db_provider.icon,
  82. ),
  83. plugin_id=None,
  84. credentials_schema=[],
  85. tools=tools,
  86. ),
  87. provider_id=db_provider.server_identifier or "",
  88. tenant_id=db_provider.tenant_id or "",
  89. server_url=db_provider.decrypted_server_url,
  90. headers={}, # TODO: get headers from db provider
  91. timeout=db_provider.timeout,
  92. sse_read_timeout=db_provider.sse_read_timeout,
  93. )
  94. def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):
  95. """
  96. validate the credentials of the provider
  97. """
  98. pass
  99. def get_tool(self, tool_name: str) -> MCPTool: # type: ignore
  100. """
  101. return tool with given name
  102. """
  103. tool_entity = next(
  104. (tool_entity for tool_entity in self.entity.tools if tool_entity.identity.name == tool_name), None
  105. )
  106. if not tool_entity:
  107. raise ValueError(f"Tool with name {tool_name} not found")
  108. return MCPTool(
  109. entity=tool_entity,
  110. runtime=ToolRuntime(tenant_id=self.tenant_id),
  111. tenant_id=self.tenant_id,
  112. icon=self.entity.identity.icon,
  113. server_url=self.server_url,
  114. provider_id=self.provider_id,
  115. headers=self.headers,
  116. timeout=self.timeout,
  117. sse_read_timeout=self.sse_read_timeout,
  118. )
  119. def get_tools(self) -> list[MCPTool]: # type: ignore
  120. """
  121. get all tools
  122. """
  123. return [
  124. MCPTool(
  125. entity=tool_entity,
  126. runtime=ToolRuntime(tenant_id=self.tenant_id),
  127. tenant_id=self.tenant_id,
  128. icon=self.entity.identity.icon,
  129. server_url=self.server_url,
  130. provider_id=self.provider_id,
  131. headers=self.headers,
  132. timeout=self.timeout,
  133. sse_read_timeout=self.sse_read_timeout,
  134. )
  135. for tool_entity in self.entity.tools
  136. ]