Browse Source

Feat: add authorization header for MCP server based on OAuth 2.1 (#8292)

### What problem does this PR solve?

Add authorization header for MCP server based on [OAuth
2.1](https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-12#section-5).

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
tags/v0.19.1
Yongteng Lei 4 months ago
parent
commit
a9532cb9e7
No account linked to committer's email address
2 changed files with 85 additions and 48 deletions
  1. 3
    0
      mcp/client/client.py
  2. 82
    48
      mcp/server/server.py

+ 3
- 0
mcp/client/client.py View File

try: try:
# To access RAGFlow server in `host` mode, you need to attach `api_key` for each request to indicate identification. # To access RAGFlow server in `host` mode, you need to attach `api_key` for each request to indicate identification.
# async with sse_client("http://localhost:9382/sse", headers={"api_key": "ragflow-IyMGI1ZDhjMTA2ZTExZjBiYTMyMGQ4Zm"}) as streams: # async with sse_client("http://localhost:9382/sse", headers={"api_key": "ragflow-IyMGI1ZDhjMTA2ZTExZjBiYTMyMGQ4Zm"}) as streams:
# Or follow the requirements of OAuth 2.1 Section 5 with Authorization header
# async with sse_client("http://localhost:9382/sse", headers={"Authorization": "Bearer ragflow-IyMGI1ZDhjMTA2ZTExZjBiYTMyMGQ4Zm"}) as streams:

async with sse_client("http://localhost:9382/sse") as streams: async with sse_client("http://localhost:9382/sse") as streams:
async with ClientSession( async with ClientSession(
streams[0], streams[0],

+ 82
- 48
mcp/server/server.py View File

import json import json
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from functools import wraps


import requests import requests
from starlette.applications import Starlette from starlette.applications import Starlette
sse = SseServerTransport("/messages/") sse = SseServerTransport("/messages/")




@app.list_tools()
async def list_tools() -> list[types.Tool]:
ctx = app.request_context
ragflow_ctx = ctx.lifespan_context["ragflow_ctx"]
if not ragflow_ctx:
raise ValueError("Get RAGFlow Context failed")
connector = ragflow_ctx.conn
def with_api_key(required=True):
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
ctx = app.request_context
ragflow_ctx = ctx.lifespan_context.get("ragflow_ctx")
if not ragflow_ctx:
raise ValueError("Get RAGFlow Context failed")

connector = ragflow_ctx.conn

if MODE == LaunchMode.HOST:
headers = ctx.session._init_options.capabilities.experimental.get("headers", {})
token = None

# lower case here, because of Starlette conversion
auth = headers.get("authorization", "")
if auth.startswith("Bearer "):
token = auth.removeprefix("Bearer ").strip()
elif "api_key" in headers:
token = headers["api_key"]

if required and not token:
raise ValueError("RAGFlow API key or Bearer token is required.")

connector.bind_api_key(token)
else:
connector.bind_api_key(HOST_API_KEY)

return await func(*args, connector=connector, **kwargs)

return wrapper

return decorator


if MODE == LaunchMode.HOST:
api_key = ctx.session._init_options.capabilities.experimental["headers"]["api_key"]
if not api_key:
raise ValueError("RAGFlow API_KEY is required.")
else:
api_key = HOST_API_KEY
connector.bind_api_key(api_key)


@app.list_tools()
@with_api_key(required=True)
async def list_tools(*, connector) -> list[types.Tool]:
dataset_description = connector.list_datasets() dataset_description = connector.list_datasets()


return [ return [
+ dataset_description, + dataset_description,
inputSchema={ inputSchema={
"type": "object", "type": "object",
"properties": {"dataset_ids": {"type": "array", "items": {"type": "string"}}, "document_ids": {"type": "array", "items": {"type": "string"}}, "question": {"type": "string"}},
"properties": {
"dataset_ids": {
"type": "array",
"items": {"type": "string"},
},
"document_ids": {
"type": "array",
"items": {"type": "string"},
},
"question": {"type": "string"},
},
"required": ["dataset_ids", "question"], "required": ["dataset_ids", "question"],
}, },
), ),




@app.call_tool() @app.call_tool()
async def call_tool(name: str, arguments: dict) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
ctx = app.request_context
ragflow_ctx = ctx.lifespan_context["ragflow_ctx"]
if not ragflow_ctx:
raise ValueError("Get RAGFlow Context failed")
connector = ragflow_ctx.conn

if MODE == LaunchMode.HOST:
api_key = ctx.session._init_options.capabilities.experimental["headers"]["api_key"]
if not api_key:
raise ValueError("RAGFlow API_KEY is required.")
else:
api_key = HOST_API_KEY
connector.bind_api_key(api_key)

@with_api_key(required=True)
async def call_tool(name: str, arguments: dict, *, connector) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
if name == "ragflow_retrieval": if name == "ragflow_retrieval":
document_ids = arguments.get("document_ids", []) document_ids = arguments.get("document_ids", [])
return connector.retrieval(dataset_ids=arguments["dataset_ids"], document_ids=document_ids, question=arguments["question"])
return connector.retrieval(
dataset_ids=arguments["dataset_ids"],
document_ids=document_ids,
question=arguments["question"],
)
raise ValueError(f"Tool not found: {name}") raise ValueError(f"Tool not found: {name}")






class AuthMiddleware(BaseHTTPMiddleware): class AuthMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request, call_next): async def dispatch(self, request, call_next):
# Authentication is deferred, will be handled by RAGFlow core service.
if request.url.path.startswith("/sse") or request.url.path.startswith("/messages"): if request.url.path.startswith("/sse") or request.url.path.startswith("/messages"):
api_key = request.headers.get("api_key")
if not api_key:
return JSONResponse({"error": "Missing unauthorization header"}, status_code=401)
return await call_next(request)
token = None


auth_header = request.headers.get("Authorization")
if auth_header and auth_header.startswith("Bearer "):
token = auth_header.removeprefix("Bearer ").strip()
elif request.headers.get("api_key"):
token = request.headers["api_key"]

if not token:
return JSONResponse({"error": "Missing or invalid authorization header"}, status_code=401)
return await call_next(request)


middleware = None
if MODE == LaunchMode.HOST:
middleware = [Middleware(AuthMiddleware)]


starlette_app = Starlette(
debug=True,
routes=[
Route("/sse", endpoint=handle_sse),
Mount("/messages/", app=sse.handle_post_message),
],
middleware=middleware,
)
def create_starlette_app():
middleware = None
if MODE == LaunchMode.HOST:
middleware = [Middleware(AuthMiddleware)]

return Starlette(
debug=True,
routes=[
Route("/sse", endpoint=handle_sse),
Mount("/messages/", app=sse.handle_post_message),
],
middleware=middleware,
)




if __name__ == "__main__": if __name__ == "__main__":
default="self-host", default="self-host",
help="Launch mode options:\n" help="Launch mode options:\n"
" * self-host: Launches an MCP server to access a specific tenant space. The 'api_key' argument is required.\n" " * self-host: Launches an MCP server to access a specific tenant space. The 'api_key' argument is required.\n"
" * host: Launches an MCP server that allows users to access their own spaces. Each request must include a header "
" * host: Launches an MCP server that allows users to access their own spaces. Each request must include a Authorization header "
"indicating the user's identification.", "indicating the user's identification.",
) )
parser.add_argument("--api_key", type=str, default="", help="RAGFlow MCP SERVER HOST API KEY") parser.add_argument("--api_key", type=str, default="", help="RAGFlow MCP SERVER HOST API KEY")
print(f"MCP base_url: {BASE_URL}", flush=True) print(f"MCP base_url: {BASE_URL}", flush=True)


uvicorn.run( uvicorn.run(
starlette_app,
create_starlette_app(),
host=HOST, host=HOST,
port=int(PORT), port=int(PORT),
) )

Loading…
Cancel
Save