- #
 - #  Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
 - #
 - #  Licensed under the Apache License, Version 2.0 (the "License");
 - #  you may not use this file except in compliance with the License.
 - #  You may obtain a copy of the License at
 - #
 - #      http://www.apache.org/licenses/LICENSE-2.0
 - #
 - #  Unless required by applicable law or agreed to in writing, software
 - #  distributed under the License is distributed on an "AS IS" BASIS,
 - #  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 - #  See the License for the specific language governing permissions and
 - #  limitations under the License.
 - #
 - 
 - import json
 - import logging
 - from collections.abc import AsyncIterator
 - from contextlib import asynccontextmanager
 - from functools import wraps
 - 
 - import click
 - import requests
 - from starlette.applications import Starlette
 - from starlette.middleware import Middleware
 - from starlette.responses import JSONResponse, Response
 - from starlette.routing import Mount, Route
 - from strenum import StrEnum
 - 
 - import mcp.types as types
 - from mcp.server.lowlevel import Server
 - 
 - 
 - class LaunchMode(StrEnum):
 -     SELF_HOST = "self-host"
 -     HOST = "host"
 - 
 - 
 - class Transport(StrEnum):
 -     SSE = "sse"
 -     STEAMABLE_HTTP = "streamable-http"
 - 
 - 
 - BASE_URL = "http://127.0.0.1:9380"
 - HOST = "127.0.0.1"
 - PORT = "9382"
 - HOST_API_KEY = ""
 - MODE = ""
 - TRANSPORT_SSE_ENABLED = True
 - TRANSPORT_STREAMABLE_HTTP_ENABLED = True
 - JSON_RESPONSE = True
 - 
 - 
 - class RAGFlowConnector:
 -     def __init__(self, base_url: str, version="v1"):
 -         self.base_url = base_url
 -         self.version = version
 -         self.api_url = f"{self.base_url}/api/{self.version}"
 - 
 -     def bind_api_key(self, api_key: str):
 -         self.api_key = api_key
 -         self.authorization_header = {"Authorization": "{} {}".format("Bearer", self.api_key)}
 - 
 -     def _post(self, path, json=None, stream=False, files=None):
 -         if not self.api_key:
 -             return None
 -         res = requests.post(url=self.api_url + path, json=json, headers=self.authorization_header, stream=stream, files=files)
 -         return res
 - 
 -     def _get(self, path, params=None, json=None):
 -         res = requests.get(url=self.api_url + path, params=params, headers=self.authorization_header, json=json)
 -         return res
 - 
 -     def list_datasets(self, page: int = 1, page_size: int = 1000, orderby: str = "create_time", desc: bool = True, id: str | None = None, name: str | None = None):
 -         res = self._get("/datasets", {"page": page, "page_size": page_size, "orderby": orderby, "desc": desc, "id": id, "name": name})
 -         if not res:
 -             raise Exception([types.TextContent(type="text", text=res.get("Cannot process this operation."))])
 - 
 -         res = res.json()
 -         if res.get("code") == 0:
 -             result_list = []
 -             for data in res["data"]:
 -                 d = {"description": data["description"], "id": data["id"]}
 -                 result_list.append(json.dumps(d, ensure_ascii=False))
 -             return "\n".join(result_list)
 -         return ""
 - 
 -     def retrieval(
 -         self, dataset_ids, document_ids=None, question="", page=1, page_size=30, similarity_threshold=0.2, vector_similarity_weight=0.3, top_k=1024, rerank_id: str | None = None, keyword: bool = False
 -     ):
 -         if document_ids is None:
 -             document_ids = []
 -         data_json = {
 -             "page": page,
 -             "page_size": page_size,
 -             "similarity_threshold": similarity_threshold,
 -             "vector_similarity_weight": vector_similarity_weight,
 -             "top_k": top_k,
 -             "rerank_id": rerank_id,
 -             "keyword": keyword,
 -             "question": question,
 -             "dataset_ids": dataset_ids,
 -             "document_ids": document_ids,
 -         }
 -         # Send a POST request to the backend service (using requests library as an example, actual implementation may vary)
 -         res = self._post("/retrieval", json=data_json)
 -         if not res:
 -             raise Exception([types.TextContent(type="text", text=res.get("Cannot process this operation."))])
 - 
 -         res = res.json()
 -         if res.get("code") == 0:
 -             chunks = []
 -             for chunk_data in res["data"].get("chunks"):
 -                 chunks.append(json.dumps(chunk_data, ensure_ascii=False))
 -             return [types.TextContent(type="text", text="\n".join(chunks))]
 -         raise Exception([types.TextContent(type="text", text=res.get("message"))])
 - 
 - 
 - class RAGFlowCtx:
 -     def __init__(self, connector: RAGFlowConnector):
 -         self.conn = connector
 - 
 - 
 - @asynccontextmanager
 - async def sse_lifespan(server: Server) -> AsyncIterator[dict]:
 -     ctx = RAGFlowCtx(RAGFlowConnector(base_url=BASE_URL))
 - 
 -     logging.info("Legacy SSE application started with StreamableHTTP session manager!")
 -     try:
 -         yield {"ragflow_ctx": ctx}
 -     finally:
 -         logging.info("Legacy SSE application shutting down...")
 - 
 - 
 - app = Server("ragflow-mcp-server", lifespan=sse_lifespan)
 - 
 - 
 - def with_api_key(required=True):
 -     def decorator(func):
 -         @wraps(func)
 -         async def wrapper(*args, **kwargs):
 -             ctx = app.request_context
 -             ragflow_ctx = ctx.lifespan_context.get("ragflow_ctx")
 -             if not ragflow_ctx:
 -                 raise ValueError("Get RAGFlow Context failed")
 - 
 -             connector = ragflow_ctx.conn
 - 
 -             if MODE == LaunchMode.HOST:
 -                 headers = ctx.session._init_options.capabilities.experimental.get("headers", {})
 -                 token = None
 - 
 -                 # lower case here, because of Starlette conversion
 -                 auth = headers.get("authorization", "")
 -                 if auth.startswith("Bearer "):
 -                     token = auth.removeprefix("Bearer ").strip()
 -                 elif "api_key" in headers:
 -                     token = headers["api_key"]
 - 
 -                 if required and not token:
 -                     raise ValueError("RAGFlow API key or Bearer token is required.")
 - 
 -                 connector.bind_api_key(token)
 -             else:
 -                 connector.bind_api_key(HOST_API_KEY)
 - 
 -             return await func(*args, connector=connector, **kwargs)
 - 
 -         return wrapper
 - 
 -     return decorator
 - 
 - 
 - @app.list_tools()
 - @with_api_key(required=True)
 - async def list_tools(*, connector) -> list[types.Tool]:
 -     dataset_description = connector.list_datasets()
 - 
 -     return [
 -         types.Tool(
 -             name="ragflow_retrieval",
 -             description="Retrieve relevant chunks from the RAGFlow retrieve interface based on the question. You can optionally specify dataset_ids to search only specific datasets, or omit dataset_ids entirely to search across ALL available datasets. You can also optionally specify document_ids to search within specific documents. When dataset_ids is not provided or is empty, the system will automatically search across all available datasets. Below is the list of all available datasets, including their descriptions and IDs:"
 -             + dataset_description,
 -             inputSchema={
 -                 "type": "object",
 -                 "properties": {
 -                     "dataset_ids": {
 -                         "type": "array",
 -                         "items": {"type": "string"},
 -                         "description": "Optional array of dataset IDs to search. If not provided or empty, all datasets will be searched."
 -                     },
 -                     "document_ids": {
 -                         "type": "array",
 -                         "items": {"type": "string"},
 -                         "description": "Optional array of document IDs to search within."
 -                     },
 -                     "question": {"type": "string", "description": "The question or query to search for."},
 -                 },
 -                 "required": ["question"],
 -             },
 -         ),
 -     ]
 - 
 - 
 - @app.call_tool()
 - @with_api_key(required=True)
 - async def call_tool(name: str, arguments: dict, *, connector) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
 -     if name == "ragflow_retrieval":
 -         document_ids = arguments.get("document_ids", [])
 -         dataset_ids = arguments.get("dataset_ids", [])
 -         
 -         # If no dataset_ids provided or empty list, get all available dataset IDs
 -         if not dataset_ids:
 -             dataset_list_str = connector.list_datasets()
 -             dataset_ids = []
 -             
 -             # Parse the dataset list to extract IDs
 -             if dataset_list_str:
 -                 for line in dataset_list_str.strip().split('\n'):
 -                     if line.strip():
 -                         try:
 -                             dataset_info = json.loads(line.strip())
 -                             dataset_ids.append(dataset_info["id"])
 -                         except (json.JSONDecodeError, KeyError):
 -                             # Skip malformed lines
 -                             continue
 -         
 -         return connector.retrieval(
 -             dataset_ids=dataset_ids,
 -             document_ids=document_ids,
 -             question=arguments["question"],
 -         )
 -     raise ValueError(f"Tool not found: {name}")
 - 
 - 
 - def create_starlette_app():
 -     routes = []
 -     middleware = None
 -     if MODE == LaunchMode.HOST:
 -         from starlette.types import ASGIApp, Receive, Scope, Send
 - 
 -         class AuthMiddleware:
 -             def __init__(self, app: ASGIApp):
 -                 self.app = app
 - 
 -             async def __call__(self, scope: Scope, receive: Receive, send: Send):
 -                 if scope["type"] != "http":
 -                     await self.app(scope, receive, send)
 -                     return
 - 
 -                 path = scope["path"]
 -                 if path.startswith("/messages/") or path.startswith("/sse") or path.startswith("/mcp"):
 -                     headers = dict(scope["headers"])
 -                     token = None
 -                     auth_header = headers.get(b"authorization")
 -                     if auth_header and auth_header.startswith(b"Bearer "):
 -                         token = auth_header.removeprefix(b"Bearer ").strip()
 -                     elif b"api_key" in headers:
 -                         token = headers[b"api_key"]
 - 
 -                     if not token:
 -                         response = JSONResponse({"error": "Missing or invalid authorization header"}, status_code=401)
 -                         await response(scope, receive, send)
 -                         return
 - 
 -                 await self.app(scope, receive, send)
 - 
 -         middleware = [Middleware(AuthMiddleware)]
 - 
 -     # Add SSE routes if enabled
 -     if TRANSPORT_SSE_ENABLED:
 -         from mcp.server.sse import SseServerTransport
 - 
 -         sse = SseServerTransport("/messages/")
 - 
 -         async def handle_sse(request):
 -             async with sse.connect_sse(request.scope, request.receive, request._send) as streams:
 -                 await app.run(streams[0], streams[1], app.create_initialization_options(experimental_capabilities={"headers": dict(request.headers)}))
 -             return Response()
 - 
 -         routes.extend(
 -             [
 -                 Route("/sse", endpoint=handle_sse, methods=["GET"]),
 -                 Mount("/messages/", app=sse.handle_post_message),
 -             ]
 -         )
 - 
 -     # Add streamable HTTP route if enabled
 -     streamablehttp_lifespan = None
 -     if TRANSPORT_STREAMABLE_HTTP_ENABLED:
 -         from starlette.types import Receive, Scope, Send
 - 
 -         from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
 - 
 -         session_manager = StreamableHTTPSessionManager(
 -             app=app,
 -             event_store=None,
 -             json_response=JSON_RESPONSE,
 -             stateless=True,
 -         )
 - 
 -         async def handle_streamable_http(scope: Scope, receive: Receive, send: Send) -> None:
 -             await session_manager.handle_request(scope, receive, send)
 - 
 -         @asynccontextmanager
 -         async def streamablehttp_lifespan(app: Starlette) -> AsyncIterator[None]:
 -             async with session_manager.run():
 -                 logging.info("StreamableHTTP application started with StreamableHTTP session manager!")
 -                 try:
 -                     yield
 -                 finally:
 -                     logging.info("StreamableHTTP application shutting down...")
 - 
 -         routes.append(Mount("/mcp", app=handle_streamable_http))
 - 
 -     return Starlette(
 -         debug=True,
 -         routes=routes,
 -         middleware=middleware,
 -         lifespan=streamablehttp_lifespan,
 -     )
 - 
 - 
 - @click.command()
 - @click.option("--base-url", type=str, default="http://127.0.0.1:9380", help="API base URL for RAGFlow backend")
 - @click.option("--host", type=str, default="127.0.0.1", help="Host to bind the RAGFlow MCP server")
 - @click.option("--port", type=int, default=9382, help="Port to bind the RAGFlow MCP server")
 - @click.option(
 -     "--mode",
 -     type=click.Choice(["self-host", "host"]),
 -     default="self-host",
 -     help=("Launch mode:\n  self-host: run MCP for a single tenant (requires --api-key)\n  host: multi-tenant mode, users must provide Authorization headers"),
 - )
 - @click.option("--api-key", type=str, default="", help="API key to use when in self-host mode")
 - @click.option(
 -     "--transport-sse-enabled/--no-transport-sse-enabled",
 -     default=True,
 -     help="Enable or disable legacy SSE transport mode (default: enabled)",
 - )
 - @click.option(
 -     "--transport-streamable-http-enabled/--no-transport-streamable-http-enabled",
 -     default=True,
 -     help="Enable or disable streamable-http transport mode (default: enabled)",
 - )
 - @click.option(
 -     "--json-response/--no-json-response",
 -     default=True,
 -     help="Enable or disable JSON response mode for streamable-http (default: enabled)",
 - )
 - def main(base_url, host, port, mode, api_key, transport_sse_enabled, transport_streamable_http_enabled, json_response):
 -     import os
 - 
 -     import uvicorn
 -     from dotenv import load_dotenv
 - 
 -     load_dotenv()
 - 
 -     def parse_bool_flag(key: str, default: bool) -> bool:
 -         val = os.environ.get(key, str(default))
 -         return str(val).strip().lower() in ("1", "true", "yes", "on")
 - 
 -     global BASE_URL, HOST, PORT, MODE, HOST_API_KEY, TRANSPORT_SSE_ENABLED, TRANSPORT_STREAMABLE_HTTP_ENABLED, JSON_RESPONSE
 -     BASE_URL = os.environ.get("RAGFLOW_MCP_BASE_URL", base_url)
 -     HOST = os.environ.get("RAGFLOW_MCP_HOST", host)
 -     PORT = os.environ.get("RAGFLOW_MCP_PORT", str(port))
 -     MODE = os.environ.get("RAGFLOW_MCP_LAUNCH_MODE", mode)
 -     HOST_API_KEY = os.environ.get("RAGFLOW_MCP_HOST_API_KEY", api_key)
 -     TRANSPORT_SSE_ENABLED = parse_bool_flag("RAGFLOW_MCP_TRANSPORT_SSE_ENABLED", transport_sse_enabled)
 -     TRANSPORT_STREAMABLE_HTTP_ENABLED = parse_bool_flag("RAGFLOW_MCP_TRANSPORT_STREAMABLE_ENABLED", transport_streamable_http_enabled)
 -     JSON_RESPONSE = parse_bool_flag("RAGFLOW_MCP_JSON_RESPONSE", json_response)
 - 
 -     if MODE == LaunchMode.SELF_HOST and not HOST_API_KEY:
 -         raise click.UsageError("--api-key is required when --mode is 'self-host'")
 - 
 -     if TRANSPORT_STREAMABLE_HTTP_ENABLED and MODE == LaunchMode.HOST:
 -         raise click.UsageError("The --host mode is not supported with streamable-http transport yet.")
 - 
 -     if not TRANSPORT_STREAMABLE_HTTP_ENABLED and JSON_RESPONSE:
 -         JSON_RESPONSE = False
 - 
 -     print(
 -         r"""
 - __  __  ____ ____       ____  _____ ______     _______ ____
 - |  \/  |/ ___|  _ \     / ___|| ____|  _ \ \   / / ____|  _ \
 - | |\/| | |   | |_) |    \___ \|  _| | |_) \ \ / /|  _| | |_) |
 - | |  | | |___|  __/      ___) | |___|  _ < \ V / | |___|  _ <
 - |_|  |_|\____|_|        |____/|_____|_| \_\ \_/  |_____|_| \_\
 -         """,
 -         flush=True,
 -     )
 -     print(f"MCP launch mode: {MODE}", flush=True)
 -     print(f"MCP host: {HOST}", flush=True)
 -     print(f"MCP port: {PORT}", flush=True)
 -     print(f"MCP base_url: {BASE_URL}", flush=True)
 - 
 -     if not any([TRANSPORT_SSE_ENABLED, TRANSPORT_STREAMABLE_HTTP_ENABLED]):
 -         print("At least one transport should be enabled, enable streamable-http automatically", flush=True)
 -         TRANSPORT_STREAMABLE_HTTP_ENABLED = True
 - 
 -     if TRANSPORT_SSE_ENABLED:
 -         print("SSE transport enabled: yes", flush=True)
 -         print("SSE endpoint available at /sse", flush=True)
 -     else:
 -         print("SSE transport enabled: no", flush=True)
 - 
 -     if TRANSPORT_STREAMABLE_HTTP_ENABLED:
 -         print("Streamable HTTP transport enabled: yes", flush=True)
 -         print("Streamable HTTP endpoint available at /mcp", flush=True)
 -         if JSON_RESPONSE:
 -             print("Streamable HTTP mode: JSON response enabled", flush=True)
 -         else:
 -             print("Streamable HTTP mode: SSE over HTTP enabled", flush=True)
 -     else:
 -         print("Streamable HTTP transport enabled: no", flush=True)
 -         if JSON_RESPONSE:
 -             print("Warning: --json-response ignored because streamable transport is disabled.", flush=True)
 - 
 -     uvicorn.run(
 -         create_starlette_app(),
 -         host=HOST,
 -         port=int(PORT),
 -     )
 - 
 - 
 - if __name__ == "__main__":
 -     """
 -     Launch examples:
 - 
 -     1. Self-host mode with both SSE and Streamable HTTP (in JSON response mode) enabled (default):
 -         uv run mcp/server/server.py --host=127.0.0.1 --port=9382 \
 -             --base-url=http://127.0.0.1:9380 \
 -             --mode=self-host --api-key=ragflow-xxxxx
 - 
 -     2. Host mode (multi-tenant, self-host only, clients must provide Authorization headers):
 -         uv run mcp/server/server.py --host=127.0.0.1 --port=9382 \
 -             --base-url=http://127.0.0.1:9380 \
 -             --mode=host
 - 
 -     3. Disable legacy SSE (only streamable HTTP will be active):
 -         uv run mcp/server/server.py --no-transport-sse-enabled \
 -             --mode=self-host --api-key=ragflow-xxxxx
 - 
 -     4. Disable streamable HTTP (only legacy SSE will be active):
 -         uv run mcp/server/server.py --no-transport-streamable-http-enabled \
 -             --mode=self-host --api-key=ragflow-xxxxx
 - 
 -     5. Use streamable HTTP with SSE-style events (disable JSON response):
 -         uv run mcp/server/server.py --transport-streamable-http-enabled --no-json-response \
 -             --mode=self-host --api-key=ragflow-xxxxx
 - 
 -     6. Disable both transports (for testing):
 -         uv run mcp/server/server.py --no-transport-sse-enabled --no-transport-streamable-http-enabled \
 -             --mode=self-host --api-key=ragflow-xxxxx
 -     """
 -     main()
 
 
  |