From 279b01a0286a0049f78d5deecc38055c026d07da Mon Sep 17 00:00:00 2001 From: Yongteng Lei Date: Fri, 6 Feb 2026 16:22:43 +0800 Subject: [PATCH] Feat: MCP host mode supports STREAMABLE-HTTP endpoint (#13037) ### What problem does this PR solve? MCP host mode supports STREAMABLE-HTTP endpoint ### Type of change - [x] New Feature (non-breaking change which adds functionality) --- mcp/client/streamable_http_client.py | 4 + mcp/server/server.py | 169 ++++++++++++++++++--------- 2 files changed, 117 insertions(+), 56 deletions(-) diff --git a/mcp/client/streamable_http_client.py b/mcp/client/streamable_http_client.py index ec679e650..d6f317c5b 100644 --- a/mcp/client/streamable_http_client.py +++ b/mcp/client/streamable_http_client.py @@ -19,6 +19,10 @@ from mcp.client.streamable_http import streamablehttp_client async def main(): try: + # To access RAGFlow server in `host` mode, you need to attach `api_key` for each request to indicate identification. + # async with streamablehttp_client("http://localhost:9382/mcp/", headers={"api_key": "ragflow-fixS-TicrohljzFkeLLWIaVhW7XlXPXIUW5solFor6o"}) as (read_stream, write_stream, _): + # Or follow the requirements of OAuth 2.1 Section 5 with Authorization header + # async with streamablehttp_client("http://localhost:9382/mcp/", headers={"Authorization": "Bearer ragflow-fixS-TicrohljzFkeLLWIaVhW7XlXPXIUW5solFor6o"}) as (read_stream, write_stream, _): async with streamablehttp_client("http://localhost:9382/mcp/") as (read_stream, write_stream, _): async with ClientSession(read_stream, write_stream) as session: await session.initialize() diff --git a/mcp/server/server.py b/mcp/server/server.py index 3c83ea2b5..07cb10d94 100644 --- a/mcp/server/server.py +++ b/mcp/server/server.py @@ -22,18 +22,18 @@ from collections import OrderedDict from collections.abc import AsyncIterator from contextlib import asynccontextmanager from functools import wraps +from typing import Any import click import httpx +import mcp.types as types +from mcp.server.lowlevel import Server from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.responses import JSONResponse, Response from starlette.routing import Mount, Route from strenum import StrEnum -import mcp.types as types -from mcp.server.lowlevel import Server - class LaunchMode(StrEnum): SELF_HOST = "self-host" @@ -68,10 +68,6 @@ class RAGFlowConnector: self.api_url = f"{self.base_url}/api/{self.version}" self._async_client = None - def bind_api_key(self, api_key: str): - self.api_key = api_key - self.authorization_header = {"Authorization": f"Bearer {self.api_key}"} - async def _get_client(self): if self._async_client is None: self._async_client = httpx.AsyncClient(timeout=httpx.Timeout(60.0)) @@ -82,16 +78,18 @@ class RAGFlowConnector: await self._async_client.aclose() self._async_client = None - async def _post(self, path, json=None, stream=False, files=None): - if not self.api_key: + async def _post(self, path, json=None, stream=False, files=None, api_key: str = ""): + if not api_key: return None client = await self._get_client() - res = await client.post(url=self.api_url + path, json=json, headers=self.authorization_header) + res = await client.post(url=self.api_url + path, json=json, headers={"Authorization": f"Bearer {api_key}"}) return res - async def _get(self, path, params=None): + async def _get(self, path, params=None, api_key: str = ""): + if not api_key: + return None client = await self._get_client() - res = await client.get(url=self.api_url + path, params=params, headers=self.authorization_header) + res = await client.get(url=self.api_url + path, params=params, headers={"Authorization": f"Bearer {api_key}"}) return res def _is_cache_valid(self, ts): @@ -129,8 +127,18 @@ class RAGFlowConnector: self._document_metadata_cache[dataset_id] = (doc_id_meta_list, self._get_expiry_timestamp()) self._document_metadata_cache.move_to_end(dataset_id) - async def list_datasets(self, page: int = 1, page_size: int = 1000, orderby: str = "create_time", desc: bool = True, id: str | None = None, name: str | None = None): - res = await self._get("/datasets", {"page": page, "page_size": page_size, "orderby": orderby, "desc": desc, "id": id, "name": name}) + async def list_datasets( + self, + *, + api_key: str, + page: int = 1, + page_size: int = 1000, + orderby: str = "create_time", + desc: bool = True, + id: str | None = None, + name: str | None = None, + ): + res = await self._get("/datasets", {"page": page, "page_size": page_size, "orderby": orderby, "desc": desc, "id": id, "name": name}, api_key=api_key) if not res or res.status_code != 200: raise Exception([types.TextContent(type="text", text="Cannot process this operation.")]) @@ -145,6 +153,8 @@ class RAGFlowConnector: async def retrieval( self, + *, + api_key: str, dataset_ids, document_ids=None, question="", @@ -162,7 +172,7 @@ class RAGFlowConnector: # If no dataset_ids provided or empty list, get all available dataset IDs if not dataset_ids: - dataset_list_str = await self.list_datasets() + dataset_list_str = await self.list_datasets(api_key=api_key) dataset_ids = [] # Parse the dataset list to extract IDs @@ -189,7 +199,7 @@ class RAGFlowConnector: "document_ids": document_ids, } # Send a POST request to the backend service (using requests library as an example, actual implementation may vary) - res = await self._post("/retrieval", json=data_json) + res = await self._post("/retrieval", json=data_json, api_key=api_key) if not res or res.status_code != 200: raise Exception([types.TextContent(type="text", text="Cannot process this operation.")]) @@ -199,7 +209,7 @@ class RAGFlowConnector: chunks = [] # Cache document metadata and dataset information - document_cache, dataset_cache = await self._get_document_metadata_cache(dataset_ids, force_refresh=force_refresh) + document_cache, dataset_cache = await self._get_document_metadata_cache(dataset_ids, api_key=api_key, force_refresh=force_refresh) # Process chunks with enhanced field mapping including per-chunk metadata for chunk_data in data.get("chunks", []): @@ -228,7 +238,7 @@ class RAGFlowConnector: raise Exception([types.TextContent(type="text", text=res.get("message"))]) - async def _get_document_metadata_cache(self, dataset_ids, force_refresh=False): + async def _get_document_metadata_cache(self, dataset_ids, *, api_key: str, force_refresh=False): """Cache document metadata for all documents in the specified datasets""" document_cache = {} dataset_cache = {} @@ -238,7 +248,7 @@ class RAGFlowConnector: dataset_meta = None if force_refresh else self._get_cached_dataset_metadata(dataset_id) if not dataset_meta: # First get dataset info for name - dataset_res = await self._get("/datasets", {"id": dataset_id, "page_size": 1}) + dataset_res = await self._get("/datasets", {"id": dataset_id, "page_size": 1}, api_key=api_key) if dataset_res and dataset_res.status_code == 200: dataset_data = dataset_res.json() if dataset_data.get("code") == 0 and dataset_data.get("data"): @@ -255,7 +265,9 @@ class RAGFlowConnector: doc_id_meta_list = [] docs = {} while page: - docs_res = await self._get(f"/datasets/{dataset_id}/documents?page={page}") + docs_res = await self._get(f"/datasets/{dataset_id}/documents?page={page}", api_key=api_key) + if not docs_res: + break docs_data = docs_res.json() if docs_data.get("code") == 0 and docs_data.get("data", {}).get("docs"): for doc in docs_data["data"]["docs"]: @@ -335,9 +347,59 @@ async def sse_lifespan(server: Server) -> AsyncIterator[dict]: app = Server("ragflow-mcp-server", lifespan=sse_lifespan) +AUTH_TOKEN_STATE_KEY = "ragflow_auth_token" -def with_api_key(required=True): +def _to_text(value: Any) -> str: + if isinstance(value, bytes): + return value.decode(errors="ignore") + return str(value) + + +def _extract_token_from_headers(headers: Any) -> str | None: + if not headers or not hasattr(headers, "get"): + return None + + auth_keys = ("authorization", "Authorization", b"authorization", b"Authorization") + for key in auth_keys: + auth = headers.get(key) + if not auth: + continue + auth_text = _to_text(auth).strip() + if auth_text.lower().startswith("bearer "): + token = auth_text[7:].strip() + if token: + return token + + api_key_keys = ("api_key", "x-api-key", "Api-Key", "X-API-Key", b"api_key", b"x-api-key", b"Api-Key", b"X-API-Key") + for key in api_key_keys: + token = headers.get(key) + if token: + token_text = _to_text(token).strip() + if token_text: + return token_text + + return None + + +def _extract_token_from_request(request: Any) -> str | None: + if request is None: + return None + + state = getattr(request, "state", None) + if state is not None: + token = getattr(state, AUTH_TOKEN_STATE_KEY, None) + if token: + return token + + token = _extract_token_from_headers(getattr(request, "headers", None)) + if token and state is not None: + setattr(state, AUTH_TOKEN_STATE_KEY, token) + + return token + + +def with_api_key(required: bool = True): def decorator(func): @wraps(func) async def wrapper(*args, **kwargs): @@ -347,26 +409,14 @@ def with_api_key(required=True): raise ValueError("Get RAGFlow Context failed") connector = ragflow_ctx.conn + api_key = HOST_API_KEY if MODE == LaunchMode.HOST: - headers = ctx.session._init_options.capabilities.experimental.get("headers", {}) - token = None - - # lower case here, because of Starlette conversion - auth = headers.get("authorization", "") - if auth.startswith("Bearer "): - token = auth.removeprefix("Bearer ").strip() - elif "api_key" in headers: - token = headers["api_key"] - - if required and not token: + api_key = _extract_token_from_request(getattr(ctx, "request", None)) or "" + if required and not api_key: raise ValueError("RAGFlow API key or Bearer token is required.") - connector.bind_api_key(token) - else: - connector.bind_api_key(HOST_API_KEY) - - return await func(*args, connector=connector, **kwargs) + return await func(*args, connector=connector, api_key=api_key, **kwargs) return wrapper @@ -375,8 +425,8 @@ def with_api_key(required=True): @app.list_tools() @with_api_key(required=True) -async def list_tools(*, connector) -> list[types.Tool]: - dataset_description = await connector.list_datasets() +async def list_tools(*, connector: RAGFlowConnector, api_key: str) -> list[types.Tool]: + dataset_description = await connector.list_datasets(api_key=api_key) return [ types.Tool( @@ -446,7 +496,13 @@ async def list_tools(*, connector) -> list[types.Tool]: @app.call_tool() @with_api_key(required=True) -async def call_tool(name: str, arguments: dict, *, connector) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]: +async def call_tool( + name: str, + arguments: dict, + *, + connector: RAGFlowConnector, + api_key: str, +) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]: if name == "ragflow_retrieval": document_ids = arguments.get("document_ids", []) dataset_ids = arguments.get("dataset_ids", []) @@ -462,7 +518,7 @@ async def call_tool(name: str, arguments: dict, *, connector) -> list[types.Text # If no dataset_ids provided or empty list, get all available dataset IDs if not dataset_ids: - dataset_list_str = await connector.list_datasets() + dataset_list_str = await connector.list_datasets(api_key=api_key) dataset_ids = [] # Parse the dataset list to extract IDs @@ -477,6 +533,7 @@ async def call_tool(name: str, arguments: dict, *, connector) -> list[types.Text continue return await connector.retrieval( + api_key=api_key, dataset_ids=dataset_ids, document_ids=document_ids, question=question, @@ -510,17 +567,13 @@ def create_starlette_app(): path = scope["path"] if path.startswith("/messages/") or path.startswith("/sse") or path.startswith("/mcp"): headers = dict(scope["headers"]) - token = None - auth_header = headers.get(b"authorization") - if auth_header and auth_header.startswith(b"Bearer "): - token = auth_header.removeprefix(b"Bearer ").strip() - elif b"api_key" in headers: - token = headers[b"api_key"] + token = _extract_token_from_headers(headers) if not token: response = JSONResponse({"error": "Missing or invalid authorization header"}, status_code=401) await response(scope, receive, send) return + scope.setdefault("state", {})[AUTH_TOKEN_STATE_KEY] = token await self.app(scope, receive, send) @@ -547,9 +600,8 @@ def create_starlette_app(): # Add streamable HTTP route if enabled streamablehttp_lifespan = None if TRANSPORT_STREAMABLE_HTTP_ENABLED: - from starlette.types import Receive, Scope, Send - from mcp.server.streamable_http_manager import StreamableHTTPSessionManager + from starlette.types import Receive, Scope, Send session_manager = StreamableHTTPSessionManager( app=app, @@ -558,8 +610,11 @@ def create_starlette_app(): stateless=True, ) - async def handle_streamable_http(scope: Scope, receive: Receive, send: Send) -> None: - await session_manager.handle_request(scope, receive, send) + class StreamableHTTPEntry: + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + await session_manager.handle_request(scope, receive, send) + + streamable_http_entry = StreamableHTTPEntry() @asynccontextmanager async def streamablehttp_lifespan(app: Starlette) -> AsyncIterator[None]: @@ -570,7 +625,12 @@ def create_starlette_app(): finally: logging.info("StreamableHTTP application shutting down...") - routes.append(Mount("/mcp", app=handle_streamable_http)) + routes.extend( + [ + Route("/mcp", endpoint=streamable_http_entry, methods=["GET", "POST", "DELETE"]), + Mount("/mcp", app=streamable_http_entry), + ] + ) return Starlette( debug=True, @@ -631,9 +691,6 @@ def main(base_url, host, port, mode, api_key, transport_sse_enabled, transport_s if MODE == LaunchMode.SELF_HOST and not HOST_API_KEY: raise click.UsageError("--api-key is required when --mode is 'self-host'") - if TRANSPORT_STREAMABLE_HTTP_ENABLED and MODE == LaunchMode.HOST: - raise click.UsageError("The --host mode is not supported with streamable-http transport yet.") - if not TRANSPORT_STREAMABLE_HTTP_ENABLED and JSON_RESPONSE: JSON_RESPONSE = False @@ -690,7 +747,7 @@ if __name__ == "__main__": --base-url=http://127.0.0.1:9380 \ --mode=self-host --api-key=ragflow-xxxxx - 2. Host mode (multi-tenant, self-host only, clients must provide Authorization headers): + 2. Host mode (multi-tenant, clients must provide Authorization headers): uv run mcp/server/server.py --host=127.0.0.1 --port=9382 \ --base-url=http://127.0.0.1:9380 \ --mode=host