Sfoglia il codice sorgente

Feat: add authorization header for MCP server based on OAuth 2.1 (#8292)

### What problem does this PR solve?

Add authorization header for MCP server based on [OAuth
2.1](https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-12#section-5).

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
tags/v0.19.1
Yongteng Lei 4 mesi fa
parent
commit
a9532cb9e7
Nessun account collegato all'indirizzo email del committer
2 ha cambiato i file con 85 aggiunte e 48 eliminazioni
  1. 3
    0
      mcp/client/client.py
  2. 82
    48
      mcp/server/server.py

+ 3
- 0
mcp/client/client.py Vedi File

@@ -23,6 +23,9 @@ async def main():
try:
# To access RAGFlow server in `host` mode, you need to attach `api_key` for each request to indicate identification.
# async with sse_client("http://localhost:9382/sse", headers={"api_key": "ragflow-IyMGI1ZDhjMTA2ZTExZjBiYTMyMGQ4Zm"}) as streams:
# Or follow the requirements of OAuth 2.1 Section 5 with Authorization header
# async with sse_client("http://localhost:9382/sse", headers={"Authorization": "Bearer ragflow-IyMGI1ZDhjMTA2ZTExZjBiYTMyMGQ4Zm"}) as streams:

async with sse_client("http://localhost:9382/sse") as streams:
async with ClientSession(
streams[0],

+ 82
- 48
mcp/server/server.py Vedi File

@@ -17,6 +17,7 @@
import json
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from functools import wraps

import requests
from starlette.applications import Starlette
@@ -127,22 +128,45 @@ app = Server("ragflow-server", lifespan=server_lifespan)
sse = SseServerTransport("/messages/")


@app.list_tools()
async def list_tools() -> list[types.Tool]:
ctx = app.request_context
ragflow_ctx = ctx.lifespan_context["ragflow_ctx"]
if not ragflow_ctx:
raise ValueError("Get RAGFlow Context failed")
connector = ragflow_ctx.conn
def with_api_key(required=True):
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
ctx = app.request_context
ragflow_ctx = ctx.lifespan_context.get("ragflow_ctx")
if not ragflow_ctx:
raise ValueError("Get RAGFlow Context failed")

connector = ragflow_ctx.conn

if MODE == LaunchMode.HOST:
headers = ctx.session._init_options.capabilities.experimental.get("headers", {})
token = None

# lower case here, because of Starlette conversion
auth = headers.get("authorization", "")
if auth.startswith("Bearer "):
token = auth.removeprefix("Bearer ").strip()
elif "api_key" in headers:
token = headers["api_key"]

if required and not token:
raise ValueError("RAGFlow API key or Bearer token is required.")

connector.bind_api_key(token)
else:
connector.bind_api_key(HOST_API_KEY)

return await func(*args, connector=connector, **kwargs)

return wrapper

return decorator

if MODE == LaunchMode.HOST:
api_key = ctx.session._init_options.capabilities.experimental["headers"]["api_key"]
if not api_key:
raise ValueError("RAGFlow API_KEY is required.")
else:
api_key = HOST_API_KEY
connector.bind_api_key(api_key)

@app.list_tools()
@with_api_key(required=True)
async def list_tools(*, connector) -> list[types.Tool]:
dataset_description = connector.list_datasets()

return [
@@ -152,7 +176,17 @@ async def list_tools() -> list[types.Tool]:
+ dataset_description,
inputSchema={
"type": "object",
"properties": {"dataset_ids": {"type": "array", "items": {"type": "string"}}, "document_ids": {"type": "array", "items": {"type": "string"}}, "question": {"type": "string"}},
"properties": {
"dataset_ids": {
"type": "array",
"items": {"type": "string"},
},
"document_ids": {
"type": "array",
"items": {"type": "string"},
},
"question": {"type": "string"},
},
"required": ["dataset_ids", "question"],
},
),
@@ -160,24 +194,15 @@ async def list_tools() -> list[types.Tool]:


@app.call_tool()
async def call_tool(name: str, arguments: dict) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
ctx = app.request_context
ragflow_ctx = ctx.lifespan_context["ragflow_ctx"]
if not ragflow_ctx:
raise ValueError("Get RAGFlow Context failed")
connector = ragflow_ctx.conn

if MODE == LaunchMode.HOST:
api_key = ctx.session._init_options.capabilities.experimental["headers"]["api_key"]
if not api_key:
raise ValueError("RAGFlow API_KEY is required.")
else:
api_key = HOST_API_KEY
connector.bind_api_key(api_key)

@with_api_key(required=True)
async def call_tool(name: str, arguments: dict, *, connector) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
if name == "ragflow_retrieval":
document_ids = arguments.get("document_ids", [])
return connector.retrieval(dataset_ids=arguments["dataset_ids"], document_ids=document_ids, question=arguments["question"])
return connector.retrieval(
dataset_ids=arguments["dataset_ids"],
document_ids=document_ids,
question=arguments["question"],
)
raise ValueError(f"Tool not found: {name}")


@@ -188,25 +213,34 @@ async def handle_sse(request):

class AuthMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request, call_next):
# Authentication is deferred, will be handled by RAGFlow core service.
if request.url.path.startswith("/sse") or request.url.path.startswith("/messages"):
api_key = request.headers.get("api_key")
if not api_key:
return JSONResponse({"error": "Missing unauthorization header"}, status_code=401)
return await call_next(request)
token = None

auth_header = request.headers.get("Authorization")
if auth_header and auth_header.startswith("Bearer "):
token = auth_header.removeprefix("Bearer ").strip()
elif request.headers.get("api_key"):
token = request.headers["api_key"]

if not token:
return JSONResponse({"error": "Missing or invalid authorization header"}, status_code=401)
return await call_next(request)

middleware = None
if MODE == LaunchMode.HOST:
middleware = [Middleware(AuthMiddleware)]

starlette_app = Starlette(
debug=True,
routes=[
Route("/sse", endpoint=handle_sse),
Mount("/messages/", app=sse.handle_post_message),
],
middleware=middleware,
)
def create_starlette_app():
middleware = None
if MODE == LaunchMode.HOST:
middleware = [Middleware(AuthMiddleware)]

return Starlette(
debug=True,
routes=[
Route("/sse", endpoint=handle_sse),
Mount("/messages/", app=sse.handle_post_message),
],
middleware=middleware,
)


if __name__ == "__main__":
@@ -236,7 +270,7 @@ if __name__ == "__main__":
default="self-host",
help="Launch mode options:\n"
" * self-host: Launches an MCP server to access a specific tenant space. The 'api_key' argument is required.\n"
" * host: Launches an MCP server that allows users to access their own spaces. Each request must include a header "
" * host: Launches an MCP server that allows users to access their own spaces. Each request must include a Authorization header "
"indicating the user's identification.",
)
parser.add_argument("--api_key", type=str, default="", help="RAGFlow MCP SERVER HOST API KEY")
@@ -268,7 +302,7 @@ __ __ ____ ____ ____ _____ ______ _______ ____
print(f"MCP base_url: {BASE_URL}", flush=True)

uvicorn.run(
starlette_app,
create_starlette_app(),
host=HOST,
port=int(PORT),
)

Loading…
Annulla
Salva