您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

mcp_tools_manage_service.py 10KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262
  1. import hashlib
  2. import json
  3. from datetime import datetime
  4. from typing import Any
  5. from sqlalchemy import or_
  6. from sqlalchemy.exc import IntegrityError
  7. from core.helper import encrypter
  8. from core.helper.provider_cache import NoOpProviderCredentialCache
  9. from core.mcp.error import MCPAuthError, MCPError
  10. from core.mcp.mcp_client import MCPClient
  11. from core.tools.entities.api_entities import ToolProviderApiEntity
  12. from core.tools.entities.common_entities import I18nObject
  13. from core.tools.entities.tool_entities import ToolProviderType
  14. from core.tools.mcp_tool.provider import MCPToolProviderController
  15. from core.tools.utils.encryption import ProviderConfigEncrypter
  16. from extensions.ext_database import db
  17. from models.tools import MCPToolProvider
  18. from services.tools.tools_transform_service import ToolTransformService
  19. UNCHANGED_SERVER_URL_PLACEHOLDER = "[__HIDDEN__]"
  20. class MCPToolManageService:
  21. """
  22. Service class for managing mcp tools.
  23. """
  24. @staticmethod
  25. def get_mcp_provider_by_provider_id(provider_id: str, tenant_id: str) -> MCPToolProvider:
  26. res = (
  27. db.session.query(MCPToolProvider)
  28. .where(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.id == provider_id)
  29. .first()
  30. )
  31. if not res:
  32. raise ValueError("MCP tool not found")
  33. return res
  34. @staticmethod
  35. def get_mcp_provider_by_server_identifier(server_identifier: str, tenant_id: str) -> MCPToolProvider:
  36. res = (
  37. db.session.query(MCPToolProvider)
  38. .where(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == server_identifier)
  39. .first()
  40. )
  41. if not res:
  42. raise ValueError("MCP tool not found")
  43. return res
  44. @staticmethod
  45. def create_mcp_provider(
  46. tenant_id: str,
  47. name: str,
  48. server_url: str,
  49. user_id: str,
  50. icon: str,
  51. icon_type: str,
  52. icon_background: str,
  53. server_identifier: str,
  54. timeout: float,
  55. sse_read_timeout: float,
  56. ) -> ToolProviderApiEntity:
  57. server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
  58. existing_provider = (
  59. db.session.query(MCPToolProvider)
  60. .where(
  61. MCPToolProvider.tenant_id == tenant_id,
  62. or_(
  63. MCPToolProvider.name == name,
  64. MCPToolProvider.server_url_hash == server_url_hash,
  65. MCPToolProvider.server_identifier == server_identifier,
  66. ),
  67. )
  68. .first()
  69. )
  70. if existing_provider:
  71. if existing_provider.name == name:
  72. raise ValueError(f"MCP tool {name} already exists")
  73. if existing_provider.server_url_hash == server_url_hash:
  74. raise ValueError(f"MCP tool {server_url} already exists")
  75. if existing_provider.server_identifier == server_identifier:
  76. raise ValueError(f"MCP tool {server_identifier} already exists")
  77. encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url)
  78. mcp_tool = MCPToolProvider(
  79. tenant_id=tenant_id,
  80. name=name,
  81. server_url=encrypted_server_url,
  82. server_url_hash=server_url_hash,
  83. user_id=user_id,
  84. authed=False,
  85. tools="[]",
  86. icon=json.dumps({"content": icon, "background": icon_background}) if icon_type == "emoji" else icon,
  87. server_identifier=server_identifier,
  88. timeout=timeout,
  89. sse_read_timeout=sse_read_timeout,
  90. )
  91. db.session.add(mcp_tool)
  92. db.session.commit()
  93. return ToolTransformService.mcp_provider_to_user_provider(mcp_tool, for_list=True)
  94. @staticmethod
  95. def retrieve_mcp_tools(tenant_id: str, for_list: bool = False) -> list[ToolProviderApiEntity]:
  96. mcp_providers = (
  97. db.session.query(MCPToolProvider)
  98. .where(MCPToolProvider.tenant_id == tenant_id)
  99. .order_by(MCPToolProvider.name)
  100. .all()
  101. )
  102. return [
  103. ToolTransformService.mcp_provider_to_user_provider(mcp_provider, for_list=for_list)
  104. for mcp_provider in mcp_providers
  105. ]
  106. @classmethod
  107. def list_mcp_tool_from_remote_server(cls, tenant_id: str, provider_id: str) -> ToolProviderApiEntity:
  108. mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
  109. server_url = mcp_provider.decrypted_server_url
  110. authed = mcp_provider.authed
  111. try:
  112. with MCPClient(server_url, provider_id, tenant_id, authed=authed, for_list=True) as mcp_client:
  113. tools = mcp_client.list_tools()
  114. except MCPAuthError:
  115. raise ValueError("Please auth the tool first")
  116. except MCPError as e:
  117. raise ValueError(f"Failed to connect to MCP server: {e}")
  118. try:
  119. mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
  120. mcp_provider.tools = json.dumps([tool.model_dump() for tool in tools])
  121. mcp_provider.authed = True
  122. mcp_provider.updated_at = datetime.now()
  123. db.session.commit()
  124. except Exception:
  125. db.session.rollback()
  126. raise
  127. user = mcp_provider.load_user()
  128. return ToolProviderApiEntity(
  129. id=mcp_provider.id,
  130. name=mcp_provider.name,
  131. tools=ToolTransformService.mcp_tool_to_user_tool(mcp_provider, tools),
  132. type=ToolProviderType.MCP,
  133. icon=mcp_provider.icon,
  134. author=user.name if user else "Anonymous",
  135. server_url=mcp_provider.masked_server_url,
  136. updated_at=int(mcp_provider.updated_at.timestamp()),
  137. description=I18nObject(en_US="", zh_Hans=""),
  138. label=I18nObject(en_US=mcp_provider.name, zh_Hans=mcp_provider.name),
  139. plugin_unique_identifier=mcp_provider.server_identifier,
  140. )
  141. @classmethod
  142. def delete_mcp_tool(cls, tenant_id: str, provider_id: str):
  143. mcp_tool = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
  144. db.session.delete(mcp_tool)
  145. db.session.commit()
  146. @classmethod
  147. def update_mcp_provider(
  148. cls,
  149. tenant_id: str,
  150. provider_id: str,
  151. name: str,
  152. server_url: str,
  153. icon: str,
  154. icon_type: str,
  155. icon_background: str,
  156. server_identifier: str,
  157. timeout: float | None = None,
  158. sse_read_timeout: float | None = None,
  159. ):
  160. mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
  161. reconnect_result = None
  162. encrypted_server_url = None
  163. server_url_hash = None
  164. if UNCHANGED_SERVER_URL_PLACEHOLDER not in server_url:
  165. encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url)
  166. server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
  167. if server_url_hash != mcp_provider.server_url_hash:
  168. reconnect_result = cls._re_connect_mcp_provider(server_url, provider_id, tenant_id)
  169. try:
  170. mcp_provider.updated_at = datetime.now()
  171. mcp_provider.name = name
  172. mcp_provider.icon = (
  173. json.dumps({"content": icon, "background": icon_background}) if icon_type == "emoji" else icon
  174. )
  175. mcp_provider.server_identifier = server_identifier
  176. if encrypted_server_url is not None and server_url_hash is not None:
  177. mcp_provider.server_url = encrypted_server_url
  178. mcp_provider.server_url_hash = server_url_hash
  179. if reconnect_result:
  180. mcp_provider.authed = reconnect_result["authed"]
  181. mcp_provider.tools = reconnect_result["tools"]
  182. mcp_provider.encrypted_credentials = reconnect_result["encrypted_credentials"]
  183. if timeout is not None:
  184. mcp_provider.timeout = timeout
  185. if sse_read_timeout is not None:
  186. mcp_provider.sse_read_timeout = sse_read_timeout
  187. db.session.commit()
  188. except IntegrityError as e:
  189. db.session.rollback()
  190. error_msg = str(e.orig)
  191. if "unique_mcp_provider_name" in error_msg:
  192. raise ValueError(f"MCP tool {name} already exists")
  193. if "unique_mcp_provider_server_url" in error_msg:
  194. raise ValueError(f"MCP tool {server_url} already exists")
  195. if "unique_mcp_provider_server_identifier" in error_msg:
  196. raise ValueError(f"MCP tool {server_identifier} already exists")
  197. raise
  198. except Exception:
  199. db.session.rollback()
  200. raise
  201. @classmethod
  202. def update_mcp_provider_credentials(
  203. cls, mcp_provider: MCPToolProvider, credentials: dict[str, Any], authed: bool = False
  204. ):
  205. provider_controller = MCPToolProviderController.from_db(mcp_provider)
  206. tool_configuration = ProviderConfigEncrypter(
  207. tenant_id=mcp_provider.tenant_id,
  208. config=list(provider_controller.get_credentials_schema()), # ty: ignore [invalid-argument-type]
  209. provider_config_cache=NoOpProviderCredentialCache(),
  210. )
  211. credentials = tool_configuration.encrypt(credentials)
  212. mcp_provider.updated_at = datetime.now()
  213. mcp_provider.encrypted_credentials = json.dumps({**mcp_provider.credentials, **credentials})
  214. mcp_provider.authed = authed
  215. if not authed:
  216. mcp_provider.tools = "[]"
  217. db.session.commit()
  218. @classmethod
  219. def _re_connect_mcp_provider(cls, server_url: str, provider_id: str, tenant_id: str):
  220. try:
  221. with MCPClient(
  222. server_url,
  223. provider_id,
  224. tenant_id,
  225. authed=False,
  226. for_list=True,
  227. ) as mcp_client:
  228. tools = mcp_client.list_tools()
  229. return {
  230. "authed": True,
  231. "tools": json.dumps([tool.model_dump() for tool in tools]),
  232. "encrypted_credentials": "{}",
  233. }
  234. except MCPAuthError:
  235. return {"authed": False, "tools": "[]", "encrypted_credentials": "{}"}
  236. except MCPError as e:
  237. raise ValueError(f"Failed to re-connect MCP server: {e}") from e