You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

mcp_tools_manage_service.py 13KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329
  1. import hashlib
  2. import json
  3. from datetime import datetime
  4. from typing import Any, cast
  5. from sqlalchemy import or_
  6. from sqlalchemy.exc import IntegrityError
  7. from core.helper import encrypter
  8. from core.helper.provider_cache import NoOpProviderCredentialCache
  9. from core.mcp.error import MCPAuthError, MCPError
  10. from core.mcp.mcp_client import MCPClient
  11. from core.tools.entities.api_entities import ToolProviderApiEntity
  12. from core.tools.entities.common_entities import I18nObject
  13. from core.tools.entities.tool_entities import ToolProviderType
  14. from core.tools.mcp_tool.provider import MCPToolProviderController
  15. from core.tools.utils.encryption import ProviderConfigEncrypter
  16. from extensions.ext_database import db
  17. from models.tools import MCPToolProvider
  18. from services.tools.tools_transform_service import ToolTransformService
  19. UNCHANGED_SERVER_URL_PLACEHOLDER = "[__HIDDEN__]"
  20. class MCPToolManageService:
  21. """
  22. Service class for managing mcp tools.
  23. """
  24. @staticmethod
  25. def _encrypt_headers(headers: dict[str, str], tenant_id: str) -> dict[str, str]:
  26. """
  27. Encrypt headers using ProviderConfigEncrypter with all headers as SECRET_INPUT.
  28. Args:
  29. headers: Dictionary of headers to encrypt
  30. tenant_id: Tenant ID for encryption
  31. Returns:
  32. Dictionary with all headers encrypted
  33. """
  34. if not headers:
  35. return {}
  36. from core.entities.provider_entities import BasicProviderConfig
  37. from core.helper.provider_cache import NoOpProviderCredentialCache
  38. from core.tools.utils.encryption import create_provider_encrypter
  39. # Create dynamic config for all headers as SECRET_INPUT
  40. config = [BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=key) for key in headers]
  41. encrypter_instance, _ = create_provider_encrypter(
  42. tenant_id=tenant_id,
  43. config=config,
  44. cache=NoOpProviderCredentialCache(),
  45. )
  46. return cast(dict[str, str], encrypter_instance.encrypt(headers))
  47. @staticmethod
  48. def get_mcp_provider_by_provider_id(provider_id: str, tenant_id: str) -> MCPToolProvider:
  49. res = (
  50. db.session.query(MCPToolProvider)
  51. .where(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.id == provider_id)
  52. .first()
  53. )
  54. if not res:
  55. raise ValueError("MCP tool not found")
  56. return res
  57. @staticmethod
  58. def get_mcp_provider_by_server_identifier(server_identifier: str, tenant_id: str) -> MCPToolProvider:
  59. res = (
  60. db.session.query(MCPToolProvider)
  61. .where(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == server_identifier)
  62. .first()
  63. )
  64. if not res:
  65. raise ValueError("MCP tool not found")
  66. return res
  67. @staticmethod
  68. def create_mcp_provider(
  69. tenant_id: str,
  70. name: str,
  71. server_url: str,
  72. user_id: str,
  73. icon: str,
  74. icon_type: str,
  75. icon_background: str,
  76. server_identifier: str,
  77. timeout: float,
  78. sse_read_timeout: float,
  79. headers: dict[str, str] | None = None,
  80. ) -> ToolProviderApiEntity:
  81. server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
  82. existing_provider = (
  83. db.session.query(MCPToolProvider)
  84. .where(
  85. MCPToolProvider.tenant_id == tenant_id,
  86. or_(
  87. MCPToolProvider.name == name,
  88. MCPToolProvider.server_url_hash == server_url_hash,
  89. MCPToolProvider.server_identifier == server_identifier,
  90. ),
  91. )
  92. .first()
  93. )
  94. if existing_provider:
  95. if existing_provider.name == name:
  96. raise ValueError(f"MCP tool {name} already exists")
  97. if existing_provider.server_url_hash == server_url_hash:
  98. raise ValueError(f"MCP tool {server_url} already exists")
  99. if existing_provider.server_identifier == server_identifier:
  100. raise ValueError(f"MCP tool {server_identifier} already exists")
  101. encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url)
  102. # Encrypt headers
  103. encrypted_headers = None
  104. if headers:
  105. encrypted_headers_dict = MCPToolManageService._encrypt_headers(headers, tenant_id)
  106. encrypted_headers = json.dumps(encrypted_headers_dict)
  107. mcp_tool = MCPToolProvider(
  108. tenant_id=tenant_id,
  109. name=name,
  110. server_url=encrypted_server_url,
  111. server_url_hash=server_url_hash,
  112. user_id=user_id,
  113. authed=False,
  114. tools="[]",
  115. icon=json.dumps({"content": icon, "background": icon_background}) if icon_type == "emoji" else icon,
  116. server_identifier=server_identifier,
  117. timeout=timeout,
  118. sse_read_timeout=sse_read_timeout,
  119. encrypted_headers=encrypted_headers,
  120. )
  121. db.session.add(mcp_tool)
  122. db.session.commit()
  123. return ToolTransformService.mcp_provider_to_user_provider(mcp_tool, for_list=True)
  124. @staticmethod
  125. def retrieve_mcp_tools(tenant_id: str, for_list: bool = False) -> list[ToolProviderApiEntity]:
  126. mcp_providers = (
  127. db.session.query(MCPToolProvider)
  128. .where(MCPToolProvider.tenant_id == tenant_id)
  129. .order_by(MCPToolProvider.name)
  130. .all()
  131. )
  132. return [
  133. ToolTransformService.mcp_provider_to_user_provider(mcp_provider, for_list=for_list)
  134. for mcp_provider in mcp_providers
  135. ]
  136. @classmethod
  137. def list_mcp_tool_from_remote_server(cls, tenant_id: str, provider_id: str) -> ToolProviderApiEntity:
  138. mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
  139. server_url = mcp_provider.decrypted_server_url
  140. authed = mcp_provider.authed
  141. headers = mcp_provider.decrypted_headers
  142. timeout = mcp_provider.timeout
  143. sse_read_timeout = mcp_provider.sse_read_timeout
  144. try:
  145. with MCPClient(
  146. server_url,
  147. provider_id,
  148. tenant_id,
  149. authed=authed,
  150. for_list=True,
  151. headers=headers,
  152. timeout=timeout,
  153. sse_read_timeout=sse_read_timeout,
  154. ) as mcp_client:
  155. tools = mcp_client.list_tools()
  156. except MCPAuthError:
  157. raise ValueError("Please auth the tool first")
  158. except MCPError as e:
  159. raise ValueError(f"Failed to connect to MCP server: {e}")
  160. try:
  161. mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
  162. mcp_provider.tools = json.dumps([tool.model_dump() for tool in tools])
  163. mcp_provider.authed = True
  164. mcp_provider.updated_at = datetime.now()
  165. db.session.commit()
  166. except Exception:
  167. db.session.rollback()
  168. raise
  169. user = mcp_provider.load_user()
  170. return ToolProviderApiEntity(
  171. id=mcp_provider.id,
  172. name=mcp_provider.name,
  173. tools=ToolTransformService.mcp_tool_to_user_tool(mcp_provider, tools),
  174. type=ToolProviderType.MCP,
  175. icon=mcp_provider.icon,
  176. author=user.name if user else "Anonymous",
  177. server_url=mcp_provider.masked_server_url,
  178. updated_at=int(mcp_provider.updated_at.timestamp()),
  179. description=I18nObject(en_US="", zh_Hans=""),
  180. label=I18nObject(en_US=mcp_provider.name, zh_Hans=mcp_provider.name),
  181. plugin_unique_identifier=mcp_provider.server_identifier,
  182. )
  183. @classmethod
  184. def delete_mcp_tool(cls, tenant_id: str, provider_id: str):
  185. mcp_tool = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
  186. db.session.delete(mcp_tool)
  187. db.session.commit()
  188. @classmethod
  189. def update_mcp_provider(
  190. cls,
  191. tenant_id: str,
  192. provider_id: str,
  193. name: str,
  194. server_url: str,
  195. icon: str,
  196. icon_type: str,
  197. icon_background: str,
  198. server_identifier: str,
  199. timeout: float | None = None,
  200. sse_read_timeout: float | None = None,
  201. headers: dict[str, str] | None = None,
  202. ):
  203. mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
  204. reconnect_result = None
  205. encrypted_server_url = None
  206. server_url_hash = None
  207. if UNCHANGED_SERVER_URL_PLACEHOLDER not in server_url:
  208. encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url)
  209. server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
  210. if server_url_hash != mcp_provider.server_url_hash:
  211. reconnect_result = cls._re_connect_mcp_provider(server_url, provider_id, tenant_id)
  212. try:
  213. mcp_provider.updated_at = datetime.now()
  214. mcp_provider.name = name
  215. mcp_provider.icon = (
  216. json.dumps({"content": icon, "background": icon_background}) if icon_type == "emoji" else icon
  217. )
  218. mcp_provider.server_identifier = server_identifier
  219. if encrypted_server_url is not None and server_url_hash is not None:
  220. mcp_provider.server_url = encrypted_server_url
  221. mcp_provider.server_url_hash = server_url_hash
  222. if reconnect_result:
  223. mcp_provider.authed = reconnect_result["authed"]
  224. mcp_provider.tools = reconnect_result["tools"]
  225. mcp_provider.encrypted_credentials = reconnect_result["encrypted_credentials"]
  226. if timeout is not None:
  227. mcp_provider.timeout = timeout
  228. if sse_read_timeout is not None:
  229. mcp_provider.sse_read_timeout = sse_read_timeout
  230. if headers is not None:
  231. # Encrypt headers
  232. if headers:
  233. encrypted_headers_dict = MCPToolManageService._encrypt_headers(headers, tenant_id)
  234. mcp_provider.encrypted_headers = json.dumps(encrypted_headers_dict)
  235. else:
  236. mcp_provider.encrypted_headers = None
  237. db.session.commit()
  238. except IntegrityError as e:
  239. db.session.rollback()
  240. error_msg = str(e.orig)
  241. if "unique_mcp_provider_name" in error_msg:
  242. raise ValueError(f"MCP tool {name} already exists")
  243. if "unique_mcp_provider_server_url" in error_msg:
  244. raise ValueError(f"MCP tool {server_url} already exists")
  245. if "unique_mcp_provider_server_identifier" in error_msg:
  246. raise ValueError(f"MCP tool {server_identifier} already exists")
  247. raise
  248. except Exception:
  249. db.session.rollback()
  250. raise
  251. @classmethod
  252. def update_mcp_provider_credentials(
  253. cls, mcp_provider: MCPToolProvider, credentials: dict[str, Any], authed: bool = False
  254. ):
  255. provider_controller = MCPToolProviderController.from_db(mcp_provider)
  256. tool_configuration = ProviderConfigEncrypter(
  257. tenant_id=mcp_provider.tenant_id,
  258. config=list(provider_controller.get_credentials_schema()), # ty: ignore [invalid-argument-type]
  259. provider_config_cache=NoOpProviderCredentialCache(),
  260. )
  261. credentials = tool_configuration.encrypt(credentials)
  262. mcp_provider.updated_at = datetime.now()
  263. mcp_provider.encrypted_credentials = json.dumps({**mcp_provider.credentials, **credentials})
  264. mcp_provider.authed = authed
  265. if not authed:
  266. mcp_provider.tools = "[]"
  267. db.session.commit()
  268. @classmethod
  269. def _re_connect_mcp_provider(cls, server_url: str, provider_id: str, tenant_id: str):
  270. # Get the existing provider to access headers and timeout settings
  271. mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
  272. headers = mcp_provider.decrypted_headers
  273. timeout = mcp_provider.timeout
  274. sse_read_timeout = mcp_provider.sse_read_timeout
  275. try:
  276. with MCPClient(
  277. server_url,
  278. provider_id,
  279. tenant_id,
  280. authed=False,
  281. for_list=True,
  282. headers=headers,
  283. timeout=timeout,
  284. sse_read_timeout=sse_read_timeout,
  285. ) as mcp_client:
  286. tools = mcp_client.list_tools()
  287. return {
  288. "authed": True,
  289. "tools": json.dumps([tool.model_dump() for tool in tools]),
  290. "encrypted_credentials": "{}",
  291. }
  292. except MCPAuthError:
  293. return {"authed": False, "tools": "[]", "encrypted_credentials": "{}"}
  294. except MCPError as e:
  295. raise ValueError(f"Failed to re-connect MCP server: {e}") from e