### What problem does this PR solve? This is a cherry-pick from #7781 as requested. ### Type of change - [x] New Feature (non-breaking change which adds functionality) Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>tags/v0.20.0
| @@ -200,6 +200,7 @@ COPY graphrag graphrag | |||
| COPY agentic_reasoning agentic_reasoning | |||
| COPY pyproject.toml uv.lock ./ | |||
| COPY mcp mcp | |||
| COPY mcp_client mcp_client | |||
| COPY plugin plugin | |||
| COPY docker/service_conf.yaml.template ./conf/service_conf.yaml.template | |||
| @@ -33,6 +33,7 @@ ADD ./rag ./rag | |||
| ADD ./requirements.txt ./requirements.txt | |||
| ADD ./agent ./agent | |||
| ADD ./graphrag ./graphrag | |||
| ADD ./mcp_client ./mcp_client | |||
| ADD ./plugin ./plugin | |||
| RUN dnf install -y openmpi openmpi-devel python3-openmpi | |||
| @@ -0,0 +1,107 @@ | |||
| from flask import Response, request | |||
| from flask_login import current_user, login_required | |||
| from api.db.db_models import MCPServer | |||
| from api.db.services.mcp_server_service import MCPServerService | |||
| from api.db.services.user_service import TenantService | |||
| from api.settings import RetCode | |||
| from api.utils import get_uuid | |||
| from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request | |||
| @manager.route("/list", methods=["GET"]) # noqa: F821 | |||
| @login_required | |||
| def get_list() -> Response: | |||
| try: | |||
| return get_json_result(data=MCPServerService.get_servers(current_user.id) or []) | |||
| except Exception as e: | |||
| return server_error_response(e) | |||
| @manager.route("/get_multiple", methods=["POST"]) # noqa: F821 | |||
| @login_required | |||
| @validate_request("id_list") | |||
| def get_multiple() -> Response: | |||
| req = request.json | |||
| try: | |||
| return get_json_result(data=MCPServerService.get_servers(current_user.id, id_list=req["id_list"]) or []) | |||
| except Exception as e: | |||
| return server_error_response(e) | |||
| @manager.route("/get/<ms_id>", methods=["GET"]) # noqa: F821 | |||
| @login_required | |||
| def get(ms_id: str) -> Response: | |||
| try: | |||
| mcp_server = MCPServerService.get_or_none(id=ms_id, tenant_id=current_user.id) | |||
| if mcp_server is None: | |||
| return get_json_result(code=RetCode.NOT_FOUND, data=None) | |||
| return get_json_result(data=mcp_server.to_dict()) | |||
| except Exception as e: | |||
| return server_error_response(e) | |||
| @manager.route("/create", methods=["POST"]) # noqa: F821 | |||
| @login_required | |||
| @validate_request("name", "url", "server_type") | |||
| def create() -> Response: | |||
| req = request.json | |||
| try: | |||
| req["id"] = get_uuid() | |||
| req["tenant_id"] = current_user.id | |||
| e, _ = TenantService.get_by_id(current_user.id) | |||
| if not e: | |||
| return get_data_error_result(message="Tenant not found.") | |||
| if not req.get("headers"): | |||
| req["headers"] = {} | |||
| if not MCPServerService.insert(**req): | |||
| return get_data_error_result() | |||
| return get_json_result(data={"id": req["id"]}) | |||
| except Exception as e: | |||
| return server_error_response(e) | |||
| @manager.route("/update", methods=["POST"]) # noqa: F821 | |||
| @login_required | |||
| @validate_request("id", "name", "url", "server_type") | |||
| def update() -> Response: | |||
| req = request.json | |||
| if not req.get("headers"): | |||
| req["headers"] = {} | |||
| try: | |||
| req["tenant_id"] = current_user.id | |||
| if not MCPServerService.filter_update([MCPServer.id == req["id"], MCPServer.tenant_id == req["tenant_id"]], req): | |||
| return get_data_error_result() | |||
| return get_json_result(data={"id": req["id"]}) | |||
| except Exception as e: | |||
| return server_error_response(e) | |||
| @manager.route("/rm", methods=["POST"]) # noqa: F821 | |||
| @login_required | |||
| @validate_request("id") | |||
| def rm() -> Response: | |||
| req = request.json | |||
| ms_id = req["id"] | |||
| try: | |||
| req["tenant_id"] = current_user.id | |||
| if not MCPServerService.filter_delete([MCPServer.id == ms_id, MCPServer.tenant_id == req["tenant_id"]]): | |||
| return get_data_error_result() | |||
| return get_json_result(data={"id": req["id"]}) | |||
| except Exception as e: | |||
| return server_error_response(e) | |||
| @@ -104,4 +104,9 @@ class CanvasType(StrEnum): | |||
| ChatBot = "chatbot" | |||
| DocBot = "docbot" | |||
| class MCPServerType(StrEnum): | |||
| SSE = "sse" | |||
| StreamableHttp = "streamable-http" | |||
| KNOWLEDGEBASE_FOLDER_NAME=".knowledgebase" | |||
| @@ -799,6 +799,20 @@ class UserCanvasVersion(DataBaseModel): | |||
| db_table = "user_canvas_version" | |||
| class MCPServer(DataBaseModel): | |||
| id = CharField(max_length=32, primary_key=True) | |||
| name = CharField(max_length=255, null=False, help_text="MCP Server name") | |||
| tenant_id = CharField(max_length=32, null=False, index=True) | |||
| url = CharField(max_length=2048, null=False, help_text="MCP Server URL") | |||
| server_type = CharField(max_length=32, null=False, help_text="MCP Server type") | |||
| description = TextField(null=True, help_text="MCP Server description") | |||
| variables = JSONField(null=True, default=[], help_text="MCP Server variables") | |||
| headers = JSONField(null=True, default={}, help_text="MCP Server additional request headers") | |||
| class Meta: | |||
| db_table = "mcp_server" | |||
| class Search(DataBaseModel): | |||
| id = CharField(max_length=32, primary_key=True) | |||
| avatar = TextField(null=True, help_text="avatar base64 string") | |||
| @@ -934,3 +948,7 @@ def migrate_db(): | |||
| migrate(migrator.add_column("llm", "is_tools", BooleanField(null=False, help_text="support tools", default=False))) | |||
| except Exception: | |||
| pass | |||
| try: | |||
| migrate(migrator.add_column("mcp_server", "variables", JSONField(null=True, help_text="MCP Server variables", default=[]))) | |||
| except Exception: | |||
| pass | |||
| @@ -0,0 +1,61 @@ | |||
| # | |||
| # Copyright 2025 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| from api.db.db_models import DB, MCPServer | |||
| from api.db.services.common_service import CommonService | |||
| class MCPServerService(CommonService): | |||
| """Service class for managing MCP server related database operations. | |||
| This class extends CommonService to provide specialized functionality for MCP server management, | |||
| including MCP server creation, updates, and deletions. | |||
| Attributes: | |||
| model: The MCPServer model class for database operations. | |||
| """ | |||
| model = MCPServer | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def get_servers(cls, tenant_id: str, id_list: list[str] | None = None): | |||
| """Retrieve all MCP servers associated with a tenant. | |||
| This method fetches all MCP servers for a given tenant, ordered by creation time. | |||
| It only includes fields for list display. | |||
| Args: | |||
| tenant_id (str): The unique identifier of the tenant. | |||
| id_list (list[str]): Get servers by ID list. Will ignore this condition if None. | |||
| Returns: | |||
| list[dict]: List of MCP server dictionaries containing MCP server details. | |||
| Returns None if no MCP servers are found. | |||
| """ | |||
| fields = [ | |||
| cls.model.id, cls.model.name, cls.model.server_type, cls.model.url, cls.model.description, | |||
| cls.model.variables, cls.model.update_date | |||
| ] | |||
| servers = cls.model.select(*fields).order_by(cls.model.create_time.desc()).where(cls.model.tenant_id == tenant_id) | |||
| if id_list is not None: | |||
| servers = servers.where(cls.model.id.in_(id_list)) | |||
| servers = list(servers.dicts()) | |||
| if not servers: | |||
| return None | |||
| return servers | |||
| @@ -0,0 +1,23 @@ | |||
| from mcp.server import FastMCP | |||
| app = FastMCP("simple-tools", port=8080) | |||
| @app.tool() | |||
| async def bad_calculator(a: int, b: int) -> str: | |||
| """ | |||
| A calculator to sum up two numbers (will give wrong answer) | |||
| Args: | |||
| a: The first number | |||
| b: The second number | |||
| Returns: | |||
| Sum of a and b | |||
| """ | |||
| return str(a + b + 200) | |||
| if __name__ == "__main__": | |||
| app.run(transport="sse") | |||
| @@ -0,0 +1,2 @@ | |||
| # ruff: noqa: F401 | |||
| from .mcp_tool_call import MCPToolCallSession, mcp_tool_metadata_to_openai_tool, close_multiple_mcp_toolcall_sessions | |||
| @@ -0,0 +1,145 @@ | |||
| import asyncio | |||
| from concurrent.futures import ThreadPoolExecutor | |||
| import logging | |||
| from string import Template | |||
| from typing import Any, Literal | |||
| from typing_extensions import override | |||
| from mcp.client.session import ClientSession | |||
| from mcp.client.sse import sse_client | |||
| from mcp.client.streamable_http import streamablehttp_client | |||
| from mcp.types import CallToolResult, ListToolsResult, TextContent, Tool | |||
| from api.db import MCPServerType | |||
| from rag.llm.chat_model import ToolCallSession | |||
| MCPTaskType = Literal["list_tools", "tool_call", "stop"] | |||
| MCPTask = tuple[MCPTaskType, dict[str, Any], asyncio.Queue[Any]] | |||
| class MCPToolCallSession(ToolCallSession): | |||
| _EVENT_LOOP = asyncio.new_event_loop() | |||
| _THREAD_POOL = ThreadPoolExecutor(max_workers=1) | |||
| _mcp_server: Any | |||
| _server_variables: dict[str, Any] | |||
| _queue: asyncio.Queue[MCPTask] | |||
| _stop = False | |||
| @classmethod | |||
| def _init_thread_pool(cls) -> None: | |||
| cls._THREAD_POOL.submit(cls._EVENT_LOOP.run_forever) | |||
| def __init__(self, mcp_server: Any, server_variables: dict[str, Any] | None = None) -> None: | |||
| self._mcp_server = mcp_server | |||
| self._server_variables = server_variables or {} | |||
| self._queue = asyncio.Queue() | |||
| asyncio.run_coroutine_threadsafe(self._mcp_server_loop(), MCPToolCallSession._EVENT_LOOP) | |||
| async def _mcp_server_loop(self) -> None: | |||
| url = self._mcp_server.url | |||
| raw_headers: dict[str, str] = self._mcp_server.headers or {} | |||
| headers: dict[str, str] = {} | |||
| for h, v in raw_headers.items(): | |||
| nh = Template(h).safe_substitute(self._server_variables) | |||
| nv = Template(v).safe_substitute(self._server_variables) | |||
| headers[nh] = nv | |||
| _streams_source: Any | |||
| if self._mcp_server.server_type == MCPServerType.SSE: | |||
| _streams_source = sse_client(url, headers) | |||
| elif self._mcp_server.server_type == MCPServerType.StreamableHttp: | |||
| _streams_source = streamablehttp_client(url, headers) | |||
| else: | |||
| raise ValueError(f"Unsupported MCP server type {self._mcp_server.server_type} id {self._mcp_server.id}") | |||
| async with _streams_source as streams: | |||
| async with ClientSession(*streams) as client_session: | |||
| await client_session.initialize() | |||
| while not self._stop: | |||
| mcp_task, arguments, result_queue = await self._queue.get() | |||
| logging.debug(f"Got MCP task {mcp_task} arguments {arguments}") | |||
| r: Any | |||
| try: | |||
| if mcp_task == "list_tools": | |||
| r = await client_session.list_tools() | |||
| elif mcp_task == "tool_call": | |||
| r = await client_session.call_tool(**arguments) | |||
| elif mcp_task == "stop": | |||
| logging.debug(f"Shutting down MCPToolCallSession for server {self._mcp_server.id}") | |||
| self._stop = True | |||
| continue | |||
| else: | |||
| r = ValueError(f"MCPToolCallSession for server {self._mcp_server.id} received an unknown task {mcp_task}") | |||
| except Exception as e: | |||
| r = e | |||
| await result_queue.put(r) | |||
| async def _call_mcp_server(self, task_type: MCPTaskType, **kwargs) -> Any: | |||
| results = asyncio.Queue() | |||
| await self._queue.put((task_type, kwargs, results)) | |||
| result: CallToolResult | Exception = await results.get() | |||
| if isinstance(result, Exception): | |||
| raise result | |||
| return result | |||
| async def _call_mcp_tool(self, name: str, arguments: dict[str, Any]) -> str: | |||
| result: CallToolResult = await self._call_mcp_server("tool_call", name=name, arguments=arguments) | |||
| if result.isError: | |||
| return f"MCP server error: {result.content}" | |||
| # For now we only support text content | |||
| if isinstance(result.content[0], TextContent): | |||
| return result.content[0].text | |||
| else: | |||
| return f"Unsupported content type {type(result.content)}" | |||
| async def _get_tools_from_mcp_server(self) -> list[Tool]: | |||
| # For now we only fetch the first page of tools | |||
| result: ListToolsResult = await self._call_mcp_server("list_tools") | |||
| return result.tools | |||
| def get_tools(self) -> list[Tool]: | |||
| return asyncio.run_coroutine_threadsafe(self._get_tools_from_mcp_server(), MCPToolCallSession._EVENT_LOOP).result() | |||
| @override | |||
| def tool_call(self, name: str, arguments: dict[str, Any]) -> str: | |||
| return asyncio.run_coroutine_threadsafe(self._call_mcp_tool(name, arguments), MCPToolCallSession._EVENT_LOOP).result() | |||
| async def close(self) -> None: | |||
| await self._call_mcp_server("stop") | |||
| def close_sync(self) -> None: | |||
| asyncio.run_coroutine_threadsafe(self.close(), MCPToolCallSession._EVENT_LOOP).result() | |||
| MCPToolCallSession._init_thread_pool() | |||
| def close_multiple_mcp_toolcall_sessions(sessions: list[MCPToolCallSession]) -> None: | |||
| async def _gather() -> None: | |||
| await asyncio.gather(*[s.close() for s in sessions], return_exceptions=True) | |||
| asyncio.run_coroutine_threadsafe(_gather(), MCPToolCallSession._EVENT_LOOP).result() | |||
| def mcp_tool_metadata_to_openai_tool(mcp_tool: Tool) -> dict[str, Any]: | |||
| return { | |||
| "type": "function", | |||
| "function": { | |||
| "name": mcp_tool.name, | |||
| "description": mcp_tool.description, | |||
| "parameters": mcp_tool.inputSchema, | |||
| }, | |||
| } | |||
| @@ -61,6 +61,9 @@ class ToolCallSession(Protocol): | |||
| class Base(ABC): | |||
| tools: list[Any] | |||
| toolcall_sessions: dict[str, ToolCallSession] | |||
| def __init__(self, key, model_name, base_url, **kwargs): | |||
| timeout = int(os.environ.get("LM_TIMEOUT_SECONDS", 600)) | |||
| self.client = OpenAI(api_key=key, base_url=base_url, timeout=timeout) | |||
| @@ -70,6 +73,8 @@ class Base(ABC): | |||
| self.base_delay = kwargs.get("retry_interval", float(os.environ.get("LLM_BASE_DELAY", 2.0))) | |||
| self.max_rounds = kwargs.get("max_rounds", 5) | |||
| self.is_tools = False | |||
| self.tools = [] | |||
| self.toolcall_sessions = {} | |||
| def _get_delay(self): | |||
| """Calculate retry delay time""" | |||
| @@ -145,8 +150,10 @@ class Base(ABC): | |||
| if not (toolcall_session and tools): | |||
| return | |||
| self.is_tools = True | |||
| self.toolcall_session = toolcall_session | |||
| self.tools = tools | |||
| for tool in tools: | |||
| self.toolcall_sessions[tool["function"]["name"]] = toolcall_session | |||
| self.tools.append(tool) | |||
| def chat_with_tools(self, system: str, history: list, gen_conf: dict): | |||
| gen_conf = self._clean_conf() | |||
| @@ -180,7 +187,7 @@ class Base(ABC): | |||
| name = tool_call.function.name | |||
| try: | |||
| args = json_repair.loads(tool_call.function.arguments) | |||
| tool_response = self.toolcall_session.tool_call(name, args) | |||
| tool_response = self.toolcall_sessions[name].tool_call(name, args) | |||
| history.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_response)}) | |||
| except Exception as e: | |||
| history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)}) | |||
| @@ -286,7 +293,7 @@ class Base(ABC): | |||
| name = tool_call.function.name | |||
| try: | |||
| args = json_repair.loads(tool_call.function.arguments) | |||
| tool_response = self.toolcall_session.tool_call(name, args) | |||
| tool_response = self.toolcall_sessions[name].tool_call(name, args) | |||
| history.append( | |||
| { | |||
| "role": "assistant", | |||
| @@ -585,7 +592,7 @@ class QWenChat(Base): | |||
| tool_name = assistant_output.tool_calls[0]["function"]["name"] | |||
| if tool_name: | |||
| arguments = json.loads(assistant_output.tool_calls[0]["function"]["arguments"]) | |||
| tool_info["content"] = self.toolcall_session.tool_call(name=tool_name, arguments=arguments) | |||
| tool_info["content"] = self.toolcall_sessions[tool_name].tool_call(name=tool_name, arguments=arguments) | |||
| history.append(tool_info) | |||
| response = Generation.call(self.model_name, messages=history, result_format="message", tools=self.tools, **gen_conf) | |||
| @@ -708,7 +715,7 @@ class QWenChat(Base): | |||
| tool_name = toolcall_message.tool_calls[0]["function"]["name"] | |||
| history.append(toolcall_message) | |||
| tool_info["content"] = self.toolcall_session.tool_call(name=tool_name, arguments=tool_arguments) | |||
| tool_info["content"] = self.toolcall_sessions[tool_name].tool_call(name=tool_name, arguments=tool_arguments) | |||
| history.append(tool_info) | |||
| tool_info = {"content": "", "role": "tool"} | |||
| tool_name = "" | |||
| @@ -1,5 +1,5 @@ | |||
| version = 1 | |||
| revision = 1 | |||
| revision = 2 | |||
| requires-python = ">=3.10, <3.13" | |||
| resolution-markers = [ | |||
| "python_full_version >= '3.12' and sys_platform == 'darwin'", | |||
| @@ -0,0 +1,19 @@ | |||
| export enum McpServerType { | |||
| Sse = 'sse', | |||
| StreamableHttp = 'streamable-http', | |||
| } | |||
| export interface IMcpServerVariable { | |||
| key: string; | |||
| name: string; | |||
| } | |||
| export interface IMcpServerInfo { | |||
| id: string; | |||
| name: string; | |||
| url: string; | |||
| server_type: McpServerType; | |||
| description?: string; | |||
| variables?: IMcpServerVariable[]; | |||
| headers: Map<string, string>; | |||
| } | |||
| @@ -0,0 +1,41 @@ | |||
| import api from '@/utils/api'; | |||
| import registerServer from '@/utils/register-server'; | |||
| import request from '@/utils/request'; | |||
| const { | |||
| getMcpServerList, | |||
| getMultipleMcpServers, | |||
| createMcpServer, | |||
| updateMcpServer, | |||
| deleteMcpServer, | |||
| } = api; | |||
| const methods = { | |||
| get_list: { | |||
| url: getMcpServerList, | |||
| method: 'get', | |||
| }, | |||
| get_multiple: { | |||
| url: getMultipleMcpServers, | |||
| method: 'post', | |||
| }, | |||
| add: { | |||
| url: createMcpServer, | |||
| method: 'post' | |||
| }, | |||
| update: { | |||
| url: updateMcpServer, | |||
| method: 'post' | |||
| }, | |||
| rm: { | |||
| url: deleteMcpServer, | |||
| method: 'post' | |||
| }, | |||
| } as const; | |||
| const mcpServerService = registerServer<keyof typeof methods>(methods, request); | |||
| export const getMcpServer = (serverId: string) => | |||
| request.get(api.getMcpServer(serverId)); | |||
| export default mcpServerService; | |||
| @@ -143,4 +143,12 @@ export default { | |||
| testDbConnect: `${api_host}/canvas/test_db_connect`, | |||
| getInputElements: `${api_host}/canvas/input_elements`, | |||
| debug: `${api_host}/canvas/debug`, | |||
| // mcp server | |||
| getMcpServerList: `${api_host}/mcp_server/list`, | |||
| getMultipleMcpServers: `${api_host}/mcp_server/get_multiple`, | |||
| getMcpServer: (serverId: string) => `${api_host}/mcp_server/get/${serverId}`, | |||
| createMcpServer: `${api_host}/mcp_server/create`, | |||
| updateMcpServer: `${api_host}/mcp_server/update`, | |||
| deleteMcpServer: `${api_host}/mcp_server/rm`, | |||
| }; | |||