You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

provider.py 5.1KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. import json
  2. from typing import Any, Self
  3. from core.mcp.types import Tool as RemoteMCPTool
  4. from core.tools.__base.tool_provider import ToolProviderController
  5. from core.tools.__base.tool_runtime import ToolRuntime
  6. from core.tools.entities.common_entities import I18nObject
  7. from core.tools.entities.tool_entities import (
  8. ToolDescription,
  9. ToolEntity,
  10. ToolIdentity,
  11. ToolProviderEntityWithPlugin,
  12. ToolProviderIdentity,
  13. ToolProviderType,
  14. )
  15. from core.tools.mcp_tool.tool import MCPTool
  16. from models.tools import MCPToolProvider
  17. from services.tools.tools_transform_service import ToolTransformService
  18. class MCPToolProviderController(ToolProviderController):
  19. def __init__(
  20. self,
  21. entity: ToolProviderEntityWithPlugin,
  22. provider_id: str,
  23. tenant_id: str,
  24. server_url: str,
  25. headers: dict[str, str] | None = None,
  26. timeout: float | None = None,
  27. sse_read_timeout: float | None = None,
  28. ):
  29. super().__init__(entity)
  30. self.entity: ToolProviderEntityWithPlugin = entity
  31. self.tenant_id = tenant_id
  32. self.provider_id = provider_id
  33. self.server_url = server_url
  34. self.headers = headers or {}
  35. self.timeout = timeout
  36. self.sse_read_timeout = sse_read_timeout
  37. @property
  38. def provider_type(self) -> ToolProviderType:
  39. """
  40. returns the type of the provider
  41. :return: type of the provider
  42. """
  43. return ToolProviderType.MCP
  44. @classmethod
  45. def from_db(cls, db_provider: MCPToolProvider) -> Self:
  46. """
  47. from db provider
  48. """
  49. tools = []
  50. tools_data = json.loads(db_provider.tools)
  51. remote_mcp_tools = [RemoteMCPTool(**tool) for tool in tools_data]
  52. user = db_provider.load_user()
  53. tools = [
  54. ToolEntity(
  55. identity=ToolIdentity(
  56. author=user.name if user else "Anonymous",
  57. name=remote_mcp_tool.name,
  58. label=I18nObject(en_US=remote_mcp_tool.name, zh_Hans=remote_mcp_tool.name),
  59. provider=db_provider.server_identifier,
  60. icon=db_provider.icon,
  61. ),
  62. parameters=ToolTransformService.convert_mcp_schema_to_parameter(remote_mcp_tool.inputSchema),
  63. description=ToolDescription(
  64. human=I18nObject(
  65. en_US=remote_mcp_tool.description or "", zh_Hans=remote_mcp_tool.description or ""
  66. ),
  67. llm=remote_mcp_tool.description or "",
  68. ),
  69. has_runtime_parameters=len(remote_mcp_tool.inputSchema) > 0,
  70. )
  71. for remote_mcp_tool in remote_mcp_tools
  72. ]
  73. return cls(
  74. entity=ToolProviderEntityWithPlugin(
  75. identity=ToolProviderIdentity(
  76. author=user.name if user else "Anonymous",
  77. name=db_provider.name,
  78. label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name),
  79. description=I18nObject(en_US="", zh_Hans=""),
  80. icon=db_provider.icon,
  81. ),
  82. plugin_id=None,
  83. credentials_schema=[],
  84. tools=tools,
  85. ),
  86. provider_id=db_provider.server_identifier or "",
  87. tenant_id=db_provider.tenant_id or "",
  88. server_url=db_provider.decrypted_server_url,
  89. headers=db_provider.decrypted_headers or {},
  90. timeout=db_provider.timeout,
  91. sse_read_timeout=db_provider.sse_read_timeout,
  92. )
  93. def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):
  94. """
  95. validate the credentials of the provider
  96. """
  97. pass
  98. def get_tool(self, tool_name: str) -> MCPTool: # type: ignore
  99. """
  100. return tool with given name
  101. """
  102. tool_entity = next(
  103. (tool_entity for tool_entity in self.entity.tools if tool_entity.identity.name == tool_name), None
  104. )
  105. if not tool_entity:
  106. raise ValueError(f"Tool with name {tool_name} not found")
  107. return MCPTool(
  108. entity=tool_entity,
  109. runtime=ToolRuntime(tenant_id=self.tenant_id),
  110. tenant_id=self.tenant_id,
  111. icon=self.entity.identity.icon,
  112. server_url=self.server_url,
  113. provider_id=self.provider_id,
  114. headers=self.headers,
  115. timeout=self.timeout,
  116. sse_read_timeout=self.sse_read_timeout,
  117. )
  118. def get_tools(self) -> list[MCPTool]: # type: ignore
  119. """
  120. get all tools
  121. """
  122. return [
  123. MCPTool(
  124. entity=tool_entity,
  125. runtime=ToolRuntime(tenant_id=self.tenant_id),
  126. tenant_id=self.tenant_id,
  127. icon=self.entity.identity.icon,
  128. server_url=self.server_url,
  129. provider_id=self.provider_id,
  130. headers=self.headers,
  131. timeout=self.timeout,
  132. sse_read_timeout=self.sse_read_timeout,
  133. )
  134. for tool_entity in self.entity.tools
  135. ]