Feat: add authorization header for MCP server based on OAuth 2.1 (#8292)

### What problem does this PR solve?

Add authorization header for MCP server based on [OAuth
2.1](https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-12#section-5).

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Yongteng Lei
2025-06-17 09:29:12 +08:00
committed by GitHub
parent efc3caf702
commit a9532cb9e7
2 changed files with 85 additions and 48 deletions

View File

@ -23,6 +23,9 @@ async def main():
try: try:
# To access RAGFlow server in `host` mode, you need to attach `api_key` for each request to indicate identification. # To access RAGFlow server in `host` mode, you need to attach `api_key` for each request to indicate identification.
# async with sse_client("http://localhost:9382/sse", headers={"api_key": "ragflow-IyMGI1ZDhjMTA2ZTExZjBiYTMyMGQ4Zm"}) as streams: # async with sse_client("http://localhost:9382/sse", headers={"api_key": "ragflow-IyMGI1ZDhjMTA2ZTExZjBiYTMyMGQ4Zm"}) as streams:
# Or follow the requirements of OAuth 2.1 Section 5 with Authorization header
# async with sse_client("http://localhost:9382/sse", headers={"Authorization": "Bearer ragflow-IyMGI1ZDhjMTA2ZTExZjBiYTMyMGQ4Zm"}) as streams:
async with sse_client("http://localhost:9382/sse") as streams: async with sse_client("http://localhost:9382/sse") as streams:
async with ClientSession( async with ClientSession(
streams[0], streams[0],

View File

@ -17,6 +17,7 @@
import json import json
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from functools import wraps
import requests import requests
from starlette.applications import Starlette from starlette.applications import Starlette
@ -127,22 +128,45 @@ app = Server("ragflow-server", lifespan=server_lifespan)
sse = SseServerTransport("/messages/") sse = SseServerTransport("/messages/")
def with_api_key(required=True):
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
ctx = app.request_context
ragflow_ctx = ctx.lifespan_context.get("ragflow_ctx")
if not ragflow_ctx:
raise ValueError("Get RAGFlow Context failed")
connector = ragflow_ctx.conn
if MODE == LaunchMode.HOST:
headers = ctx.session._init_options.capabilities.experimental.get("headers", {})
token = None
# lower case here, because of Starlette conversion
auth = headers.get("authorization", "")
if auth.startswith("Bearer "):
token = auth.removeprefix("Bearer ").strip()
elif "api_key" in headers:
token = headers["api_key"]
if required and not token:
raise ValueError("RAGFlow API key or Bearer token is required.")
connector.bind_api_key(token)
else:
connector.bind_api_key(HOST_API_KEY)
return await func(*args, connector=connector, **kwargs)
return wrapper
return decorator
@app.list_tools() @app.list_tools()
async def list_tools() -> list[types.Tool]: @with_api_key(required=True)
ctx = app.request_context async def list_tools(*, connector) -> list[types.Tool]:
ragflow_ctx = ctx.lifespan_context["ragflow_ctx"]
if not ragflow_ctx:
raise ValueError("Get RAGFlow Context failed")
connector = ragflow_ctx.conn
if MODE == LaunchMode.HOST:
api_key = ctx.session._init_options.capabilities.experimental["headers"]["api_key"]
if not api_key:
raise ValueError("RAGFlow API_KEY is required.")
else:
api_key = HOST_API_KEY
connector.bind_api_key(api_key)
dataset_description = connector.list_datasets() dataset_description = connector.list_datasets()
return [ return [
@ -152,7 +176,17 @@ async def list_tools() -> list[types.Tool]:
+ dataset_description, + dataset_description,
inputSchema={ inputSchema={
"type": "object", "type": "object",
"properties": {"dataset_ids": {"type": "array", "items": {"type": "string"}}, "document_ids": {"type": "array", "items": {"type": "string"}}, "question": {"type": "string"}}, "properties": {
"dataset_ids": {
"type": "array",
"items": {"type": "string"},
},
"document_ids": {
"type": "array",
"items": {"type": "string"},
},
"question": {"type": "string"},
},
"required": ["dataset_ids", "question"], "required": ["dataset_ids", "question"],
}, },
), ),
@ -160,24 +194,15 @@ async def list_tools() -> list[types.Tool]:
@app.call_tool() @app.call_tool()
async def call_tool(name: str, arguments: dict) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]: @with_api_key(required=True)
ctx = app.request_context async def call_tool(name: str, arguments: dict, *, connector) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
ragflow_ctx = ctx.lifespan_context["ragflow_ctx"]
if not ragflow_ctx:
raise ValueError("Get RAGFlow Context failed")
connector = ragflow_ctx.conn
if MODE == LaunchMode.HOST:
api_key = ctx.session._init_options.capabilities.experimental["headers"]["api_key"]
if not api_key:
raise ValueError("RAGFlow API_KEY is required.")
else:
api_key = HOST_API_KEY
connector.bind_api_key(api_key)
if name == "ragflow_retrieval": if name == "ragflow_retrieval":
document_ids = arguments.get("document_ids", []) document_ids = arguments.get("document_ids", [])
return connector.retrieval(dataset_ids=arguments["dataset_ids"], document_ids=document_ids, question=arguments["question"]) return connector.retrieval(
dataset_ids=arguments["dataset_ids"],
document_ids=document_ids,
question=arguments["question"],
)
raise ValueError(f"Tool not found: {name}") raise ValueError(f"Tool not found: {name}")
@ -188,25 +213,34 @@ async def handle_sse(request):
class AuthMiddleware(BaseHTTPMiddleware): class AuthMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request, call_next): async def dispatch(self, request, call_next):
# Authentication is deferred, will be handled by RAGFlow core service.
if request.url.path.startswith("/sse") or request.url.path.startswith("/messages"): if request.url.path.startswith("/sse") or request.url.path.startswith("/messages"):
api_key = request.headers.get("api_key") token = None
if not api_key:
return JSONResponse({"error": "Missing unauthorization header"}, status_code=401) auth_header = request.headers.get("Authorization")
if auth_header and auth_header.startswith("Bearer "):
token = auth_header.removeprefix("Bearer ").strip()
elif request.headers.get("api_key"):
token = request.headers["api_key"]
if not token:
return JSONResponse({"error": "Missing or invalid authorization header"}, status_code=401)
return await call_next(request) return await call_next(request)
middleware = None def create_starlette_app():
if MODE == LaunchMode.HOST: middleware = None
middleware = [Middleware(AuthMiddleware)] if MODE == LaunchMode.HOST:
middleware = [Middleware(AuthMiddleware)]
starlette_app = Starlette( return Starlette(
debug=True, debug=True,
routes=[ routes=[
Route("/sse", endpoint=handle_sse), Route("/sse", endpoint=handle_sse),
Mount("/messages/", app=sse.handle_post_message), Mount("/messages/", app=sse.handle_post_message),
], ],
middleware=middleware, middleware=middleware,
) )
if __name__ == "__main__": if __name__ == "__main__":
@ -236,7 +270,7 @@ if __name__ == "__main__":
default="self-host", default="self-host",
help="Launch mode options:\n" help="Launch mode options:\n"
" * self-host: Launches an MCP server to access a specific tenant space. The 'api_key' argument is required.\n" " * self-host: Launches an MCP server to access a specific tenant space. The 'api_key' argument is required.\n"
" * host: Launches an MCP server that allows users to access their own spaces. Each request must include a header " " * host: Launches an MCP server that allows users to access their own spaces. Each request must include a Authorization header "
"indicating the user's identification.", "indicating the user's identification.",
) )
parser.add_argument("--api_key", type=str, default="", help="RAGFlow MCP SERVER HOST API KEY") parser.add_argument("--api_key", type=str, default="", help="RAGFlow MCP SERVER HOST API KEY")
@ -268,7 +302,7 @@ __ __ ____ ____ ____ _____ ______ _______ ____
print(f"MCP base_url: {BASE_URL}", flush=True) print(f"MCP base_url: {BASE_URL}", flush=True)
uvicorn.run( uvicorn.run(
starlette_app, create_starlette_app(),
host=HOST, host=HOST,
port=int(PORT), port=int(PORT),
) )