You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

mcp_tools_mange_service.py 9.1KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. import hashlib
  2. import json
  3. from datetime import datetime
  4. from typing import Any
  5. from sqlalchemy import or_
  6. from sqlalchemy.exc import IntegrityError
  7. from core.helper import encrypter
  8. from core.mcp.error import MCPAuthError, MCPError
  9. from core.mcp.mcp_client import MCPClient
  10. from core.tools.entities.api_entities import ToolProviderApiEntity
  11. from core.tools.entities.common_entities import I18nObject
  12. from core.tools.entities.tool_entities import ToolProviderType
  13. from core.tools.mcp_tool.provider import MCPToolProviderController
  14. from core.tools.utils.configuration import ProviderConfigEncrypter
  15. from extensions.ext_database import db
  16. from models.tools import MCPToolProvider
  17. from services.tools.tools_transform_service import ToolTransformService
  18. UNCHANGED_SERVER_URL_PLACEHOLDER = "[__HIDDEN__]"
  19. class MCPToolManageService:
  20. """
  21. Service class for managing mcp tools.
  22. """
  23. @staticmethod
  24. def get_mcp_provider_by_provider_id(provider_id: str, tenant_id: str) -> MCPToolProvider:
  25. res = (
  26. db.session.query(MCPToolProvider)
  27. .filter(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.id == provider_id)
  28. .first()
  29. )
  30. if not res:
  31. raise ValueError("MCP tool not found")
  32. return res
  33. @staticmethod
  34. def get_mcp_provider_by_server_identifier(server_identifier: str, tenant_id: str) -> MCPToolProvider:
  35. res = (
  36. db.session.query(MCPToolProvider)
  37. .filter(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == server_identifier)
  38. .first()
  39. )
  40. if not res:
  41. raise ValueError("MCP tool not found")
  42. return res
  43. @staticmethod
  44. def create_mcp_provider(
  45. tenant_id: str,
  46. name: str,
  47. server_url: str,
  48. user_id: str,
  49. icon: str,
  50. icon_type: str,
  51. icon_background: str,
  52. server_identifier: str,
  53. ) -> ToolProviderApiEntity:
  54. server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
  55. existing_provider = (
  56. db.session.query(MCPToolProvider)
  57. .filter(
  58. MCPToolProvider.tenant_id == tenant_id,
  59. or_(
  60. MCPToolProvider.name == name,
  61. MCPToolProvider.server_url_hash == server_url_hash,
  62. MCPToolProvider.server_identifier == server_identifier,
  63. ),
  64. )
  65. .first()
  66. )
  67. if existing_provider:
  68. if existing_provider.name == name:
  69. raise ValueError(f"MCP tool {name} already exists")
  70. elif existing_provider.server_url_hash == server_url_hash:
  71. raise ValueError(f"MCP tool {server_url} already exists")
  72. elif existing_provider.server_identifier == server_identifier:
  73. raise ValueError(f"MCP tool {server_identifier} already exists")
  74. encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url)
  75. mcp_tool = MCPToolProvider(
  76. tenant_id=tenant_id,
  77. name=name,
  78. server_url=encrypted_server_url,
  79. server_url_hash=server_url_hash,
  80. user_id=user_id,
  81. authed=False,
  82. tools="[]",
  83. icon=json.dumps({"content": icon, "background": icon_background}) if icon_type == "emoji" else icon,
  84. server_identifier=server_identifier,
  85. )
  86. db.session.add(mcp_tool)
  87. db.session.commit()
  88. return ToolTransformService.mcp_provider_to_user_provider(mcp_tool, for_list=True)
  89. @staticmethod
  90. def retrieve_mcp_tools(tenant_id: str, for_list: bool = False) -> list[ToolProviderApiEntity]:
  91. mcp_providers = (
  92. db.session.query(MCPToolProvider)
  93. .filter(MCPToolProvider.tenant_id == tenant_id)
  94. .order_by(MCPToolProvider.name)
  95. .all()
  96. )
  97. return [
  98. ToolTransformService.mcp_provider_to_user_provider(mcp_provider, for_list=for_list)
  99. for mcp_provider in mcp_providers
  100. ]
  101. @classmethod
  102. def list_mcp_tool_from_remote_server(cls, tenant_id: str, provider_id: str):
  103. mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
  104. try:
  105. with MCPClient(
  106. mcp_provider.decrypted_server_url, provider_id, tenant_id, authed=mcp_provider.authed, for_list=True
  107. ) as mcp_client:
  108. tools = mcp_client.list_tools()
  109. except MCPAuthError as e:
  110. raise ValueError("Please auth the tool first")
  111. except MCPError as e:
  112. raise ValueError(f"Failed to connect to MCP server: {e}")
  113. mcp_provider.tools = json.dumps([tool.model_dump() for tool in tools])
  114. mcp_provider.authed = True
  115. mcp_provider.updated_at = datetime.now()
  116. db.session.commit()
  117. user = mcp_provider.load_user()
  118. return ToolProviderApiEntity(
  119. id=mcp_provider.id,
  120. name=mcp_provider.name,
  121. tools=ToolTransformService.mcp_tool_to_user_tool(mcp_provider, tools),
  122. type=ToolProviderType.MCP,
  123. icon=mcp_provider.icon,
  124. author=user.name if user else "Anonymous",
  125. server_url=mcp_provider.masked_server_url,
  126. updated_at=int(mcp_provider.updated_at.timestamp()),
  127. description=I18nObject(en_US="", zh_Hans=""),
  128. label=I18nObject(en_US=mcp_provider.name, zh_Hans=mcp_provider.name),
  129. plugin_unique_identifier=mcp_provider.server_identifier,
  130. )
  131. @classmethod
  132. def delete_mcp_tool(cls, tenant_id: str, provider_id: str):
  133. mcp_tool = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
  134. db.session.delete(mcp_tool)
  135. db.session.commit()
  136. @classmethod
  137. def update_mcp_provider(
  138. cls,
  139. tenant_id: str,
  140. provider_id: str,
  141. name: str,
  142. server_url: str,
  143. icon: str,
  144. icon_type: str,
  145. icon_background: str,
  146. server_identifier: str,
  147. ):
  148. mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
  149. mcp_provider.updated_at = datetime.now()
  150. mcp_provider.name = name
  151. mcp_provider.icon = (
  152. json.dumps({"content": icon, "background": icon_background}) if icon_type == "emoji" else icon
  153. )
  154. mcp_provider.server_identifier = server_identifier
  155. if UNCHANGED_SERVER_URL_PLACEHOLDER not in server_url:
  156. encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url)
  157. mcp_provider.server_url = encrypted_server_url
  158. server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
  159. if server_url_hash != mcp_provider.server_url_hash:
  160. cls._re_connect_mcp_provider(mcp_provider, provider_id, tenant_id)
  161. mcp_provider.server_url_hash = server_url_hash
  162. try:
  163. db.session.commit()
  164. except IntegrityError as e:
  165. db.session.rollback()
  166. error_msg = str(e.orig)
  167. if "unique_mcp_provider_name" in error_msg:
  168. raise ValueError(f"MCP tool {name} already exists")
  169. elif "unique_mcp_provider_server_url" in error_msg:
  170. raise ValueError(f"MCP tool {server_url} already exists")
  171. elif "unique_mcp_provider_server_identifier" in error_msg:
  172. raise ValueError(f"MCP tool {server_identifier} already exists")
  173. else:
  174. raise
  175. @classmethod
  176. def update_mcp_provider_credentials(
  177. cls, mcp_provider: MCPToolProvider, credentials: dict[str, Any], authed: bool = False
  178. ):
  179. provider_controller = MCPToolProviderController._from_db(mcp_provider)
  180. tool_configuration = ProviderConfigEncrypter(
  181. tenant_id=mcp_provider.tenant_id,
  182. config=list(provider_controller.get_credentials_schema()),
  183. provider_type=provider_controller.provider_type.value,
  184. provider_identity=provider_controller.provider_id,
  185. )
  186. credentials = tool_configuration.encrypt(credentials)
  187. mcp_provider.updated_at = datetime.now()
  188. mcp_provider.encrypted_credentials = json.dumps({**mcp_provider.credentials, **credentials})
  189. mcp_provider.authed = authed
  190. if not authed:
  191. mcp_provider.tools = "[]"
  192. db.session.commit()
  193. @classmethod
  194. def _re_connect_mcp_provider(cls, mcp_provider: MCPToolProvider, provider_id: str, tenant_id: str):
  195. """re-connect mcp provider"""
  196. try:
  197. with MCPClient(
  198. mcp_provider.decrypted_server_url,
  199. provider_id,
  200. tenant_id,
  201. authed=False,
  202. for_list=True,
  203. ) as mcp_client:
  204. tools = mcp_client.list_tools()
  205. mcp_provider.authed = True
  206. mcp_provider.tools = json.dumps([tool.model_dump() for tool in tools])
  207. except MCPAuthError:
  208. mcp_provider.authed = False
  209. mcp_provider.tools = "[]"
  210. except MCPError as e:
  211. raise ValueError(f"Failed to re-connect MCP server: {e}") from e
  212. # reset credentials
  213. mcp_provider.encrypted_credentials = "{}"