Du kannst nicht mehr als 25 Themen auswählen Themen müssen mit entweder einem Buchstaben oder einer Ziffer beginnen. Sie können Bindestriche („-“) enthalten und bis zu 35 Zeichen lang sein.

mcp_tool_call_conn.py 11KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250
  1. #
  2. # Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import asyncio
  17. import logging
  18. import threading
  19. import weakref
  20. from concurrent.futures import ThreadPoolExecutor
  21. from concurrent.futures import TimeoutError as FuturesTimeoutError
  22. from string import Template
  23. from typing import Any, Literal
  24. from typing_extensions import override
  25. from api.db import MCPServerType
  26. from mcp.client.session import ClientSession
  27. from mcp.client.sse import sse_client
  28. from mcp.client.streamable_http import streamablehttp_client
  29. from mcp.types import CallToolResult, ListToolsResult, TextContent, Tool
  30. from rag.llm.chat_model import ToolCallSession
  31. MCPTaskType = Literal["list_tools", "tool_call"]
  32. MCPTask = tuple[MCPTaskType, dict[str, Any], asyncio.Queue[Any]]
  33. class MCPToolCallSession(ToolCallSession):
  34. _ALL_INSTANCES: weakref.WeakSet["MCPToolCallSession"] = weakref.WeakSet()
  35. def __init__(self, mcp_server: Any, server_variables: dict[str, Any] | None = None) -> None:
  36. self.__class__._ALL_INSTANCES.add(self)
  37. self._mcp_server = mcp_server
  38. self._server_variables = server_variables or {}
  39. self._queue = asyncio.Queue()
  40. self._close = False
  41. self._event_loop = asyncio.new_event_loop()
  42. self._thread_pool = ThreadPoolExecutor(max_workers=1)
  43. self._thread_pool.submit(self._event_loop.run_forever)
  44. asyncio.run_coroutine_threadsafe(self._mcp_server_loop(), self._event_loop)
  45. async def _mcp_server_loop(self) -> None:
  46. url = self._mcp_server.url.strip()
  47. raw_headers: dict[str, str] = self._mcp_server.headers or {}
  48. headers: dict[str, str] = {}
  49. for h, v in raw_headers.items():
  50. nh = Template(h).safe_substitute(self._server_variables)
  51. nv = Template(v).safe_substitute(self._server_variables)
  52. headers[nh] = nv
  53. if self._mcp_server.server_type == MCPServerType.SSE:
  54. # SSE transport
  55. try:
  56. async with sse_client(url, headers) as stream:
  57. async with ClientSession(*stream) as client_session:
  58. try:
  59. await asyncio.wait_for(client_session.initialize(), timeout=5)
  60. logging.info("client_session initialized successfully")
  61. await self._process_mcp_tasks(client_session)
  62. except asyncio.TimeoutError:
  63. msg = f"Timeout initializing client_session for server {self._mcp_server.id}"
  64. logging.error(msg)
  65. await self._process_mcp_tasks(None, msg)
  66. except Exception:
  67. msg = "Connection failed (possibly due to auth error). Please check authentication settings first"
  68. await self._process_mcp_tasks(None, msg)
  69. elif self._mcp_server.server_type == MCPServerType.STREAMABLE_HTTP:
  70. # Streamable HTTP transport
  71. try:
  72. async with streamablehttp_client(url, headers) as (read_stream, write_stream, _):
  73. async with ClientSession(read_stream, write_stream) as client_session:
  74. try:
  75. await asyncio.wait_for(client_session.initialize(), timeout=5)
  76. logging.info("client_session initialized successfully")
  77. await self._process_mcp_tasks(client_session)
  78. except asyncio.TimeoutError:
  79. msg = f"Timeout initializing client_session for server {self._mcp_server.id}"
  80. logging.error(msg)
  81. await self._process_mcp_tasks(None, msg)
  82. except Exception:
  83. msg = "Connection failed (possibly due to auth error). Please check authentication settings first"
  84. await self._process_mcp_tasks(None, msg)
  85. else:
  86. await self._process_mcp_tasks(None, f"Unsupported MCP server type: {self._mcp_server.server_type}, id: {self._mcp_server.id}")
  87. async def _process_mcp_tasks(self, client_session: ClientSession | None, error_message: str | None = None) -> None:
  88. while not self._close:
  89. try:
  90. mcp_task, arguments, result_queue = await asyncio.wait_for(self._queue.get(), timeout=1)
  91. except asyncio.TimeoutError:
  92. continue
  93. logging.debug(f"Got MCP task {mcp_task} arguments {arguments}")
  94. r: Any = None
  95. if not client_session or error_message:
  96. r = ValueError(error_message)
  97. await result_queue.put(r)
  98. continue
  99. try:
  100. if mcp_task == "list_tools":
  101. r = await client_session.list_tools()
  102. elif mcp_task == "tool_call":
  103. r = await client_session.call_tool(**arguments)
  104. else:
  105. r = ValueError(f"Unknown MCP task {mcp_task}")
  106. except Exception as e:
  107. r = e
  108. await result_queue.put(r)
  109. async def _call_mcp_server(self, task_type: MCPTaskType, timeout: float | int = 8, **kwargs) -> Any:
  110. results = asyncio.Queue()
  111. await self._queue.put((task_type, kwargs, results))
  112. try:
  113. result: CallToolResult | Exception = await asyncio.wait_for(results.get(), timeout=timeout)
  114. if isinstance(result, Exception):
  115. raise result
  116. return result
  117. except asyncio.TimeoutError:
  118. raise asyncio.TimeoutError(f"MCP task '{task_type}' timeout after {timeout}s")
  119. except Exception:
  120. raise
  121. async def _call_mcp_tool(self, name: str, arguments: dict[str, Any], timeout: float | int = 10) -> str:
  122. result: CallToolResult = await self._call_mcp_server("tool_call", name=name, arguments=arguments, timeout=timeout)
  123. if result.isError:
  124. return f"MCP server error: {result.content}"
  125. # For now we only support text content
  126. if isinstance(result.content[0], TextContent):
  127. return result.content[0].text
  128. else:
  129. return f"Unsupported content type {type(result.content)}"
  130. async def _get_tools_from_mcp_server(self, timeout: float | int = 8) -> list[Tool]:
  131. try:
  132. result: ListToolsResult = await self._call_mcp_server("list_tools", timeout=timeout)
  133. return result.tools
  134. except Exception:
  135. raise
  136. def get_tools(self, timeout: float | int = 10) -> list[Tool]:
  137. future = asyncio.run_coroutine_threadsafe(self._get_tools_from_mcp_server(timeout=timeout), self._event_loop)
  138. try:
  139. return future.result(timeout=timeout)
  140. except FuturesTimeoutError:
  141. msg = f"Timeout when fetching tools from MCP server: {self._mcp_server.id} (timeout={timeout})"
  142. logging.error(msg)
  143. raise RuntimeError(msg)
  144. except Exception:
  145. logging.exception(f"Error fetching tools from MCP server: {self._mcp_server.id}")
  146. raise
  147. @override
  148. def tool_call(self, name: str, arguments: dict[str, Any], timeout: float | int = 10) -> str:
  149. future = asyncio.run_coroutine_threadsafe(self._call_mcp_tool(name, arguments), self._event_loop)
  150. try:
  151. return future.result(timeout=timeout)
  152. except FuturesTimeoutError:
  153. logging.error(f"Timeout calling tool '{name}' on MCP server: {self._mcp_server.id} (timeout={timeout})")
  154. return f"Timeout calling tool '{name}' (timeout={timeout})."
  155. except Exception as e:
  156. logging.exception(f"Error calling tool '{name}' on MCP server: {self._mcp_server.id}")
  157. return f"Error calling tool '{name}': {e}."
  158. async def close(self) -> None:
  159. if self._close:
  160. return
  161. self._close = True
  162. self._event_loop.call_soon_threadsafe(self._event_loop.stop)
  163. self._thread_pool.shutdown(wait=True)
  164. self.__class__._ALL_INSTANCES.discard(self)
  165. def close_sync(self, timeout: float | int = 5) -> None:
  166. if not self._event_loop.is_running():
  167. logging.warning(f"Event loop already stopped for {self._mcp_server.id}")
  168. return
  169. future = asyncio.run_coroutine_threadsafe(self.close(), self._event_loop)
  170. try:
  171. future.result(timeout=timeout)
  172. except FuturesTimeoutError:
  173. logging.error(f"Timeout while closing session for server {self._mcp_server.id} (timeout={timeout})")
  174. except Exception:
  175. logging.exception(f"Unexpected error during close_sync for {self._mcp_server.id}")
  176. def close_multiple_mcp_toolcall_sessions(sessions: list[MCPToolCallSession]) -> None:
  177. logging.info(f"Want to clean up {len(sessions)} MCP sessions")
  178. async def _gather_and_stop() -> None:
  179. try:
  180. await asyncio.gather(*[s.close() for s in sessions if s is not None], return_exceptions=True)
  181. finally:
  182. loop.call_soon_threadsafe(loop.stop)
  183. loop = asyncio.new_event_loop()
  184. thread = threading.Thread(target=loop.run_forever, daemon=True)
  185. thread.start()
  186. asyncio.run_coroutine_threadsafe(_gather_and_stop(), loop).result()
  187. thread.join()
  188. logging.info(f"{len(sessions)} MCP sessions has been cleaned up. {len(list(MCPToolCallSession._ALL_INSTANCES))} in global context.")
  189. def shutdown_all_mcp_sessions():
  190. """Gracefully shutdown all active MCPToolCallSession instances."""
  191. sessions = list(MCPToolCallSession._ALL_INSTANCES)
  192. if not sessions:
  193. logging.info("No MCPToolCallSession instances to close.")
  194. return
  195. logging.info(f"Shutting down {len(sessions)} MCPToolCallSession instances...")
  196. close_multiple_mcp_toolcall_sessions(sessions)
  197. logging.info("All MCPToolCallSession instances have been closed.")
  198. def mcp_tool_metadata_to_openai_tool(mcp_tool: Tool) -> dict[str, Any]:
  199. return {
  200. "type": "function",
  201. "function": {
  202. "name": mcp_tool.name,
  203. "description": mcp_tool.description,
  204. "parameters": mcp_tool.inputSchema,
  205. },
  206. }