| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456 |
- #
- # Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
-
- import json
- import logging
- from collections.abc import AsyncIterator
- from contextlib import asynccontextmanager
- from functools import wraps
-
- import click
- import requests
- from starlette.applications import Starlette
- from starlette.middleware import Middleware
- from starlette.responses import JSONResponse, Response
- from starlette.routing import Mount, Route
- from strenum import StrEnum
-
- import mcp.types as types
- from mcp.server.lowlevel import Server
-
-
- class LaunchMode(StrEnum):
- SELF_HOST = "self-host"
- HOST = "host"
-
-
- class Transport(StrEnum):
- SSE = "sse"
- STEAMABLE_HTTP = "streamable-http"
-
-
- BASE_URL = "http://127.0.0.1:9380"
- HOST = "127.0.0.1"
- PORT = "9382"
- HOST_API_KEY = ""
- MODE = ""
- TRANSPORT_SSE_ENABLED = True
- TRANSPORT_STREAMABLE_HTTP_ENABLED = True
- JSON_RESPONSE = True
-
-
- class RAGFlowConnector:
- def __init__(self, base_url: str, version="v1"):
- self.base_url = base_url
- self.version = version
- self.api_url = f"{self.base_url}/api/{self.version}"
-
- def bind_api_key(self, api_key: str):
- self.api_key = api_key
- self.authorization_header = {"Authorization": "{} {}".format("Bearer", self.api_key)}
-
- def _post(self, path, json=None, stream=False, files=None):
- if not self.api_key:
- return None
- res = requests.post(url=self.api_url + path, json=json, headers=self.authorization_header, stream=stream, files=files)
- return res
-
- def _get(self, path, params=None, json=None):
- res = requests.get(url=self.api_url + path, params=params, headers=self.authorization_header, json=json)
- return res
-
- def list_datasets(self, page: int = 1, page_size: int = 1000, orderby: str = "create_time", desc: bool = True, id: str | None = None, name: str | None = None):
- res = self._get("/datasets", {"page": page, "page_size": page_size, "orderby": orderby, "desc": desc, "id": id, "name": name})
- if not res:
- raise Exception([types.TextContent(type="text", text=res.get("Cannot process this operation."))])
-
- res = res.json()
- if res.get("code") == 0:
- result_list = []
- for data in res["data"]:
- d = {"description": data["description"], "id": data["id"]}
- result_list.append(json.dumps(d, ensure_ascii=False))
- return "\n".join(result_list)
- return ""
-
- def retrieval(
- self, dataset_ids, document_ids=None, question="", page=1, page_size=30, similarity_threshold=0.2, vector_similarity_weight=0.3, top_k=1024, rerank_id: str | None = None, keyword: bool = False
- ):
- if document_ids is None:
- document_ids = []
- data_json = {
- "page": page,
- "page_size": page_size,
- "similarity_threshold": similarity_threshold,
- "vector_similarity_weight": vector_similarity_weight,
- "top_k": top_k,
- "rerank_id": rerank_id,
- "keyword": keyword,
- "question": question,
- "dataset_ids": dataset_ids,
- "document_ids": document_ids,
- }
- # Send a POST request to the backend service (using requests library as an example, actual implementation may vary)
- res = self._post("/retrieval", json=data_json)
- if not res:
- raise Exception([types.TextContent(type="text", text=res.get("Cannot process this operation."))])
-
- res = res.json()
- if res.get("code") == 0:
- chunks = []
- for chunk_data in res["data"].get("chunks"):
- chunks.append(json.dumps(chunk_data, ensure_ascii=False))
- return [types.TextContent(type="text", text="\n".join(chunks))]
- raise Exception([types.TextContent(type="text", text=res.get("message"))])
-
-
- class RAGFlowCtx:
- def __init__(self, connector: RAGFlowConnector):
- self.conn = connector
-
-
- @asynccontextmanager
- async def sse_lifespan(server: Server) -> AsyncIterator[dict]:
- ctx = RAGFlowCtx(RAGFlowConnector(base_url=BASE_URL))
-
- logging.info("Legacy SSE application started with StreamableHTTP session manager!")
- try:
- yield {"ragflow_ctx": ctx}
- finally:
- logging.info("Legacy SSE application shutting down...")
-
-
- app = Server("ragflow-mcp-server", lifespan=sse_lifespan)
-
-
- def with_api_key(required=True):
- def decorator(func):
- @wraps(func)
- async def wrapper(*args, **kwargs):
- ctx = app.request_context
- ragflow_ctx = ctx.lifespan_context.get("ragflow_ctx")
- if not ragflow_ctx:
- raise ValueError("Get RAGFlow Context failed")
-
- connector = ragflow_ctx.conn
-
- if MODE == LaunchMode.HOST:
- headers = ctx.session._init_options.capabilities.experimental.get("headers", {})
- token = None
-
- # lower case here, because of Starlette conversion
- auth = headers.get("authorization", "")
- if auth.startswith("Bearer "):
- token = auth.removeprefix("Bearer ").strip()
- elif "api_key" in headers:
- token = headers["api_key"]
-
- if required and not token:
- raise ValueError("RAGFlow API key or Bearer token is required.")
-
- connector.bind_api_key(token)
- else:
- connector.bind_api_key(HOST_API_KEY)
-
- return await func(*args, connector=connector, **kwargs)
-
- return wrapper
-
- return decorator
-
-
- @app.list_tools()
- @with_api_key(required=True)
- async def list_tools(*, connector) -> list[types.Tool]:
- dataset_description = connector.list_datasets()
-
- return [
- types.Tool(
- name="ragflow_retrieval",
- description="Retrieve relevant chunks from the RAGFlow retrieve interface based on the question. You can optionally specify dataset_ids to search only specific datasets, or omit dataset_ids entirely to search across ALL available datasets. You can also optionally specify document_ids to search within specific documents. When dataset_ids is not provided or is empty, the system will automatically search across all available datasets. Below is the list of all available datasets, including their descriptions and IDs:"
- + dataset_description,
- inputSchema={
- "type": "object",
- "properties": {
- "dataset_ids": {
- "type": "array",
- "items": {"type": "string"},
- "description": "Optional array of dataset IDs to search. If not provided or empty, all datasets will be searched."
- },
- "document_ids": {
- "type": "array",
- "items": {"type": "string"},
- "description": "Optional array of document IDs to search within."
- },
- "question": {"type": "string", "description": "The question or query to search for."},
- },
- "required": ["question"],
- },
- ),
- ]
-
-
- @app.call_tool()
- @with_api_key(required=True)
- async def call_tool(name: str, arguments: dict, *, connector) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
- if name == "ragflow_retrieval":
- document_ids = arguments.get("document_ids", [])
- dataset_ids = arguments.get("dataset_ids", [])
-
- # If no dataset_ids provided or empty list, get all available dataset IDs
- if not dataset_ids:
- dataset_list_str = connector.list_datasets()
- dataset_ids = []
-
- # Parse the dataset list to extract IDs
- if dataset_list_str:
- for line in dataset_list_str.strip().split('\n'):
- if line.strip():
- try:
- dataset_info = json.loads(line.strip())
- dataset_ids.append(dataset_info["id"])
- except (json.JSONDecodeError, KeyError):
- # Skip malformed lines
- continue
-
- return connector.retrieval(
- dataset_ids=dataset_ids,
- document_ids=document_ids,
- question=arguments["question"],
- )
- raise ValueError(f"Tool not found: {name}")
-
-
- def create_starlette_app():
- routes = []
- middleware = None
- if MODE == LaunchMode.HOST:
- from starlette.types import ASGIApp, Receive, Scope, Send
-
- class AuthMiddleware:
- def __init__(self, app: ASGIApp):
- self.app = app
-
- async def __call__(self, scope: Scope, receive: Receive, send: Send):
- if scope["type"] != "http":
- await self.app(scope, receive, send)
- return
-
- path = scope["path"]
- if path.startswith("/messages/") or path.startswith("/sse") or path.startswith("/mcp"):
- headers = dict(scope["headers"])
- token = None
- auth_header = headers.get(b"authorization")
- if auth_header and auth_header.startswith(b"Bearer "):
- token = auth_header.removeprefix(b"Bearer ").strip()
- elif b"api_key" in headers:
- token = headers[b"api_key"]
-
- if not token:
- response = JSONResponse({"error": "Missing or invalid authorization header"}, status_code=401)
- await response(scope, receive, send)
- return
-
- await self.app(scope, receive, send)
-
- middleware = [Middleware(AuthMiddleware)]
-
- # Add SSE routes if enabled
- if TRANSPORT_SSE_ENABLED:
- from mcp.server.sse import SseServerTransport
-
- sse = SseServerTransport("/messages/")
-
- async def handle_sse(request):
- async with sse.connect_sse(request.scope, request.receive, request._send) as streams:
- await app.run(streams[0], streams[1], app.create_initialization_options(experimental_capabilities={"headers": dict(request.headers)}))
- return Response()
-
- routes.extend(
- [
- Route("/sse", endpoint=handle_sse, methods=["GET"]),
- Mount("/messages/", app=sse.handle_post_message),
- ]
- )
-
- # Add streamable HTTP route if enabled
- streamablehttp_lifespan = None
- if TRANSPORT_STREAMABLE_HTTP_ENABLED:
- from starlette.types import Receive, Scope, Send
-
- from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
-
- session_manager = StreamableHTTPSessionManager(
- app=app,
- event_store=None,
- json_response=JSON_RESPONSE,
- stateless=True,
- )
-
- async def handle_streamable_http(scope: Scope, receive: Receive, send: Send) -> None:
- await session_manager.handle_request(scope, receive, send)
-
- @asynccontextmanager
- async def streamablehttp_lifespan(app: Starlette) -> AsyncIterator[None]:
- async with session_manager.run():
- logging.info("StreamableHTTP application started with StreamableHTTP session manager!")
- try:
- yield
- finally:
- logging.info("StreamableHTTP application shutting down...")
-
- routes.append(Mount("/mcp", app=handle_streamable_http))
-
- return Starlette(
- debug=True,
- routes=routes,
- middleware=middleware,
- lifespan=streamablehttp_lifespan,
- )
-
-
- @click.command()
- @click.option("--base-url", type=str, default="http://127.0.0.1:9380", help="API base URL for RAGFlow backend")
- @click.option("--host", type=str, default="127.0.0.1", help="Host to bind the RAGFlow MCP server")
- @click.option("--port", type=int, default=9382, help="Port to bind the RAGFlow MCP server")
- @click.option(
- "--mode",
- type=click.Choice(["self-host", "host"]),
- default="self-host",
- help=("Launch mode:\n self-host: run MCP for a single tenant (requires --api-key)\n host: multi-tenant mode, users must provide Authorization headers"),
- )
- @click.option("--api-key", type=str, default="", help="API key to use when in self-host mode")
- @click.option(
- "--transport-sse-enabled/--no-transport-sse-enabled",
- default=True,
- help="Enable or disable legacy SSE transport mode (default: enabled)",
- )
- @click.option(
- "--transport-streamable-http-enabled/--no-transport-streamable-http-enabled",
- default=True,
- help="Enable or disable streamable-http transport mode (default: enabled)",
- )
- @click.option(
- "--json-response/--no-json-response",
- default=True,
- help="Enable or disable JSON response mode for streamable-http (default: enabled)",
- )
- def main(base_url, host, port, mode, api_key, transport_sse_enabled, transport_streamable_http_enabled, json_response):
- import os
-
- import uvicorn
- from dotenv import load_dotenv
-
- load_dotenv()
-
- def parse_bool_flag(key: str, default: bool) -> bool:
- val = os.environ.get(key, str(default))
- return str(val).strip().lower() in ("1", "true", "yes", "on")
-
- global BASE_URL, HOST, PORT, MODE, HOST_API_KEY, TRANSPORT_SSE_ENABLED, TRANSPORT_STREAMABLE_HTTP_ENABLED, JSON_RESPONSE
- BASE_URL = os.environ.get("RAGFLOW_MCP_BASE_URL", base_url)
- HOST = os.environ.get("RAGFLOW_MCP_HOST", host)
- PORT = os.environ.get("RAGFLOW_MCP_PORT", str(port))
- MODE = os.environ.get("RAGFLOW_MCP_LAUNCH_MODE", mode)
- HOST_API_KEY = os.environ.get("RAGFLOW_MCP_HOST_API_KEY", api_key)
- TRANSPORT_SSE_ENABLED = parse_bool_flag("RAGFLOW_MCP_TRANSPORT_SSE_ENABLED", transport_sse_enabled)
- TRANSPORT_STREAMABLE_HTTP_ENABLED = parse_bool_flag("RAGFLOW_MCP_TRANSPORT_STREAMABLE_ENABLED", transport_streamable_http_enabled)
- JSON_RESPONSE = parse_bool_flag("RAGFLOW_MCP_JSON_RESPONSE", json_response)
-
- if MODE == LaunchMode.SELF_HOST and not HOST_API_KEY:
- raise click.UsageError("--api-key is required when --mode is 'self-host'")
-
- if TRANSPORT_STREAMABLE_HTTP_ENABLED and MODE == LaunchMode.HOST:
- raise click.UsageError("The --host mode is not supported with streamable-http transport yet.")
-
- if not TRANSPORT_STREAMABLE_HTTP_ENABLED and JSON_RESPONSE:
- JSON_RESPONSE = False
-
- print(
- r"""
- __ __ ____ ____ ____ _____ ______ _______ ____
- | \/ |/ ___| _ \ / ___|| ____| _ \ \ / / ____| _ \
- | |\/| | | | |_) | \___ \| _| | |_) \ \ / /| _| | |_) |
- | | | | |___| __/ ___) | |___| _ < \ V / | |___| _ <
- |_| |_|\____|_| |____/|_____|_| \_\ \_/ |_____|_| \_\
- """,
- flush=True,
- )
- print(f"MCP launch mode: {MODE}", flush=True)
- print(f"MCP host: {HOST}", flush=True)
- print(f"MCP port: {PORT}", flush=True)
- print(f"MCP base_url: {BASE_URL}", flush=True)
-
- if not any([TRANSPORT_SSE_ENABLED, TRANSPORT_STREAMABLE_HTTP_ENABLED]):
- print("At least one transport should be enabled, enable streamable-http automatically", flush=True)
- TRANSPORT_STREAMABLE_HTTP_ENABLED = True
-
- if TRANSPORT_SSE_ENABLED:
- print("SSE transport enabled: yes", flush=True)
- print("SSE endpoint available at /sse", flush=True)
- else:
- print("SSE transport enabled: no", flush=True)
-
- if TRANSPORT_STREAMABLE_HTTP_ENABLED:
- print("Streamable HTTP transport enabled: yes", flush=True)
- print("Streamable HTTP endpoint available at /mcp", flush=True)
- if JSON_RESPONSE:
- print("Streamable HTTP mode: JSON response enabled", flush=True)
- else:
- print("Streamable HTTP mode: SSE over HTTP enabled", flush=True)
- else:
- print("Streamable HTTP transport enabled: no", flush=True)
- if JSON_RESPONSE:
- print("Warning: --json-response ignored because streamable transport is disabled.", flush=True)
-
- uvicorn.run(
- create_starlette_app(),
- host=HOST,
- port=int(PORT),
- )
-
-
- if __name__ == "__main__":
- """
- Launch examples:
-
- 1. Self-host mode with both SSE and Streamable HTTP (in JSON response mode) enabled (default):
- uv run mcp/server/server.py --host=127.0.0.1 --port=9382 \
- --base-url=http://127.0.0.1:9380 \
- --mode=self-host --api-key=ragflow-xxxxx
-
- 2. Host mode (multi-tenant, self-host only, clients must provide Authorization headers):
- uv run mcp/server/server.py --host=127.0.0.1 --port=9382 \
- --base-url=http://127.0.0.1:9380 \
- --mode=host
-
- 3. Disable legacy SSE (only streamable HTTP will be active):
- uv run mcp/server/server.py --no-transport-sse-enabled \
- --mode=self-host --api-key=ragflow-xxxxx
-
- 4. Disable streamable HTTP (only legacy SSE will be active):
- uv run mcp/server/server.py --no-transport-streamable-http-enabled \
- --mode=self-host --api-key=ragflow-xxxxx
-
- 5. Use streamable HTTP with SSE-style events (disable JSON response):
- uv run mcp/server/server.py --transport-streamable-http-enabled --no-json-response \
- --mode=self-host --api-key=ragflow-xxxxx
-
- 6. Disable both transports (for testing):
- uv run mcp/server/server.py --no-transport-sse-enabled --no-transport-streamable-http-enabled \
- --mode=self-host --api-key=ragflow-xxxxx
- """
- main()
|