You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

server.py 17KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436
  1. #
  2. # Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import json
  17. import logging
  18. from collections.abc import AsyncIterator
  19. from contextlib import asynccontextmanager
  20. from functools import wraps
  21. import click
  22. import requests
  23. from starlette.applications import Starlette
  24. from starlette.middleware import Middleware
  25. from starlette.responses import JSONResponse, Response
  26. from starlette.routing import Mount, Route
  27. from strenum import StrEnum
  28. import mcp.types as types
  29. from mcp.server.lowlevel import Server
  30. class LaunchMode(StrEnum):
  31. SELF_HOST = "self-host"
  32. HOST = "host"
  33. class Transport(StrEnum):
  34. SSE = "sse"
  35. STEAMABLE_HTTP = "streamable-http"
  36. BASE_URL = "http://127.0.0.1:9380"
  37. HOST = "127.0.0.1"
  38. PORT = "9382"
  39. HOST_API_KEY = ""
  40. MODE = ""
  41. TRANSPORT_SSE_ENABLED = True
  42. TRANSPORT_STREAMABLE_HTTP_ENABLED = True
  43. JSON_RESPONSE = True
  44. class RAGFlowConnector:
  45. def __init__(self, base_url: str, version="v1"):
  46. self.base_url = base_url
  47. self.version = version
  48. self.api_url = f"{self.base_url}/api/{self.version}"
  49. def bind_api_key(self, api_key: str):
  50. self.api_key = api_key
  51. self.authorization_header = {"Authorization": "{} {}".format("Bearer", self.api_key)}
  52. def _post(self, path, json=None, stream=False, files=None):
  53. if not self.api_key:
  54. return None
  55. res = requests.post(url=self.api_url + path, json=json, headers=self.authorization_header, stream=stream, files=files)
  56. return res
  57. def _get(self, path, params=None, json=None):
  58. res = requests.get(url=self.api_url + path, params=params, headers=self.authorization_header, json=json)
  59. return res
  60. def list_datasets(self, page: int = 1, page_size: int = 1000, orderby: str = "create_time", desc: bool = True, id: str | None = None, name: str | None = None):
  61. res = self._get("/datasets", {"page": page, "page_size": page_size, "orderby": orderby, "desc": desc, "id": id, "name": name})
  62. if not res:
  63. raise Exception([types.TextContent(type="text", text=res.get("Cannot process this operation."))])
  64. res = res.json()
  65. if res.get("code") == 0:
  66. result_list = []
  67. for data in res["data"]:
  68. d = {"description": data["description"], "id": data["id"]}
  69. result_list.append(json.dumps(d, ensure_ascii=False))
  70. return "\n".join(result_list)
  71. return ""
  72. def retrieval(
  73. self, dataset_ids, document_ids=None, question="", page=1, page_size=30, similarity_threshold=0.2, vector_similarity_weight=0.3, top_k=1024, rerank_id: str | None = None, keyword: bool = False
  74. ):
  75. if document_ids is None:
  76. document_ids = []
  77. data_json = {
  78. "page": page,
  79. "page_size": page_size,
  80. "similarity_threshold": similarity_threshold,
  81. "vector_similarity_weight": vector_similarity_weight,
  82. "top_k": top_k,
  83. "rerank_id": rerank_id,
  84. "keyword": keyword,
  85. "question": question,
  86. "dataset_ids": dataset_ids,
  87. "document_ids": document_ids,
  88. }
  89. # Send a POST request to the backend service (using requests library as an example, actual implementation may vary)
  90. res = self._post("/retrieval", json=data_json)
  91. if not res:
  92. raise Exception([types.TextContent(type="text", text=res.get("Cannot process this operation."))])
  93. res = res.json()
  94. if res.get("code") == 0:
  95. chunks = []
  96. for chunk_data in res["data"].get("chunks"):
  97. chunks.append(json.dumps(chunk_data, ensure_ascii=False))
  98. return [types.TextContent(type="text", text="\n".join(chunks))]
  99. raise Exception([types.TextContent(type="text", text=res.get("message"))])
  100. class RAGFlowCtx:
  101. def __init__(self, connector: RAGFlowConnector):
  102. self.conn = connector
  103. @asynccontextmanager
  104. async def sse_lifespan(server: Server) -> AsyncIterator[dict]:
  105. ctx = RAGFlowCtx(RAGFlowConnector(base_url=BASE_URL))
  106. logging.info("Legacy SSE application started with StreamableHTTP session manager!")
  107. try:
  108. yield {"ragflow_ctx": ctx}
  109. finally:
  110. logging.info("Legacy SSE application shutting down...")
  111. app = Server("ragflow-mcp-server", lifespan=sse_lifespan)
  112. def with_api_key(required=True):
  113. def decorator(func):
  114. @wraps(func)
  115. async def wrapper(*args, **kwargs):
  116. ctx = app.request_context
  117. ragflow_ctx = ctx.lifespan_context.get("ragflow_ctx")
  118. if not ragflow_ctx:
  119. raise ValueError("Get RAGFlow Context failed")
  120. connector = ragflow_ctx.conn
  121. if MODE == LaunchMode.HOST:
  122. headers = ctx.session._init_options.capabilities.experimental.get("headers", {})
  123. token = None
  124. # lower case here, because of Starlette conversion
  125. auth = headers.get("authorization", "")
  126. if auth.startswith("Bearer "):
  127. token = auth.removeprefix("Bearer ").strip()
  128. elif "api_key" in headers:
  129. token = headers["api_key"]
  130. if required and not token:
  131. raise ValueError("RAGFlow API key or Bearer token is required.")
  132. connector.bind_api_key(token)
  133. else:
  134. connector.bind_api_key(HOST_API_KEY)
  135. return await func(*args, connector=connector, **kwargs)
  136. return wrapper
  137. return decorator
  138. @app.list_tools()
  139. @with_api_key(required=True)
  140. async def list_tools(*, connector) -> list[types.Tool]:
  141. dataset_description = connector.list_datasets()
  142. return [
  143. types.Tool(
  144. name="ragflow_retrieval",
  145. description="Retrieve relevant chunks from the RAGFlow retrieve interface based on the question, using the specified dataset_ids and optionally document_ids. Below is the list of all available datasets, including their descriptions and IDs. If you're unsure which datasets are relevant to the question, simply pass all dataset IDs to the function."
  146. + dataset_description,
  147. inputSchema={
  148. "type": "object",
  149. "properties": {
  150. "dataset_ids": {
  151. "type": "array",
  152. "items": {"type": "string"},
  153. },
  154. "document_ids": {
  155. "type": "array",
  156. "items": {"type": "string"},
  157. },
  158. "question": {"type": "string"},
  159. },
  160. "required": ["dataset_ids", "question"],
  161. },
  162. ),
  163. ]
  164. @app.call_tool()
  165. @with_api_key(required=True)
  166. async def call_tool(name: str, arguments: dict, *, connector) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
  167. if name == "ragflow_retrieval":
  168. document_ids = arguments.get("document_ids", [])
  169. return connector.retrieval(
  170. dataset_ids=arguments["dataset_ids"],
  171. document_ids=document_ids,
  172. question=arguments["question"],
  173. )
  174. raise ValueError(f"Tool not found: {name}")
  175. def create_starlette_app():
  176. routes = []
  177. middleware = None
  178. if MODE == LaunchMode.HOST:
  179. from starlette.types import ASGIApp, Receive, Scope, Send
  180. class AuthMiddleware:
  181. def __init__(self, app: ASGIApp):
  182. self.app = app
  183. async def __call__(self, scope: Scope, receive: Receive, send: Send):
  184. if scope["type"] != "http":
  185. await self.app(scope, receive, send)
  186. return
  187. path = scope["path"]
  188. if path.startswith("/messages/") or path.startswith("/sse") or path.startswith("/mcp"):
  189. headers = dict(scope["headers"])
  190. token = None
  191. auth_header = headers.get(b"authorization")
  192. if auth_header and auth_header.startswith(b"Bearer "):
  193. token = auth_header.removeprefix(b"Bearer ").strip()
  194. elif b"api_key" in headers:
  195. token = headers[b"api_key"]
  196. if not token:
  197. response = JSONResponse({"error": "Missing or invalid authorization header"}, status_code=401)
  198. await response(scope, receive, send)
  199. return
  200. await self.app(scope, receive, send)
  201. middleware = [Middleware(AuthMiddleware)]
  202. # Add SSE routes if enabled
  203. if TRANSPORT_SSE_ENABLED:
  204. from mcp.server.sse import SseServerTransport
  205. sse = SseServerTransport("/messages/")
  206. async def handle_sse(request):
  207. async with sse.connect_sse(request.scope, request.receive, request._send) as streams:
  208. await app.run(streams[0], streams[1], app.create_initialization_options(experimental_capabilities={"headers": dict(request.headers)}))
  209. return Response()
  210. routes.extend(
  211. [
  212. Route("/sse", endpoint=handle_sse, methods=["GET"]),
  213. Mount("/messages/", app=sse.handle_post_message),
  214. ]
  215. )
  216. # Add streamable HTTP route if enabled
  217. streamablehttp_lifespan = None
  218. if TRANSPORT_STREAMABLE_HTTP_ENABLED:
  219. from starlette.types import Receive, Scope, Send
  220. from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
  221. session_manager = StreamableHTTPSessionManager(
  222. app=app,
  223. event_store=None,
  224. json_response=JSON_RESPONSE,
  225. stateless=True,
  226. )
  227. async def handle_streamable_http(scope: Scope, receive: Receive, send: Send) -> None:
  228. await session_manager.handle_request(scope, receive, send)
  229. @asynccontextmanager
  230. async def streamablehttp_lifespan(app: Starlette) -> AsyncIterator[None]:
  231. async with session_manager.run():
  232. logging.info("StreamableHTTP application started with StreamableHTTP session manager!")
  233. try:
  234. yield
  235. finally:
  236. logging.info("StreamableHTTP application shutting down...")
  237. routes.append(Mount("/mcp", app=handle_streamable_http))
  238. return Starlette(
  239. debug=True,
  240. routes=routes,
  241. middleware=middleware,
  242. lifespan=streamablehttp_lifespan,
  243. )
  244. @click.command()
  245. @click.option("--base-url", type=str, default="http://127.0.0.1:9380", help="API base URL for RAGFlow backend")
  246. @click.option("--host", type=str, default="127.0.0.1", help="Host to bind the RAGFlow MCP server")
  247. @click.option("--port", type=int, default=9382, help="Port to bind the RAGFlow MCP server")
  248. @click.option(
  249. "--mode",
  250. type=click.Choice(["self-host", "host"]),
  251. default="self-host",
  252. help=("Launch mode:\n self-host: run MCP for a single tenant (requires --api-key)\n host: multi-tenant mode, users must provide Authorization headers"),
  253. )
  254. @click.option("--api-key", type=str, default="", help="API key to use when in self-host mode")
  255. @click.option(
  256. "--transport-sse-enabled/--no-transport-sse-enabled",
  257. default=True,
  258. help="Enable or disable legacy SSE transport mode (default: enabled)",
  259. )
  260. @click.option(
  261. "--transport-streamable-http-enabled/--no-transport-streamable-http-enabled",
  262. default=True,
  263. help="Enable or disable streamable-http transport mode (default: enabled)",
  264. )
  265. @click.option(
  266. "--json-response/--no-json-response",
  267. default=True,
  268. help="Enable or disable JSON response mode for streamable-http (default: enabled)",
  269. )
  270. def main(base_url, host, port, mode, api_key, transport_sse_enabled, transport_streamable_http_enabled, json_response):
  271. import os
  272. import uvicorn
  273. from dotenv import load_dotenv
  274. load_dotenv()
  275. def parse_bool_flag(key: str, default: bool) -> bool:
  276. val = os.environ.get(key, str(default))
  277. return str(val).strip().lower() in ("1", "true", "yes", "on")
  278. global BASE_URL, HOST, PORT, MODE, HOST_API_KEY, TRANSPORT_SSE_ENABLED, TRANSPORT_STREAMABLE_HTTP_ENABLED, JSON_RESPONSE
  279. BASE_URL = os.environ.get("RAGFLOW_MCP_BASE_URL", base_url)
  280. HOST = os.environ.get("RAGFLOW_MCP_HOST", host)
  281. PORT = os.environ.get("RAGFLOW_MCP_PORT", str(port))
  282. MODE = os.environ.get("RAGFLOW_MCP_LAUNCH_MODE", mode)
  283. HOST_API_KEY = os.environ.get("RAGFLOW_MCP_HOST_API_KEY", api_key)
  284. TRANSPORT_SSE_ENABLED = parse_bool_flag("RAGFLOW_MCP_TRANSPORT_SSE_ENABLED", transport_sse_enabled)
  285. TRANSPORT_STREAMABLE_HTTP_ENABLED = parse_bool_flag("RAGFLOW_MCP_TRANSPORT_STREAMABLE_ENABLED", transport_streamable_http_enabled)
  286. JSON_RESPONSE = parse_bool_flag("RAGFLOW_MCP_JSON_RESPONSE", json_response)
  287. if MODE == LaunchMode.SELF_HOST and not HOST_API_KEY:
  288. raise click.UsageError("--api-key is required when --mode is 'self-host'")
  289. if TRANSPORT_STREAMABLE_HTTP_ENABLED and MODE == LaunchMode.HOST:
  290. raise click.UsageError("The --host mode is not supported with streamable-http transport yet.")
  291. if not TRANSPORT_STREAMABLE_HTTP_ENABLED and JSON_RESPONSE:
  292. JSON_RESPONSE = False
  293. print(
  294. r"""
  295. __ __ ____ ____ ____ _____ ______ _______ ____
  296. | \/ |/ ___| _ \ / ___|| ____| _ \ \ / / ____| _ \
  297. | |\/| | | | |_) | \___ \| _| | |_) \ \ / /| _| | |_) |
  298. | | | | |___| __/ ___) | |___| _ < \ V / | |___| _ <
  299. |_| |_|\____|_| |____/|_____|_| \_\ \_/ |_____|_| \_\
  300. """,
  301. flush=True,
  302. )
  303. print(f"MCP launch mode: {MODE}", flush=True)
  304. print(f"MCP host: {HOST}", flush=True)
  305. print(f"MCP port: {PORT}", flush=True)
  306. print(f"MCP base_url: {BASE_URL}", flush=True)
  307. if not any([TRANSPORT_SSE_ENABLED, TRANSPORT_STREAMABLE_HTTP_ENABLED]):
  308. print("At least one transport should be enabled, enable streamable-http automatically", flush=True)
  309. TRANSPORT_STREAMABLE_HTTP_ENABLED = True
  310. if TRANSPORT_SSE_ENABLED:
  311. print("SSE transport enabled: yes", flush=True)
  312. print("SSE endpoint available at /sse", flush=True)
  313. else:
  314. print("SSE transport enabled: no", flush=True)
  315. if TRANSPORT_STREAMABLE_HTTP_ENABLED:
  316. print("Streamable HTTP transport enabled: yes", flush=True)
  317. print("Streamable HTTP endpoint available at /mcp", flush=True)
  318. if JSON_RESPONSE:
  319. print("Streamable HTTP mode: JSON response enabled", flush=True)
  320. else:
  321. print("Streamable HTTP mode: SSE over HTTP enabled", flush=True)
  322. else:
  323. print("Streamable HTTP transport enabled: no", flush=True)
  324. if JSON_RESPONSE:
  325. print("Warning: --json-response ignored because streamable transport is disabled.", flush=True)
  326. uvicorn.run(
  327. create_starlette_app(),
  328. host=HOST,
  329. port=int(PORT),
  330. )
  331. if __name__ == "__main__":
  332. """
  333. Launch examples:
  334. 1. Self-host mode with both SSE and Streamable HTTP (in JSON response mode) enabled (default):
  335. uv run mcp/server/server.py --host=127.0.0.1 --port=9382 \
  336. --base-url=http://127.0.0.1:9380 \
  337. --mode=self-host --api-key=ragflow-xxxxx
  338. 2. Host mode (multi-tenant, self-host only, clients must provide Authorization headers):
  339. uv run mcp/server/server.py --host=127.0.0.1 --port=9382 \
  340. --base-url=http://127.0.0.1:9380 \
  341. --mode=host
  342. 3. Disable legacy SSE (only streamable HTTP will be active):
  343. uv run mcp/server/server.py --no-transport-sse-enabled \
  344. --mode=self-host --api-key=ragflow-xxxxx
  345. 4. Disable streamable HTTP (only legacy SSE will be active):
  346. uv run mcp/server/server.py --no-transport-streamable-http-enabled \
  347. --mode=self-host --api-key=ragflow-xxxxx
  348. 5. Use streamable HTTP with SSE-style events (disable JSON response):
  349. uv run mcp/server/server.py --transport-streamable-http-enabled --no-json-response \
  350. --mode=self-host --api-key=ragflow-xxxxx
  351. 6. Disable both transports (for testing):
  352. uv run mcp/server/server.py --no-transport-sse-enabled --no-transport-streamable-http-enabled \
  353. --mode=self-host --api-key=ragflow-xxxxx
  354. """
  355. main()