Refa: convert RAGFlow MCP server from sync to async (#12834)

### What problem does this PR solve?

Convert RAGFlow MCP server from sync to async.

### Type of change

- [x] Refactoring
- [x] Performance Improvement
This commit is contained in:
Yongteng Lei
2026-01-27 12:45:43 +08:00
committed by GitHub
parent f096917eeb
commit c8338dec57

View File

@ -24,7 +24,7 @@ from contextlib import asynccontextmanager
from functools import wraps
import click
import requests
import httpx
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.responses import JSONResponse, Response
@ -66,19 +66,32 @@ class RAGFlowConnector:
self.base_url = base_url
self.version = version
self.api_url = f"{self.base_url}/api/{self.version}"
self._async_client = None
def bind_api_key(self, api_key: str):
self.api_key = api_key
self.authorization_header = {"Authorization": "{} {}".format("Bearer", self.api_key)}
self.authorization_header = {"Authorization": f"Bearer {self.api_key}"}
def _post(self, path, json=None, stream=False, files=None):
async def _get_client(self):
if self._async_client is None:
self._async_client = httpx.AsyncClient(timeout=httpx.Timeout(60.0))
return self._async_client
async def close(self):
if self._async_client is not None:
await self._async_client.aclose()
self._async_client = None
async def _post(self, path, json=None, stream=False, files=None):
if not self.api_key:
return None
res = requests.post(url=self.api_url + path, json=json, headers=self.authorization_header, stream=stream, files=files)
client = await self._get_client()
res = await client.post(url=self.api_url + path, json=json, headers=self.authorization_header)
return res
def _get(self, path, params=None, json=None):
res = requests.get(url=self.api_url + path, params=params, headers=self.authorization_header, json=json)
async def _get(self, path, params=None):
client = await self._get_client()
res = await client.get(url=self.api_url + path, params=params, headers=self.authorization_header)
return res
def _is_cache_valid(self, ts):
@ -116,10 +129,10 @@ class RAGFlowConnector:
self._document_metadata_cache[dataset_id] = (doc_id_meta_list, self._get_expiry_timestamp())
self._document_metadata_cache.move_to_end(dataset_id)
def list_datasets(self, page: int = 1, page_size: int = 1000, orderby: str = "create_time", desc: bool = True, id: str | None = None, name: str | None = None):
res = self._get("/datasets", {"page": page, "page_size": page_size, "orderby": orderby, "desc": desc, "id": id, "name": name})
if not res:
raise Exception([types.TextContent(type="text", text=res.get("Cannot process this operation."))])
async def list_datasets(self, page: int = 1, page_size: int = 1000, orderby: str = "create_time", desc: bool = True, id: str | None = None, name: str | None = None):
res = await self._get("/datasets", {"page": page, "page_size": page_size, "orderby": orderby, "desc": desc, "id": id, "name": name})
if not res or res.status_code != 200:
raise Exception([types.TextContent(type="text", text="Cannot process this operation.")])
res = res.json()
if res.get("code") == 0:
@ -130,7 +143,7 @@ class RAGFlowConnector:
return "\n".join(result_list)
return ""
def retrieval(
async def retrieval(
self,
dataset_ids,
document_ids=None,
@ -146,15 +159,15 @@ class RAGFlowConnector:
):
if document_ids is None:
document_ids = []
# If no dataset_ids provided or empty list, get all available dataset IDs
if not dataset_ids:
dataset_list_str = self.list_datasets()
dataset_list_str = await self.list_datasets()
dataset_ids = []
# Parse the dataset list to extract IDs
if dataset_list_str:
for line in dataset_list_str.strip().split('\n'):
for line in dataset_list_str.strip().split("\n"):
if line.strip():
try:
dataset_info = json.loads(line.strip())
@ -162,7 +175,7 @@ class RAGFlowConnector:
except (json.JSONDecodeError, KeyError):
# Skip malformed lines
continue
data_json = {
"page": page,
"page_size": page_size,
@ -176,9 +189,9 @@ class RAGFlowConnector:
"document_ids": document_ids,
}
# Send a POST request to the backend service (using requests library as an example, actual implementation may vary)
res = self._post("/retrieval", json=data_json)
if not res:
raise Exception([types.TextContent(type="text", text=res.get("Cannot process this operation."))])
res = await self._post("/retrieval", json=data_json)
if not res or res.status_code != 200:
raise Exception([types.TextContent(type="text", text="Cannot process this operation.")])
res = res.json()
if res.get("code") == 0:
@ -186,7 +199,7 @@ class RAGFlowConnector:
chunks = []
# Cache document metadata and dataset information
document_cache, dataset_cache = self._get_document_metadata_cache(dataset_ids, force_refresh=force_refresh)
document_cache, dataset_cache = await self._get_document_metadata_cache(dataset_ids, force_refresh=force_refresh)
# Process chunks with enhanced field mapping including per-chunk metadata
for chunk_data in data.get("chunks", []):
@ -215,7 +228,7 @@ class RAGFlowConnector:
raise Exception([types.TextContent(type="text", text=res.get("message"))])
def _get_document_metadata_cache(self, dataset_ids, force_refresh=False):
async def _get_document_metadata_cache(self, dataset_ids, force_refresh=False):
"""Cache document metadata for all documents in the specified datasets"""
document_cache = {}
dataset_cache = {}
@ -225,7 +238,7 @@ class RAGFlowConnector:
dataset_meta = None if force_refresh else self._get_cached_dataset_metadata(dataset_id)
if not dataset_meta:
# First get dataset info for name
dataset_res = self._get("/datasets", {"id": dataset_id, "page_size": 1})
dataset_res = await self._get("/datasets", {"id": dataset_id, "page_size": 1})
if dataset_res and dataset_res.status_code == 200:
dataset_data = dataset_res.json()
if dataset_data.get("code") == 0 and dataset_data.get("data"):
@ -242,7 +255,7 @@ class RAGFlowConnector:
doc_id_meta_list = []
docs = {}
while page:
docs_res = self._get(f"/datasets/{dataset_id}/documents?page={page}")
docs_res = await self._get(f"/datasets/{dataset_id}/documents?page={page}")
docs_data = docs_res.json()
if docs_data.get("code") == 0 and docs_data.get("data", {}).get("docs"):
for doc in docs_data["data"]["docs"]:
@ -317,6 +330,7 @@ async def sse_lifespan(server: Server) -> AsyncIterator[dict]:
try:
yield {"ragflow_ctx": ctx}
finally:
await ctx.conn.close()
logging.info("Legacy SSE application shutting down...")
@ -362,7 +376,7 @@ def with_api_key(required=True):
@app.list_tools()
@with_api_key(required=True)
async def list_tools(*, connector) -> list[types.Tool]:
dataset_description = connector.list_datasets()
dataset_description = await connector.list_datasets()
return [
types.Tool(
@ -372,20 +386,9 @@ async def list_tools(*, connector) -> list[types.Tool]:
inputSchema={
"type": "object",
"properties": {
"dataset_ids": {
"type": "array",
"items": {"type": "string"},
"description": "Optional array of dataset IDs to search. If not provided or empty, all datasets will be searched."
},
"document_ids": {
"type": "array",
"items": {"type": "string"},
"description": "Optional array of document IDs to search within."
},
"question": {
"type": "string",
"description": "The question or query to search for."
},
"dataset_ids": {"type": "array", "items": {"type": "string"}, "description": "Optional array of dataset IDs to search. If not provided or empty, all datasets will be searched."},
"document_ids": {"type": "array", "items": {"type": "string"}, "description": "Optional array of document IDs to search within."},
"question": {"type": "string", "description": "The question or query to search for."},
"page": {
"type": "integer",
"description": "Page number for pagination",
@ -457,15 +460,14 @@ async def call_tool(name: str, arguments: dict, *, connector) -> list[types.Text
rerank_id = arguments.get("rerank_id")
force_refresh = arguments.get("force_refresh", False)
# If no dataset_ids provided or empty list, get all available dataset IDs
if not dataset_ids:
dataset_list_str = connector.list_datasets()
dataset_list_str = await connector.list_datasets()
dataset_ids = []
# Parse the dataset list to extract IDs
if dataset_list_str:
for line in dataset_list_str.strip().split('\n'):
for line in dataset_list_str.strip().split("\n"):
if line.strip():
try:
dataset_info = json.loads(line.strip())
@ -473,8 +475,8 @@ async def call_tool(name: str, arguments: dict, *, connector) -> list[types.Text
except (json.JSONDecodeError, KeyError):
# Skip malformed lines
continue
return connector.retrieval(
return await connector.retrieval(
dataset_ids=dataset_ids,
document_ids=document_ids,
question=question,