From c8338dec57173d32e35efba25651586635152d3f Mon Sep 17 00:00:00 2001 From: Yongteng Lei Date: Tue, 27 Jan 2026 12:45:43 +0800 Subject: [PATCH] Refa: convert RAGFlow MCP server from sync to async (#12834) ### What problem does this PR solve? Convert RAGFlow MCP server from sync to async. ### Type of change - [x] Refactoring - [x] Performance Improvement --- mcp/server/server.py | 90 ++++++++++++++++++++++---------------------- 1 file changed, 46 insertions(+), 44 deletions(-) diff --git a/mcp/server/server.py b/mcp/server/server.py index 8350b184b..3c83ea2b5 100644 --- a/mcp/server/server.py +++ b/mcp/server/server.py @@ -24,7 +24,7 @@ from contextlib import asynccontextmanager from functools import wraps import click -import requests +import httpx from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.responses import JSONResponse, Response @@ -66,19 +66,32 @@ class RAGFlowConnector: self.base_url = base_url self.version = version self.api_url = f"{self.base_url}/api/{self.version}" + self._async_client = None def bind_api_key(self, api_key: str): self.api_key = api_key - self.authorization_header = {"Authorization": "{} {}".format("Bearer", self.api_key)} + self.authorization_header = {"Authorization": f"Bearer {self.api_key}"} - def _post(self, path, json=None, stream=False, files=None): + async def _get_client(self): + if self._async_client is None: + self._async_client = httpx.AsyncClient(timeout=httpx.Timeout(60.0)) + return self._async_client + + async def close(self): + if self._async_client is not None: + await self._async_client.aclose() + self._async_client = None + + async def _post(self, path, json=None, stream=False, files=None): if not self.api_key: return None - res = requests.post(url=self.api_url + path, json=json, headers=self.authorization_header, stream=stream, files=files) + client = await self._get_client() + res = await client.post(url=self.api_url + path, json=json, headers=self.authorization_header) return res - def _get(self, path, params=None, json=None): - res = requests.get(url=self.api_url + path, params=params, headers=self.authorization_header, json=json) + async def _get(self, path, params=None): + client = await self._get_client() + res = await client.get(url=self.api_url + path, params=params, headers=self.authorization_header) return res def _is_cache_valid(self, ts): @@ -116,10 +129,10 @@ class RAGFlowConnector: self._document_metadata_cache[dataset_id] = (doc_id_meta_list, self._get_expiry_timestamp()) self._document_metadata_cache.move_to_end(dataset_id) - def list_datasets(self, page: int = 1, page_size: int = 1000, orderby: str = "create_time", desc: bool = True, id: str | None = None, name: str | None = None): - res = self._get("/datasets", {"page": page, "page_size": page_size, "orderby": orderby, "desc": desc, "id": id, "name": name}) - if not res: - raise Exception([types.TextContent(type="text", text=res.get("Cannot process this operation."))]) + async def list_datasets(self, page: int = 1, page_size: int = 1000, orderby: str = "create_time", desc: bool = True, id: str | None = None, name: str | None = None): + res = await self._get("/datasets", {"page": page, "page_size": page_size, "orderby": orderby, "desc": desc, "id": id, "name": name}) + if not res or res.status_code != 200: + raise Exception([types.TextContent(type="text", text="Cannot process this operation.")]) res = res.json() if res.get("code") == 0: @@ -130,7 +143,7 @@ class RAGFlowConnector: return "\n".join(result_list) return "" - def retrieval( + async def retrieval( self, dataset_ids, document_ids=None, @@ -146,15 +159,15 @@ class RAGFlowConnector: ): if document_ids is None: document_ids = [] - + # If no dataset_ids provided or empty list, get all available dataset IDs if not dataset_ids: - dataset_list_str = self.list_datasets() + dataset_list_str = await self.list_datasets() dataset_ids = [] - + # Parse the dataset list to extract IDs if dataset_list_str: - for line in dataset_list_str.strip().split('\n'): + for line in dataset_list_str.strip().split("\n"): if line.strip(): try: dataset_info = json.loads(line.strip()) @@ -162,7 +175,7 @@ class RAGFlowConnector: except (json.JSONDecodeError, KeyError): # Skip malformed lines continue - + data_json = { "page": page, "page_size": page_size, @@ -176,9 +189,9 @@ class RAGFlowConnector: "document_ids": document_ids, } # Send a POST request to the backend service (using requests library as an example, actual implementation may vary) - res = self._post("/retrieval", json=data_json) - if not res: - raise Exception([types.TextContent(type="text", text=res.get("Cannot process this operation."))]) + res = await self._post("/retrieval", json=data_json) + if not res or res.status_code != 200: + raise Exception([types.TextContent(type="text", text="Cannot process this operation.")]) res = res.json() if res.get("code") == 0: @@ -186,7 +199,7 @@ class RAGFlowConnector: chunks = [] # Cache document metadata and dataset information - document_cache, dataset_cache = self._get_document_metadata_cache(dataset_ids, force_refresh=force_refresh) + document_cache, dataset_cache = await self._get_document_metadata_cache(dataset_ids, force_refresh=force_refresh) # Process chunks with enhanced field mapping including per-chunk metadata for chunk_data in data.get("chunks", []): @@ -215,7 +228,7 @@ class RAGFlowConnector: raise Exception([types.TextContent(type="text", text=res.get("message"))]) - def _get_document_metadata_cache(self, dataset_ids, force_refresh=False): + async def _get_document_metadata_cache(self, dataset_ids, force_refresh=False): """Cache document metadata for all documents in the specified datasets""" document_cache = {} dataset_cache = {} @@ -225,7 +238,7 @@ class RAGFlowConnector: dataset_meta = None if force_refresh else self._get_cached_dataset_metadata(dataset_id) if not dataset_meta: # First get dataset info for name - dataset_res = self._get("/datasets", {"id": dataset_id, "page_size": 1}) + dataset_res = await self._get("/datasets", {"id": dataset_id, "page_size": 1}) if dataset_res and dataset_res.status_code == 200: dataset_data = dataset_res.json() if dataset_data.get("code") == 0 and dataset_data.get("data"): @@ -242,7 +255,7 @@ class RAGFlowConnector: doc_id_meta_list = [] docs = {} while page: - docs_res = self._get(f"/datasets/{dataset_id}/documents?page={page}") + docs_res = await self._get(f"/datasets/{dataset_id}/documents?page={page}") docs_data = docs_res.json() if docs_data.get("code") == 0 and docs_data.get("data", {}).get("docs"): for doc in docs_data["data"]["docs"]: @@ -317,6 +330,7 @@ async def sse_lifespan(server: Server) -> AsyncIterator[dict]: try: yield {"ragflow_ctx": ctx} finally: + await ctx.conn.close() logging.info("Legacy SSE application shutting down...") @@ -362,7 +376,7 @@ def with_api_key(required=True): @app.list_tools() @with_api_key(required=True) async def list_tools(*, connector) -> list[types.Tool]: - dataset_description = connector.list_datasets() + dataset_description = await connector.list_datasets() return [ types.Tool( @@ -372,20 +386,9 @@ async def list_tools(*, connector) -> list[types.Tool]: inputSchema={ "type": "object", "properties": { - "dataset_ids": { - "type": "array", - "items": {"type": "string"}, - "description": "Optional array of dataset IDs to search. If not provided or empty, all datasets will be searched." - }, - "document_ids": { - "type": "array", - "items": {"type": "string"}, - "description": "Optional array of document IDs to search within." - }, - "question": { - "type": "string", - "description": "The question or query to search for." - }, + "dataset_ids": {"type": "array", "items": {"type": "string"}, "description": "Optional array of dataset IDs to search. If not provided or empty, all datasets will be searched."}, + "document_ids": {"type": "array", "items": {"type": "string"}, "description": "Optional array of document IDs to search within."}, + "question": {"type": "string", "description": "The question or query to search for."}, "page": { "type": "integer", "description": "Page number for pagination", @@ -457,15 +460,14 @@ async def call_tool(name: str, arguments: dict, *, connector) -> list[types.Text rerank_id = arguments.get("rerank_id") force_refresh = arguments.get("force_refresh", False) - # If no dataset_ids provided or empty list, get all available dataset IDs if not dataset_ids: - dataset_list_str = connector.list_datasets() + dataset_list_str = await connector.list_datasets() dataset_ids = [] - + # Parse the dataset list to extract IDs if dataset_list_str: - for line in dataset_list_str.strip().split('\n'): + for line in dataset_list_str.strip().split("\n"): if line.strip(): try: dataset_info = json.loads(line.strip()) @@ -473,8 +475,8 @@ async def call_tool(name: str, arguments: dict, *, connector) -> list[types.Text except (json.JSONDecodeError, KeyError): # Skip malformed lines continue - - return connector.retrieval( + + return await connector.retrieval( dataset_ids=dataset_ids, document_ids=document_ids, question=question,