瀏覽代碼

Feat: add MCP dashboard functionalities list_tools and test_tool (#8505)

### What problem does this PR solve?

Add MCP dashboard functionalities list_tools and test_tool.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
tags/v0.20.0
Yongteng Lei 4 月之前
父節點
當前提交
0eb90e73a5
No account linked to committer's email address
共有 4 個文件被更改,包括 228 次插入87 次删除
  1. 75
    3
      api/apps/mcp_server_app.py
  2. 2
    0
      api/ragflow_server.py
  3. 21
    21
      api/utils/web_utils.py
  4. 130
    63
      mcp_client/mcp_tool_call.py

+ 75
- 3
api/apps/mcp_server_app.py 查看文件

from api.settings import RetCode from api.settings import RetCode
from api.utils import get_uuid from api.utils import get_uuid
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request
from api.utils.web_utils import safe_json_parse
from api.utils.web_utils import get_float, safe_json_parse
from mcp_client.mcp_tool_call import MCPToolCallSession, close_multiple_mcp_toolcall_sessions




@manager.route("/list", methods=["POST"]) # noqa: F821 @manager.route("/list", methods=["POST"]) # noqa: F821
if server_name and len(server_name.encode("utf-8")) > 255: if server_name and len(server_name.encode("utf-8")) > 255:
return get_data_error_result(message=f"Invaild MCP name or length is {len(server_name)} which is large than 255.") return get_data_error_result(message=f"Invaild MCP name or length is {len(server_name)} which is large than 255.")


req["headers"] = safe_json_parse(req.get("headers", {}))
req["variables"] = safe_json_parse(req.get("variables", {}))
mcp_id = req.get("id", "")
e, mcp_server = MCPServerService.get_by_id(mcp_id)
if not e or mcp_server.tenant_id != current_user.id:
return get_data_error_result(message=f"Cannot find MCP server {mcp_id} for user {current_user.id}")

req["headers"] = safe_json_parse(req.get("headers", mcp_server.headers))
req["variables"] = safe_json_parse(req.get("variables", mcp_server.variables))


try: try:
req["tenant_id"] = current_user.id req["tenant_id"] = current_user.id
return get_json_result(data={"mcpServers": exported_servers}) return get_json_result(data={"mcpServers": exported_servers})
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)


@manager.route("/list_tools", methods=["POST"]) # noqa: F821
@login_required
@validate_request("mcp_ids")
def list_tools() -> Response:
req = request.get_json()
mcp_ids = req.get("mcp_ids", [])
if not mcp_ids:
return get_data_error_result(message="No MCP server IDs provided.")

timeout = get_float(req, "timeout", 10)

results = {}
tool_call_sessions = []
try:
for mcp_id in mcp_ids:
e, mcp_server = MCPServerService.get_by_id(mcp_id)

if e and mcp_server.tenant_id == current_user.id:
server_key = mcp_server.id

tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables)
tool_call_sessions.append(tool_call_session)
tools = tool_call_session.get_tools(timeout)

results[server_key] = [tool.model_dump() for tool in tools]

# PERF: blocking call to close sessions — consider moving to background thread or task queue
close_multiple_mcp_toolcall_sessions(tool_call_sessions)
return get_json_result(data=results)
except Exception as e:
return server_error_response(e)


@manager.route("/test_tool", methods=["POST"]) # noqa: F821
@login_required
@validate_request("mcp_id", "tool_name", "arguments")
def test_tool() -> Response:
req = request.get_json()
mcp_id = req.get("mcp_id", "")
if not mcp_id:
return get_data_error_result(message="No MCP server ID provided.")

timeout = get_float(req, "timeout", 10)

tool_name = req.get("tool_name", "")
arguments = req.get("arguments", {})
if not all([tool_name, arguments]):
return get_data_error_result(message="Require provide tool name and arguments.")

tool_call_sessions = []
try:
e, mcp_server = MCPServerService.get_by_id(mcp_id)
if not e or mcp_server.tenant_id != current_user.id:
return get_data_error_result(message=f"Cannot find MCP server {mcp_id} for user {current_user.id}")

tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables)
tool_call_sessions.append(tool_call_session)
result = tool_call_session.tool_call(tool_name, arguments, timeout)

# PERF: blocking call to close sessions — consider moving to background thread or task queue
close_multiple_mcp_toolcall_sessions(tool_call_sessions)
return get_json_result(data=result)
except Exception as e:
return server_error_response(e)

+ 2
- 0
api/ragflow_server.py 查看文件

# beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code # beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code


from api.utils.log_utils import init_root_logger from api.utils.log_utils import init_root_logger
from mcp_client.mcp_tool_call import shutdown_all_mcp_sessions
from plugin import GlobalPluginManager from plugin import GlobalPluginManager
init_root_logger("ragflow_server") init_root_logger("ragflow_server")




def signal_handler(sig, frame): def signal_handler(sig, frame):
logging.info("Received interrupt signal, shutting down...") logging.info("Received interrupt signal, shutting down...")
shutdown_all_mcp_sessions()
stop_event.set() stop_event.set()
time.sleep(1) time.sleep(1)
sys.exit(0) sys.exit(0)

+ 21
- 21
api/utils/web_utils.py 查看文件

# limitations under the License. # limitations under the License.
# #


import base64
import ipaddress
import json
import re import re
import socket import socket
from urllib.parse import urlparse from urllib.parse import urlparse
import ipaddress
import json
import base64


from selenium import webdriver from selenium import webdriver
from selenium.common.exceptions import TimeoutException
from selenium.webdriver.chrome.options import Options from selenium.webdriver.chrome.options import Options
from selenium.webdriver.chrome.service import Service from selenium.webdriver.chrome.service import Service
from selenium.common.exceptions import TimeoutException
from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.common.by import By
from selenium.webdriver.support.expected_conditions import staleness_of from selenium.webdriver.support.expected_conditions import staleness_of
from selenium.webdriver.support.ui import WebDriverWait
from webdriver_manager.chrome import ChromeDriverManager from webdriver_manager.chrome import ChromeDriverManager
from selenium.webdriver.common.by import By




def html2pdf( def html2pdf(
source: str,
timeout: int = 2,
install_driver: bool = True,
print_options: dict = {},
source: str,
timeout: int = 2,
install_driver: bool = True,
print_options: dict = {},
): ):
result = __get_pdf_from_html(source, timeout, install_driver, print_options) result = __get_pdf_from_html(source, timeout, install_driver, print_options)
return result return result
return response.get("value") return response.get("value")




def __get_pdf_from_html(
path: str,
timeout: int,
install_driver: bool,
print_options: dict
):
def __get_pdf_from_html(path: str, timeout: int, install_driver: bool, print_options: dict):
webdriver_options = Options() webdriver_options = Options()
webdriver_prefs = {} webdriver_prefs = {}
webdriver_options.add_argument("--headless") webdriver_options.add_argument("--headless")
driver.get(path) driver.get(path)


try: try:
WebDriverWait(driver, timeout).until(
staleness_of(driver.find_element(by=By.TAG_NAME, value="html"))
)
WebDriverWait(driver, timeout).until(staleness_of(driver.find_element(by=By.TAG_NAME, value="html")))
except TimeoutException: except TimeoutException:
calculated_print_options = { calculated_print_options = {
"landscape": False, "landscape": False,
"preferCSSPageSize": True, "preferCSSPageSize": True,
} }
calculated_print_options.update(print_options) calculated_print_options.update(print_options)
result = __send_devtools(
driver, "Page.printToPDF", calculated_print_options)
result = __send_devtools(driver, "Page.printToPDF", calculated_print_options)
driver.quit() driver.quit()
return base64.b64decode(result["data"]) return base64.b64decode(result["data"])


except ValueError: except ValueError:
return False return False



