From a9532cb9e7e2b9215ce1de0d0aaa7ea8b46bc872 Mon Sep 17 00:00:00 2001 From: Yongteng Lei Date: Tue, 17 Jun 2025 09:29:12 +0800 Subject: [PATCH] Feat: add authorization header for MCP server based on OAuth 2.1 (#8292) ### What problem does this PR solve? Add authorization header for MCP server based on [OAuth 2.1](https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-12#section-5). ### Type of change - [x] New Feature (non-breaking change which adds functionality) --- mcp/client/client.py | 3 + mcp/server/server.py | 130 +++++++++++++++++++++++++++---------------- 2 files changed, 85 insertions(+), 48 deletions(-) diff --git a/mcp/client/client.py b/mcp/client/client.py index 2f3ad81ce..3c54ea030 100644 --- a/mcp/client/client.py +++ b/mcp/client/client.py @@ -23,6 +23,9 @@ async def main(): try: # To access RAGFlow server in `host` mode, you need to attach `api_key` for each request to indicate identification. # async with sse_client("http://localhost:9382/sse", headers={"api_key": "ragflow-IyMGI1ZDhjMTA2ZTExZjBiYTMyMGQ4Zm"}) as streams: + # Or follow the requirements of OAuth 2.1 Section 5 with Authorization header + # async with sse_client("http://localhost:9382/sse", headers={"Authorization": "Bearer ragflow-IyMGI1ZDhjMTA2ZTExZjBiYTMyMGQ4Zm"}) as streams: + async with sse_client("http://localhost:9382/sse") as streams: async with ClientSession( streams[0], diff --git a/mcp/server/server.py b/mcp/server/server.py index de87c221d..743cd16f9 100644 --- a/mcp/server/server.py +++ b/mcp/server/server.py @@ -17,6 +17,7 @@ import json from collections.abc import AsyncIterator from contextlib import asynccontextmanager +from functools import wraps import requests from starlette.applications import Starlette @@ -127,22 +128,45 @@ app = Server("ragflow-server", lifespan=server_lifespan) sse = SseServerTransport("/messages/") +def with_api_key(required=True): + def decorator(func): + @wraps(func) + async def wrapper(*args, **kwargs): + ctx = app.request_context + ragflow_ctx = ctx.lifespan_context.get("ragflow_ctx") + if not ragflow_ctx: + raise ValueError("Get RAGFlow Context failed") + + connector = ragflow_ctx.conn + + if MODE == LaunchMode.HOST: + headers = ctx.session._init_options.capabilities.experimental.get("headers", {}) + token = None + + # lower case here, because of Starlette conversion + auth = headers.get("authorization", "") + if auth.startswith("Bearer "): + token = auth.removeprefix("Bearer ").strip() + elif "api_key" in headers: + token = headers["api_key"] + + if required and not token: + raise ValueError("RAGFlow API key or Bearer token is required.") + + connector.bind_api_key(token) + else: + connector.bind_api_key(HOST_API_KEY) + + return await func(*args, connector=connector, **kwargs) + + return wrapper + + return decorator + + @app.list_tools() -async def list_tools() -> list[types.Tool]: - ctx = app.request_context - ragflow_ctx = ctx.lifespan_context["ragflow_ctx"] - if not ragflow_ctx: - raise ValueError("Get RAGFlow Context failed") - connector = ragflow_ctx.conn - - if MODE == LaunchMode.HOST: - api_key = ctx.session._init_options.capabilities.experimental["headers"]["api_key"] - if not api_key: - raise ValueError("RAGFlow API_KEY is required.") - else: - api_key = HOST_API_KEY - connector.bind_api_key(api_key) - +@with_api_key(required=True) +async def list_tools(*, connector) -> list[types.Tool]: dataset_description = connector.list_datasets() return [ @@ -152,7 +176,17 @@ async def list_tools() -> list[types.Tool]: + dataset_description, inputSchema={ "type": "object", - "properties": {"dataset_ids": {"type": "array", "items": {"type": "string"}}, "document_ids": {"type": "array", "items": {"type": "string"}}, "question": {"type": "string"}}, + "properties": { + "dataset_ids": { + "type": "array", + "items": {"type": "string"}, + }, + "document_ids": { + "type": "array", + "items": {"type": "string"}, + }, + "question": {"type": "string"}, + }, "required": ["dataset_ids", "question"], }, ), @@ -160,24 +194,15 @@ async def list_tools() -> list[types.Tool]: @app.call_tool() -async def call_tool(name: str, arguments: dict) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]: - ctx = app.request_context - ragflow_ctx = ctx.lifespan_context["ragflow_ctx"] - if not ragflow_ctx: - raise ValueError("Get RAGFlow Context failed") - connector = ragflow_ctx.conn - - if MODE == LaunchMode.HOST: - api_key = ctx.session._init_options.capabilities.experimental["headers"]["api_key"] - if not api_key: - raise ValueError("RAGFlow API_KEY is required.") - else: - api_key = HOST_API_KEY - connector.bind_api_key(api_key) - +@with_api_key(required=True) +async def call_tool(name: str, arguments: dict, *, connector) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]: if name == "ragflow_retrieval": document_ids = arguments.get("document_ids", []) - return connector.retrieval(dataset_ids=arguments["dataset_ids"], document_ids=document_ids, question=arguments["question"]) + return connector.retrieval( + dataset_ids=arguments["dataset_ids"], + document_ids=document_ids, + question=arguments["question"], + ) raise ValueError(f"Tool not found: {name}") @@ -188,25 +213,34 @@ async def handle_sse(request): class AuthMiddleware(BaseHTTPMiddleware): async def dispatch(self, request, call_next): + # Authentication is deferred, will be handled by RAGFlow core service. if request.url.path.startswith("/sse") or request.url.path.startswith("/messages"): - api_key = request.headers.get("api_key") - if not api_key: - return JSONResponse({"error": "Missing unauthorization header"}, status_code=401) + token = None + + auth_header = request.headers.get("Authorization") + if auth_header and auth_header.startswith("Bearer "): + token = auth_header.removeprefix("Bearer ").strip() + elif request.headers.get("api_key"): + token = request.headers["api_key"] + + if not token: + return JSONResponse({"error": "Missing or invalid authorization header"}, status_code=401) return await call_next(request) -middleware = None -if MODE == LaunchMode.HOST: - middleware = [Middleware(AuthMiddleware)] +def create_starlette_app(): + middleware = None + if MODE == LaunchMode.HOST: + middleware = [Middleware(AuthMiddleware)] -starlette_app = Starlette( - debug=True, - routes=[ - Route("/sse", endpoint=handle_sse), - Mount("/messages/", app=sse.handle_post_message), - ], - middleware=middleware, -) + return Starlette( + debug=True, + routes=[ + Route("/sse", endpoint=handle_sse), + Mount("/messages/", app=sse.handle_post_message), + ], + middleware=middleware, + ) if __name__ == "__main__": @@ -236,7 +270,7 @@ if __name__ == "__main__": default="self-host", help="Launch mode options:\n" " * self-host: Launches an MCP server to access a specific tenant space. The 'api_key' argument is required.\n" - " * host: Launches an MCP server that allows users to access their own spaces. Each request must include a header " + " * host: Launches an MCP server that allows users to access their own spaces. Each request must include a Authorization header " "indicating the user's identification.", ) parser.add_argument("--api_key", type=str, default="", help="RAGFlow MCP SERVER HOST API KEY") @@ -268,7 +302,7 @@ __ __ ____ ____ ____ _____ ______ _______ ____ print(f"MCP base_url: {BASE_URL}", flush=True) uvicorn.run( - starlette_app, + create_starlette_app(), host=HOST, port=int(PORT), )