mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-01-28 06:06:34 +08:00
Refa: convert RAGFlow MCP server from sync to async (#12834)
### What problem does this PR solve? Convert RAGFlow MCP server from sync to async. ### Type of change - [x] Refactoring - [x] Performance Improvement
This commit is contained in:
@ -24,7 +24,7 @@ from contextlib import asynccontextmanager
|
||||
from functools import wraps
|
||||
|
||||
import click
|
||||
import requests
|
||||
import httpx
|
||||
from starlette.applications import Starlette
|
||||
from starlette.middleware import Middleware
|
||||
from starlette.responses import JSONResponse, Response
|
||||
@ -66,19 +66,32 @@ class RAGFlowConnector:
|
||||
self.base_url = base_url
|
||||
self.version = version
|
||||
self.api_url = f"{self.base_url}/api/{self.version}"
|
||||
self._async_client = None
|
||||
|
||||
def bind_api_key(self, api_key: str):
|
||||
self.api_key = api_key
|
||||
self.authorization_header = {"Authorization": "{} {}".format("Bearer", self.api_key)}
|
||||
self.authorization_header = {"Authorization": f"Bearer {self.api_key}"}
|
||||
|
||||
def _post(self, path, json=None, stream=False, files=None):
|
||||
async def _get_client(self):
|
||||
if self._async_client is None:
|
||||
self._async_client = httpx.AsyncClient(timeout=httpx.Timeout(60.0))
|
||||
return self._async_client
|
||||
|
||||
async def close(self):
|
||||
if self._async_client is not None:
|
||||
await self._async_client.aclose()
|
||||
self._async_client = None
|
||||
|
||||
async def _post(self, path, json=None, stream=False, files=None):
|
||||
if not self.api_key:
|
||||
return None
|
||||
res = requests.post(url=self.api_url + path, json=json, headers=self.authorization_header, stream=stream, files=files)
|
||||
client = await self._get_client()
|
||||
res = await client.post(url=self.api_url + path, json=json, headers=self.authorization_header)
|
||||
return res
|
||||
|
||||
def _get(self, path, params=None, json=None):
|
||||
res = requests.get(url=self.api_url + path, params=params, headers=self.authorization_header, json=json)
|
||||
async def _get(self, path, params=None):
|
||||
client = await self._get_client()
|
||||
res = await client.get(url=self.api_url + path, params=params, headers=self.authorization_header)
|
||||
return res
|
||||
|
||||
def _is_cache_valid(self, ts):
|
||||
@ -116,10 +129,10 @@ class RAGFlowConnector:
|
||||
self._document_metadata_cache[dataset_id] = (doc_id_meta_list, self._get_expiry_timestamp())
|
||||
self._document_metadata_cache.move_to_end(dataset_id)
|
||||
|
||||
def list_datasets(self, page: int = 1, page_size: int = 1000, orderby: str = "create_time", desc: bool = True, id: str | None = None, name: str | None = None):
|
||||
res = self._get("/datasets", {"page": page, "page_size": page_size, "orderby": orderby, "desc": desc, "id": id, "name": name})
|
||||
if not res:
|
||||
raise Exception([types.TextContent(type="text", text=res.get("Cannot process this operation."))])
|
||||
async def list_datasets(self, page: int = 1, page_size: int = 1000, orderby: str = "create_time", desc: bool = True, id: str | None = None, name: str | None = None):
|
||||
res = await self._get("/datasets", {"page": page, "page_size": page_size, "orderby": orderby, "desc": desc, "id": id, "name": name})
|
||||
if not res or res.status_code != 200:
|
||||
raise Exception([types.TextContent(type="text", text="Cannot process this operation.")])
|
||||
|
||||
res = res.json()
|
||||
if res.get("code") == 0:
|
||||
@ -130,7 +143,7 @@ class RAGFlowConnector:
|
||||
return "\n".join(result_list)
|
||||
return ""
|
||||
|
||||
def retrieval(
|
||||
async def retrieval(
|
||||
self,
|
||||
dataset_ids,
|
||||
document_ids=None,
|
||||
@ -146,15 +159,15 @@ class RAGFlowConnector:
|
||||
):
|
||||
if document_ids is None:
|
||||
document_ids = []
|
||||
|
||||
|
||||
# If no dataset_ids provided or empty list, get all available dataset IDs
|
||||
if not dataset_ids:
|
||||
dataset_list_str = self.list_datasets()
|
||||
dataset_list_str = await self.list_datasets()
|
||||
dataset_ids = []
|
||||
|
||||
|
||||
# Parse the dataset list to extract IDs
|
||||
if dataset_list_str:
|
||||
for line in dataset_list_str.strip().split('\n'):
|
||||
for line in dataset_list_str.strip().split("\n"):
|
||||
if line.strip():
|
||||
try:
|
||||
dataset_info = json.loads(line.strip())
|
||||
@ -162,7 +175,7 @@ class RAGFlowConnector:
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
# Skip malformed lines
|
||||
continue
|
||||
|
||||
|
||||
data_json = {
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
@ -176,9 +189,9 @@ class RAGFlowConnector:
|
||||
"document_ids": document_ids,
|
||||
}
|
||||
# Send a POST request to the backend service (using requests library as an example, actual implementation may vary)
|
||||
res = self._post("/retrieval", json=data_json)
|
||||
if not res:
|
||||
raise Exception([types.TextContent(type="text", text=res.get("Cannot process this operation."))])
|
||||
res = await self._post("/retrieval", json=data_json)
|
||||
if not res or res.status_code != 200:
|
||||
raise Exception([types.TextContent(type="text", text="Cannot process this operation.")])
|
||||
|
||||
res = res.json()
|
||||
if res.get("code") == 0:
|
||||
@ -186,7 +199,7 @@ class RAGFlowConnector:
|
||||
chunks = []
|
||||
|
||||
# Cache document metadata and dataset information
|
||||
document_cache, dataset_cache = self._get_document_metadata_cache(dataset_ids, force_refresh=force_refresh)
|
||||
document_cache, dataset_cache = await self._get_document_metadata_cache(dataset_ids, force_refresh=force_refresh)
|
||||
|
||||
# Process chunks with enhanced field mapping including per-chunk metadata
|
||||
for chunk_data in data.get("chunks", []):
|
||||
@ -215,7 +228,7 @@ class RAGFlowConnector:
|
||||
|
||||
raise Exception([types.TextContent(type="text", text=res.get("message"))])
|
||||
|
||||
def _get_document_metadata_cache(self, dataset_ids, force_refresh=False):
|
||||
async def _get_document_metadata_cache(self, dataset_ids, force_refresh=False):
|
||||
"""Cache document metadata for all documents in the specified datasets"""
|
||||
document_cache = {}
|
||||
dataset_cache = {}
|
||||
@ -225,7 +238,7 @@ class RAGFlowConnector:
|
||||
dataset_meta = None if force_refresh else self._get_cached_dataset_metadata(dataset_id)
|
||||
if not dataset_meta:
|
||||
# First get dataset info for name
|
||||
dataset_res = self._get("/datasets", {"id": dataset_id, "page_size": 1})
|
||||
dataset_res = await self._get("/datasets", {"id": dataset_id, "page_size": 1})
|
||||
if dataset_res and dataset_res.status_code == 200:
|
||||
dataset_data = dataset_res.json()
|
||||
if dataset_data.get("code") == 0 and dataset_data.get("data"):
|
||||
@ -242,7 +255,7 @@ class RAGFlowConnector:
|
||||
doc_id_meta_list = []
|
||||
docs = {}
|
||||
while page:
|
||||
docs_res = self._get(f"/datasets/{dataset_id}/documents?page={page}")
|
||||
docs_res = await self._get(f"/datasets/{dataset_id}/documents?page={page}")
|
||||
docs_data = docs_res.json()
|
||||
if docs_data.get("code") == 0 and docs_data.get("data", {}).get("docs"):
|
||||
for doc in docs_data["data"]["docs"]:
|
||||
@ -317,6 +330,7 @@ async def sse_lifespan(server: Server) -> AsyncIterator[dict]:
|
||||
try:
|
||||
yield {"ragflow_ctx": ctx}
|
||||
finally:
|
||||
await ctx.conn.close()
|
||||
logging.info("Legacy SSE application shutting down...")
|
||||
|
||||
|
||||
@ -362,7 +376,7 @@ def with_api_key(required=True):
|
||||
@app.list_tools()
|
||||
@with_api_key(required=True)
|
||||
async def list_tools(*, connector) -> list[types.Tool]:
|
||||
dataset_description = connector.list_datasets()
|
||||
dataset_description = await connector.list_datasets()
|
||||
|
||||
return [
|
||||
types.Tool(
|
||||
@ -372,20 +386,9 @@ async def list_tools(*, connector) -> list[types.Tool]:
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"dataset_ids": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Optional array of dataset IDs to search. If not provided or empty, all datasets will be searched."
|
||||
},
|
||||
"document_ids": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Optional array of document IDs to search within."
|
||||
},
|
||||
"question": {
|
||||
"type": "string",
|
||||
"description": "The question or query to search for."
|
||||
},
|
||||
"dataset_ids": {"type": "array", "items": {"type": "string"}, "description": "Optional array of dataset IDs to search. If not provided or empty, all datasets will be searched."},
|
||||
"document_ids": {"type": "array", "items": {"type": "string"}, "description": "Optional array of document IDs to search within."},
|
||||
"question": {"type": "string", "description": "The question or query to search for."},
|
||||
"page": {
|
||||
"type": "integer",
|
||||
"description": "Page number for pagination",
|
||||
@ -457,15 +460,14 @@ async def call_tool(name: str, arguments: dict, *, connector) -> list[types.Text
|
||||
rerank_id = arguments.get("rerank_id")
|
||||
force_refresh = arguments.get("force_refresh", False)
|
||||
|
||||
|
||||
# If no dataset_ids provided or empty list, get all available dataset IDs
|
||||
if not dataset_ids:
|
||||
dataset_list_str = connector.list_datasets()
|
||||
dataset_list_str = await connector.list_datasets()
|
||||
dataset_ids = []
|
||||
|
||||
|
||||
# Parse the dataset list to extract IDs
|
||||
if dataset_list_str:
|
||||
for line in dataset_list_str.strip().split('\n'):
|
||||
for line in dataset_list_str.strip().split("\n"):
|
||||
if line.strip():
|
||||
try:
|
||||
dataset_info = json.loads(line.strip())
|
||||
@ -473,8 +475,8 @@ async def call_tool(name: str, arguments: dict, *, connector) -> list[types.Text
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
# Skip malformed lines
|
||||
continue
|
||||
|
||||
return connector.retrieval(
|
||||
|
||||
return await connector.retrieval(
|
||||
dataset_ids=dataset_ids,
|
||||
document_ids=document_ids,
|
||||
question=question,
|
||||
|
||||
Reference in New Issue
Block a user