Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715
  1. #
  2. # Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import json
  17. import logging
  18. import random
  19. import time
  20. from collections import OrderedDict
  21. from collections.abc import AsyncIterator
  22. from contextlib import asynccontextmanager
  23. from functools import wraps
  24. import click
  25. import requests
  26. from starlette.applications import Starlette
  27. from starlette.middleware import Middleware
  28. from starlette.responses import JSONResponse, Response
  29. from starlette.routing import Mount, Route
  30. from strenum import StrEnum
  31. import mcp.types as types
  32. from mcp.server.lowlevel import Server
  33. class LaunchMode(StrEnum):
  34. SELF_HOST = "self-host"
  35. HOST = "host"
  36. class Transport(StrEnum):
  37. SSE = "sse"
  38. STEAMABLE_HTTP = "streamable-http"
  39. BASE_URL = "http://127.0.0.1:9380"
  40. HOST = "127.0.0.1"
  41. PORT = "9382"
  42. HOST_API_KEY = ""
  43. MODE = ""
  44. TRANSPORT_SSE_ENABLED = True
  45. TRANSPORT_STREAMABLE_HTTP_ENABLED = True
  46. JSON_RESPONSE = True
  47. class RAGFlowConnector:
  48. _MAX_DATASET_CACHE = 32
  49. _MAX_DOCUMENT_CACHE = 128
  50. _CACHE_TTL = 300
  51. _dataset_metadata_cache: OrderedDict[str, tuple[dict, float | int]] = OrderedDict() # "dataset_id" -> (metadata, expiry_ts)
  52. _document_metadata_cache: OrderedDict[str, tuple[list[tuple[str, dict]], float | int]] = OrderedDict() # "dataset_id" -> ([(document_id, doc_metadata)], expiry_ts)
  53. def __init__(self, base_url: str, version="v1"):
  54. self.base_url = base_url
  55. self.version = version
  56. self.api_url = f"{self.base_url}/api/{self.version}"
  57. def bind_api_key(self, api_key: str):
  58. self.api_key = api_key
  59. self.authorization_header = {"Authorization": "{} {}".format("Bearer", self.api_key)}
  60. def _post(self, path, json=None, stream=False, files=None):
  61. if not self.api_key:
  62. return None
  63. res = requests.post(url=self.api_url + path, json=json, headers=self.authorization_header, stream=stream, files=files)
  64. return res
  65. def _get(self, path, params=None, json=None):
  66. res = requests.get(url=self.api_url + path, params=params, headers=self.authorization_header, json=json)
  67. return res
  68. def _is_cache_valid(self, ts):
  69. return time.time() < ts
  70. def _get_expiry_timestamp(self):
  71. offset = random.randint(-30, 30)
  72. return time.time() + self._CACHE_TTL + offset
  73. def _get_cached_dataset_metadata(self, dataset_id):
  74. entry = self._dataset_metadata_cache.get(dataset_id)
  75. if entry:
  76. data, ts = entry
  77. if self._is_cache_valid(ts):
  78. self._dataset_metadata_cache.move_to_end(dataset_id)
  79. return data
  80. return None
  81. def _set_cached_dataset_metadata(self, dataset_id, metadata):
  82. self._dataset_metadata_cache[dataset_id] = (metadata, self._get_expiry_timestamp())
  83. self._dataset_metadata_cache.move_to_end(dataset_id)
  84. if len(self._dataset_metadata_cache) > self._MAX_DATASET_CACHE:
  85. self._dataset_metadata_cache.popitem(last=False)
  86. def _get_cached_document_metadata_by_dataset(self, dataset_id):
  87. entry = self._document_metadata_cache.get(dataset_id)
  88. if entry:
  89. data_list, ts = entry
  90. if self._is_cache_valid(ts):
  91. self._document_metadata_cache.move_to_end(dataset_id)
  92. return {doc_id: doc_meta for doc_id, doc_meta in data_list}
  93. return None
  94. def _set_cached_document_metadata_by_dataset(self, dataset_id, doc_id_meta_list):
  95. self._document_metadata_cache[dataset_id] = (doc_id_meta_list, self._get_expiry_timestamp())
  96. self._document_metadata_cache.move_to_end(dataset_id)
  97. if len(self._document_metadata_cache) > self._MAX_DOCUMENT_CACHE:
  98. self._document_metadata_cache.popitem(last=False)
  99. def list_datasets(self, page: int = 1, page_size: int = 1000, orderby: str = "create_time", desc: bool = True, id: str | None = None, name: str | None = None):
  100. res = self._get("/datasets", {"page": page, "page_size": page_size, "orderby": orderby, "desc": desc, "id": id, "name": name})
  101. if not res:
  102. raise Exception([types.TextContent(type="text", text=res.get("Cannot process this operation."))])
  103. res = res.json()
  104. if res.get("code") == 0:
  105. result_list = []
  106. for data in res["data"]:
  107. d = {"description": data["description"], "id": data["id"]}
  108. result_list.append(json.dumps(d, ensure_ascii=False))
  109. return "\n".join(result_list)
  110. return ""
  111. def retrieval(
  112. self,
  113. dataset_ids,
  114. document_ids=None,
  115. question="",
  116. page=1,
  117. page_size=30,
  118. similarity_threshold=0.2,
  119. vector_similarity_weight=0.3,
  120. top_k=1024,
  121. rerank_id: str | None = None,
  122. keyword: bool = False,
  123. force_refresh: bool = False,
  124. ):
  125. if document_ids is None:
  126. document_ids = []
  127. # If no dataset_ids provided or empty list, get all available dataset IDs
  128. if not dataset_ids:
  129. dataset_list_str = self.list_datasets()
  130. dataset_ids = []
  131. # Parse the dataset list to extract IDs
  132. if dataset_list_str:
  133. for line in dataset_list_str.strip().split('\n'):
  134. if line.strip():
  135. try:
  136. dataset_info = json.loads(line.strip())
  137. dataset_ids.append(dataset_info["id"])
  138. except (json.JSONDecodeError, KeyError):
  139. # Skip malformed lines
  140. continue
  141. data_json = {
  142. "page": page,
  143. "page_size": page_size,
  144. "similarity_threshold": similarity_threshold,
  145. "vector_similarity_weight": vector_similarity_weight,
  146. "top_k": top_k,
  147. "rerank_id": rerank_id,
  148. "keyword": keyword,
  149. "question": question,
  150. "dataset_ids": dataset_ids,
  151. "document_ids": document_ids,
  152. }
  153. # Send a POST request to the backend service (using requests library as an example, actual implementation may vary)
  154. res = self._post("/retrieval", json=data_json)
  155. if not res:
  156. raise Exception([types.TextContent(type="text", text=res.get("Cannot process this operation."))])
  157. res = res.json()
  158. if res.get("code") == 0:
  159. data = res["data"]
  160. chunks = []
  161. # Cache document metadata and dataset information
  162. document_cache, dataset_cache = self._get_document_metadata_cache(dataset_ids, force_refresh=force_refresh)
  163. # Process chunks with enhanced field mapping including per-chunk metadata
  164. for chunk_data in data.get("chunks", []):
  165. enhanced_chunk = self._map_chunk_fields(chunk_data, dataset_cache, document_cache)
  166. chunks.append(enhanced_chunk)
  167. # Build structured response (no longer need response-level document_metadata)
  168. response = {
  169. "chunks": chunks,
  170. "pagination": {
  171. "page": data.get("page", page),
  172. "page_size": data.get("page_size", page_size),
  173. "total_chunks": data.get("total", len(chunks)),
  174. "total_pages": (data.get("total", len(chunks)) + page_size - 1) // page_size,
  175. },
  176. "query_info": {
  177. "question": question,
  178. "similarity_threshold": similarity_threshold,
  179. "vector_weight": vector_similarity_weight,
  180. "keyword_search": keyword,
  181. "dataset_count": len(dataset_ids),
  182. },
  183. }
  184. return [types.TextContent(type="text", text=json.dumps(response, ensure_ascii=False))]
  185. raise Exception([types.TextContent(type="text", text=res.get("message"))])
  186. def _get_document_metadata_cache(self, dataset_ids, force_refresh=False):
  187. """Cache document metadata for all documents in the specified datasets"""
  188. document_cache = {}
  189. dataset_cache = {}
  190. try:
  191. for dataset_id in dataset_ids:
  192. dataset_meta = None if force_refresh else self._get_cached_dataset_metadata(dataset_id)
  193. if not dataset_meta:
  194. # First get dataset info for name
  195. dataset_res = self._get("/datasets", {"id": dataset_id, "page_size": 1})
  196. if dataset_res and dataset_res.status_code == 200:
  197. dataset_data = dataset_res.json()
  198. if dataset_data.get("code") == 0 and dataset_data.get("data"):
  199. dataset_info = dataset_data["data"][0]
  200. dataset_meta = {"name": dataset_info.get("name", "Unknown"), "description": dataset_info.get("description", "")}
  201. self._set_cached_dataset_metadata(dataset_id, dataset_meta)
  202. if dataset_meta:
  203. dataset_cache[dataset_id] = dataset_meta
  204. docs = None if force_refresh else self._get_cached_document_metadata_by_dataset(dataset_id)
  205. if docs is None:
  206. docs_res = self._get(f"/datasets/{dataset_id}/documents")
  207. docs_data = docs_res.json()
  208. if docs_data.get("code") == 0 and docs_data.get("data", {}).get("docs"):
  209. doc_id_meta_list = []
  210. docs = {}
  211. for doc in docs_data["data"]["docs"]:
  212. doc_id = doc.get("id")
  213. if not doc_id:
  214. continue
  215. doc_meta = {
  216. "document_id": doc_id,
  217. "name": doc.get("name", ""),
  218. "location": doc.get("location", ""),
  219. "type": doc.get("type", ""),
  220. "size": doc.get("size"),
  221. "chunk_count": doc.get("chunk_count"),
  222. # "chunk_method": doc.get("chunk_method", ""),
  223. "create_date": doc.get("create_date", ""),
  224. "update_date": doc.get("update_date", ""),
  225. # "process_begin_at": doc.get("process_begin_at", ""),
  226. # "process_duration": doc.get("process_duration"),
  227. # "progress": doc.get("progress"),
  228. # "progress_msg": doc.get("progress_msg", ""),
  229. # "status": doc.get("status", ""),
  230. # "run": doc.get("run", ""),
  231. "token_count": doc.get("token_count"),
  232. # "source_type": doc.get("source_type", ""),
  233. "thumbnail": doc.get("thumbnail", ""),
  234. "dataset_id": doc.get("dataset_id", dataset_id),
  235. "meta_fields": doc.get("meta_fields", {}),
  236. # "parser_config": doc.get("parser_config", {})
  237. }
  238. doc_id_meta_list.append((doc_id, doc_meta))
  239. docs[doc_id] = doc_meta
  240. self._set_cached_document_metadata_by_dataset(dataset_id, doc_id_meta_list)
  241. if docs:
  242. document_cache.update(docs)
  243. except Exception:
  244. # Gracefully handle metadata cache failures
  245. pass
  246. return document_cache, dataset_cache
  247. def _map_chunk_fields(self, chunk_data, dataset_cache, document_cache):
  248. """Preserve all original API fields and add per-chunk document metadata"""
  249. # Start with ALL raw data from API (preserve everything like original version)
  250. mapped = dict(chunk_data)
  251. # Add dataset name enhancement
  252. dataset_id = chunk_data.get("dataset_id") or chunk_data.get("kb_id")
  253. if dataset_id and dataset_id in dataset_cache:
  254. mapped["dataset_name"] = dataset_cache[dataset_id]["name"]
  255. else:
  256. mapped["dataset_name"] = "Unknown"
  257. # Add document name convenience field
  258. mapped["document_name"] = chunk_data.get("document_keyword", "")
  259. # Add per-chunk document metadata
  260. document_id = chunk_data.get("document_id")
  261. if document_id and document_id in document_cache:
  262. mapped["document_metadata"] = document_cache[document_id]
  263. return mapped
  264. class RAGFlowCtx:
  265. def __init__(self, connector: RAGFlowConnector):
  266. self.conn = connector
  267. @asynccontextmanager
  268. async def sse_lifespan(server: Server) -> AsyncIterator[dict]:
  269. ctx = RAGFlowCtx(RAGFlowConnector(base_url=BASE_URL))
  270. logging.info("Legacy SSE application started with StreamableHTTP session manager!")
  271. try:
  272. yield {"ragflow_ctx": ctx}
  273. finally:
  274. logging.info("Legacy SSE application shutting down...")
  275. app = Server("ragflow-mcp-server", lifespan=sse_lifespan)
  276. def with_api_key(required=True):
  277. def decorator(func):
  278. @wraps(func)
  279. async def wrapper(*args, **kwargs):
  280. ctx = app.request_context
  281. ragflow_ctx = ctx.lifespan_context.get("ragflow_ctx")
  282. if not ragflow_ctx:
  283. raise ValueError("Get RAGFlow Context failed")
  284. connector = ragflow_ctx.conn
  285. if MODE == LaunchMode.HOST:
  286. headers = ctx.session._init_options.capabilities.experimental.get("headers", {})
  287. token = None
  288. # lower case here, because of Starlette conversion
  289. auth = headers.get("authorization", "")
  290. if auth.startswith("Bearer "):
  291. token = auth.removeprefix("Bearer ").strip()
  292. elif "api_key" in headers:
  293. token = headers["api_key"]
  294. if required and not token:
  295. raise ValueError("RAGFlow API key or Bearer token is required.")
  296. connector.bind_api_key(token)
  297. else:
  298. connector.bind_api_key(HOST_API_KEY)
  299. return await func(*args, connector=connector, **kwargs)
  300. return wrapper
  301. return decorator
  302. @app.list_tools()
  303. @with_api_key(required=True)
  304. async def list_tools(*, connector) -> list[types.Tool]:
  305. dataset_description = connector.list_datasets()
  306. return [
  307. types.Tool(
  308. name="ragflow_retrieval",
  309. description="Retrieve relevant chunks from the RAGFlow retrieve interface based on the question. You can optionally specify dataset_ids to search only specific datasets, or omit dataset_ids entirely to search across ALL available datasets. You can also optionally specify document_ids to search within specific documents. When dataset_ids is not provided or is empty, the system will automatically search across all available datasets. Below is the list of all available datasets, including their descriptions and IDs:"
  310. + dataset_description,
  311. inputSchema={
  312. "type": "object",
  313. "properties": {
  314. "dataset_ids": {
  315. "type": "array",
  316. "items": {"type": "string"},
  317. "description": "Optional array of dataset IDs to search. If not provided or empty, all datasets will be searched."
  318. },
  319. "document_ids": {
  320. "type": "array",
  321. "items": {"type": "string"},
  322. "description": "Optional array of document IDs to search within."
  323. },
  324. "question": {
  325. "type": "string",
  326. "description": "The question or query to search for."
  327. },
  328. "page": {
  329. "type": "integer",
  330. "description": "Page number for pagination",
  331. "default": 1,
  332. "minimum": 1,
  333. },
  334. "page_size": {
  335. "type": "integer",
  336. "description": "Number of results to return per page (default: 10, max recommended: 50 to avoid token limits)",
  337. "default": 10,
  338. "minimum": 1,
  339. "maximum": 100,
  340. },
  341. "similarity_threshold": {
  342. "type": "number",
  343. "description": "Minimum similarity threshold for results",
  344. "default": 0.2,
  345. "minimum": 0.0,
  346. "maximum": 1.0,
  347. },
  348. "vector_similarity_weight": {
  349. "type": "number",
  350. "description": "Weight for vector similarity vs term similarity",
  351. "default": 0.3,
  352. "minimum": 0.0,
  353. "maximum": 1.0,
  354. },
  355. "keyword": {
  356. "type": "boolean",
  357. "description": "Enable keyword-based search",
  358. "default": False,
  359. },
  360. "top_k": {
  361. "type": "integer",
  362. "description": "Maximum results to consider before ranking",
  363. "default": 1024,
  364. "minimum": 1,
  365. "maximum": 1024,
  366. },
  367. "rerank_id": {
  368. "type": "string",
  369. "description": "Optional reranking model identifier",
  370. },
  371. "force_refresh": {
  372. "type": "boolean",
  373. "description": "Set to true only if fresh dataset and document metadata is explicitly required. Otherwise, cached metadata is used (default: false).",
  374. "default": False,
  375. },
  376. },
  377. "required": ["question"],
  378. },
  379. ),
  380. ]
  381. @app.call_tool()
  382. @with_api_key(required=True)
  383. async def call_tool(name: str, arguments: dict, *, connector) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
  384. if name == "ragflow_retrieval":
  385. document_ids = arguments.get("document_ids", [])
  386. dataset_ids = arguments.get("dataset_ids", [])
  387. question = arguments.get("question", "")
  388. page = arguments.get("page", 1)
  389. page_size = arguments.get("page_size", 10)
  390. similarity_threshold = arguments.get("similarity_threshold", 0.2)
  391. vector_similarity_weight = arguments.get("vector_similarity_weight", 0.3)
  392. keyword = arguments.get("keyword", False)
  393. top_k = arguments.get("top_k", 1024)
  394. rerank_id = arguments.get("rerank_id")
  395. force_refresh = arguments.get("force_refresh", False)
  396. # If no dataset_ids provided or empty list, get all available dataset IDs
  397. if not dataset_ids:
  398. dataset_list_str = connector.list_datasets()
  399. dataset_ids = []
  400. # Parse the dataset list to extract IDs
  401. if dataset_list_str:
  402. for line in dataset_list_str.strip().split('\n'):
  403. if line.strip():
  404. try:
  405. dataset_info = json.loads(line.strip())
  406. dataset_ids.append(dataset_info["id"])
  407. except (json.JSONDecodeError, KeyError):
  408. # Skip malformed lines
  409. continue
  410. return connector.retrieval(
  411. dataset_ids=dataset_ids,
  412. document_ids=document_ids,
  413. question=question,
  414. page=page,
  415. page_size=page_size,
  416. similarity_threshold=similarity_threshold,
  417. vector_similarity_weight=vector_similarity_weight,
  418. keyword=keyword,
  419. top_k=top_k,
  420. rerank_id=rerank_id,
  421. force_refresh=force_refresh,
  422. )
  423. raise ValueError(f"Tool not found: {name}")
  424. def create_starlette_app():
  425. routes = []
  426. middleware = None
  427. if MODE == LaunchMode.HOST:
  428. from starlette.types import ASGIApp, Receive, Scope, Send
  429. class AuthMiddleware:
  430. def __init__(self, app: ASGIApp):
  431. self.app = app
  432. async def __call__(self, scope: Scope, receive: Receive, send: Send):
  433. if scope["type"] != "http":
  434. await self.app(scope, receive, send)
  435. return
  436. path = scope["path"]
  437. if path.startswith("/messages/") or path.startswith("/sse") or path.startswith("/mcp"):
  438. headers = dict(scope["headers"])
  439. token = None
  440. auth_header = headers.get(b"authorization")
  441. if auth_header and auth_header.startswith(b"Bearer "):
  442. token = auth_header.removeprefix(b"Bearer ").strip()
  443. elif b"api_key" in headers:
  444. token = headers[b"api_key"]
  445. if not token:
  446. response = JSONResponse({"error": "Missing or invalid authorization header"}, status_code=401)
  447. await response(scope, receive, send)
  448. return
  449. await self.app(scope, receive, send)
  450. middleware = [Middleware(AuthMiddleware)]
  451. # Add SSE routes if enabled
  452. if TRANSPORT_SSE_ENABLED:
  453. from mcp.server.sse import SseServerTransport
  454. sse = SseServerTransport("/messages/")
  455. async def handle_sse(request):
  456. async with sse.connect_sse(request.scope, request.receive, request._send) as streams:
  457. await app.run(streams[0], streams[1], app.create_initialization_options(experimental_capabilities={"headers": dict(request.headers)}))
  458. return Response()
  459. routes.extend(
  460. [
  461. Route("/sse", endpoint=handle_sse, methods=["GET"]),
  462. Mount("/messages/", app=sse.handle_post_message),
  463. ]
  464. )
  465. # Add streamable HTTP route if enabled
  466. streamablehttp_lifespan = None
  467. if TRANSPORT_STREAMABLE_HTTP_ENABLED:
  468. from starlette.types import Receive, Scope, Send
  469. from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
  470. session_manager = StreamableHTTPSessionManager(
  471. app=app,
  472. event_store=None,
  473. json_response=JSON_RESPONSE,
  474. stateless=True,
  475. )
  476. async def handle_streamable_http(scope: Scope, receive: Receive, send: Send) -> None:
  477. await session_manager.handle_request(scope, receive, send)
  478. @asynccontextmanager
  479. async def streamablehttp_lifespan(app: Starlette) -> AsyncIterator[None]:
  480. async with session_manager.run():
  481. logging.info("StreamableHTTP application started with StreamableHTTP session manager!")
  482. try:
  483. yield
  484. finally:
  485. logging.info("StreamableHTTP application shutting down...")
  486. routes.append(Mount("/mcp", app=handle_streamable_http))
  487. return Starlette(
  488. debug=True,
  489. routes=routes,
  490. middleware=middleware,
  491. lifespan=streamablehttp_lifespan,
  492. )
  493. @click.command()
  494. @click.option("--base-url", type=str, default="http://127.0.0.1:9380", help="API base URL for RAGFlow backend")
  495. @click.option("--host", type=str, default="127.0.0.1", help="Host to bind the RAGFlow MCP server")
  496. @click.option("--port", type=int, default=9382, help="Port to bind the RAGFlow MCP server")
  497. @click.option(
  498. "--mode",
  499. type=click.Choice(["self-host", "host"]),
  500. default="self-host",
  501. help=("Launch mode:\n self-host: run MCP for a single tenant (requires --api-key)\n host: multi-tenant mode, users must provide Authorization headers"),
  502. )
  503. @click.option("--api-key", type=str, default="", help="API key to use when in self-host mode")
  504. @click.option(
  505. "--transport-sse-enabled/--no-transport-sse-enabled",
  506. default=True,
  507. help="Enable or disable legacy SSE transport mode (default: enabled)",
  508. )
  509. @click.option(
  510. "--transport-streamable-http-enabled/--no-transport-streamable-http-enabled",
  511. default=True,
  512. help="Enable or disable streamable-http transport mode (default: enabled)",
  513. )
  514. @click.option(
  515. "--json-response/--no-json-response",
  516. default=True,
  517. help="Enable or disable JSON response mode for streamable-http (default: enabled)",
  518. )
  519. def main(base_url, host, port, mode, api_key, transport_sse_enabled, transport_streamable_http_enabled, json_response):
  520. import os
  521. import uvicorn
  522. from dotenv import load_dotenv
  523. load_dotenv()
  524. def parse_bool_flag(key: str, default: bool) -> bool:
  525. val = os.environ.get(key, str(default))
  526. return str(val).strip().lower() in ("1", "true", "yes", "on")
  527. global BASE_URL, HOST, PORT, MODE, HOST_API_KEY, TRANSPORT_SSE_ENABLED, TRANSPORT_STREAMABLE_HTTP_ENABLED, JSON_RESPONSE
  528. BASE_URL = os.environ.get("RAGFLOW_MCP_BASE_URL", base_url)
  529. HOST = os.environ.get("RAGFLOW_MCP_HOST", host)
  530. PORT = os.environ.get("RAGFLOW_MCP_PORT", str(port))
  531. MODE = os.environ.get("RAGFLOW_MCP_LAUNCH_MODE", mode)
  532. HOST_API_KEY = os.environ.get("RAGFLOW_MCP_HOST_API_KEY", api_key)
  533. TRANSPORT_SSE_ENABLED = parse_bool_flag("RAGFLOW_MCP_TRANSPORT_SSE_ENABLED", transport_sse_enabled)
  534. TRANSPORT_STREAMABLE_HTTP_ENABLED = parse_bool_flag("RAGFLOW_MCP_TRANSPORT_STREAMABLE_ENABLED", transport_streamable_http_enabled)
  535. JSON_RESPONSE = parse_bool_flag("RAGFLOW_MCP_JSON_RESPONSE", json_response)
  536. if MODE == LaunchMode.SELF_HOST and not HOST_API_KEY:
  537. raise click.UsageError("--api-key is required when --mode is 'self-host'")
  538. if TRANSPORT_STREAMABLE_HTTP_ENABLED and MODE == LaunchMode.HOST:
  539. raise click.UsageError("The --host mode is not supported with streamable-http transport yet.")
  540. if not TRANSPORT_STREAMABLE_HTTP_ENABLED and JSON_RESPONSE:
  541. JSON_RESPONSE = False
  542. print(
  543. r"""
  544. __ __ ____ ____ ____ _____ ______ _______ ____
  545. | \/ |/ ___| _ \ / ___|| ____| _ \ \ / / ____| _ \
  546. | |\/| | | | |_) | \___ \| _| | |_) \ \ / /| _| | |_) |
  547. | | | | |___| __/ ___) | |___| _ < \ V / | |___| _ <
  548. |_| |_|\____|_| |____/|_____|_| \_\ \_/ |_____|_| \_\
  549. """,
  550. flush=True,
  551. )
  552. print(f"MCP launch mode: {MODE}", flush=True)
  553. print(f"MCP host: {HOST}", flush=True)
  554. print(f"MCP port: {PORT}", flush=True)
  555. print(f"MCP base_url: {BASE_URL}", flush=True)
  556. if not any([TRANSPORT_SSE_ENABLED, TRANSPORT_STREAMABLE_HTTP_ENABLED]):
  557. print("At least one transport should be enabled, enable streamable-http automatically", flush=True)
  558. TRANSPORT_STREAMABLE_HTTP_ENABLED = True
  559. if TRANSPORT_SSE_ENABLED:
  560. print("SSE transport enabled: yes", flush=True)
  561. print("SSE endpoint available at /sse", flush=True)
  562. else:
  563. print("SSE transport enabled: no", flush=True)
  564. if TRANSPORT_STREAMABLE_HTTP_ENABLED:
  565. print("Streamable HTTP transport enabled: yes", flush=True)
  566. print("Streamable HTTP endpoint available at /mcp", flush=True)
  567. if JSON_RESPONSE:
  568. print("Streamable HTTP mode: JSON response enabled", flush=True)
  569. else:
  570. print("Streamable HTTP mode: SSE over HTTP enabled", flush=True)
  571. else:
  572. print("Streamable HTTP transport enabled: no", flush=True)
  573. if JSON_RESPONSE:
  574. print("Warning: --json-response ignored because streamable transport is disabled.", flush=True)
  575. uvicorn.run(
  576. create_starlette_app(),
  577. host=HOST,
  578. port=int(PORT),
  579. )
  580. if __name__ == "__main__":
  581. """
  582. Launch examples:
  583. 1. Self-host mode with both SSE and Streamable HTTP (in JSON response mode) enabled (default):
  584. uv run mcp/server/server.py --host=127.0.0.1 --port=9382 \
  585. --base-url=http://127.0.0.1:9380 \
  586. --mode=self-host --api-key=ragflow-xxxxx
  587. 2. Host mode (multi-tenant, self-host only, clients must provide Authorization headers):
  588. uv run mcp/server/server.py --host=127.0.0.1 --port=9382 \
  589. --base-url=http://127.0.0.1:9380 \
  590. --mode=host
  591. 3. Disable legacy SSE (only streamable HTTP will be active):
  592. uv run mcp/server/server.py --no-transport-sse-enabled \
  593. --mode=self-host --api-key=ragflow-xxxxx
  594. 4. Disable streamable HTTP (only legacy SSE will be active):
  595. uv run mcp/server/server.py --no-transport-streamable-http-enabled \
  596. --mode=self-host --api-key=ragflow-xxxxx
  597. 5. Use streamable HTTP with SSE-style events (disable JSON response):
  598. uv run mcp/server/server.py --transport-streamable-http-enabled --no-json-response \
  599. --mode=self-host --api-key=ragflow-xxxxx
  600. 6. Disable both transports (for testing):
  601. uv run mcp/server/server.py --no-transport-sse-enabled --no-transport-streamable-http-enabled \
  602. --mode=self-host --api-key=ragflow-xxxxx
  603. """
  604. main()