You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. import asyncio
  2. from concurrent.futures import ThreadPoolExecutor
  3. import logging
  4. from string import Template
  5. from typing import Any, Literal
  6. from typing_extensions import override
  7. from mcp.client.session import ClientSession
  8. from mcp.client.sse import sse_client
  9. from mcp.client.streamable_http import streamablehttp_client
  10. from mcp.types import CallToolResult, ListToolsResult, TextContent, Tool
  11. from api.db import MCPServerType
  12. from rag.llm.chat_model import ToolCallSession
  13. MCPTaskType = Literal["list_tools", "tool_call", "stop"]
  14. MCPTask = tuple[MCPTaskType, dict[str, Any], asyncio.Queue[Any]]
  15. class MCPToolCallSession(ToolCallSession):
  16. _EVENT_LOOP = asyncio.new_event_loop()
  17. _THREAD_POOL = ThreadPoolExecutor(max_workers=1)
  18. _mcp_server: Any
  19. _server_variables: dict[str, Any]
  20. _queue: asyncio.Queue[MCPTask]
  21. _stop = False
  22. @classmethod
  23. def _init_thread_pool(cls) -> None:
  24. cls._THREAD_POOL.submit(cls._EVENT_LOOP.run_forever)
  25. def __init__(self, mcp_server: Any, server_variables: dict[str, Any] | None = None) -> None:
  26. self._mcp_server = mcp_server
  27. self._server_variables = server_variables or {}
  28. self._queue = asyncio.Queue()
  29. asyncio.run_coroutine_threadsafe(self._mcp_server_loop(), MCPToolCallSession._EVENT_LOOP)
  30. async def _mcp_server_loop(self) -> None:
  31. url = self._mcp_server.url
  32. raw_headers: dict[str, str] = self._mcp_server.headers or {}
  33. headers: dict[str, str] = {}
  34. for h, v in raw_headers.items():
  35. nh = Template(h).safe_substitute(self._server_variables)
  36. nv = Template(v).safe_substitute(self._server_variables)
  37. headers[nh] = nv
  38. _streams_source: Any
  39. if self._mcp_server.server_type == MCPServerType.SSE:
  40. _streams_source = sse_client(url, headers)
  41. elif self._mcp_server.server_type == MCPServerType.StreamableHttp:
  42. _streams_source = streamablehttp_client(url, headers)
  43. else:
  44. raise ValueError(f"Unsupported MCP server type {self._mcp_server.server_type} id {self._mcp_server.id}")
  45. async with _streams_source as streams:
  46. async with ClientSession(*streams) as client_session:
  47. await client_session.initialize()
  48. while not self._stop:
  49. mcp_task, arguments, result_queue = await self._queue.get()
  50. logging.debug(f"Got MCP task {mcp_task} arguments {arguments}")
  51. r: Any
  52. try:
  53. if mcp_task == "list_tools":
  54. r = await client_session.list_tools()
  55. elif mcp_task == "tool_call":
  56. r = await client_session.call_tool(**arguments)
  57. elif mcp_task == "stop":
  58. logging.debug(f"Shutting down MCPToolCallSession for server {self._mcp_server.id}")
  59. self._stop = True
  60. continue
  61. else:
  62. r = ValueError(f"MCPToolCallSession for server {self._mcp_server.id} received an unknown task {mcp_task}")
  63. except Exception as e:
  64. r = e
  65. await result_queue.put(r)
  66. async def _call_mcp_server(self, task_type: MCPTaskType, **kwargs) -> Any:
  67. results = asyncio.Queue()
  68. await self._queue.put((task_type, kwargs, results))
  69. result: CallToolResult | Exception = await results.get()
  70. if isinstance(result, Exception):
  71. raise result
  72. return result
  73. async def _call_mcp_tool(self, name: str, arguments: dict[str, Any]) -> str:
  74. result: CallToolResult = await self._call_mcp_server("tool_call", name=name, arguments=arguments)
  75. if result.isError:
  76. return f"MCP server error: {result.content}"
  77. # For now we only support text content
  78. if isinstance(result.content[0], TextContent):
  79. return result.content[0].text
  80. else:
  81. return f"Unsupported content type {type(result.content)}"
  82. async def _get_tools_from_mcp_server(self) -> list[Tool]:
  83. # For now we only fetch the first page of tools
  84. result: ListToolsResult = await self._call_mcp_server("list_tools")
  85. return result.tools
  86. def get_tools(self) -> list[Tool]:
  87. return asyncio.run_coroutine_threadsafe(self._get_tools_from_mcp_server(), MCPToolCallSession._EVENT_LOOP).result()
  88. @override
  89. def tool_call(self, name: str, arguments: dict[str, Any]) -> str:
  90. return asyncio.run_coroutine_threadsafe(self._call_mcp_tool(name, arguments), MCPToolCallSession._EVENT_LOOP).result()
  91. async def close(self) -> None:
  92. await self._call_mcp_server("stop")
  93. def close_sync(self) -> None:
  94. asyncio.run_coroutine_threadsafe(self.close(), MCPToolCallSession._EVENT_LOOP).result()
  95. MCPToolCallSession._init_thread_pool()
  96. def close_multiple_mcp_toolcall_sessions(sessions: list[MCPToolCallSession]) -> None:
  97. async def _gather() -> None:
  98. await asyncio.gather(*[s.close() for s in sessions], return_exceptions=True)
  99. asyncio.run_coroutine_threadsafe(_gather(), MCPToolCallSession._EVENT_LOOP).result()
  100. def mcp_tool_metadata_to_openai_tool(mcp_tool: Tool) -> dict[str, Any]:
  101. return {
  102. "type": "function",
  103. "function": {
  104. "name": mcp_tool.name,
  105. "description": mcp_tool.description,
  106. "parameters": mcp_tool.inputSchema,
  107. },
  108. }