def is_valid_url(url: str) -> bool: def is_valid_url(url: str) -> bool:
if not re.match(r"(https?)://[-A-Za-z0-9+&@#/%?=~_|!:,.;]+[-A-Za-z0-9+&@#/%=~_|]", url): if not re.match(r"(https?)://[-A-Za-z0-9+&@#/%?=~_|!:,.;]+[-A-Za-z0-9+&@#/%=~_|]", url):
return False return False
except (json.JSONDecodeError, TypeError): except (json.JSONDecodeError, TypeError):
return {} return {}



def get_float(req: dict, key: str, default: float | int = 10.0) -> float:
try:
parsed = float(req.get(key, default))
return parsed if parsed > 0 else default
except (TypeError, ValueError):
return default

+ 130
- 63
mcp_client/mcp_tool_call.py 查看文件

import asyncio import asyncio
from concurrent.futures import ThreadPoolExecutor
import logging import logging
import threading
import weakref
from concurrent.futures import ThreadPoolExecutor
from string import Template from string import Template
from typing import Any, Literal from typing import Any, Literal

from typing_extensions import override from typing_extensions import override


from api.db import MCPServerType
from mcp.client.session import ClientSession from mcp.client.session import ClientSession
from mcp.client.sse import sse_client from mcp.client.sse import sse_client
from mcp.client.streamable_http import streamablehttp_client from mcp.client.streamable_http import streamablehttp_client
from mcp.types import CallToolResult, ListToolsResult, TextContent, Tool from mcp.types import CallToolResult, ListToolsResult, TextContent, Tool

from api.db import MCPServerType
from rag.llm.chat_model import ToolCallSession from rag.llm.chat_model import ToolCallSession



MCPTaskType = Literal["list_tools", "tool_call", "stop"]
MCPTaskType = Literal["list_tools", "tool_call"]
MCPTask = tuple[MCPTaskType, dict[str, Any], asyncio.Queue[Any]] MCPTask = tuple[MCPTaskType, dict[str, Any], asyncio.Queue[Any]]




class MCPToolCallSession(ToolCallSession): class MCPToolCallSession(ToolCallSession):
_EVENT_LOOP = asyncio.new_event_loop()
_THREAD_POOL = ThreadPoolExecutor(max_workers=1)

_mcp_server: Any
_server_variables: dict[str, Any]
_queue: asyncio.Queue[MCPTask]
_stop = False

@classmethod
def _init_thread_pool(cls) -> None:
cls._THREAD_POOL.submit(cls._EVENT_LOOP.run_forever)
_ALL_INSTANCES: weakref.WeakSet["MCPToolCallSession"] = weakref.WeakSet()


def __init__(self, mcp_server: Any, server_variables: dict[str, Any] | None = None) -> None: def __init__(self, mcp_server: Any, server_variables: dict[str, Any] | None = None) -> None:
self.__class__._ALL_INSTANCES.add(self)

self._mcp_server = mcp_server self._mcp_server = mcp_server
self._server_variables = server_variables or {} self._server_variables = server_variables or {}
self._queue = asyncio.Queue() self._queue = asyncio.Queue()
self._close = False

self._event_loop = asyncio.new_event_loop()
self._thread_pool = ThreadPoolExecutor(max_workers=1)
self._thread_pool.submit(self._event_loop.run_forever)


asyncio.run_coroutine_threadsafe(self._mcp_server_loop(), MCPToolCallSession._EVENT_LOOP)
asyncio.run_coroutine_threadsafe(self._mcp_server_loop(), self._event_loop)


async def _mcp_server_loop(self) -> None: async def _mcp_server_loop(self) -> None:
url = self._mcp_server.url
url = self._mcp_server.url.strip()
raw_headers: dict[str, str] = self._mcp_server.headers or {} raw_headers: dict[str, str] = self._mcp_server.headers or {}
headers: dict[str, str] = {} headers: dict[str, str] = {}


nv = Template(v).safe_substitute(self._server_variables) nv = Template(v).safe_substitute(self._server_variables)
headers[nh] = nv headers[nh] = nv


_streams_source: Any

