Ви не можете вибрати більше 25 тем Теми мають розпочинатися з літери або цифри, можуть містити дефіси (-) і не повинні перевищувати 35 символів.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
  1. #
  2. # Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import json
  17. from collections.abc import AsyncIterator
  18. from contextlib import asynccontextmanager
  19. import requests
  20. from starlette.applications import Starlette
  21. from starlette.middleware import Middleware
  22. from starlette.middleware.base import BaseHTTPMiddleware
  23. from starlette.responses import JSONResponse
  24. from starlette.routing import Mount, Route
  25. import mcp.types as types
  26. from mcp.server.lowlevel import Server
  27. from mcp.server.sse import SseServerTransport
  28. BASE_URL = "http://127.0.0.1:9380"
  29. HOST = "127.0.0.1"
  30. PORT = "9382"
  31. class RAGFlowConnector:
  32. def __init__(self, base_url: str, version="v1"):
  33. self.base_url = base_url
  34. self.version = version
  35. self.api_url = f"{self.base_url}/api/{self.version}"
  36. def bind_api_key(self, api_key: str):
  37. self.api_key = api_key
  38. self.authorization_header = {"Authorization": "{} {}".format("Bearer", self.api_key)}
  39. def _post(self, path, json=None, stream=False, files=None):
  40. if not self.api_key:
  41. return None
  42. res = requests.post(url=self.api_url + path, json=json, headers=self.authorization_header, stream=stream, files=files)
  43. return res
  44. def _get(self, path, params=None, json=None):
  45. res = requests.get(url=self.api_url + path, params=params, headers=self.authorization_header, json=json)
  46. return res
  47. def list_datasets(self, page: int = 1, page_size: int = 1000, orderby: str = "create_time", desc: bool = True, id: str | None = None, name: str | None = None):
  48. res = self._get("/datasets", {"page": page, "page_size": page_size, "orderby": orderby, "desc": desc, "id": id, "name": name})
  49. if not res:
  50. raise Exception([types.TextContent(type="text", text=res.get("Cannot process this operation."))])
  51. res = res.json()
  52. if res.get("code") == 0:
  53. result_list = []
  54. for data in res["data"]:
  55. d = {"description": data["description"], "id": data["id"]}
  56. result_list.append(json.dumps(d, ensure_ascii=False))
  57. return "\n".join(result_list)
  58. return ""
  59. def retrival(
  60. self, dataset_ids, document_ids=None, question="", page=1, page_size=30, similarity_threshold=0.2, vector_similarity_weight=0.3, top_k=1024, rerank_id: str | None = None, keyword: bool = False
  61. ):
  62. if document_ids is None:
  63. document_ids = []
  64. data_json = {
  65. "page": page,
  66. "page_size": page_size,
  67. "similarity_threshold": similarity_threshold,
  68. "vector_similarity_weight": vector_similarity_weight,
  69. "top_k": top_k,
  70. "rerank_id": rerank_id,
  71. "keyword": keyword,
  72. "question": question,
  73. "dataset_ids": dataset_ids,
  74. "document_ids": document_ids,
  75. }
  76. # Send a POST request to the backend service (using requests library as an example, actual implementation may vary)
  77. res = self._post("/retrieval", json=data_json)
  78. if not res:
  79. raise Exception([types.TextContent(type="text", text=res.get("Cannot process this operation."))])
  80. res = res.json()
  81. if res.get("code") == 0:
  82. chunks = []
  83. for chunk_data in res["data"].get("chunks"):
  84. chunks.append(json.dumps(chunk_data, ensure_ascii=False))
  85. return [types.TextContent(type="text", text="\n".join(chunks))]
  86. raise Exception([types.TextContent(type="text", text=res.get("message"))])
  87. class RAGFlowCtx:
  88. def __init__(self, connector: RAGFlowConnector):
  89. self.conn = connector
  90. @asynccontextmanager
  91. async def server_lifespan(server: Server) -> AsyncIterator[dict]:
  92. ctx = RAGFlowCtx(RAGFlowConnector(base_url=BASE_URL))
  93. try:
  94. yield {"ragflow_ctx": ctx}
  95. finally:
  96. pass
  97. app = Server("ragflow-server", lifespan=server_lifespan)
  98. sse = SseServerTransport("/messages/")
  99. @app.list_tools()
  100. async def list_tools() -> list[types.Tool]:
  101. ctx = app.request_context
  102. ragflow_ctx = ctx.lifespan_context["ragflow_ctx"]
  103. if not ragflow_ctx:
  104. raise ValueError("Get RAGFlow Context failed")
  105. connector = ragflow_ctx.conn
  106. api_key = ctx.session._init_options.capabilities.experimental["headers"]["api_key"]
  107. if not api_key:
  108. raise ValueError("RAGFlow API_KEY is required.")
  109. connector.bind_api_key(api_key)
  110. dataset_description = connector.list_datasets()
  111. return [
  112. types.Tool(
  113. name="retrival",
  114. description="Retrieve relevant chunks from the RAGFlow retrieve interface based on the question, using the specified dataset_ids and optionally document_ids. Below is the list of all available datasets, including their descriptions and IDs. If you're unsure which datasets are relevant to the question, simply pass all dataset IDs to the function."
  115. + dataset_description,
  116. inputSchema={
  117. "type": "object",
  118. "properties": {"dataset_ids": {"type": "array", "items": {"type": "string"}}, "documents_ids": {"type": "array", "items": {"type": "string"}}, "question": {"type": "string"}},
  119. "required": ["dataset_ids", "question"],
  120. },
  121. ),
  122. ]
  123. @app.call_tool()
  124. async def call_tool(name: str, arguments: dict) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
  125. ctx = app.request_context
  126. ragflow_ctx = ctx.lifespan_context["ragflow_ctx"]
  127. if not ragflow_ctx:
  128. raise ValueError("Get RAGFlow Context failed")
  129. connector = ragflow_ctx.conn
  130. api_key = ctx.session._init_options.capabilities.experimental["headers"]["api_key"]
  131. if not api_key:
  132. raise ValueError("RAGFlow API_KEY is required.")
  133. connector.bind_api_key(api_key)
  134. if name == "ragflow_retrival":
  135. return connector.retrival(dataset_ids=arguments["dataset_ids"], document_ids=arguments["document_ids"], question=arguments["question"])
  136. raise ValueError(f"Tool not found: {name}")
  137. async def handle_sse(request):
  138. async with sse.connect_sse(request.scope, request.receive, request._send) as streams:
  139. await app.run(streams[0], streams[1], app.create_initialization_options(experimental_capabilities={"headers": dict(request.headers)}))
  140. class AuthMiddleware(BaseHTTPMiddleware):
  141. async def dispatch(self, request, call_next):
  142. if request.url.path.startswith("/sse") or request.url.path.startswith("/messages"):
  143. api_key = request.headers.get("api_key")
  144. if not api_key:
  145. return JSONResponse({"error": "Missing unauthorization header"}, status_code=401)
  146. return await call_next(request)
  147. starlette_app = Starlette(
  148. debug=True,
  149. routes=[
  150. Route("/sse", endpoint=handle_sse),
  151. Mount("/messages/", app=sse.handle_post_message),
  152. ],
  153. middleware=[Middleware(AuthMiddleware)],
  154. )
  155. if __name__ == "__main__":
  156. """
  157. Launch example:
  158. uv run mcp/server/server.py --host=127.0.0.1 --port=9382 --base_url=http://127.0.0.1:9380
  159. """
  160. import argparse
  161. import os
  162. import uvicorn
  163. from dotenv import load_dotenv
  164. load_dotenv()
  165. parser = argparse.ArgumentParser(description="RAGFlow MCP Server, `base_url` and `api_key` are needed.")
  166. parser.add_argument("--base_url", type=str, default="http://127.0.0.1:9380", help="api_url: http://<host_address>")
  167. parser.add_argument("--host", type=str, default="127.0.0.1", help="RAGFlow MCP SERVER host")
  168. parser.add_argument("--port", type=str, default="9382", help="RAGFlow MCP SERVER port")
  169. args = parser.parse_args()
  170. BASE_URL = os.environ.get("RAGFLOW_MCP_BASE_URL", args.base_url)
  171. HOST = os.environ.get("RAGFLOW_MCP_HOST", args.host)
  172. PORT = os.environ.get("RAGFLOW_MCP_PORT", args.port)
  173. print(
  174. r"""
  175. __ __ ____ ____ ____ _____ ______ _______ ____
  176. | \/ |/ ___| _ \ / ___|| ____| _ \ \ / / ____| _ \
  177. | |\/| | | | |_) | \___ \| _| | |_) \ \ / /| _| | |_) |
  178. | | | | |___| __/ ___) | |___| _ < \ V / | |___| _ <
  179. |_| |_|\____|_| |____/|_____|_| \_\ \_/ |_____|_| \_\
  180. """,
  181. flush=True,
  182. )
  183. print(f"MCP host: {HOST}", flush=True)
  184. print(f"MCP port: {PORT}", flush=True)
  185. print(f"MCP base_url: {BASE_URL}", flush=True)
  186. uvicorn.run(
  187. starlette_app,
  188. host=HOST,
  189. port=int(PORT),
  190. )