Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.

mcp_tool_call.py 8.7KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. import asyncio
  2. import logging
  3. import threading
  4. import weakref
  5. from concurrent.futures import ThreadPoolExecutor
  6. from string import Template
  7. from typing import Any, Literal
  8. from typing_extensions import override
  9. from api.db import MCPServerType
  10. from mcp.client.session import ClientSession
  11. from mcp.client.sse import sse_client
  12. from mcp.client.streamable_http import streamablehttp_client
  13. from mcp.types import CallToolResult, ListToolsResult, TextContent, Tool
  14. from rag.llm.chat_model import ToolCallSession
  15. MCPTaskType = Literal["list_tools", "tool_call"]
  16. MCPTask = tuple[MCPTaskType, dict[str, Any], asyncio.Queue[Any]]
  17. class MCPToolCallSession(ToolCallSession):
  18. _ALL_INSTANCES: weakref.WeakSet["MCPToolCallSession"] = weakref.WeakSet()
  19. def __init__(self, mcp_server: Any, server_variables: dict[str, Any] | None = None) -> None:
  20. self.__class__._ALL_INSTANCES.add(self)
  21. self._mcp_server = mcp_server
  22. self._server_variables = server_variables or {}
  23. self._queue = asyncio.Queue()
  24. self._close = False
  25. self._event_loop = asyncio.new_event_loop()
  26. self._thread_pool = ThreadPoolExecutor(max_workers=1)
  27. self._thread_pool.submit(self._event_loop.run_forever)
  28. asyncio.run_coroutine_threadsafe(self._mcp_server_loop(), self._event_loop)
  29. async def _mcp_server_loop(self) -> None:
  30. url = self._mcp_server.url.strip()
  31. raw_headers: dict[str, str] = self._mcp_server.headers or {}
  32. headers: dict[str, str] = {}
  33. for h, v in raw_headers.items():
  34. nh = Template(h).safe_substitute(self._server_variables)
  35. nv = Template(v).safe_substitute(self._server_variables)
  36. headers[nh] = nv
  37. if self._mcp_server.server_type == MCPServerType.SSE:
  38. # SSE transport
  39. async with sse_client(url, headers) as stream:
  40. async with ClientSession(*stream) as client_session:
  41. try:
  42. await asyncio.wait_for(client_session.initialize(), timeout=5)
  43. logging.info("client_session initialized successfully")
  44. except asyncio.TimeoutError:
  45. logging.error(f"Timeout initializing client_session for server {self._mcp_server.id}")
  46. return
  47. await self._process_mcp_tasks(client_session)
  48. elif self._mcp_server.server_type == MCPServerType.STREAMABLE_HTTP:
  49. # Streamable HTTP transport
  50. async with streamablehttp_client(url, headers) as (read_stream, write_stream, _):
  51. async with ClientSession(read_stream, write_stream) as client_session:
  52. try:
  53. await asyncio.wait_for(client_session.initialize(), timeout=5)
  54. logging.info("client_session initialized successfully")
  55. except asyncio.TimeoutError:
  56. logging.error(f"Timeout initializing client_session for server {self._mcp_server.id}")
  57. return
  58. await asyncio.wait_for(client_session.initialize(), timeout=5)
  59. await self._process_mcp_tasks(client_session)
  60. else:
  61. raise ValueError(f"Unsupported MCP server type {self._mcp_server.server_type} id {self._mcp_server.id}")
  62. async def _process_mcp_tasks(self, client_session: ClientSession) -> None:
  63. while not self._close:
  64. try:
  65. mcp_task, arguments, result_queue = await asyncio.wait_for(self._queue.get(), timeout=1)
  66. except asyncio.TimeoutError:
  67. continue
  68. logging.debug(f"Got MCP task {mcp_task} arguments {arguments}")
  69. r: Any = None
  70. try:
  71. if mcp_task == "list_tools":
  72. r = await client_session.list_tools()
  73. elif mcp_task == "tool_call":
  74. r = await client_session.call_tool(**arguments)
  75. else:
  76. r = ValueError(f"Unknown MCP task {mcp_task}")
  77. except Exception as e:
  78. r = e
  79. await result_queue.put(r)
  80. async def _call_mcp_server(self, task_type: MCPTaskType, timeout: float = 8, **kwargs) -> Any:
  81. results = asyncio.Queue()
  82. await self._queue.put((task_type, kwargs, results))
  83. try:
  84. result: CallToolResult | Exception = await asyncio.wait_for(results.get(), timeout=timeout)
  85. except asyncio.TimeoutError:
  86. raise TimeoutError(f"MCP task '{task_type}' timeout after {timeout}s")
  87. if isinstance(result, Exception):
  88. raise result
  89. return result
  90. async def _call_mcp_tool(self, name: str, arguments: dict[str, Any]) -> str:
  91. result: CallToolResult = await self._call_mcp_server("tool_call", name=name, arguments=arguments)
  92. if result.isError:
  93. return f"MCP server error: {result.content}"
  94. # For now we only support text content
  95. if isinstance(result.content[0], TextContent):
  96. return result.content[0].text
  97. else:
  98. return f"Unsupported content type {type(result.content)}"
  99. async def _get_tools_from_mcp_server(self) -> list[Tool]:
  100. result: ListToolsResult = await self._call_mcp_server("list_tools")
  101. return result.tools
  102. def get_tools(self, timeout: float = 10) -> list[Tool]:
  103. future = asyncio.run_coroutine_threadsafe(self._get_tools_from_mcp_server(), self._event_loop)
  104. try:
  105. return future.result(timeout=timeout)
  106. except TimeoutError:
  107. logging.error(f"Timeout when fetching tools from MCP server: {self._mcp_server.id}")
  108. return []
  109. except Exception:
  110. logging.exception(f"Error fetching tools from MCP server: {self._mcp_server.id}")
  111. return []
  112. @override
  113. def tool_call(self, name: str, arguments: dict[str, Any], timeout: float = 10) -> str:
  114. future = asyncio.run_coroutine_threadsafe(self._call_mcp_tool(name, arguments), self._event_loop)
  115. try:
  116. return future.result(timeout=timeout)
  117. except TimeoutError as te:
  118. logging.error(f"Timeout calling tool '{name}' on MCP server: {self._mcp_server.id}")
  119. return f"Timeout calling tool '{name}': {te}."
  120. except Exception as e:
  121. logging.exception(f"Error calling tool '{name}' on MCP server: {self._mcp_server.id}")
  122. return f"Error calling tool '{name}': {e}."
  123. async def close(self) -> None:
  124. if self._close:
  125. return
  126. self._close = True
  127. self._event_loop.call_soon_threadsafe(self._event_loop.stop)
  128. self._thread_pool.shutdown(wait=True)
  129. self.__class__._ALL_INSTANCES.discard(self)
  130. def close_sync(self, timeout: float = 5) -> None:
  131. if not self._event_loop.is_running():
  132. logging.warning(f"Event loop already stopped for {self._mcp_server.id}")
  133. return
  134. future = asyncio.run_coroutine_threadsafe(self.close(), self._event_loop)
  135. try:
  136. future.result(timeout=timeout)
  137. except TimeoutError:
  138. logging.error(f"Timeout while closing session for server {self._mcp_server.id}")
  139. except Exception:
  140. logging.exception(f"Unexpected error during close_sync for {self._mcp_server.id}")
  141. def close_multiple_mcp_toolcall_sessions(sessions: list[MCPToolCallSession]) -> None:
  142. logging.info(f"Want to clean up {len(sessions)} MCP sessions")
  143. async def _gather_and_stop() -> None:
  144. try:
  145. await asyncio.gather(*[s.close() for s in sessions if s is not None], return_exceptions=True)
  146. finally:
  147. loop.call_soon_threadsafe(loop.stop)
  148. loop = asyncio.new_event_loop()
  149. thread = threading.Thread(target=loop.run_forever, daemon=True)
  150. thread.start()
  151. asyncio.run_coroutine_threadsafe(_gather_and_stop(), loop).result()
  152. thread.join()
  153. logging.info(f"{len(sessions)} MCP sessions has been cleaned up. {len(list(MCPToolCallSession._ALL_INSTANCES))} in global context.")
  154. def shutdown_all_mcp_sessions():
  155. """Gracefully shutdown all active MCPToolCallSession instances."""
  156. sessions = list(MCPToolCallSession._ALL_INSTANCES)
  157. if not sessions:
  158. logging.info("No MCPToolCallSession instances to close.")
  159. return
  160. logging.info(f"Shutting down {len(sessions)} MCPToolCallSession instances...")
  161. close_multiple_mcp_toolcall_sessions(sessions)
  162. logging.info("All MCPToolCallSession instances have been closed.")
  163. def mcp_tool_metadata_to_openai_tool(mcp_tool: Tool) -> dict[str, Any]:
  164. return {
  165. "type": "function",
  166. "function": {
  167. "name": mcp_tool.name,
  168. "description": mcp_tool.description,
  169. "parameters": mcp_tool.inputSchema,
  170. },
  171. }