if self._mcp_server.server_type == MCPServerType.SSE: if self._mcp_server.server_type == MCPServerType.SSE:
_streams_source = sse_client(url, headers)
elif self._mcp_server.server_type == MCPServerType.StreamableHttp:
_streams_source = streamablehttp_client(url, headers)
# SSE transport
async with sse_client(url, headers) as stream:
async with ClientSession(*stream) as client_session:
try:
await asyncio.wait_for(client_session.initialize(), timeout=5)
logging.info("client_session initialized successfully")
except asyncio.TimeoutError:
logging.error(f"Timeout initializing client_session for server {self._mcp_server.id}")
return
await self._process_mcp_tasks(client_session)

elif self._mcp_server.server_type == MCPServerType.STREAMABLE_HTTP:
# Streamable HTTP transport
async with streamablehttp_client(url, headers) as (read_stream, write_stream, _):
async with ClientSession(read_stream, write_stream) as client_session:
try:
await asyncio.wait_for(client_session.initialize(), timeout=5)
logging.info("client_session initialized successfully")
except asyncio.TimeoutError:
logging.error(f"Timeout initializing client_session for server {self._mcp_server.id}")
return
await asyncio.wait_for(client_session.initialize(), timeout=5)
await self._process_mcp_tasks(client_session)
else: else:
raise ValueError(f"Unsupported MCP server type {self._mcp_server.server_type} id {self._mcp_server.id}") raise ValueError(f"Unsupported MCP server type {self._mcp_server.server_type} id {self._mcp_server.id}")


async with _streams_source as streams:
async with ClientSession(*streams) as client_session:
await client_session.initialize()

while not self._stop:
mcp_task, arguments, result_queue = await self._queue.get()
logging.debug(f"Got MCP task {mcp_task} arguments {arguments}")

r: Any

try:
if mcp_task == "list_tools":
r = await client_session.list_tools()
elif mcp_task == "tool_call":
r = await client_session.call_tool(**arguments)
elif mcp_task == "stop":
logging.debug(f"Shutting down MCPToolCallSession for server {self._mcp_server.id}")
self._stop = True
continue
else:
r = ValueError(f"MCPToolCallSession for server {self._mcp_server.id} received an unknown task {mcp_task}")
except Exception as e:
r = e

await result_queue.put(r)

async def _call_mcp_server(self, task_type: MCPTaskType, **kwargs) -> Any:
async def _process_mcp_tasks(self, client_session: ClientSession) -> None:
while not self._close:
try:
mcp_task, arguments, result_queue = await asyncio.wait_for(self._queue.get(), timeout=1)
except asyncio.TimeoutError:
continue

logging.debug(f"Got MCP task {mcp_task} arguments {arguments}")

r: Any = None
try:
if mcp_task == "list_tools":
r = await client_session.list_tools()
elif mcp_task == "tool_call":
r = await client_session.call_tool(**arguments)
else:
r = ValueError(f"Unknown MCP task {mcp_task}")
except Exception as e:
r = e

await result_queue.put(r)

async def _call_mcp_server(self, task_type: MCPTaskType, timeout: float = 8, **kwargs) -> Any:
results = asyncio.Queue() results = asyncio.Queue()
await self._queue.put((task_type, kwargs, results)) await self._queue.put((task_type, kwargs, results))
result: CallToolResult | Exception = await results.get()
try:
result: CallToolResult | Exception = await asyncio.wait_for(results.get(), timeout=timeout)
except asyncio.TimeoutError:
raise TimeoutError(f"MCP task '{task_type}' timeout after {timeout}s")


if isinstance(result, Exception): if isinstance(result, Exception):
raise result raise result
return f"Unsupported content type {type(result.content)}" return f"Unsupported content type {type(result.content)}"


async def _get_tools_from_mcp_server(self) -> list[Tool]: async def _get_tools_from_mcp_server(self) -> list[Tool]:
# For now we only fetch the first page of tools
result: ListToolsResult = await self._call_mcp_server("list_tools") result: ListToolsResult = await self._call_mcp_server("list_tools")
return result.tools return result.tools


