Refa: convert RAGFlow MCP server from sync to async (#12834)

### What problem does this PR solve?

Convert RAGFlow MCP server from sync to async.

### Type of change

- [x] Refactoring
- [x] Performance Improvement
This commit is contained in:
Yongteng Lei
2026-01-27 12:45:43 +08:00
committed by GitHub
parent f096917eeb
commit c8338dec57

View File

@ -24,7 +24,7 @@ from contextlib import asynccontextmanager
from functools import wraps from functools import wraps
import click import click
import requests import httpx
from starlette.applications import Starlette from starlette.applications import Starlette
from starlette.middleware import Middleware from starlette.middleware import Middleware
from starlette.responses import JSONResponse, Response from starlette.responses import JSONResponse, Response
@ -66,19 +66,32 @@ class RAGFlowConnector:
self.base_url = base_url self.base_url = base_url
self.version = version self.version = version
self.api_url = f"{self.base_url}/api/{self.version}" self.api_url = f"{self.base_url}/api/{self.version}"
self._async_client = None
def bind_api_key(self, api_key: str): def bind_api_key(self, api_key: str):
self.api_key = api_key self.api_key = api_key
self.authorization_header = {"Authorization": "{} {}".format("Bearer", self.api_key)} self.authorization_header = {"Authorization": f"Bearer {self.api_key}"}
def _post(self, path, json=None, stream=False, files=None): async def _get_client(self):
if self._async_client is None:
self._async_client = httpx.AsyncClient(timeout=httpx.Timeout(60.0))
return self._async_client
async def close(self):
if self._async_client is not None:
await self._async_client.aclose()
self._async_client = None
async def _post(self, path, json=None, stream=False, files=None):
if not self.api_key: if not self.api_key:
return None return None
res = requests.post(url=self.api_url + path, json=json, headers=self.authorization_header, stream=stream, files=files) client = await self._get_client()
res = await client.post(url=self.api_url + path, json=json, headers=self.authorization_header)
return res return res
def _get(self, path, params=None, json=None): async def _get(self, path, params=None):
res = requests.get(url=self.api_url + path, params=params, headers=self.authorization_header, json=json) client = await self._get_client()
res = await client.get(url=self.api_url + path, params=params, headers=self.authorization_header)
return res return res
def _is_cache_valid(self, ts): def _is_cache_valid(self, ts):
@ -116,10 +129,10 @@ class RAGFlowConnector:
self._document_metadata_cache[dataset_id] = (doc_id_meta_list, self._get_expiry_timestamp()) self._document_metadata_cache[dataset_id] = (doc_id_meta_list, self._get_expiry_timestamp())
self._document_metadata_cache.move_to_end(dataset_id) self._document_metadata_cache.move_to_end(dataset_id)
def list_datasets(self, page: int = 1, page_size: int = 1000, orderby: str = "create_time", desc: bool = True, id: str | None = None, name: str | None = None): async def list_datasets(self, page: int = 1, page_size: int = 1000, orderby: str = "create_time", desc: bool = True, id: str | None = None, name: str | None = None):
res = self._get("/datasets", {"page": page, "page_size": page_size, "orderby": orderby, "desc": desc, "id": id, "name": name}) res = await self._get("/datasets", {"page": page, "page_size": page_size, "orderby": orderby, "desc": desc, "id": id, "name": name})
if not res: if not res or res.status_code != 200:
raise Exception([types.TextContent(type="text", text=res.get("Cannot process this operation."))]) raise Exception([types.TextContent(type="text", text="Cannot process this operation.")])
res = res.json() res = res.json()
if res.get("code") == 0: if res.get("code") == 0:
@ -130,7 +143,7 @@ class RAGFlowConnector:
return "\n".join(result_list) return "\n".join(result_list)
return "" return ""
def retrieval( async def retrieval(
self, self,
dataset_ids, dataset_ids,
document_ids=None, document_ids=None,
@ -146,15 +159,15 @@ class RAGFlowConnector:
): ):
if document_ids is None: if document_ids is None:
document_ids = [] document_ids = []
# If no dataset_ids provided or empty list, get all available dataset IDs # If no dataset_ids provided or empty list, get all available dataset IDs
if not dataset_ids: if not dataset_ids:
dataset_list_str = self.list_datasets() dataset_list_str = await self.list_datasets()
dataset_ids = [] dataset_ids = []
# Parse the dataset list to extract IDs # Parse the dataset list to extract IDs
if dataset_list_str: if dataset_list_str:
for line in dataset_list_str.strip().split('\n'): for line in dataset_list_str.strip().split("\n"):
if line.strip(): if line.strip():
try: try:
dataset_info = json.loads(line.strip()) dataset_info = json.loads(line.strip())
@ -162,7 +175,7 @@ class RAGFlowConnector:
except (json.JSONDecodeError, KeyError): except (json.JSONDecodeError, KeyError):
# Skip malformed lines # Skip malformed lines
continue continue
data_json = { data_json = {
"page": page, "page": page,
"page_size": page_size, "page_size": page_size,
@ -176,9 +189,9 @@ class RAGFlowConnector:
"document_ids": document_ids, "document_ids": document_ids,
} }
# Send a POST request to the backend service (using requests library as an example, actual implementation may vary) # Send a POST request to the backend service (using requests library as an example, actual implementation may vary)
res = self._post("/retrieval", json=data_json) res = await self._post("/retrieval", json=data_json)
if not res: if not res or res.status_code != 200:
raise Exception([types.TextContent(type="text", text=res.get("Cannot process this operation."))]) raise Exception([types.TextContent(type="text", text="Cannot process this operation.")])
res = res.json() res = res.json()
if res.get("code") == 0: if res.get("code") == 0:
@ -186,7 +199,7 @@ class RAGFlowConnector:
chunks = [] chunks = []
# Cache document metadata and dataset information # Cache document metadata and dataset information
document_cache, dataset_cache = self._get_document_metadata_cache(dataset_ids, force_refresh=force_refresh) document_cache, dataset_cache = await self._get_document_metadata_cache(dataset_ids, force_refresh=force_refresh)
# Process chunks with enhanced field mapping including per-chunk metadata # Process chunks with enhanced field mapping including per-chunk metadata
for chunk_data in data.get("chunks", []): for chunk_data in data.get("chunks", []):
@ -215,7 +228,7 @@ class RAGFlowConnector:
raise Exception([types.TextContent(type="text", text=res.get("message"))]) raise Exception([types.TextContent(type="text", text=res.get("message"))])
def _get_document_metadata_cache(self, dataset_ids, force_refresh=False): async def _get_document_metadata_cache(self, dataset_ids, force_refresh=False):
"""Cache document metadata for all documents in the specified datasets""" """Cache document metadata for all documents in the specified datasets"""
document_cache = {} document_cache = {}
dataset_cache = {} dataset_cache = {}
@ -225,7 +238,7 @@ class RAGFlowConnector:
dataset_meta = None if force_refresh else self._get_cached_dataset_metadata(dataset_id) dataset_meta = None if force_refresh else self._get_cached_dataset_metadata(dataset_id)
if not dataset_meta: if not dataset_meta:
# First get dataset info for name # First get dataset info for name
dataset_res = self._get("/datasets", {"id": dataset_id, "page_size": 1}) dataset_res = await self._get("/datasets", {"id": dataset_id, "page_size": 1})
if dataset_res and dataset_res.status_code == 200: if dataset_res and dataset_res.status_code == 200:
dataset_data = dataset_res.json() dataset_data = dataset_res.json()
if dataset_data.get("code") == 0 and dataset_data.get("data"): if dataset_data.get("code") == 0 and dataset_data.get("data"):
@ -242,7 +255,7 @@ class RAGFlowConnector:
doc_id_meta_list = [] doc_id_meta_list = []
docs = {} docs = {}
while page: while page:
docs_res = self._get(f"/datasets/{dataset_id}/documents?page={page}") docs_res = await self._get(f"/datasets/{dataset_id}/documents?page={page}")
docs_data = docs_res.json() docs_data = docs_res.json()
if docs_data.get("code") == 0 and docs_data.get("data", {}).get("docs"): if docs_data.get("code") == 0 and docs_data.get("data", {}).get("docs"):
for doc in docs_data["data"]["docs"]: for doc in docs_data["data"]["docs"]:
@ -317,6 +330,7 @@ async def sse_lifespan(server: Server) -> AsyncIterator[dict]:
try: try:
yield {"ragflow_ctx": ctx} yield {"ragflow_ctx": ctx}
finally: finally:
await ctx.conn.close()
logging.info("Legacy SSE application shutting down...") logging.info("Legacy SSE application shutting down...")
@ -362,7 +376,7 @@ def with_api_key(required=True):
@app.list_tools() @app.list_tools()
@with_api_key(required=True) @with_api_key(required=True)
async def list_tools(*, connector) -> list[types.Tool]: async def list_tools(*, connector) -> list[types.Tool]:
dataset_description = connector.list_datasets() dataset_description = await connector.list_datasets()
return [ return [
types.Tool( types.Tool(
@ -372,20 +386,9 @@ async def list_tools(*, connector) -> list[types.Tool]:
inputSchema={ inputSchema={
"type": "object", "type": "object",
"properties": { "properties": {
"dataset_ids": { "dataset_ids": {"type": "array", "items": {"type": "string"}, "description": "Optional array of dataset IDs to search. If not provided or empty, all datasets will be searched."},
"type": "array", "document_ids": {"type": "array", "items": {"type": "string"}, "description": "Optional array of document IDs to search within."},
"items": {"type": "string"}, "question": {"type": "string", "description": "The question or query to search for."},
"description": "Optional array of dataset IDs to search. If not provided or empty, all datasets will be searched."
},
"document_ids": {
"type": "array",
"items": {"type": "string"},
"description": "Optional array of document IDs to search within."
},
"question": {
"type": "string",
"description": "The question or query to search for."
},
"page": { "page": {
"type": "integer", "type": "integer",
"description": "Page number for pagination", "description": "Page number for pagination",
@ -457,15 +460,14 @@ async def call_tool(name: str, arguments: dict, *, connector) -> list[types.Text
rerank_id = arguments.get("rerank_id") rerank_id = arguments.get("rerank_id")
force_refresh = arguments.get("force_refresh", False) force_refresh = arguments.get("force_refresh", False)
# If no dataset_ids provided or empty list, get all available dataset IDs # If no dataset_ids provided or empty list, get all available dataset IDs
if not dataset_ids: if not dataset_ids:
dataset_list_str = connector.list_datasets() dataset_list_str = await connector.list_datasets()
dataset_ids = [] dataset_ids = []
# Parse the dataset list to extract IDs # Parse the dataset list to extract IDs
if dataset_list_str: if dataset_list_str:
for line in dataset_list_str.strip().split('\n'): for line in dataset_list_str.strip().split("\n"):
if line.strip(): if line.strip():
try: try:
dataset_info = json.loads(line.strip()) dataset_info = json.loads(line.strip())
@ -473,8 +475,8 @@ async def call_tool(name: str, arguments: dict, *, connector) -> list[types.Text
except (json.JSONDecodeError, KeyError): except (json.JSONDecodeError, KeyError):
# Skip malformed lines # Skip malformed lines
continue continue
return connector.retrieval( return await connector.retrieval(
dataset_ids=dataset_ids, dataset_ids=dataset_ids,
document_ids=document_ids, document_ids=document_ids,
question=question, question=question,