Feat: MCP host mode supports STREAMABLE-HTTP endpoint (#13037)

### What problem does this PR solve?

MCP host mode supports STREAMABLE-HTTP endpoint

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Yongteng Lei
2026-02-06 16:22:43 +08:00
committed by GitHub
parent c130ac0f88
commit 279b01a028
2 changed files with 117 additions and 56 deletions

View File

@ -19,6 +19,10 @@ from mcp.client.streamable_http import streamablehttp_client
async def main(): async def main():
try: try:
# To access RAGFlow server in `host` mode, you need to attach `api_key` for each request to indicate identification.
# async with streamablehttp_client("http://localhost:9382/mcp/", headers={"api_key": "ragflow-fixS-TicrohljzFkeLLWIaVhW7XlXPXIUW5solFor6o"}) as (read_stream, write_stream, _):
# Or follow the requirements of OAuth 2.1 Section 5 with Authorization header
# async with streamablehttp_client("http://localhost:9382/mcp/", headers={"Authorization": "Bearer ragflow-fixS-TicrohljzFkeLLWIaVhW7XlXPXIUW5solFor6o"}) as (read_stream, write_stream, _):
async with streamablehttp_client("http://localhost:9382/mcp/") as (read_stream, write_stream, _): async with streamablehttp_client("http://localhost:9382/mcp/") as (read_stream, write_stream, _):
async with ClientSession(read_stream, write_stream) as session: async with ClientSession(read_stream, write_stream) as session:
await session.initialize() await session.initialize()

View File

@ -22,18 +22,18 @@ from collections import OrderedDict
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from functools import wraps from functools import wraps
from typing import Any
import click import click
import httpx import httpx
import mcp.types as types
from mcp.server.lowlevel import Server
from starlette.applications import Starlette from starlette.applications import Starlette
from starlette.middleware import Middleware from starlette.middleware import Middleware
from starlette.responses import JSONResponse, Response from starlette.responses import JSONResponse, Response
from starlette.routing import Mount, Route from starlette.routing import Mount, Route
from strenum import StrEnum from strenum import StrEnum
import mcp.types as types
from mcp.server.lowlevel import Server
class LaunchMode(StrEnum): class LaunchMode(StrEnum):
SELF_HOST = "self-host" SELF_HOST = "self-host"
@ -68,10 +68,6 @@ class RAGFlowConnector:
self.api_url = f"{self.base_url}/api/{self.version}" self.api_url = f"{self.base_url}/api/{self.version}"
self._async_client = None self._async_client = None
def bind_api_key(self, api_key: str):
self.api_key = api_key
self.authorization_header = {"Authorization": f"Bearer {self.api_key}"}
async def _get_client(self): async def _get_client(self):
if self._async_client is None: if self._async_client is None:
self._async_client = httpx.AsyncClient(timeout=httpx.Timeout(60.0)) self._async_client = httpx.AsyncClient(timeout=httpx.Timeout(60.0))
@ -82,16 +78,18 @@ class RAGFlowConnector:
await self._async_client.aclose() await self._async_client.aclose()
self._async_client = None self._async_client = None
async def _post(self, path, json=None, stream=False, files=None): async def _post(self, path, json=None, stream=False, files=None, api_key: str = ""):
if not self.api_key: if not api_key:
return None return None
client = await self._get_client() client = await self._get_client()
res = await client.post(url=self.api_url + path, json=json, headers=self.authorization_header) res = await client.post(url=self.api_url + path, json=json, headers={"Authorization": f"Bearer {api_key}"})
return res return res
async def _get(self, path, params=None): async def _get(self, path, params=None, api_key: str = ""):
if not api_key:
return None
client = await self._get_client() client = await self._get_client()
res = await client.get(url=self.api_url + path, params=params, headers=self.authorization_header) res = await client.get(url=self.api_url + path, params=params, headers={"Authorization": f"Bearer {api_key}"})
return res return res
def _is_cache_valid(self, ts): def _is_cache_valid(self, ts):
@ -129,8 +127,18 @@ class RAGFlowConnector:
self._document_metadata_cache[dataset_id] = (doc_id_meta_list, self._get_expiry_timestamp()) self._document_metadata_cache[dataset_id] = (doc_id_meta_list, self._get_expiry_timestamp())
self._document_metadata_cache.move_to_end(dataset_id) self._document_metadata_cache.move_to_end(dataset_id)
async def list_datasets(self, page: int = 1, page_size: int = 1000, orderby: str = "create_time", desc: bool = True, id: str | None = None, name: str | None = None): async def list_datasets(
res = await self._get("/datasets", {"page": page, "page_size": page_size, "orderby": orderby, "desc": desc, "id": id, "name": name}) self,
*,
api_key: str,
page: int = 1,
page_size: int = 1000,
orderby: str = "create_time",
desc: bool = True,
id: str | None = None,
name: str | None = None,
):
res = await self._get("/datasets", {"page": page, "page_size": page_size, "orderby": orderby, "desc": desc, "id": id, "name": name}, api_key=api_key)
if not res or res.status_code != 200: if not res or res.status_code != 200:
raise Exception([types.TextContent(type="text", text="Cannot process this operation.")]) raise Exception([types.TextContent(type="text", text="Cannot process this operation.")])
@ -145,6 +153,8 @@ class RAGFlowConnector:
async def retrieval( async def retrieval(
self, self,
*,
api_key: str,
dataset_ids, dataset_ids,
document_ids=None, document_ids=None,
question="", question="",
@ -162,7 +172,7 @@ class RAGFlowConnector:
# If no dataset_ids provided or empty list, get all available dataset IDs # If no dataset_ids provided or empty list, get all available dataset IDs
if not dataset_ids: if not dataset_ids:
dataset_list_str = await self.list_datasets() dataset_list_str = await self.list_datasets(api_key=api_key)
dataset_ids = [] dataset_ids = []
# Parse the dataset list to extract IDs # Parse the dataset list to extract IDs
@ -189,7 +199,7 @@ class RAGFlowConnector:
"document_ids": document_ids, "document_ids": document_ids,
} }
# Send a POST request to the backend service (using requests library as an example, actual implementation may vary) # Send a POST request to the backend service (using requests library as an example, actual implementation may vary)
res = await self._post("/retrieval", json=data_json) res = await self._post("/retrieval", json=data_json, api_key=api_key)
if not res or res.status_code != 200: if not res or res.status_code != 200:
raise Exception([types.TextContent(type="text", text="Cannot process this operation.")]) raise Exception([types.TextContent(type="text", text="Cannot process this operation.")])
@ -199,7 +209,7 @@ class RAGFlowConnector:
chunks = [] chunks = []
# Cache document metadata and dataset information # Cache document metadata and dataset information
document_cache, dataset_cache = await self._get_document_metadata_cache(dataset_ids, force_refresh=force_refresh) document_cache, dataset_cache = await self._get_document_metadata_cache(dataset_ids, api_key=api_key, force_refresh=force_refresh)
# Process chunks with enhanced field mapping including per-chunk metadata # Process chunks with enhanced field mapping including per-chunk metadata
for chunk_data in data.get("chunks", []): for chunk_data in data.get("chunks", []):
@ -228,7 +238,7 @@ class RAGFlowConnector:
raise Exception([types.TextContent(type="text", text=res.get("message"))]) raise Exception([types.TextContent(type="text", text=res.get("message"))])
async def _get_document_metadata_cache(self, dataset_ids, force_refresh=False): async def _get_document_metadata_cache(self, dataset_ids, *, api_key: str, force_refresh=False):
"""Cache document metadata for all documents in the specified datasets""" """Cache document metadata for all documents in the specified datasets"""
document_cache = {} document_cache = {}
dataset_cache = {} dataset_cache = {}
@ -238,7 +248,7 @@ class RAGFlowConnector:
dataset_meta = None if force_refresh else self._get_cached_dataset_metadata(dataset_id) dataset_meta = None if force_refresh else self._get_cached_dataset_metadata(dataset_id)
if not dataset_meta: if not dataset_meta:
# First get dataset info for name # First get dataset info for name
dataset_res = await self._get("/datasets", {"id": dataset_id, "page_size": 1}) dataset_res = await self._get("/datasets", {"id": dataset_id, "page_size": 1}, api_key=api_key)
if dataset_res and dataset_res.status_code == 200: if dataset_res and dataset_res.status_code == 200:
dataset_data = dataset_res.json() dataset_data = dataset_res.json()
if dataset_data.get("code") == 0 and dataset_data.get("data"): if dataset_data.get("code") == 0 and dataset_data.get("data"):
@ -255,7 +265,9 @@ class RAGFlowConnector:
doc_id_meta_list = [] doc_id_meta_list = []
docs = {} docs = {}
while page: while page:
docs_res = await self._get(f"/datasets/{dataset_id}/documents?page={page}") docs_res = await self._get(f"/datasets/{dataset_id}/documents?page={page}", api_key=api_key)
if not docs_res:
break
docs_data = docs_res.json() docs_data = docs_res.json()
if docs_data.get("code") == 0 and docs_data.get("data", {}).get("docs"): if docs_data.get("code") == 0 and docs_data.get("data", {}).get("docs"):
for doc in docs_data["data"]["docs"]: for doc in docs_data["data"]["docs"]:
@ -335,9 +347,59 @@ async def sse_lifespan(server: Server) -> AsyncIterator[dict]:
app = Server("ragflow-mcp-server", lifespan=sse_lifespan) app = Server("ragflow-mcp-server", lifespan=sse_lifespan)
AUTH_TOKEN_STATE_KEY = "ragflow_auth_token"
def with_api_key(required=True): def _to_text(value: Any) -> str:
if isinstance(value, bytes):
return value.decode(errors="ignore")
return str(value)
def _extract_token_from_headers(headers: Any) -> str | None:
if not headers or not hasattr(headers, "get"):
return None
auth_keys = ("authorization", "Authorization", b"authorization", b"Authorization")
for key in auth_keys:
auth = headers.get(key)
if not auth:
continue
auth_text = _to_text(auth).strip()
if auth_text.lower().startswith("bearer "):
token = auth_text[7:].strip()
if token:
return token
api_key_keys = ("api_key", "x-api-key", "Api-Key", "X-API-Key", b"api_key", b"x-api-key", b"Api-Key", b"X-API-Key")
for key in api_key_keys:
token = headers.get(key)
if token:
token_text = _to_text(token).strip()
if token_text:
return token_text
return None
def _extract_token_from_request(request: Any) -> str | None:
if request is None:
return None
state = getattr(request, "state", None)
if state is not None:
token = getattr(state, AUTH_TOKEN_STATE_KEY, None)
if token:
return token
token = _extract_token_from_headers(getattr(request, "headers", None))
if token and state is not None:
setattr(state, AUTH_TOKEN_STATE_KEY, token)
return token
def with_api_key(required: bool = True):
def decorator(func): def decorator(func):
@wraps(func) @wraps(func)
async def wrapper(*args, **kwargs): async def wrapper(*args, **kwargs):
@ -347,26 +409,14 @@ def with_api_key(required=True):
raise ValueError("Get RAGFlow Context failed") raise ValueError("Get RAGFlow Context failed")
connector = ragflow_ctx.conn connector = ragflow_ctx.conn
api_key = HOST_API_KEY
if MODE == LaunchMode.HOST: if MODE == LaunchMode.HOST:
headers = ctx.session._init_options.capabilities.experimental.get("headers", {}) api_key = _extract_token_from_request(getattr(ctx, "request", None)) or ""
token = None if required and not api_key:
# lower case here, because of Starlette conversion
auth = headers.get("authorization", "")
if auth.startswith("Bearer "):
token = auth.removeprefix("Bearer ").strip()
elif "api_key" in headers:
token = headers["api_key"]
if required and not token:
raise ValueError("RAGFlow API key or Bearer token is required.") raise ValueError("RAGFlow API key or Bearer token is required.")
connector.bind_api_key(token) return await func(*args, connector=connector, api_key=api_key, **kwargs)
else:
connector.bind_api_key(HOST_API_KEY)
return await func(*args, connector=connector, **kwargs)
return wrapper return wrapper
@ -375,8 +425,8 @@ def with_api_key(required=True):
@app.list_tools() @app.list_tools()
@with_api_key(required=True) @with_api_key(required=True)
async def list_tools(*, connector) -> list[types.Tool]: async def list_tools(*, connector: RAGFlowConnector, api_key: str) -> list[types.Tool]:
dataset_description = await connector.list_datasets() dataset_description = await connector.list_datasets(api_key=api_key)
return [ return [
types.Tool( types.Tool(
@ -446,7 +496,13 @@ async def list_tools(*, connector) -> list[types.Tool]:
@app.call_tool() @app.call_tool()
@with_api_key(required=True) @with_api_key(required=True)
async def call_tool(name: str, arguments: dict, *, connector) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]: async def call_tool(
name: str,
arguments: dict,
*,
connector: RAGFlowConnector,
api_key: str,
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
if name == "ragflow_retrieval": if name == "ragflow_retrieval":
document_ids = arguments.get("document_ids", []) document_ids = arguments.get("document_ids", [])
dataset_ids = arguments.get("dataset_ids", []) dataset_ids = arguments.get("dataset_ids", [])
@ -462,7 +518,7 @@ async def call_tool(name: str, arguments: dict, *, connector) -> list[types.Text
# If no dataset_ids provided or empty list, get all available dataset IDs # If no dataset_ids provided or empty list, get all available dataset IDs
if not dataset_ids: if not dataset_ids:
dataset_list_str = await connector.list_datasets() dataset_list_str = await connector.list_datasets(api_key=api_key)
dataset_ids = [] dataset_ids = []
# Parse the dataset list to extract IDs # Parse the dataset list to extract IDs
@ -477,6 +533,7 @@ async def call_tool(name: str, arguments: dict, *, connector) -> list[types.Text
continue continue
return await connector.retrieval( return await connector.retrieval(
api_key=api_key,
dataset_ids=dataset_ids, dataset_ids=dataset_ids,
document_ids=document_ids, document_ids=document_ids,
question=question, question=question,
@ -510,17 +567,13 @@ def create_starlette_app():
path = scope["path"] path = scope["path"]
if path.startswith("/messages/") or path.startswith("/sse") or path.startswith("/mcp"): if path.startswith("/messages/") or path.startswith("/sse") or path.startswith("/mcp"):
headers = dict(scope["headers"]) headers = dict(scope["headers"])
token = None token = _extract_token_from_headers(headers)
auth_header = headers.get(b"authorization")
if auth_header and auth_header.startswith(b"Bearer "):
token = auth_header.removeprefix(b"Bearer ").strip()
elif b"api_key" in headers:
token = headers[b"api_key"]
if not token: if not token:
response = JSONResponse({"error": "Missing or invalid authorization header"}, status_code=401) response = JSONResponse({"error": "Missing or invalid authorization header"}, status_code=401)
await response(scope, receive, send) await response(scope, receive, send)
return return
scope.setdefault("state", {})[AUTH_TOKEN_STATE_KEY] = token
await self.app(scope, receive, send) await self.app(scope, receive, send)
@ -547,9 +600,8 @@ def create_starlette_app():
# Add streamable HTTP route if enabled # Add streamable HTTP route if enabled
streamablehttp_lifespan = None streamablehttp_lifespan = None
if TRANSPORT_STREAMABLE_HTTP_ENABLED: if TRANSPORT_STREAMABLE_HTTP_ENABLED:
from starlette.types import Receive, Scope, Send
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
from starlette.types import Receive, Scope, Send
session_manager = StreamableHTTPSessionManager( session_manager = StreamableHTTPSessionManager(
app=app, app=app,
@ -558,8 +610,11 @@ def create_starlette_app():
stateless=True, stateless=True,
) )
async def handle_streamable_http(scope: Scope, receive: Receive, send: Send) -> None: class StreamableHTTPEntry:
await session_manager.handle_request(scope, receive, send) async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
await session_manager.handle_request(scope, receive, send)
streamable_http_entry = StreamableHTTPEntry()
@asynccontextmanager @asynccontextmanager
async def streamablehttp_lifespan(app: Starlette) -> AsyncIterator[None]: async def streamablehttp_lifespan(app: Starlette) -> AsyncIterator[None]:
@ -570,7 +625,12 @@ def create_starlette_app():
finally: finally:
logging.info("StreamableHTTP application shutting down...") logging.info("StreamableHTTP application shutting down...")
routes.append(Mount("/mcp", app=handle_streamable_http)) routes.extend(
[
Route("/mcp", endpoint=streamable_http_entry, methods=["GET", "POST", "DELETE"]),
Mount("/mcp", app=streamable_http_entry),
]
)
return Starlette( return Starlette(
debug=True, debug=True,
@ -631,9 +691,6 @@ def main(base_url, host, port, mode, api_key, transport_sse_enabled, transport_s
if MODE == LaunchMode.SELF_HOST and not HOST_API_KEY: if MODE == LaunchMode.SELF_HOST and not HOST_API_KEY:
raise click.UsageError("--api-key is required when --mode is 'self-host'") raise click.UsageError("--api-key is required when --mode is 'self-host'")
if TRANSPORT_STREAMABLE_HTTP_ENABLED and MODE == LaunchMode.HOST:
raise click.UsageError("The --host mode is not supported with streamable-http transport yet.")
if not TRANSPORT_STREAMABLE_HTTP_ENABLED and JSON_RESPONSE: if not TRANSPORT_STREAMABLE_HTTP_ENABLED and JSON_RESPONSE:
JSON_RESPONSE = False JSON_RESPONSE = False
@ -690,7 +747,7 @@ if __name__ == "__main__":
--base-url=http://127.0.0.1:9380 \ --base-url=http://127.0.0.1:9380 \
--mode=self-host --api-key=ragflow-xxxxx --mode=self-host --api-key=ragflow-xxxxx
2. Host mode (multi-tenant, self-host only, clients must provide Authorization headers): 2. Host mode (multi-tenant, clients must provide Authorization headers):
uv run mcp/server/server.py --host=127.0.0.1 --port=9382 \ uv run mcp/server/server.py --host=127.0.0.1 --port=9382 \
--base-url=http://127.0.0.1:9380 \ --base-url=http://127.0.0.1:9380 \
--mode=host --mode=host