def get_tools(self) -> list[Tool]:
return asyncio.run_coroutine_threadsafe(self._get_tools_from_mcp_server(), MCPToolCallSession._EVENT_LOOP).result()
def get_tools(self, timeout: float = 10) -> list[Tool]:
future = asyncio.run_coroutine_threadsafe(self._get_tools_from_mcp_server(), self._event_loop)
try:
return future.result(timeout=timeout)
except TimeoutError:
logging.error(f"Timeout when fetching tools from MCP server: {self._mcp_server.id}")
return []
except Exception:
logging.exception(f"Error fetching tools from MCP server: {self._mcp_server.id}")
return []


@override @override
def tool_call(self, name: str, arguments: dict[str, Any]) -> str:
return asyncio.run_coroutine_threadsafe(self._call_mcp_tool(name, arguments), MCPToolCallSession._EVENT_LOOP).result()
def tool_call(self, name: str, arguments: dict[str, Any], timeout: float = 10) -> str:
future = asyncio.run_coroutine_threadsafe(self._call_mcp_tool(name, arguments), self._event_loop)
try:
return future.result(timeout=timeout)
except TimeoutError as te:
logging.error(f"Timeout calling tool '{name}' on MCP server: {self._mcp_server.id}")
return f"Timeout calling tool '{name}': {te}."
except Exception as e:
logging.exception(f"Error calling tool '{name}' on MCP server: {self._mcp_server.id}")
return f"Error calling tool '{name}': {e}."


async def close(self) -> None: async def close(self) -> None:
await self._call_mcp_server("stop")
if self._close:
return


def close_sync(self) -> None:
asyncio.run_coroutine_threadsafe(self.close(), MCPToolCallSession._EVENT_LOOP).result()
self._close = True
self._event_loop.call_soon_threadsafe(self._event_loop.stop)
self._thread_pool.shutdown(wait=True)
self.__class__._ALL_INSTANCES.discard(self)


def close_sync(self, timeout: float = 5) -> None:
if not self._event_loop.is_running():
logging.warning(f"Event loop already stopped for {self._mcp_server.id}")
return


MCPToolCallSession._init_thread_pool()
future = asyncio.run_coroutine_threadsafe(self.close(), self._event_loop)
try:
future.result(timeout=timeout)
except TimeoutError:
logging.error(f"Timeout while closing session for server {self._mcp_server.id}")
except Exception:
logging.exception(f"Unexpected error during close_sync for {self._mcp_server.id}")




def close_multiple_mcp_toolcall_sessions(sessions: list[MCPToolCallSession]) -> None: def close_multiple_mcp_toolcall_sessions(sessions: list[MCPToolCallSession]) -> None:
async def _gather() -> None:
await asyncio.gather(*[s.close() for s in sessions], return_exceptions=True)
logging.info(f"Want to clean up {len(sessions)} MCP sessions")

async def _gather_and_stop() -> None:
try:
await asyncio.gather(*[s.close() for s in sessions if s is not None], return_exceptions=True)
finally:
loop.call_soon_threadsafe(loop.stop)

loop = asyncio.new_event_loop()
thread = threading.Thread(target=loop.run_forever, daemon=True)
thread.start()

asyncio.run_coroutine_threadsafe(_gather_and_stop(), loop).result()

thread.join()
logging.info(f"{len(sessions)} MCP sessions has been cleaned up. {len(list(MCPToolCallSession._ALL_INSTANCES))} in global context.")


def shutdown_all_mcp_sessions():
"""Gracefully shutdown all active MCPToolCallSession instances."""
sessions = list(MCPToolCallSession._ALL_INSTANCES)
if not sessions:
logging.info("No MCPToolCallSession instances to close.")
return


asyncio.run_coroutine_threadsafe(_gather(), MCPToolCallSession._EVENT_LOOP).result()
logging.info(f"Shutting down {len(sessions)} MCPToolCallSession instances...")
close_multiple_mcp_toolcall_sessions(sessions)
logging.info("All MCPToolCallSession instances have been closed.")




def mcp_tool_metadata_to_openai_tool(mcp_tool: Tool) -> dict[str, Any]: def mcp_tool_metadata_to_openai_tool(mcp_tool: Tool) -> dict[str, Any]:

Loading…
取消
儲存