|
|
|
@@ -17,6 +17,7 @@ |
|
|
|
import json |
|
|
|
from collections.abc import AsyncIterator |
|
|
|
from contextlib import asynccontextmanager |
|
|
|
from functools import wraps |
|
|
|
|
|
|
|
import requests |
|
|
|
from starlette.applications import Starlette |
|
|
|
@@ -127,22 +128,45 @@ app = Server("ragflow-server", lifespan=server_lifespan) |
|
|
|
sse = SseServerTransport("/messages/") |
|
|
|
|
|
|
|
|
|
|
|
@app.list_tools() |
|
|
|
async def list_tools() -> list[types.Tool]: |
|
|
|
ctx = app.request_context |
|
|
|
ragflow_ctx = ctx.lifespan_context["ragflow_ctx"] |
|
|
|
if not ragflow_ctx: |
|
|
|
raise ValueError("Get RAGFlow Context failed") |
|
|
|
connector = ragflow_ctx.conn |
|
|
|
def with_api_key(required=True): |
|
|
|
def decorator(func): |
|
|
|
@wraps(func) |
|
|
|
async def wrapper(*args, **kwargs): |
|
|
|
ctx = app.request_context |
|
|
|
ragflow_ctx = ctx.lifespan_context.get("ragflow_ctx") |
|
|
|
if not ragflow_ctx: |
|
|
|
raise ValueError("Get RAGFlow Context failed") |
|
|
|
|
|
|
|
connector = ragflow_ctx.conn |
|
|
|
|
|
|
|
if MODE == LaunchMode.HOST: |
|
|
|
headers = ctx.session._init_options.capabilities.experimental.get("headers", {}) |
|
|
|
token = None |
|
|
|
|
|
|
|
# lower case here, because of Starlette conversion |
|
|
|
auth = headers.get("authorization", "") |
|
|
|
if auth.startswith("Bearer "): |
|
|
|
token = auth.removeprefix("Bearer ").strip() |
|
|
|
elif "api_key" in headers: |
|
|
|
token = headers["api_key"] |
|
|
|
|
|
|
|
if required and not token: |
|
|
|
raise ValueError("RAGFlow API key or Bearer token is required.") |
|
|
|
|
|
|
|
connector.bind_api_key(token) |
|
|
|
else: |
|
|
|
connector.bind_api_key(HOST_API_KEY) |
|
|
|
|
|
|
|
return await func(*args, connector=connector, **kwargs) |
|
|
|
|
|
|
|
return wrapper |
|
|
|
|
|
|
|
return decorator |
|
|
|
|
|
|
|
if MODE == LaunchMode.HOST: |
|
|
|
api_key = ctx.session._init_options.capabilities.experimental["headers"]["api_key"] |
|
|
|
if not api_key: |
|
|
|
raise ValueError("RAGFlow API_KEY is required.") |
|
|
|
else: |
|
|
|
api_key = HOST_API_KEY |
|
|
|
connector.bind_api_key(api_key) |
|
|
|
|
|
|
|
@app.list_tools() |
|
|
|
@with_api_key(required=True) |
|
|
|
async def list_tools(*, connector) -> list[types.Tool]: |
|
|
|
dataset_description = connector.list_datasets() |
|
|
|
|
|
|
|
return [ |
|
|
|
@@ -152,7 +176,17 @@ async def list_tools() -> list[types.Tool]: |
|
|
|
+ dataset_description, |
|
|
|
inputSchema={ |
|
|
|
"type": "object", |
|
|
|
"properties": {"dataset_ids": {"type": "array", "items": {"type": "string"}}, "document_ids": {"type": "array", "items": {"type": "string"}}, "question": {"type": "string"}}, |
|
|
|
"properties": { |
|
|
|
"dataset_ids": { |
|
|
|
"type": "array", |
|
|
|
"items": {"type": "string"}, |
|
|
|
}, |
|
|
|
"document_ids": { |
|
|
|
"type": "array", |
|
|
|
"items": {"type": "string"}, |
|
|
|
}, |
|
|
|
"question": {"type": "string"}, |
|
|
|
}, |
|
|
|
"required": ["dataset_ids", "question"], |
|
|
|
}, |
|
|
|
), |
|
|
|
@@ -160,24 +194,15 @@ async def list_tools() -> list[types.Tool]: |
|
|
|
|
|
|
|
|
|
|
|
@app.call_tool() |
|
|
|
async def call_tool(name: str, arguments: dict) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]: |
|
|
|
ctx = app.request_context |
|
|
|
ragflow_ctx = ctx.lifespan_context["ragflow_ctx"] |
|
|
|
if not ragflow_ctx: |
|
|
|
raise ValueError("Get RAGFlow Context failed") |
|
|
|
connector = ragflow_ctx.conn |
|
|
|
|
|
|
|
if MODE == LaunchMode.HOST: |
|
|
|
api_key = ctx.session._init_options.capabilities.experimental["headers"]["api_key"] |
|
|
|
if not api_key: |
|
|
|
raise ValueError("RAGFlow API_KEY is required.") |
|
|
|
else: |
|
|
|
api_key = HOST_API_KEY |
|
|
|
connector.bind_api_key(api_key) |
|
|
|
|
|
|
|
@with_api_key(required=True) |
|
|
|
async def call_tool(name: str, arguments: dict, *, connector) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]: |
|
|
|
if name == "ragflow_retrieval": |
|
|
|
document_ids = arguments.get("document_ids", []) |
|
|
|
return connector.retrieval(dataset_ids=arguments["dataset_ids"], document_ids=document_ids, question=arguments["question"]) |
|
|
|
return connector.retrieval( |
|
|
|
dataset_ids=arguments["dataset_ids"], |
|
|
|
document_ids=document_ids, |
|
|
|
question=arguments["question"], |
|
|
|
) |
|
|
|
raise ValueError(f"Tool not found: {name}") |
|
|
|
|
|
|
|
|
|
|
|
@@ -188,25 +213,34 @@ async def handle_sse(request): |
|
|
|
|
|
|
|
class AuthMiddleware(BaseHTTPMiddleware): |
|
|
|
async def dispatch(self, request, call_next): |
|
|
|
# Authentication is deferred, will be handled by RAGFlow core service. |
|
|
|
if request.url.path.startswith("/sse") or request.url.path.startswith("/messages"): |
|
|
|
api_key = request.headers.get("api_key") |
|
|
|
if not api_key: |
|
|
|
return JSONResponse({"error": "Missing unauthorization header"}, status_code=401) |
|
|
|
return await call_next(request) |
|
|
|
token = None |
|
|
|
|
|
|
|
auth_header = request.headers.get("Authorization") |
|
|
|
if auth_header and auth_header.startswith("Bearer "): |
|
|
|
token = auth_header.removeprefix("Bearer ").strip() |
|
|
|
elif request.headers.get("api_key"): |
|
|
|
token = request.headers["api_key"] |
|
|
|
|
|
|
|
if not token: |
|
|
|
return JSONResponse({"error": "Missing or invalid authorization header"}, status_code=401) |
|
|
|
return await call_next(request) |
|
|
|
|
|
|
|
middleware = None |
|
|
|
if MODE == LaunchMode.HOST: |
|
|
|
middleware = [Middleware(AuthMiddleware)] |
|
|
|
|
|
|
|
starlette_app = Starlette( |
|
|
|
debug=True, |
|
|
|
routes=[ |
|
|
|
Route("/sse", endpoint=handle_sse), |
|
|
|
Mount("/messages/", app=sse.handle_post_message), |
|
|
|
], |
|
|
|
middleware=middleware, |
|
|
|
) |
|
|
|
def create_starlette_app(): |
|
|
|
middleware = None |
|
|
|
if MODE == LaunchMode.HOST: |
|
|
|
middleware = [Middleware(AuthMiddleware)] |
|
|
|
|
|
|
|
return Starlette( |
|
|
|
debug=True, |
|
|
|
routes=[ |
|
|
|
Route("/sse", endpoint=handle_sse), |
|
|
|
Mount("/messages/", app=sse.handle_post_message), |
|
|
|
], |
|
|
|
middleware=middleware, |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
@@ -236,7 +270,7 @@ if __name__ == "__main__": |
|
|
|
default="self-host", |
|
|
|
help="Launch mode options:\n" |
|
|
|
" * self-host: Launches an MCP server to access a specific tenant space. The 'api_key' argument is required.\n" |
|
|
|
" * host: Launches an MCP server that allows users to access their own spaces. Each request must include a header " |
|
|
|
" * host: Launches an MCP server that allows users to access their own spaces. Each request must include a Authorization header " |
|
|
|
"indicating the user's identification.", |
|
|
|
) |
|
|
|
parser.add_argument("--api_key", type=str, default="", help="RAGFlow MCP SERVER HOST API KEY") |
|
|
|
@@ -268,7 +302,7 @@ __ __ ____ ____ ____ _____ ______ _______ ____ |
|
|
|
print(f"MCP base_url: {BASE_URL}", flush=True) |
|
|
|
|
|
|
|
uvicorn.run( |
|
|
|
starlette_app, |
|
|
|
create_starlette_app(), |
|
|
|
host=HOST, |
|
|
|
port=int(PORT), |
|
|
|
) |