|
|
|
@@ -1,45 +1,43 @@ |
|
|
|
import asyncio |
|
|
|
from concurrent.futures import ThreadPoolExecutor |
|
|
|
import logging |
|
|
|
import threading |
|
|
|
import weakref |
|
|
|
from concurrent.futures import ThreadPoolExecutor |
|
|
|
from string import Template |
|
|
|
from typing import Any, Literal |
|
|
|
|
|
|
|
from typing_extensions import override |
|
|
|
|
|
|
|
from api.db import MCPServerType |
|
|
|
from mcp.client.session import ClientSession |
|
|
|
from mcp.client.sse import sse_client |
|
|
|
from mcp.client.streamable_http import streamablehttp_client |
|
|
|
from mcp.types import CallToolResult, ListToolsResult, TextContent, Tool |
|
|
|
|
|
|
|
from api.db import MCPServerType |
|
|
|
from rag.llm.chat_model import ToolCallSession |
|
|
|
|
|
|
|
|
|
|
|
MCPTaskType = Literal["list_tools", "tool_call", "stop"] |
|
|
|
MCPTaskType = Literal["list_tools", "tool_call"] |
|
|
|
MCPTask = tuple[MCPTaskType, dict[str, Any], asyncio.Queue[Any]] |
|
|
|
|
|
|
|
|
|
|
|
class MCPToolCallSession(ToolCallSession): |
|
|
|
_EVENT_LOOP = asyncio.new_event_loop() |
|
|
|
_THREAD_POOL = ThreadPoolExecutor(max_workers=1) |
|
|
|
|
|
|
|
_mcp_server: Any |
|
|
|
_server_variables: dict[str, Any] |
|
|
|
_queue: asyncio.Queue[MCPTask] |
|
|
|
_stop = False |
|
|
|
|
|
|
|
@classmethod |
|
|
|
def _init_thread_pool(cls) -> None: |
|
|
|
cls._THREAD_POOL.submit(cls._EVENT_LOOP.run_forever) |
|
|
|
_ALL_INSTANCES: weakref.WeakSet["MCPToolCallSession"] = weakref.WeakSet() |
|
|
|
|
|
|
|
def __init__(self, mcp_server: Any, server_variables: dict[str, Any] | None = None) -> None: |
|
|
|
self.__class__._ALL_INSTANCES.add(self) |
|
|
|
|
|
|
|
self._mcp_server = mcp_server |
|
|
|
self._server_variables = server_variables or {} |
|
|
|
self._queue = asyncio.Queue() |
|
|
|
self._close = False |
|
|
|
|
|
|
|
self._event_loop = asyncio.new_event_loop() |
|
|
|
self._thread_pool = ThreadPoolExecutor(max_workers=1) |
|
|
|
self._thread_pool.submit(self._event_loop.run_forever) |
|
|
|
|
|
|
|
asyncio.run_coroutine_threadsafe(self._mcp_server_loop(), MCPToolCallSession._EVENT_LOOP) |
|
|
|
asyncio.run_coroutine_threadsafe(self._mcp_server_loop(), self._event_loop) |
|
|
|
|
|
|
|
async def _mcp_server_loop(self) -> None: |
|
|
|
url = self._mcp_server.url |
|
|
|
url = self._mcp_server.url.strip() |
|
|
|
raw_headers: dict[str, str] = self._mcp_server.headers or {} |
|
|
|
headers: dict[str, str] = {} |
|
|
|
|
|
|
|
@@ -48,45 +46,62 @@ class MCPToolCallSession(ToolCallSession): |
|
|
|
nv = Template(v).safe_substitute(self._server_variables) |
|
|
|
headers[nh] = nv |
|
|
|
|
|
|
|
_streams_source: Any |
|
|
|
|
|
|
|
if self._mcp_server.server_type == MCPServerType.SSE: |
|
|
|
_streams_source = sse_client(url, headers) |
|
|
|
elif self._mcp_server.server_type == MCPServerType.StreamableHttp: |
|
|
|
_streams_source = streamablehttp_client(url, headers) |
|
|
|
# SSE transport |
|
|
|
async with sse_client(url, headers) as stream: |
|
|
|
async with ClientSession(*stream) as client_session: |
|
|
|
try: |
|
|
|
await asyncio.wait_for(client_session.initialize(), timeout=5) |
|
|
|
logging.info("client_session initialized successfully") |
|
|
|
except asyncio.TimeoutError: |
|
|
|
logging.error(f"Timeout initializing client_session for server {self._mcp_server.id}") |
|
|
|
return |
|
|
|
await self._process_mcp_tasks(client_session) |
|
|
|
|
|
|
|
elif self._mcp_server.server_type == MCPServerType.STREAMABLE_HTTP: |
|
|
|
# Streamable HTTP transport |
|
|
|
async with streamablehttp_client(url, headers) as (read_stream, write_stream, _): |
|
|
|
async with ClientSession(read_stream, write_stream) as client_session: |
|
|
|
try: |
|
|
|
await asyncio.wait_for(client_session.initialize(), timeout=5) |
|
|
|
logging.info("client_session initialized successfully") |
|
|
|
except asyncio.TimeoutError: |
|
|
|
logging.error(f"Timeout initializing client_session for server {self._mcp_server.id}") |
|
|
|
return |
|
|
|
await asyncio.wait_for(client_session.initialize(), timeout=5) |
|
|
|
await self._process_mcp_tasks(client_session) |
|
|
|
else: |
|
|
|
raise ValueError(f"Unsupported MCP server type {self._mcp_server.server_type} id {self._mcp_server.id}") |
|
|
|
|
|
|
|
async with _streams_source as streams: |
|
|
|
async with ClientSession(*streams) as client_session: |
|
|
|
await client_session.initialize() |
|
|
|
|
|
|
|
while not self._stop: |
|
|
|
mcp_task, arguments, result_queue = await self._queue.get() |
|
|
|
logging.debug(f"Got MCP task {mcp_task} arguments {arguments}") |
|
|
|
|
|
|
|
r: Any |
|
|
|
|
|
|
|
try: |
|
|
|
if mcp_task == "list_tools": |
|
|
|
r = await client_session.list_tools() |
|
|
|
elif mcp_task == "tool_call": |
|
|
|
r = await client_session.call_tool(**arguments) |
|
|
|
elif mcp_task == "stop": |
|
|
|
logging.debug(f"Shutting down MCPToolCallSession for server {self._mcp_server.id}") |
|
|
|
self._stop = True |
|
|
|
continue |
|
|
|
else: |
|
|
|
r = ValueError(f"MCPToolCallSession for server {self._mcp_server.id} received an unknown task {mcp_task}") |
|
|
|
except Exception as e: |
|
|
|
r = e |
|
|
|
|
|
|
|
await result_queue.put(r) |
|
|
|
|
|
|
|
async def _call_mcp_server(self, task_type: MCPTaskType, **kwargs) -> Any: |
|
|
|
async def _process_mcp_tasks(self, client_session: ClientSession) -> None: |
|
|
|
while not self._close: |
|
|
|
try: |
|
|
|
mcp_task, arguments, result_queue = await asyncio.wait_for(self._queue.get(), timeout=1) |
|
|
|
except asyncio.TimeoutError: |
|
|
|
continue |
|
|
|
|
|
|
|
logging.debug(f"Got MCP task {mcp_task} arguments {arguments}") |
|
|
|
|
|
|
|
r: Any = None |
|
|
|
try: |
|
|
|
if mcp_task == "list_tools": |
|
|
|
r = await client_session.list_tools() |
|
|
|
elif mcp_task == "tool_call": |
|
|
|
r = await client_session.call_tool(**arguments) |
|
|
|
else: |
|
|
|
r = ValueError(f"Unknown MCP task {mcp_task}") |
|
|
|
except Exception as e: |
|
|
|
r = e |
|
|
|
|
|
|
|
await result_queue.put(r) |
|
|
|
|
|
|
|
async def _call_mcp_server(self, task_type: MCPTaskType, timeout: float = 8, **kwargs) -> Any: |
|
|
|
results = asyncio.Queue() |
|
|
|
await self._queue.put((task_type, kwargs, results)) |
|
|
|
result: CallToolResult | Exception = await results.get() |
|
|
|
try: |
|
|
|
result: CallToolResult | Exception = await asyncio.wait_for(results.get(), timeout=timeout) |
|
|
|
except asyncio.TimeoutError: |
|
|
|
raise TimeoutError(f"MCP task '{task_type}' timeout after {timeout}s") |
|
|
|
|
|
|
|
if isinstance(result, Exception): |
|
|
|
raise result |
|
|
|
@@ -106,32 +121,84 @@ class MCPToolCallSession(ToolCallSession): |
|
|
|
return f"Unsupported content type {type(result.content)}" |
|
|
|
|
|
|
|
async def _get_tools_from_mcp_server(self) -> list[Tool]: |
|
|
|
# For now we only fetch the first page of tools |
|
|
|
result: ListToolsResult = await self._call_mcp_server("list_tools") |
|
|
|
return result.tools |
|
|
|
|
|
|
|
def get_tools(self) -> list[Tool]: |
|
|
|
return asyncio.run_coroutine_threadsafe(self._get_tools_from_mcp_server(), MCPToolCallSession._EVENT_LOOP).result() |
|
|
|
def get_tools(self, timeout: float = 10) -> list[Tool]: |
|
|
|
future = asyncio.run_coroutine_threadsafe(self._get_tools_from_mcp_server(), self._event_loop) |
|
|
|
try: |
|
|
|
return future.result(timeout=timeout) |
|
|
|
except TimeoutError: |
|
|
|
logging.error(f"Timeout when fetching tools from MCP server: {self._mcp_server.id}") |
|
|
|
return [] |
|
|
|
except Exception: |
|
|
|
logging.exception(f"Error fetching tools from MCP server: {self._mcp_server.id}") |
|
|
|
return [] |
|
|
|
|
|
|
|
@override |
|
|
|
def tool_call(self, name: str, arguments: dict[str, Any]) -> str: |
|
|
|
return asyncio.run_coroutine_threadsafe(self._call_mcp_tool(name, arguments), MCPToolCallSession._EVENT_LOOP).result() |
|
|
|
def tool_call(self, name: str, arguments: dict[str, Any], timeout: float = 10) -> str: |
|
|
|
future = asyncio.run_coroutine_threadsafe(self._call_mcp_tool(name, arguments), self._event_loop) |
|
|
|
try: |
|
|
|
return future.result(timeout=timeout) |
|
|
|
except TimeoutError as te: |
|
|
|
logging.error(f"Timeout calling tool '{name}' on MCP server: {self._mcp_server.id}") |
|
|
|
return f"Timeout calling tool '{name}': {te}." |
|
|
|
except Exception as e: |
|
|
|
logging.exception(f"Error calling tool '{name}' on MCP server: {self._mcp_server.id}") |
|
|
|
return f"Error calling tool '{name}': {e}." |
|
|
|
|
|
|
|
async def close(self) -> None: |
|
|
|
await self._call_mcp_server("stop") |
|
|
|
if self._close: |
|
|
|
return |
|
|
|
|
|
|
|
def close_sync(self) -> None: |
|
|
|
asyncio.run_coroutine_threadsafe(self.close(), MCPToolCallSession._EVENT_LOOP).result() |
|
|
|
self._close = True |
|
|
|
self._event_loop.call_soon_threadsafe(self._event_loop.stop) |
|
|
|
self._thread_pool.shutdown(wait=True) |
|
|
|
self.__class__._ALL_INSTANCES.discard(self) |
|
|
|
|
|
|
|
def close_sync(self, timeout: float = 5) -> None: |
|
|
|
if not self._event_loop.is_running(): |
|
|
|
logging.warning(f"Event loop already stopped for {self._mcp_server.id}") |
|
|
|
return |
|
|
|
|
|
|
|
MCPToolCallSession._init_thread_pool() |
|
|
|
future = asyncio.run_coroutine_threadsafe(self.close(), self._event_loop) |
|
|
|
try: |
|
|
|
future.result(timeout=timeout) |
|
|
|
except TimeoutError: |
|
|
|
logging.error(f"Timeout while closing session for server {self._mcp_server.id}") |
|
|
|
except Exception: |
|
|
|
logging.exception(f"Unexpected error during close_sync for {self._mcp_server.id}") |
|
|
|
|
|
|
|
|
|
|
|
def close_multiple_mcp_toolcall_sessions(sessions: list[MCPToolCallSession]) -> None: |
|
|
|
async def _gather() -> None: |
|
|
|
await asyncio.gather(*[s.close() for s in sessions], return_exceptions=True) |
|
|
|
logging.info(f"Want to clean up {len(sessions)} MCP sessions") |
|
|
|
|
|
|
|
async def _gather_and_stop() -> None: |
|
|
|
try: |
|
|
|
await asyncio.gather(*[s.close() for s in sessions if s is not None], return_exceptions=True) |
|
|
|
finally: |
|
|
|
loop.call_soon_threadsafe(loop.stop) |
|
|
|
|
|
|
|
loop = asyncio.new_event_loop() |
|
|
|
thread = threading.Thread(target=loop.run_forever, daemon=True) |
|
|
|
thread.start() |
|
|
|
|
|
|
|
asyncio.run_coroutine_threadsafe(_gather_and_stop(), loop).result() |
|
|
|
|
|
|
|
thread.join() |
|
|
|
logging.info(f"{len(sessions)} MCP sessions has been cleaned up. {len(list(MCPToolCallSession._ALL_INSTANCES))} in global context.") |
|
|
|
|
|
|
|
|
|
|
|
def shutdown_all_mcp_sessions(): |
|
|
|
"""Gracefully shutdown all active MCPToolCallSession instances.""" |
|
|
|
sessions = list(MCPToolCallSession._ALL_INSTANCES) |
|
|
|
if not sessions: |
|
|
|
logging.info("No MCPToolCallSession instances to close.") |
|
|
|
return |
|
|
|
|
|
|
|
asyncio.run_coroutine_threadsafe(_gather(), MCPToolCallSession._EVENT_LOOP).result() |
|
|
|
logging.info(f"Shutting down {len(sessions)} MCPToolCallSession instances...") |
|
|
|
close_multiple_mcp_toolcall_sessions(sessions) |
|
|
|
logging.info("All MCPToolCallSession instances have been closed.") |
|
|
|
|
|
|
|
|
|
|
|
def mcp_tool_metadata_to_openai_tool(mcp_tool: Tool) -> dict[str, Any]: |