Feat: MCP host mode supports STREAMABLE-HTTP endpoint (#13037)

### What problem does this PR solve?

MCP host mode supports STREAMABLE-HTTP endpoint

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Yongteng Lei
2026-02-06 16:22:43 +08:00
committed by GitHub
parent c130ac0f88
commit 279b01a028
2 changed files with 117 additions and 56 deletions

View File

@ -19,6 +19,10 @@ from mcp.client.streamable_http import streamablehttp_client
async def main():
try:
# To access RAGFlow server in `host` mode, you need to attach `api_key` for each request to indicate identification.
# async with streamablehttp_client("http://localhost:9382/mcp/", headers={"api_key": "ragflow-fixS-TicrohljzFkeLLWIaVhW7XlXPXIUW5solFor6o"}) as (read_stream, write_stream, _):
# Or follow the requirements of OAuth 2.1 Section 5 with Authorization header
# async with streamablehttp_client("http://localhost:9382/mcp/", headers={"Authorization": "Bearer ragflow-fixS-TicrohljzFkeLLWIaVhW7XlXPXIUW5solFor6o"}) as (read_stream, write_stream, _):
async with streamablehttp_client("http://localhost:9382/mcp/") as (read_stream, write_stream, _):
async with ClientSession(read_stream, write_stream) as session:
await session.initialize()

View File

@ -22,18 +22,18 @@ from collections import OrderedDict
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from functools import wraps
from typing import Any
import click
import httpx
import mcp.types as types
from mcp.server.lowlevel import Server
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.responses import JSONResponse, Response
from starlette.routing import Mount, Route
from strenum import StrEnum
import mcp.types as types
from mcp.server.lowlevel import Server
class LaunchMode(StrEnum):
SELF_HOST = "self-host"
@ -68,10 +68,6 @@ class RAGFlowConnector:
self.api_url = f"{self.base_url}/api/{self.version}"
self._async_client = None
def bind_api_key(self, api_key: str):
self.api_key = api_key
self.authorization_header = {"Authorization": f"Bearer {self.api_key}"}
async def _get_client(self):
if self._async_client is None:
self._async_client = httpx.AsyncClient(timeout=httpx.Timeout(60.0))
@ -82,16 +78,18 @@ class RAGFlowConnector:
await self._async_client.aclose()
self._async_client = None
async def _post(self, path, json=None, stream=False, files=None):
if not self.api_key:
async def _post(self, path, json=None, stream=False, files=None, api_key: str = ""):
if not api_key:
return None
client = await self._get_client()
res = await client.post(url=self.api_url + path, json=json, headers=self.authorization_header)
res = await client.post(url=self.api_url + path, json=json, headers={"Authorization": f"Bearer {api_key}"})
return res
async def _get(self, path, params=None):
async def _get(self, path, params=None, api_key: str = ""):
if not api_key:
return None
client = await self._get_client()
res = await client.get(url=self.api_url + path, params=params, headers=self.authorization_header)
res = await client.get(url=self.api_url + path, params=params, headers={"Authorization": f"Bearer {api_key}"})
return res
def _is_cache_valid(self, ts):
@ -129,8 +127,18 @@ class RAGFlowConnector:
self._document_metadata_cache[dataset_id] = (doc_id_meta_list, self._get_expiry_timestamp())
self._document_metadata_cache.move_to_end(dataset_id)
async def list_datasets(self, page: int = 1, page_size: int = 1000, orderby: str = "create_time", desc: bool = True, id: str | None = None, name: str | None = None):
res = await self._get("/datasets", {"page": page, "page_size": page_size, "orderby": orderby, "desc": desc, "id": id, "name": name})
async def list_datasets(
self,
*,
api_key: str,
page: int = 1,
page_size: int = 1000,
orderby: str = "create_time",
desc: bool = True,
id: str | None = None,
name: str | None = None,
):
res = await self._get("/datasets", {"page": page, "page_size": page_size, "orderby": orderby, "desc": desc, "id": id, "name": name}, api_key=api_key)
if not res or res.status_code != 200:
raise Exception([types.TextContent(type="text", text="Cannot process this operation.")])
@ -145,6 +153,8 @@ class RAGFlowConnector:
async def retrieval(
self,
*,
api_key: str,
dataset_ids,
document_ids=None,
question="",
@ -162,7 +172,7 @@ class RAGFlowConnector:
# If no dataset_ids provided or empty list, get all available dataset IDs
if not dataset_ids:
dataset_list_str = await self.list_datasets()
dataset_list_str = await self.list_datasets(api_key=api_key)
dataset_ids = []
# Parse the dataset list to extract IDs
@ -189,7 +199,7 @@ class RAGFlowConnector:
"document_ids": document_ids,
}
# Send a POST request to the backend service (using requests library as an example, actual implementation may vary)
res = await self._post("/retrieval", json=data_json)
res = await self._post("/retrieval", json=data_json, api_key=api_key)
if not res or res.status_code != 200:
raise Exception([types.TextContent(type="text", text="Cannot process this operation.")])
@ -199,7 +209,7 @@ class RAGFlowConnector:
chunks = []
# Cache document metadata and dataset information
document_cache, dataset_cache = await self._get_document_metadata_cache(dataset_ids, force_refresh=force_refresh)
document_cache, dataset_cache = await self._get_document_metadata_cache(dataset_ids, api_key=api_key, force_refresh=force_refresh)
# Process chunks with enhanced field mapping including per-chunk metadata
for chunk_data in data.get("chunks", []):
@ -228,7 +238,7 @@ class RAGFlowConnector:
raise Exception([types.TextContent(type="text", text=res.get("message"))])
async def _get_document_metadata_cache(self, dataset_ids, force_refresh=False):
async def _get_document_metadata_cache(self, dataset_ids, *, api_key: str, force_refresh=False):
"""Cache document metadata for all documents in the specified datasets"""
document_cache = {}
dataset_cache = {}
@ -238,7 +248,7 @@ class RAGFlowConnector:
dataset_meta = None if force_refresh else self._get_cached_dataset_metadata(dataset_id)
if not dataset_meta:
# First get dataset info for name
dataset_res = await self._get("/datasets", {"id": dataset_id, "page_size": 1})
dataset_res = await self._get("/datasets", {"id": dataset_id, "page_size": 1}, api_key=api_key)
if dataset_res and dataset_res.status_code == 200:
dataset_data = dataset_res.json()
if dataset_data.get("code") == 0 and dataset_data.get("data"):
@ -255,7 +265,9 @@ class RAGFlowConnector:
doc_id_meta_list = []
docs = {}
while page:
docs_res = await self._get(f"/datasets/{dataset_id}/documents?page={page}")
docs_res = await self._get(f"/datasets/{dataset_id}/documents?page={page}", api_key=api_key)
if not docs_res:
break
docs_data = docs_res.json()
if docs_data.get("code") == 0 and docs_data.get("data", {}).get("docs"):
for doc in docs_data["data"]["docs"]:
@ -335,9 +347,59 @@ async def sse_lifespan(server: Server) -> AsyncIterator[dict]:
app = Server("ragflow-mcp-server", lifespan=sse_lifespan)
AUTH_TOKEN_STATE_KEY = "ragflow_auth_token"
def with_api_key(required=True):
def _to_text(value: Any) -> str:
if isinstance(value, bytes):
return value.decode(errors="ignore")
return str(value)
def _extract_token_from_headers(headers: Any) -> str | None:
if not headers or not hasattr(headers, "get"):
return None
auth_keys = ("authorization", "Authorization", b"authorization", b"Authorization")
for key in auth_keys:
auth = headers.get(key)
if not auth:
continue
auth_text = _to_text(auth).strip()
if auth_text.lower().startswith("bearer "):
token = auth_text[7:].strip()
if token:
return token
api_key_keys = ("api_key", "x-api-key", "Api-Key", "X-API-Key", b"api_key", b"x-api-key", b"Api-Key", b"X-API-Key")
for key in api_key_keys:
token = headers.get(key)
if token:
token_text = _to_text(token).strip()
if token_text:
return token_text
return None
def _extract_token_from_request(request: Any) -> str | None:
if request is None:
return None
state = getattr(request, "state", None)
if state is not None:
token = getattr(state, AUTH_TOKEN_STATE_KEY, None)
if token:
return token
token = _extract_token_from_headers(getattr(request, "headers", None))
if token and state is not None:
setattr(state, AUTH_TOKEN_STATE_KEY, token)
return token
def with_api_key(required: bool = True):
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
@ -347,26 +409,14 @@ def with_api_key(required=True):
raise ValueError("Get RAGFlow Context failed")
connector = ragflow_ctx.conn
api_key = HOST_API_KEY
if MODE == LaunchMode.HOST:
headers = ctx.session._init_options.capabilities.experimental.get("headers", {})
token = None
# lower case here, because of Starlette conversion
auth = headers.get("authorization", "")
if auth.startswith("Bearer "):
token = auth.removeprefix("Bearer ").strip()
elif "api_key" in headers:
token = headers["api_key"]
if required and not token:
api_key = _extract_token_from_request(getattr(ctx, "request", None)) or ""
if required and not api_key:
raise ValueError("RAGFlow API key or Bearer token is required.")
connector.bind_api_key(token)
else:
connector.bind_api_key(HOST_API_KEY)
return await func(*args, connector=connector, **kwargs)
return await func(*args, connector=connector, api_key=api_key, **kwargs)
return wrapper
@ -375,8 +425,8 @@ def with_api_key(required=True):
@app.list_tools()
@with_api_key(required=True)
async def list_tools(*, connector) -> list[types.Tool]:
dataset_description = await connector.list_datasets()
async def list_tools(*, connector: RAGFlowConnector, api_key: str) -> list[types.Tool]:
dataset_description = await connector.list_datasets(api_key=api_key)
return [
types.Tool(
@ -446,7 +496,13 @@ async def list_tools(*, connector) -> list[types.Tool]:
@app.call_tool()
@with_api_key(required=True)
async def call_tool(name: str, arguments: dict, *, connector) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
async def call_tool(
name: str,
arguments: dict,
*,
connector: RAGFlowConnector,
api_key: str,
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
if name == "ragflow_retrieval":
document_ids = arguments.get("document_ids", [])
dataset_ids = arguments.get("dataset_ids", [])
@ -462,7 +518,7 @@ async def call_tool(name: str, arguments: dict, *, connector) -> list[types.Text
# If no dataset_ids provided or empty list, get all available dataset IDs
if not dataset_ids:
dataset_list_str = await connector.list_datasets()
dataset_list_str = await connector.list_datasets(api_key=api_key)
dataset_ids = []
# Parse the dataset list to extract IDs
@ -477,6 +533,7 @@ async def call_tool(name: str, arguments: dict, *, connector) -> list[types.Text
continue
return await connector.retrieval(
api_key=api_key,
dataset_ids=dataset_ids,
document_ids=document_ids,
question=question,
@ -510,17 +567,13 @@ def create_starlette_app():
path = scope["path"]
if path.startswith("/messages/") or path.startswith("/sse") or path.startswith("/mcp"):
headers = dict(scope["headers"])
token = None
auth_header = headers.get(b"authorization")
if auth_header and auth_header.startswith(b"Bearer "):
token = auth_header.removeprefix(b"Bearer ").strip()
elif b"api_key" in headers:
token = headers[b"api_key"]
token = _extract_token_from_headers(headers)
if not token:
response = JSONResponse({"error": "Missing or invalid authorization header"}, status_code=401)
await response(scope, receive, send)
return
scope.setdefault("state", {})[AUTH_TOKEN_STATE_KEY] = token
await self.app(scope, receive, send)
@ -547,9 +600,8 @@ def create_starlette_app():
# Add streamable HTTP route if enabled
streamablehttp_lifespan = None
if TRANSPORT_STREAMABLE_HTTP_ENABLED:
from starlette.types import Receive, Scope, Send
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
from starlette.types import Receive, Scope, Send
session_manager = StreamableHTTPSessionManager(
app=app,
@ -558,9 +610,12 @@ def create_starlette_app():
stateless=True,
)
async def handle_streamable_http(scope: Scope, receive: Receive, send: Send) -> None:
class StreamableHTTPEntry:
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
await session_manager.handle_request(scope, receive, send)
streamable_http_entry = StreamableHTTPEntry()
@asynccontextmanager
async def streamablehttp_lifespan(app: Starlette) -> AsyncIterator[None]:
async with session_manager.run():
@ -570,7 +625,12 @@ def create_starlette_app():
finally:
logging.info("StreamableHTTP application shutting down...")
routes.append(Mount("/mcp", app=handle_streamable_http))
routes.extend(
[
Route("/mcp", endpoint=streamable_http_entry, methods=["GET", "POST", "DELETE"]),
Mount("/mcp", app=streamable_http_entry),
]
)
return Starlette(
debug=True,
@ -631,9 +691,6 @@ def main(base_url, host, port, mode, api_key, transport_sse_enabled, transport_s
if MODE == LaunchMode.SELF_HOST and not HOST_API_KEY:
raise click.UsageError("--api-key is required when --mode is 'self-host'")
if TRANSPORT_STREAMABLE_HTTP_ENABLED and MODE == LaunchMode.HOST:
raise click.UsageError("The --host mode is not supported with streamable-http transport yet.")
if not TRANSPORT_STREAMABLE_HTTP_ENABLED and JSON_RESPONSE:
JSON_RESPONSE = False
@ -690,7 +747,7 @@ if __name__ == "__main__":
--base-url=http://127.0.0.1:9380 \
--mode=self-host --api-key=ragflow-xxxxx
2. Host mode (multi-tenant, self-host only, clients must provide Authorization headers):
2. Host mode (multi-tenant, clients must provide Authorization headers):
uv run mcp/server/server.py --host=127.0.0.1 --port=9382 \
--base-url=http://127.0.0.1:9380 \
--mode=host