mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
### What problem does this PR solve? Expand the capabilities of the MCP Server. #8644. Special thanks to @Drasek, this change is largely based on his original implementation, it is super neat and well-structured to me. I basically just integrated his code into the codebase with minimal modifications. My main contribution is implementing a proper cache layer for dataset and document metadata, using the LRU strategy with a 300s ± random 30s TTL. The original code did not actually perform caching. ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: Caspar Armster <caspar@armster.de>
716 lines
30 KiB
Python
716 lines
30 KiB
Python
#
|
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
|
|
import json
|
|
import logging
|
|
import random
|
|
import time
|
|
from collections import OrderedDict
|
|
from collections.abc import AsyncIterator
|
|
from contextlib import asynccontextmanager
|
|
from functools import wraps
|
|
|
|
import click
|
|
import requests
|
|
from starlette.applications import Starlette
|
|
from starlette.middleware import Middleware
|
|
from starlette.responses import JSONResponse, Response
|
|
from starlette.routing import Mount, Route
|
|
from strenum import StrEnum
|
|
|
|
import mcp.types as types
|
|
from mcp.server.lowlevel import Server
|
|
|
|
|
|
class LaunchMode(StrEnum):
|
|
SELF_HOST = "self-host"
|
|
HOST = "host"
|
|
|
|
|
|
class Transport(StrEnum):
|
|
SSE = "sse"
|
|
STEAMABLE_HTTP = "streamable-http"
|
|
|
|
|
|
BASE_URL = "http://127.0.0.1:9380"
|
|
HOST = "127.0.0.1"
|
|
PORT = "9382"
|
|
HOST_API_KEY = ""
|
|
MODE = ""
|
|
TRANSPORT_SSE_ENABLED = True
|
|
TRANSPORT_STREAMABLE_HTTP_ENABLED = True
|
|
JSON_RESPONSE = True
|
|
|
|
|
|
class RAGFlowConnector:
|
|
_MAX_DATASET_CACHE = 32
|
|
_MAX_DOCUMENT_CACHE = 128
|
|
_CACHE_TTL = 300
|
|
|
|
_dataset_metadata_cache: OrderedDict[str, tuple[dict, float | int]] = OrderedDict() # "dataset_id" -> (metadata, expiry_ts)
|
|
_document_metadata_cache: OrderedDict[str, tuple[list[tuple[str, dict]], float | int]] = OrderedDict() # "dataset_id" -> ([(document_id, doc_metadata)], expiry_ts)
|
|
|
|
def __init__(self, base_url: str, version="v1"):
|
|
self.base_url = base_url
|
|
self.version = version
|
|
self.api_url = f"{self.base_url}/api/{self.version}"
|
|
|
|
def bind_api_key(self, api_key: str):
|
|
self.api_key = api_key
|
|
self.authorization_header = {"Authorization": "{} {}".format("Bearer", self.api_key)}
|
|
|
|
def _post(self, path, json=None, stream=False, files=None):
|
|
if not self.api_key:
|
|
return None
|
|
res = requests.post(url=self.api_url + path, json=json, headers=self.authorization_header, stream=stream, files=files)
|
|
return res
|
|
|
|
def _get(self, path, params=None, json=None):
|
|
res = requests.get(url=self.api_url + path, params=params, headers=self.authorization_header, json=json)
|
|
return res
|
|
|
|
def _is_cache_valid(self, ts):
|
|
return time.time() < ts
|
|
|
|
def _get_expiry_timestamp(self):
|
|
offset = random.randint(-30, 30)
|
|
return time.time() + self._CACHE_TTL + offset
|
|
|
|
def _get_cached_dataset_metadata(self, dataset_id):
|
|
entry = self._dataset_metadata_cache.get(dataset_id)
|
|
if entry:
|
|
data, ts = entry
|
|
if self._is_cache_valid(ts):
|
|
self._dataset_metadata_cache.move_to_end(dataset_id)
|
|
return data
|
|
return None
|
|
|
|
def _set_cached_dataset_metadata(self, dataset_id, metadata):
|
|
self._dataset_metadata_cache[dataset_id] = (metadata, self._get_expiry_timestamp())
|
|
self._dataset_metadata_cache.move_to_end(dataset_id)
|
|
if len(self._dataset_metadata_cache) > self._MAX_DATASET_CACHE:
|
|
self._dataset_metadata_cache.popitem(last=False)
|
|
|
|
def _get_cached_document_metadata_by_dataset(self, dataset_id):
|
|
entry = self._document_metadata_cache.get(dataset_id)
|
|
if entry:
|
|
data_list, ts = entry
|
|
if self._is_cache_valid(ts):
|
|
self._document_metadata_cache.move_to_end(dataset_id)
|
|
return {doc_id: doc_meta for doc_id, doc_meta in data_list}
|
|
return None
|
|
|
|
def _set_cached_document_metadata_by_dataset(self, dataset_id, doc_id_meta_list):
|
|
self._document_metadata_cache[dataset_id] = (doc_id_meta_list, self._get_expiry_timestamp())
|
|
self._document_metadata_cache.move_to_end(dataset_id)
|
|
if len(self._document_metadata_cache) > self._MAX_DOCUMENT_CACHE:
|
|
self._document_metadata_cache.popitem(last=False)
|
|
|
|
def list_datasets(self, page: int = 1, page_size: int = 1000, orderby: str = "create_time", desc: bool = True, id: str | None = None, name: str | None = None):
|
|
res = self._get("/datasets", {"page": page, "page_size": page_size, "orderby": orderby, "desc": desc, "id": id, "name": name})
|
|
if not res:
|
|
raise Exception([types.TextContent(type="text", text=res.get("Cannot process this operation."))])
|
|
|
|
res = res.json()
|
|
if res.get("code") == 0:
|
|
result_list = []
|
|
for data in res["data"]:
|
|
d = {"description": data["description"], "id": data["id"]}
|
|
result_list.append(json.dumps(d, ensure_ascii=False))
|
|
return "\n".join(result_list)
|
|
return ""
|
|
|
|
def retrieval(
|
|
self,
|
|
dataset_ids,
|
|
document_ids=None,
|
|
question="",
|
|
page=1,
|
|
page_size=30,
|
|
similarity_threshold=0.2,
|
|
vector_similarity_weight=0.3,
|
|
top_k=1024,
|
|
rerank_id: str | None = None,
|
|
keyword: bool = False,
|
|
force_refresh: bool = False,
|
|
):
|
|
if document_ids is None:
|
|
document_ids = []
|
|
|
|
# If no dataset_ids provided or empty list, get all available dataset IDs
|
|
if not dataset_ids:
|
|
dataset_list_str = self.list_datasets()
|
|
dataset_ids = []
|
|
|
|
# Parse the dataset list to extract IDs
|
|
if dataset_list_str:
|
|
for line in dataset_list_str.strip().split('\n'):
|
|
if line.strip():
|
|
try:
|
|
dataset_info = json.loads(line.strip())
|
|
dataset_ids.append(dataset_info["id"])
|
|
except (json.JSONDecodeError, KeyError):
|
|
# Skip malformed lines
|
|
continue
|
|
|
|
data_json = {
|
|
"page": page,
|
|
"page_size": page_size,
|
|
"similarity_threshold": similarity_threshold,
|
|
"vector_similarity_weight": vector_similarity_weight,
|
|
"top_k": top_k,
|
|
"rerank_id": rerank_id,
|
|
"keyword": keyword,
|
|
"question": question,
|
|
"dataset_ids": dataset_ids,
|
|
"document_ids": document_ids,
|
|
}
|
|
# Send a POST request to the backend service (using requests library as an example, actual implementation may vary)
|
|
res = self._post("/retrieval", json=data_json)
|
|
if not res:
|
|
raise Exception([types.TextContent(type="text", text=res.get("Cannot process this operation."))])
|
|
|
|
res = res.json()
|
|
if res.get("code") == 0:
|
|
data = res["data"]
|
|
chunks = []
|
|
|
|
# Cache document metadata and dataset information
|
|
document_cache, dataset_cache = self._get_document_metadata_cache(dataset_ids, force_refresh=force_refresh)
|
|
|
|
# Process chunks with enhanced field mapping including per-chunk metadata
|
|
for chunk_data in data.get("chunks", []):
|
|
enhanced_chunk = self._map_chunk_fields(chunk_data, dataset_cache, document_cache)
|
|
chunks.append(enhanced_chunk)
|
|
|
|
# Build structured response (no longer need response-level document_metadata)
|
|
response = {
|
|
"chunks": chunks,
|
|
"pagination": {
|
|
"page": data.get("page", page),
|
|
"page_size": data.get("page_size", page_size),
|
|
"total_chunks": data.get("total", len(chunks)),
|
|
"total_pages": (data.get("total", len(chunks)) + page_size - 1) // page_size,
|
|
},
|
|
"query_info": {
|
|
"question": question,
|
|
"similarity_threshold": similarity_threshold,
|
|
"vector_weight": vector_similarity_weight,
|
|
"keyword_search": keyword,
|
|
"dataset_count": len(dataset_ids),
|
|
},
|
|
}
|
|
|
|
return [types.TextContent(type="text", text=json.dumps(response, ensure_ascii=False))]
|
|
|
|
raise Exception([types.TextContent(type="text", text=res.get("message"))])
|
|
|
|
def _get_document_metadata_cache(self, dataset_ids, force_refresh=False):
|
|
"""Cache document metadata for all documents in the specified datasets"""
|
|
document_cache = {}
|
|
dataset_cache = {}
|
|
|
|
try:
|
|
for dataset_id in dataset_ids:
|
|
dataset_meta = None if force_refresh else self._get_cached_dataset_metadata(dataset_id)
|
|
if not dataset_meta:
|
|
# First get dataset info for name
|
|
dataset_res = self._get("/datasets", {"id": dataset_id, "page_size": 1})
|
|
if dataset_res and dataset_res.status_code == 200:
|
|
dataset_data = dataset_res.json()
|
|
if dataset_data.get("code") == 0 and dataset_data.get("data"):
|
|
dataset_info = dataset_data["data"][0]
|
|
dataset_meta = {"name": dataset_info.get("name", "Unknown"), "description": dataset_info.get("description", "")}
|
|
self._set_cached_dataset_metadata(dataset_id, dataset_meta)
|
|
if dataset_meta:
|
|
dataset_cache[dataset_id] = dataset_meta
|
|
|
|
docs = None if force_refresh else self._get_cached_document_metadata_by_dataset(dataset_id)
|
|
if docs is None:
|
|
docs_res = self._get(f"/datasets/{dataset_id}/documents")
|
|
docs_data = docs_res.json()
|
|
if docs_data.get("code") == 0 and docs_data.get("data", {}).get("docs"):
|
|
doc_id_meta_list = []
|
|
docs = {}
|
|
for doc in docs_data["data"]["docs"]:
|
|
doc_id = doc.get("id")
|
|
if not doc_id:
|
|
continue
|
|
doc_meta = {
|
|
"document_id": doc_id,
|
|
"name": doc.get("name", ""),
|
|
"location": doc.get("location", ""),
|
|
"type": doc.get("type", ""),
|
|
"size": doc.get("size"),
|
|
"chunk_count": doc.get("chunk_count"),
|
|
# "chunk_method": doc.get("chunk_method", ""),
|
|
"create_date": doc.get("create_date", ""),
|
|
"update_date": doc.get("update_date", ""),
|
|
# "process_begin_at": doc.get("process_begin_at", ""),
|
|
# "process_duration": doc.get("process_duration"),
|
|
# "progress": doc.get("progress"),
|
|
# "progress_msg": doc.get("progress_msg", ""),
|
|
# "status": doc.get("status", ""),
|
|
# "run": doc.get("run", ""),
|
|
"token_count": doc.get("token_count"),
|
|
# "source_type": doc.get("source_type", ""),
|
|
"thumbnail": doc.get("thumbnail", ""),
|
|
"dataset_id": doc.get("dataset_id", dataset_id),
|
|
"meta_fields": doc.get("meta_fields", {}),
|
|
# "parser_config": doc.get("parser_config", {})
|
|
}
|
|
doc_id_meta_list.append((doc_id, doc_meta))
|
|
docs[doc_id] = doc_meta
|
|
self._set_cached_document_metadata_by_dataset(dataset_id, doc_id_meta_list)
|
|
if docs:
|
|
document_cache.update(docs)
|
|
|
|
except Exception:
|
|
# Gracefully handle metadata cache failures
|
|
pass
|
|
|
|
return document_cache, dataset_cache
|
|
|
|
def _map_chunk_fields(self, chunk_data, dataset_cache, document_cache):
|
|
"""Preserve all original API fields and add per-chunk document metadata"""
|
|
# Start with ALL raw data from API (preserve everything like original version)
|
|
mapped = dict(chunk_data)
|
|
|
|
# Add dataset name enhancement
|
|
dataset_id = chunk_data.get("dataset_id") or chunk_data.get("kb_id")
|
|
if dataset_id and dataset_id in dataset_cache:
|
|
mapped["dataset_name"] = dataset_cache[dataset_id]["name"]
|
|
else:
|
|
mapped["dataset_name"] = "Unknown"
|
|
|
|
# Add document name convenience field
|
|
mapped["document_name"] = chunk_data.get("document_keyword", "")
|
|
|
|
# Add per-chunk document metadata
|
|
document_id = chunk_data.get("document_id")
|
|
if document_id and document_id in document_cache:
|
|
mapped["document_metadata"] = document_cache[document_id]
|
|
|
|
return mapped
|
|
|
|
|
|
class RAGFlowCtx:
|
|
def __init__(self, connector: RAGFlowConnector):
|
|
self.conn = connector
|
|
|
|
|
|
@asynccontextmanager
|
|
async def sse_lifespan(server: Server) -> AsyncIterator[dict]:
|
|
ctx = RAGFlowCtx(RAGFlowConnector(base_url=BASE_URL))
|
|
|
|
logging.info("Legacy SSE application started with StreamableHTTP session manager!")
|
|
try:
|
|
yield {"ragflow_ctx": ctx}
|
|
finally:
|
|
logging.info("Legacy SSE application shutting down...")
|
|
|
|
|
|
app = Server("ragflow-mcp-server", lifespan=sse_lifespan)
|
|
|
|
|
|
def with_api_key(required=True):
|
|
def decorator(func):
|
|
@wraps(func)
|
|
async def wrapper(*args, **kwargs):
|
|
ctx = app.request_context
|
|
ragflow_ctx = ctx.lifespan_context.get("ragflow_ctx")
|
|
if not ragflow_ctx:
|
|
raise ValueError("Get RAGFlow Context failed")
|
|
|
|
connector = ragflow_ctx.conn
|
|
|
|
if MODE == LaunchMode.HOST:
|
|
headers = ctx.session._init_options.capabilities.experimental.get("headers", {})
|
|
token = None
|
|
|
|
# lower case here, because of Starlette conversion
|
|
auth = headers.get("authorization", "")
|
|
if auth.startswith("Bearer "):
|
|
token = auth.removeprefix("Bearer ").strip()
|
|
elif "api_key" in headers:
|
|
token = headers["api_key"]
|
|
|
|
if required and not token:
|
|
raise ValueError("RAGFlow API key or Bearer token is required.")
|
|
|
|
connector.bind_api_key(token)
|
|
else:
|
|
connector.bind_api_key(HOST_API_KEY)
|
|
|
|
return await func(*args, connector=connector, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
return decorator
|
|
|
|
|
|
@app.list_tools()
|
|
@with_api_key(required=True)
|
|
async def list_tools(*, connector) -> list[types.Tool]:
|
|
dataset_description = connector.list_datasets()
|
|
|
|
return [
|
|
types.Tool(
|
|
name="ragflow_retrieval",
|
|
description="Retrieve relevant chunks from the RAGFlow retrieve interface based on the question. You can optionally specify dataset_ids to search only specific datasets, or omit dataset_ids entirely to search across ALL available datasets. You can also optionally specify document_ids to search within specific documents. When dataset_ids is not provided or is empty, the system will automatically search across all available datasets. Below is the list of all available datasets, including their descriptions and IDs:"
|
|
+ dataset_description,
|
|
inputSchema={
|
|
"type": "object",
|
|
"properties": {
|
|
"dataset_ids": {
|
|
"type": "array",
|
|
"items": {"type": "string"},
|
|
"description": "Optional array of dataset IDs to search. If not provided or empty, all datasets will be searched."
|
|
},
|
|
"document_ids": {
|
|
"type": "array",
|
|
"items": {"type": "string"},
|
|
"description": "Optional array of document IDs to search within."
|
|
},
|
|
"question": {
|
|
"type": "string",
|
|
"description": "The question or query to search for."
|
|
},
|
|
"page": {
|
|
"type": "integer",
|
|
"description": "Page number for pagination",
|
|
"default": 1,
|
|
"minimum": 1,
|
|
},
|
|
"page_size": {
|
|
"type": "integer",
|
|
"description": "Number of results to return per page (default: 10, max recommended: 50 to avoid token limits)",
|
|
"default": 10,
|
|
"minimum": 1,
|
|
"maximum": 100,
|
|
},
|
|
"similarity_threshold": {
|
|
"type": "number",
|
|
"description": "Minimum similarity threshold for results",
|
|
"default": 0.2,
|
|
"minimum": 0.0,
|
|
"maximum": 1.0,
|
|
},
|
|
"vector_similarity_weight": {
|
|
"type": "number",
|
|
"description": "Weight for vector similarity vs term similarity",
|
|
"default": 0.3,
|
|
"minimum": 0.0,
|
|
"maximum": 1.0,
|
|
},
|
|
"keyword": {
|
|
"type": "boolean",
|
|
"description": "Enable keyword-based search",
|
|
"default": False,
|
|
},
|
|
"top_k": {
|
|
"type": "integer",
|
|
"description": "Maximum results to consider before ranking",
|
|
"default": 1024,
|
|
"minimum": 1,
|
|
"maximum": 1024,
|
|
},
|
|
"rerank_id": {
|
|
"type": "string",
|
|
"description": "Optional reranking model identifier",
|
|
},
|
|
"force_refresh": {
|
|
"type": "boolean",
|
|
"description": "Set to true only if fresh dataset and document metadata is explicitly required. Otherwise, cached metadata is used (default: false).",
|
|
"default": False,
|
|
},
|
|
},
|
|
"required": ["question"],
|
|
},
|
|
),
|
|
]
|
|
|
|
|
|
@app.call_tool()
|
|
@with_api_key(required=True)
|
|
async def call_tool(name: str, arguments: dict, *, connector) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
|
|
if name == "ragflow_retrieval":
|
|
document_ids = arguments.get("document_ids", [])
|
|
dataset_ids = arguments.get("dataset_ids", [])
|
|
question = arguments.get("question", "")
|
|
page = arguments.get("page", 1)
|
|
page_size = arguments.get("page_size", 10)
|
|
similarity_threshold = arguments.get("similarity_threshold", 0.2)
|
|
vector_similarity_weight = arguments.get("vector_similarity_weight", 0.3)
|
|
keyword = arguments.get("keyword", False)
|
|
top_k = arguments.get("top_k", 1024)
|
|
rerank_id = arguments.get("rerank_id")
|
|
force_refresh = arguments.get("force_refresh", False)
|
|
|
|
|
|
# If no dataset_ids provided or empty list, get all available dataset IDs
|
|
if not dataset_ids:
|
|
dataset_list_str = connector.list_datasets()
|
|
dataset_ids = []
|
|
|
|
# Parse the dataset list to extract IDs
|
|
if dataset_list_str:
|
|
for line in dataset_list_str.strip().split('\n'):
|
|
if line.strip():
|
|
try:
|
|
dataset_info = json.loads(line.strip())
|
|
dataset_ids.append(dataset_info["id"])
|
|
except (json.JSONDecodeError, KeyError):
|
|
# Skip malformed lines
|
|
continue
|
|
|
|
return connector.retrieval(
|
|
dataset_ids=dataset_ids,
|
|
document_ids=document_ids,
|
|
question=question,
|
|
page=page,
|
|
page_size=page_size,
|
|
similarity_threshold=similarity_threshold,
|
|
vector_similarity_weight=vector_similarity_weight,
|
|
keyword=keyword,
|
|
top_k=top_k,
|
|
rerank_id=rerank_id,
|
|
force_refresh=force_refresh,
|
|
)
|
|
raise ValueError(f"Tool not found: {name}")
|
|
|
|
|
|
def create_starlette_app():
|
|
routes = []
|
|
middleware = None
|
|
if MODE == LaunchMode.HOST:
|
|
from starlette.types import ASGIApp, Receive, Scope, Send
|
|
|
|
class AuthMiddleware:
|
|
def __init__(self, app: ASGIApp):
|
|
self.app = app
|
|
|
|
async def __call__(self, scope: Scope, receive: Receive, send: Send):
|
|
if scope["type"] != "http":
|
|
await self.app(scope, receive, send)
|
|
return
|
|
|
|
path = scope["path"]
|
|
if path.startswith("/messages/") or path.startswith("/sse") or path.startswith("/mcp"):
|
|
headers = dict(scope["headers"])
|
|
token = None
|
|
auth_header = headers.get(b"authorization")
|
|
if auth_header and auth_header.startswith(b"Bearer "):
|
|
token = auth_header.removeprefix(b"Bearer ").strip()
|
|
elif b"api_key" in headers:
|
|
token = headers[b"api_key"]
|
|
|
|
if not token:
|
|
response = JSONResponse({"error": "Missing or invalid authorization header"}, status_code=401)
|
|
await response(scope, receive, send)
|
|
return
|
|
|
|
await self.app(scope, receive, send)
|
|
|
|
middleware = [Middleware(AuthMiddleware)]
|
|
|
|
# Add SSE routes if enabled
|
|
if TRANSPORT_SSE_ENABLED:
|
|
from mcp.server.sse import SseServerTransport
|
|
|
|
sse = SseServerTransport("/messages/")
|
|
|
|
async def handle_sse(request):
|
|
async with sse.connect_sse(request.scope, request.receive, request._send) as streams:
|
|
await app.run(streams[0], streams[1], app.create_initialization_options(experimental_capabilities={"headers": dict(request.headers)}))
|
|
return Response()
|
|
|
|
routes.extend(
|
|
[
|
|
Route("/sse", endpoint=handle_sse, methods=["GET"]),
|
|
Mount("/messages/", app=sse.handle_post_message),
|
|
]
|
|
)
|
|
|
|
# Add streamable HTTP route if enabled
|
|
streamablehttp_lifespan = None
|
|
if TRANSPORT_STREAMABLE_HTTP_ENABLED:
|
|
from starlette.types import Receive, Scope, Send
|
|
|
|
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
|
|
|
|
session_manager = StreamableHTTPSessionManager(
|
|
app=app,
|
|
event_store=None,
|
|
json_response=JSON_RESPONSE,
|
|
stateless=True,
|
|
)
|
|
|
|
async def handle_streamable_http(scope: Scope, receive: Receive, send: Send) -> None:
|
|
await session_manager.handle_request(scope, receive, send)
|
|
|
|
@asynccontextmanager
|
|
async def streamablehttp_lifespan(app: Starlette) -> AsyncIterator[None]:
|
|
async with session_manager.run():
|
|
logging.info("StreamableHTTP application started with StreamableHTTP session manager!")
|
|
try:
|
|
yield
|
|
finally:
|
|
logging.info("StreamableHTTP application shutting down...")
|
|
|
|
routes.append(Mount("/mcp", app=handle_streamable_http))
|
|
|
|
return Starlette(
|
|
debug=True,
|
|
routes=routes,
|
|
middleware=middleware,
|
|
lifespan=streamablehttp_lifespan,
|
|
)
|
|
|
|
|
|
@click.command()
|
|
@click.option("--base-url", type=str, default="http://127.0.0.1:9380", help="API base URL for RAGFlow backend")
|
|
@click.option("--host", type=str, default="127.0.0.1", help="Host to bind the RAGFlow MCP server")
|
|
@click.option("--port", type=int, default=9382, help="Port to bind the RAGFlow MCP server")
|
|
@click.option(
|
|
"--mode",
|
|
type=click.Choice(["self-host", "host"]),
|
|
default="self-host",
|
|
help=("Launch mode:\n self-host: run MCP for a single tenant (requires --api-key)\n host: multi-tenant mode, users must provide Authorization headers"),
|
|
)
|
|
@click.option("--api-key", type=str, default="", help="API key to use when in self-host mode")
|
|
@click.option(
|
|
"--transport-sse-enabled/--no-transport-sse-enabled",
|
|
default=True,
|
|
help="Enable or disable legacy SSE transport mode (default: enabled)",
|
|
)
|
|
@click.option(
|
|
"--transport-streamable-http-enabled/--no-transport-streamable-http-enabled",
|
|
default=True,
|
|
help="Enable or disable streamable-http transport mode (default: enabled)",
|
|
)
|
|
@click.option(
|
|
"--json-response/--no-json-response",
|
|
default=True,
|
|
help="Enable or disable JSON response mode for streamable-http (default: enabled)",
|
|
)
|
|
def main(base_url, host, port, mode, api_key, transport_sse_enabled, transport_streamable_http_enabled, json_response):
|
|
import os
|
|
|
|
import uvicorn
|
|
from dotenv import load_dotenv
|
|
|
|
load_dotenv()
|
|
|
|
def parse_bool_flag(key: str, default: bool) -> bool:
|
|
val = os.environ.get(key, str(default))
|
|
return str(val).strip().lower() in ("1", "true", "yes", "on")
|
|
|
|
global BASE_URL, HOST, PORT, MODE, HOST_API_KEY, TRANSPORT_SSE_ENABLED, TRANSPORT_STREAMABLE_HTTP_ENABLED, JSON_RESPONSE
|
|
BASE_URL = os.environ.get("RAGFLOW_MCP_BASE_URL", base_url)
|
|
HOST = os.environ.get("RAGFLOW_MCP_HOST", host)
|
|
PORT = os.environ.get("RAGFLOW_MCP_PORT", str(port))
|
|
MODE = os.environ.get("RAGFLOW_MCP_LAUNCH_MODE", mode)
|
|
HOST_API_KEY = os.environ.get("RAGFLOW_MCP_HOST_API_KEY", api_key)
|
|
TRANSPORT_SSE_ENABLED = parse_bool_flag("RAGFLOW_MCP_TRANSPORT_SSE_ENABLED", transport_sse_enabled)
|
|
TRANSPORT_STREAMABLE_HTTP_ENABLED = parse_bool_flag("RAGFLOW_MCP_TRANSPORT_STREAMABLE_ENABLED", transport_streamable_http_enabled)
|
|
JSON_RESPONSE = parse_bool_flag("RAGFLOW_MCP_JSON_RESPONSE", json_response)
|
|
|
|
if MODE == LaunchMode.SELF_HOST and not HOST_API_KEY:
|
|
raise click.UsageError("--api-key is required when --mode is 'self-host'")
|
|
|
|
if TRANSPORT_STREAMABLE_HTTP_ENABLED and MODE == LaunchMode.HOST:
|
|
raise click.UsageError("The --host mode is not supported with streamable-http transport yet.")
|
|
|
|
if not TRANSPORT_STREAMABLE_HTTP_ENABLED and JSON_RESPONSE:
|
|
JSON_RESPONSE = False
|
|
|
|
print(
|
|
r"""
|
|
__ __ ____ ____ ____ _____ ______ _______ ____
|
|
| \/ |/ ___| _ \ / ___|| ____| _ \ \ / / ____| _ \
|
|
| |\/| | | | |_) | \___ \| _| | |_) \ \ / /| _| | |_) |
|
|
| | | | |___| __/ ___) | |___| _ < \ V / | |___| _ <
|
|
|_| |_|\____|_| |____/|_____|_| \_\ \_/ |_____|_| \_\
|
|
""",
|
|
flush=True,
|
|
)
|
|
print(f"MCP launch mode: {MODE}", flush=True)
|
|
print(f"MCP host: {HOST}", flush=True)
|
|
print(f"MCP port: {PORT}", flush=True)
|
|
print(f"MCP base_url: {BASE_URL}", flush=True)
|
|
|
|
if not any([TRANSPORT_SSE_ENABLED, TRANSPORT_STREAMABLE_HTTP_ENABLED]):
|
|
print("At least one transport should be enabled, enable streamable-http automatically", flush=True)
|
|
TRANSPORT_STREAMABLE_HTTP_ENABLED = True
|
|
|
|
if TRANSPORT_SSE_ENABLED:
|
|
print("SSE transport enabled: yes", flush=True)
|
|
print("SSE endpoint available at /sse", flush=True)
|
|
else:
|
|
print("SSE transport enabled: no", flush=True)
|
|
|
|
if TRANSPORT_STREAMABLE_HTTP_ENABLED:
|
|
print("Streamable HTTP transport enabled: yes", flush=True)
|
|
print("Streamable HTTP endpoint available at /mcp", flush=True)
|
|
if JSON_RESPONSE:
|
|
print("Streamable HTTP mode: JSON response enabled", flush=True)
|
|
else:
|
|
print("Streamable HTTP mode: SSE over HTTP enabled", flush=True)
|
|
else:
|
|
print("Streamable HTTP transport enabled: no", flush=True)
|
|
if JSON_RESPONSE:
|
|
print("Warning: --json-response ignored because streamable transport is disabled.", flush=True)
|
|
|
|
uvicorn.run(
|
|
create_starlette_app(),
|
|
host=HOST,
|
|
port=int(PORT),
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
"""
|
|
Launch examples:
|
|
|
|
1. Self-host mode with both SSE and Streamable HTTP (in JSON response mode) enabled (default):
|
|
uv run mcp/server/server.py --host=127.0.0.1 --port=9382 \
|
|
--base-url=http://127.0.0.1:9380 \
|
|
--mode=self-host --api-key=ragflow-xxxxx
|
|
|
|
2. Host mode (multi-tenant, self-host only, clients must provide Authorization headers):
|
|
uv run mcp/server/server.py --host=127.0.0.1 --port=9382 \
|
|
--base-url=http://127.0.0.1:9380 \
|
|
--mode=host
|
|
|
|
3. Disable legacy SSE (only streamable HTTP will be active):
|
|
uv run mcp/server/server.py --no-transport-sse-enabled \
|
|
--mode=self-host --api-key=ragflow-xxxxx
|
|
|
|
4. Disable streamable HTTP (only legacy SSE will be active):
|
|
uv run mcp/server/server.py --no-transport-streamable-http-enabled \
|
|
--mode=self-host --api-key=ragflow-xxxxx
|
|
|
|
5. Use streamable HTTP with SSE-style events (disable JSON response):
|
|
uv run mcp/server/server.py --transport-streamable-http-enabled --no-json-response \
|
|
--mode=self-host --api-key=ragflow-xxxxx
|
|
|
|
6. Disable both transports (for testing):
|
|
uv run mcp/server/server.py --no-transport-sse-enabled --no-transport-streamable-http-enabled \
|
|
--mode=self-host --api-key=ragflow-xxxxx
|
|
"""
|
|
main()
|