You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

mcp_tool_call_conn.py 9.4KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. #
  2. # Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import asyncio
  17. import logging
  18. import threading
  19. import weakref
  20. from concurrent.futures import ThreadPoolExecutor
  21. from concurrent.futures import TimeoutError as FuturesTimeoutError
  22. from string import Template
  23. from typing import Any, Literal
  24. from typing_extensions import override
  25. from api.db import MCPServerType
  26. from mcp.client.session import ClientSession
  27. from mcp.client.sse import sse_client
  28. from mcp.client.streamable_http import streamablehttp_client
  29. from mcp.types import CallToolResult, ListToolsResult, TextContent, Tool
  30. from rag.llm.chat_model import ToolCallSession
  31. MCPTaskType = Literal["list_tools", "tool_call"]
  32. MCPTask = tuple[MCPTaskType, dict[str, Any], asyncio.Queue[Any]]
  33. class MCPToolCallSession(ToolCallSession):
  34. _ALL_INSTANCES: weakref.WeakSet["MCPToolCallSession"] = weakref.WeakSet()
  35. def __init__(self, mcp_server: Any, server_variables: dict[str, Any] | None = None) -> None:
  36. self.__class__._ALL_INSTANCES.add(self)
  37. self._mcp_server = mcp_server
  38. self._server_variables = server_variables or {}
  39. self._queue = asyncio.Queue()
  40. self._close = False
  41. self._event_loop = asyncio.new_event_loop()
  42. self._thread_pool = ThreadPoolExecutor(max_workers=1)
  43. self._thread_pool.submit(self._event_loop.run_forever)
  44. asyncio.run_coroutine_threadsafe(self._mcp_server_loop(), self._event_loop)
  45. async def _mcp_server_loop(self) -> None:
  46. url = self._mcp_server.url.strip()
  47. raw_headers: dict[str, str] = self._mcp_server.headers or {}
  48. headers: dict[str, str] = {}
  49. for h, v in raw_headers.items():
  50. nh = Template(h).safe_substitute(self._server_variables)
  51. nv = Template(v).safe_substitute(self._server_variables)
  52. headers[nh] = nv
  53. if self._mcp_server.server_type == MCPServerType.SSE:
  54. # SSE transport
  55. async with sse_client(url, headers) as stream:
  56. async with ClientSession(*stream) as client_session:
  57. try:
  58. await asyncio.wait_for(client_session.initialize(), timeout=5)
  59. logging.info("client_session initialized successfully")
  60. except asyncio.TimeoutError:
  61. logging.error(f"Timeout initializing client_session for server {self._mcp_server.id}")
  62. return
  63. await self._process_mcp_tasks(client_session)
  64. elif self._mcp_server.server_type == MCPServerType.STREAMABLE_HTTP:
  65. # Streamable HTTP transport
  66. async with streamablehttp_client(url, headers) as (read_stream, write_stream, _):
  67. async with ClientSession(read_stream, write_stream) as client_session:
  68. try:
  69. await asyncio.wait_for(client_session.initialize(), timeout=5)
  70. logging.info("client_session initialized successfully")
  71. except asyncio.TimeoutError:
  72. logging.error(f"Timeout initializing client_session for server {self._mcp_server.id}")
  73. return
  74. await asyncio.wait_for(client_session.initialize(), timeout=5)
  75. await self._process_mcp_tasks(client_session)
  76. else:
  77. raise ValueError(f"Unsupported MCP server type {self._mcp_server.server_type} id {self._mcp_server.id}")
  78. async def _process_mcp_tasks(self, client_session: ClientSession) -> None:
  79. while not self._close:
  80. try:
  81. mcp_task, arguments, result_queue = await asyncio.wait_for(self._queue.get(), timeout=1)
  82. except asyncio.TimeoutError:
  83. continue
  84. logging.debug(f"Got MCP task {mcp_task} arguments {arguments}")
  85. r: Any = None
  86. try:
  87. if mcp_task == "list_tools":
  88. r = await client_session.list_tools()
  89. elif mcp_task == "tool_call":
  90. r = await client_session.call_tool(**arguments)
  91. else:
  92. r = ValueError(f"Unknown MCP task {mcp_task}")
  93. except Exception as e:
  94. r = e
  95. await result_queue.put(r)
  96. async def _call_mcp_server(self, task_type: MCPTaskType, timeout: float = 8, **kwargs) -> Any:
  97. results = asyncio.Queue()
  98. await self._queue.put((task_type, kwargs, results))
  99. try:
  100. result: CallToolResult | Exception = await asyncio.wait_for(results.get(), timeout=timeout)
  101. except asyncio.TimeoutError:
  102. raise asyncio.TimeoutError(f"MCP task '{task_type}' timeout after {timeout}s")
  103. if isinstance(result, Exception):
  104. raise result
  105. return result
  106. async def _call_mcp_tool(self, name: str, arguments: dict[str, Any]) -> str:
  107. result: CallToolResult = await self._call_mcp_server("tool_call", name=name, arguments=arguments)
  108. if result.isError:
  109. return f"MCP server error: {result.content}"
  110. # For now we only support text content
  111. if isinstance(result.content[0], TextContent):
  112. return result.content[0].text
  113. else:
  114. return f"Unsupported content type {type(result.content)}"
  115. async def _get_tools_from_mcp_server(self) -> list[Tool]:
  116. result: ListToolsResult = await self._call_mcp_server("list_tools")
  117. return result.tools
  118. def get_tools(self, timeout: float = 10) -> list[Tool]:
  119. future = asyncio.run_coroutine_threadsafe(self._get_tools_from_mcp_server(), self._event_loop)
  120. try:
  121. return future.result(timeout=timeout)
  122. except FuturesTimeoutError:
  123. logging.error(f"Timeout when fetching tools from MCP server: {self._mcp_server.id} (timeout={timeout})")
  124. return []
  125. except Exception:
  126. logging.exception(f"Error fetching tools from MCP server: {self._mcp_server.id}")
  127. return []
  128. @override
  129. def tool_call(self, name: str, arguments: dict[str, Any], timeout: float = 10) -> str:
  130. future = asyncio.run_coroutine_threadsafe(self._call_mcp_tool(name, arguments), self._event_loop)
  131. try:
  132. return future.result(timeout=timeout)
  133. except FuturesTimeoutError:
  134. logging.error(f"Timeout calling tool '{name}' on MCP server: {self._mcp_server.id} (timeout={timeout})")
  135. return f"Timeout calling tool '{name}' (timeout={timeout})."
  136. except Exception as e:
  137. logging.exception(f"Error calling tool '{name}' on MCP server: {self._mcp_server.id}")
  138. return f"Error calling tool '{name}': {e}."
  139. async def close(self) -> None:
  140. if self._close:
  141. return
  142. self._close = True
  143. self._event_loop.call_soon_threadsafe(self._event_loop.stop)
  144. self._thread_pool.shutdown(wait=True)
  145. self.__class__._ALL_INSTANCES.discard(self)
  146. def close_sync(self, timeout: float = 5) -> None:
  147. if not self._event_loop.is_running():
  148. logging.warning(f"Event loop already stopped for {self._mcp_server.id}")
  149. return
  150. future = asyncio.run_coroutine_threadsafe(self.close(), self._event_loop)
  151. try:
  152. future.result(timeout=timeout)
  153. except FuturesTimeoutError:
  154. logging.error(f"Timeout while closing session for server {self._mcp_server.id} (timeout={timeout})")
  155. except Exception:
  156. logging.exception(f"Unexpected error during close_sync for {self._mcp_server.id}")
  157. def close_multiple_mcp_toolcall_sessions(sessions: list[MCPToolCallSession]) -> None:
  158. logging.info(f"Want to clean up {len(sessions)} MCP sessions")
  159. async def _gather_and_stop() -> None:
  160. try:
  161. await asyncio.gather(*[s.close() for s in sessions if s is not None], return_exceptions=True)
  162. finally:
  163. loop.call_soon_threadsafe(loop.stop)
  164. loop = asyncio.new_event_loop()
  165. thread = threading.Thread(target=loop.run_forever, daemon=True)
  166. thread.start()
  167. asyncio.run_coroutine_threadsafe(_gather_and_stop(), loop).result()
  168. thread.join()
  169. logging.info(f"{len(sessions)} MCP sessions has been cleaned up. {len(list(MCPToolCallSession._ALL_INSTANCES))} in global context.")
  170. def shutdown_all_mcp_sessions():
  171. """Gracefully shutdown all active MCPToolCallSession instances."""
  172. sessions = list(MCPToolCallSession._ALL_INSTANCES)
  173. if not sessions:
  174. logging.info("No MCPToolCallSession instances to close.")
  175. return
  176. logging.info(f"Shutting down {len(sessions)} MCPToolCallSession instances...")
  177. close_multiple_mcp_toolcall_sessions(sessions)
  178. logging.info("All MCPToolCallSession instances have been closed.")
  179. def mcp_tool_metadata_to_openai_tool(mcp_tool: Tool) -> dict[str, Any]:
  180. return {
  181. "type": "function",
  182. "function": {
  183. "name": mcp_tool.name,
  184. "description": mcp_tool.description,
  185. "parameters": mcp_tool.inputSchema,
  186. },
  187. }