You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

server.py 12KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317
  1. #
  2. # Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import json
  17. from collections.abc import AsyncIterator
  18. from contextlib import asynccontextmanager
  19. from functools import wraps
  20. import click
  21. import requests
  22. from starlette.applications import Starlette
  23. from starlette.middleware import Middleware
  24. from starlette.responses import JSONResponse, Response
  25. from starlette.routing import Mount, Route
  26. from strenum import StrEnum
  27. import mcp.types as types
  28. from mcp.server.lowlevel import Server
  29. from mcp.server.sse import SseServerTransport
  30. class LaunchMode(StrEnum):
  31. SELF_HOST = "self-host"
  32. HOST = "host"
  33. BASE_URL = "http://127.0.0.1:9380"
  34. HOST = "127.0.0.1"
  35. PORT = "9382"
  36. HOST_API_KEY = ""
  37. MODE = ""
  38. class RAGFlowConnector:
  39. def __init__(self, base_url: str, version="v1"):
  40. self.base_url = base_url
  41. self.version = version
  42. self.api_url = f"{self.base_url}/api/{self.version}"
  43. def bind_api_key(self, api_key: str):
  44. self.api_key = api_key
  45. self.authorization_header = {"Authorization": "{} {}".format("Bearer", self.api_key)}
  46. def _post(self, path, json=None, stream=False, files=None):
  47. if not self.api_key:
  48. return None
  49. res = requests.post(url=self.api_url + path, json=json, headers=self.authorization_header, stream=stream, files=files)
  50. return res
  51. def _get(self, path, params=None, json=None):
  52. res = requests.get(url=self.api_url + path, params=params, headers=self.authorization_header, json=json)
  53. return res
  54. def list_datasets(self, page: int = 1, page_size: int = 1000, orderby: str = "create_time", desc: bool = True, id: str | None = None, name: str | None = None):
  55. res = self._get("/datasets", {"page": page, "page_size": page_size, "orderby": orderby, "desc": desc, "id": id, "name": name})
  56. if not res:
  57. raise Exception([types.TextContent(type="text", text=res.get("Cannot process this operation."))])
  58. res = res.json()
  59. if res.get("code") == 0:
  60. result_list = []
  61. for data in res["data"]:
  62. d = {"description": data["description"], "id": data["id"]}
  63. result_list.append(json.dumps(d, ensure_ascii=False))
  64. return "\n".join(result_list)
  65. return ""
  66. def retrieval(
  67. self, dataset_ids, document_ids=None, question="", page=1, page_size=30, similarity_threshold=0.2, vector_similarity_weight=0.3, top_k=1024, rerank_id: str | None = None, keyword: bool = False
  68. ):
  69. if document_ids is None:
  70. document_ids = []
  71. data_json = {
  72. "page": page,
  73. "page_size": page_size,
  74. "similarity_threshold": similarity_threshold,
  75. "vector_similarity_weight": vector_similarity_weight,
  76. "top_k": top_k,
  77. "rerank_id": rerank_id,
  78. "keyword": keyword,
  79. "question": question,
  80. "dataset_ids": dataset_ids,
  81. "document_ids": document_ids,
  82. }
  83. # Send a POST request to the backend service (using requests library as an example, actual implementation may vary)
  84. res = self._post("/retrieval", json=data_json)
  85. if not res:
  86. raise Exception([types.TextContent(type="text", text=res.get("Cannot process this operation."))])
  87. res = res.json()
  88. if res.get("code") == 0:
  89. chunks = []
  90. for chunk_data in res["data"].get("chunks"):
  91. chunks.append(json.dumps(chunk_data, ensure_ascii=False))
  92. return [types.TextContent(type="text", text="\n".join(chunks))]
  93. raise Exception([types.TextContent(type="text", text=res.get("message"))])
  94. class RAGFlowCtx:
  95. def __init__(self, connector: RAGFlowConnector):
  96. self.conn = connector
  97. @asynccontextmanager
  98. async def server_lifespan(server: Server) -> AsyncIterator[dict]:
  99. ctx = RAGFlowCtx(RAGFlowConnector(base_url=BASE_URL))
  100. try:
  101. yield {"ragflow_ctx": ctx}
  102. finally:
  103. pass
  104. app = Server("ragflow-server", lifespan=server_lifespan)
  105. sse = SseServerTransport("/messages/")
  106. def with_api_key(required=True):
  107. def decorator(func):
  108. @wraps(func)
  109. async def wrapper(*args, **kwargs):
  110. ctx = app.request_context
  111. ragflow_ctx = ctx.lifespan_context.get("ragflow_ctx")
  112. if not ragflow_ctx:
  113. raise ValueError("Get RAGFlow Context failed")
  114. connector = ragflow_ctx.conn
  115. if MODE == LaunchMode.HOST:
  116. headers = ctx.session._init_options.capabilities.experimental.get("headers", {})
  117. token = None
  118. # lower case here, because of Starlette conversion
  119. auth = headers.get("authorization", "")
  120. if auth.startswith("Bearer "):
  121. token = auth.removeprefix("Bearer ").strip()
  122. elif "api_key" in headers:
  123. token = headers["api_key"]
  124. if required and not token:
  125. raise ValueError("RAGFlow API key or Bearer token is required.")
  126. connector.bind_api_key(token)
  127. else:
  128. connector.bind_api_key(HOST_API_KEY)
  129. return await func(*args, connector=connector, **kwargs)
  130. return wrapper
  131. return decorator
  132. @app.list_tools()
  133. @with_api_key(required=True)
  134. async def list_tools(*, connector) -> list[types.Tool]:
  135. dataset_description = connector.list_datasets()
  136. return [
  137. types.Tool(
  138. name="ragflow_retrieval",
  139. description="Retrieve relevant chunks from the RAGFlow retrieve interface based on the question, using the specified dataset_ids and optionally document_ids. Below is the list of all available datasets, including their descriptions and IDs. If you're unsure which datasets are relevant to the question, simply pass all dataset IDs to the function."
  140. + dataset_description,
  141. inputSchema={
  142. "type": "object",
  143. "properties": {
  144. "dataset_ids": {
  145. "type": "array",
  146. "items": {"type": "string"},
  147. },
  148. "document_ids": {
  149. "type": "array",
  150. "items": {"type": "string"},
  151. },
  152. "question": {"type": "string"},
  153. },
  154. "required": ["dataset_ids", "question"],
  155. },
  156. ),
  157. ]
  158. @app.call_tool()
  159. @with_api_key(required=True)
  160. async def call_tool(name: str, arguments: dict, *, connector) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
  161. if name == "ragflow_retrieval":
  162. document_ids = arguments.get("document_ids", [])
  163. return connector.retrieval(
  164. dataset_ids=arguments["dataset_ids"],
  165. document_ids=document_ids,
  166. question=arguments["question"],
  167. )
  168. raise ValueError(f"Tool not found: {name}")
  169. async def handle_sse(request):
  170. async with sse.connect_sse(request.scope, request.receive, request._send) as streams:
  171. await app.run(streams[0], streams[1], app.create_initialization_options(experimental_capabilities={"headers": dict(request.headers)}))
  172. return Response()
  173. def create_starlette_app():
  174. middleware = None
  175. if MODE == LaunchMode.HOST:
  176. from starlette.types import ASGIApp, Receive, Scope, Send
  177. class AuthMiddleware:
  178. def __init__(self, app: ASGIApp):
  179. self.app = app
  180. async def __call__(self, scope: Scope, receive: Receive, send: Send):
  181. if scope["type"] != "http":
  182. await self.app(scope, receive, send)
  183. return
  184. path = scope["path"]
  185. if path.startswith("/messages/") or path.startswith("/sse"):
  186. headers = dict(scope["headers"])
  187. token = None
  188. auth_header = headers.get(b"authorization")
  189. if auth_header and auth_header.startswith(b"Bearer "):
  190. token = auth_header.removeprefix(b"Bearer ").strip()
  191. elif b"api_key" in headers:
  192. token = headers[b"api_key"]
  193. if not token:
  194. response = JSONResponse({"error": "Missing or invalid authorization header"}, status_code=401)
  195. await response(scope, receive, send)
  196. return
  197. await self.app(scope, receive, send)
  198. middleware = [Middleware(AuthMiddleware)]
  199. return Starlette(
  200. debug=True,
  201. routes=[
  202. Route("/sse", endpoint=handle_sse, methods=["GET"]),
  203. Mount("/messages/", app=sse.handle_post_message),
  204. ],
  205. middleware=middleware,
  206. )
  207. @click.command()
  208. @click.option("--base-url", type=str, default="http://127.0.0.1:9380", help="API base URL for RAGFlow backend")
  209. @click.option("--host", type=str, default="127.0.0.1", help="Host to bind the RAGFlow MCP server")
  210. @click.option("--port", type=int, default=9382, help="Port to bind the RAGFlow MCP server")
  211. @click.option(
  212. "--mode",
  213. type=click.Choice(["self-host", "host"]),
  214. default="self-host",
  215. help=("Launch mode:\n self-host: run MCP for a single tenant (requires --api-key)\n host: multi-tenant mode, users must provide Authorization headers"),
  216. )
  217. @click.option("--api-key", type=str, default="", help="API key to use when in self-host mode")
  218. def main(base_url, host, port, mode, api_key):
  219. import os
  220. import uvicorn
  221. from dotenv import load_dotenv
  222. load_dotenv()
  223. global BASE_URL, HOST, PORT, MODE, HOST_API_KEY
  224. BASE_URL = os.environ.get("RAGFLOW_MCP_BASE_URL", base_url)
  225. HOST = os.environ.get("RAGFLOW_MCP_HOST", host)
  226. PORT = os.environ.get("RAGFLOW_MCP_PORT", str(port))
  227. MODE = os.environ.get("RAGFLOW_MCP_LAUNCH_MODE", mode)
  228. HOST_API_KEY = os.environ.get("RAGFLOW_MCP_HOST_API_KEY", api_key)
  229. if MODE == "self-host" and not HOST_API_KEY:
  230. raise click.UsageError("--api-key is required when --mode is 'self-host'")
  231. print(
  232. r"""
  233. __ __ ____ ____ ____ _____ ______ _______ ____
  234. | \/ |/ ___| _ \ / ___|| ____| _ \ \ / / ____| _ \
  235. | |\/| | | | |_) | \___ \| _| | |_) \ \ / /| _| | |_) |
  236. | | | | |___| __/ ___) | |___| _ < \ V / | |___| _ <
  237. |_| |_|\____|_| |____/|_____|_| \_\ \_/ |_____|_| \_\
  238. """,
  239. flush=True,
  240. )
  241. print(f"MCP launch mode: {MODE}", flush=True)
  242. print(f"MCP host: {HOST}", flush=True)
  243. print(f"MCP port: {PORT}", flush=True)
  244. print(f"MCP base_url: {BASE_URL}", flush=True)
  245. uvicorn.run(
  246. create_starlette_app(),
  247. host=HOST,
  248. port=int(PORT),
  249. )
  250. if __name__ == "__main__":
  251. """
  252. Launch example:
  253. self-host:
  254. uv run mcp/server/server.py --host=127.0.0.1 --port=9382 --base-url=http://127.0.0.1:9380 --mode=self-host --api-key=ragflow-xxxxx
  255. host:
  256. uv run mcp/server/server.py --host=127.0.0.1 --port=9382 --base-url=http://127.0.0.1:9380 --mode=host
  257. """
  258. main()