diff --git a/mcp/server/server.py b/mcp/server/server.py index 88b548062..8d0d12c25 100644 --- a/mcp/server/server.py +++ b/mcp/server/server.py @@ -16,6 +16,9 @@ import json import logging +import random +import time +from collections import OrderedDict from collections.abc import AsyncIterator from contextlib import asynccontextmanager from functools import wraps @@ -53,6 +56,13 @@ JSON_RESPONSE = True class RAGFlowConnector: + _MAX_DATASET_CACHE = 32 + _MAX_DOCUMENT_CACHE = 128 + _CACHE_TTL = 300 + + _dataset_metadata_cache: OrderedDict[str, tuple[dict, float | int]] = OrderedDict() # "dataset_id" -> (metadata, expiry_ts) + _document_metadata_cache: OrderedDict[str, tuple[list[tuple[str, dict]], float | int]] = OrderedDict() # "dataset_id" -> ([(document_id, doc_metadata)], expiry_ts) + def __init__(self, base_url: str, version="v1"): self.base_url = base_url self.version = version @@ -72,6 +82,43 @@ class RAGFlowConnector: res = requests.get(url=self.api_url + path, params=params, headers=self.authorization_header, json=json) return res + def _is_cache_valid(self, ts): + return time.time() < ts + + def _get_expiry_timestamp(self): + offset = random.randint(-30, 30) + return time.time() + self._CACHE_TTL + offset + + def _get_cached_dataset_metadata(self, dataset_id): + entry = self._dataset_metadata_cache.get(dataset_id) + if entry: + data, ts = entry + if self._is_cache_valid(ts): + self._dataset_metadata_cache.move_to_end(dataset_id) + return data + return None + + def _set_cached_dataset_metadata(self, dataset_id, metadata): + self._dataset_metadata_cache[dataset_id] = (metadata, self._get_expiry_timestamp()) + self._dataset_metadata_cache.move_to_end(dataset_id) + if len(self._dataset_metadata_cache) > self._MAX_DATASET_CACHE: + self._dataset_metadata_cache.popitem(last=False) + + def _get_cached_document_metadata_by_dataset(self, dataset_id): + entry = self._document_metadata_cache.get(dataset_id) + if entry: + data_list, ts = entry + if self._is_cache_valid(ts): + self._document_metadata_cache.move_to_end(dataset_id) + return {doc_id: doc_meta for doc_id, doc_meta in data_list} + return None + + def _set_cached_document_metadata_by_dataset(self, dataset_id, doc_id_meta_list): + self._document_metadata_cache[dataset_id] = (doc_id_meta_list, self._get_expiry_timestamp()) + self._document_metadata_cache.move_to_end(dataset_id) + if len(self._document_metadata_cache) > self._MAX_DOCUMENT_CACHE: + self._document_metadata_cache.popitem(last=False) + def list_datasets(self, page: int = 1, page_size: int = 1000, orderby: str = "create_time", desc: bool = True, id: str | None = None, name: str | None = None): res = self._get("/datasets", {"page": page, "page_size": page_size, "orderby": orderby, "desc": desc, "id": id, "name": name}) if not res: @@ -87,10 +134,38 @@ class RAGFlowConnector: return "" def retrieval( - self, dataset_ids, document_ids=None, question="", page=1, page_size=30, similarity_threshold=0.2, vector_similarity_weight=0.3, top_k=1024, rerank_id: str | None = None, keyword: bool = False + self, + dataset_ids, + document_ids=None, + question="", + page=1, + page_size=30, + similarity_threshold=0.2, + vector_similarity_weight=0.3, + top_k=1024, + rerank_id: str | None = None, + keyword: bool = False, + force_refresh: bool = False, ): if document_ids is None: document_ids = [] + + # If no dataset_ids provided or empty list, get all available dataset IDs + if not dataset_ids: + dataset_list_str = self.list_datasets() + dataset_ids = [] + + # Parse the dataset list to extract IDs + if dataset_list_str: + for line in dataset_list_str.strip().split('\n'): + if line.strip(): + try: + dataset_info = json.loads(line.strip()) + dataset_ids.append(dataset_info["id"]) + except (json.JSONDecodeError, KeyError): + # Skip malformed lines + continue + data_json = { "page": page, "page_size": page_size, @@ -110,12 +185,127 @@ class RAGFlowConnector: res = res.json() if res.get("code") == 0: + data = res["data"] chunks = [] - for chunk_data in res["data"].get("chunks"): - chunks.append(json.dumps(chunk_data, ensure_ascii=False)) - return [types.TextContent(type="text", text="\n".join(chunks))] + + # Cache document metadata and dataset information + document_cache, dataset_cache = self._get_document_metadata_cache(dataset_ids, force_refresh=force_refresh) + + # Process chunks with enhanced field mapping including per-chunk metadata + for chunk_data in data.get("chunks", []): + enhanced_chunk = self._map_chunk_fields(chunk_data, dataset_cache, document_cache) + chunks.append(enhanced_chunk) + + # Build structured response (no longer need response-level document_metadata) + response = { + "chunks": chunks, + "pagination": { + "page": data.get("page", page), + "page_size": data.get("page_size", page_size), + "total_chunks": data.get("total", len(chunks)), + "total_pages": (data.get("total", len(chunks)) + page_size - 1) // page_size, + }, + "query_info": { + "question": question, + "similarity_threshold": similarity_threshold, + "vector_weight": vector_similarity_weight, + "keyword_search": keyword, + "dataset_count": len(dataset_ids), + }, + } + + return [types.TextContent(type="text", text=json.dumps(response, ensure_ascii=False))] + raise Exception([types.TextContent(type="text", text=res.get("message"))]) + def _get_document_metadata_cache(self, dataset_ids, force_refresh=False): + """Cache document metadata for all documents in the specified datasets""" + document_cache = {} + dataset_cache = {} + + try: + for dataset_id in dataset_ids: + dataset_meta = None if force_refresh else self._get_cached_dataset_metadata(dataset_id) + if not dataset_meta: + # First get dataset info for name + dataset_res = self._get("/datasets", {"id": dataset_id, "page_size": 1}) + if dataset_res and dataset_res.status_code == 200: + dataset_data = dataset_res.json() + if dataset_data.get("code") == 0 and dataset_data.get("data"): + dataset_info = dataset_data["data"][0] + dataset_meta = {"name": dataset_info.get("name", "Unknown"), "description": dataset_info.get("description", "")} + self._set_cached_dataset_metadata(dataset_id, dataset_meta) + if dataset_meta: + dataset_cache[dataset_id] = dataset_meta + + docs = None if force_refresh else self._get_cached_document_metadata_by_dataset(dataset_id) + if docs is None: + docs_res = self._get(f"/datasets/{dataset_id}/documents") + docs_data = docs_res.json() + if docs_data.get("code") == 0 and docs_data.get("data", {}).get("docs"): + doc_id_meta_list = [] + docs = {} + for doc in docs_data["data"]["docs"]: + doc_id = doc.get("id") + if not doc_id: + continue + doc_meta = { + "document_id": doc_id, + "name": doc.get("name", ""), + "location": doc.get("location", ""), + "type": doc.get("type", ""), + "size": doc.get("size"), + "chunk_count": doc.get("chunk_count"), + # "chunk_method": doc.get("chunk_method", ""), + "create_date": doc.get("create_date", ""), + "update_date": doc.get("update_date", ""), + # "process_begin_at": doc.get("process_begin_at", ""), + # "process_duration": doc.get("process_duration"), + # "progress": doc.get("progress"), + # "progress_msg": doc.get("progress_msg", ""), + # "status": doc.get("status", ""), + # "run": doc.get("run", ""), + "token_count": doc.get("token_count"), + # "source_type": doc.get("source_type", ""), + "thumbnail": doc.get("thumbnail", ""), + "dataset_id": doc.get("dataset_id", dataset_id), + "meta_fields": doc.get("meta_fields", {}), + # "parser_config": doc.get("parser_config", {}) + } + doc_id_meta_list.append((doc_id, doc_meta)) + docs[doc_id] = doc_meta + self._set_cached_document_metadata_by_dataset(dataset_id, doc_id_meta_list) + if docs: + document_cache.update(docs) + + except Exception: + # Gracefully handle metadata cache failures + pass + + return document_cache, dataset_cache + + def _map_chunk_fields(self, chunk_data, dataset_cache, document_cache): + """Preserve all original API fields and add per-chunk document metadata""" + # Start with ALL raw data from API (preserve everything like original version) + mapped = dict(chunk_data) + + # Add dataset name enhancement + dataset_id = chunk_data.get("dataset_id") or chunk_data.get("kb_id") + if dataset_id and dataset_id in dataset_cache: + mapped["dataset_name"] = dataset_cache[dataset_id]["name"] + else: + mapped["dataset_name"] = "Unknown" + + # Add document name convenience field + mapped["document_name"] = chunk_data.get("document_keyword", "") + + # Add per-chunk document metadata + document_id = chunk_data.get("document_id") + if document_id and document_id in document_cache: + mapped["document_metadata"] = document_cache[document_id] + + return mapped + class RAGFlowCtx: def __init__(self, connector: RAGFlowConnector): @@ -195,7 +385,58 @@ async def list_tools(*, connector) -> list[types.Tool]: "items": {"type": "string"}, "description": "Optional array of document IDs to search within." }, - "question": {"type": "string", "description": "The question or query to search for."}, + "question": { + "type": "string", + "description": "The question or query to search for." + }, + "page": { + "type": "integer", + "description": "Page number for pagination", + "default": 1, + "minimum": 1, + }, + "page_size": { + "type": "integer", + "description": "Number of results to return per page (default: 10, max recommended: 50 to avoid token limits)", + "default": 10, + "minimum": 1, + "maximum": 100, + }, + "similarity_threshold": { + "type": "number", + "description": "Minimum similarity threshold for results", + "default": 0.2, + "minimum": 0.0, + "maximum": 1.0, + }, + "vector_similarity_weight": { + "type": "number", + "description": "Weight for vector similarity vs term similarity", + "default": 0.3, + "minimum": 0.0, + "maximum": 1.0, + }, + "keyword": { + "type": "boolean", + "description": "Enable keyword-based search", + "default": False, + }, + "top_k": { + "type": "integer", + "description": "Maximum results to consider before ranking", + "default": 1024, + "minimum": 1, + "maximum": 1024, + }, + "rerank_id": { + "type": "string", + "description": "Optional reranking model identifier", + }, + "force_refresh": { + "type": "boolean", + "description": "Set to true only if fresh dataset and document metadata is explicitly required. Otherwise, cached metadata is used (default: false).", + "default": False, + }, }, "required": ["question"], }, @@ -209,6 +450,16 @@ async def call_tool(name: str, arguments: dict, *, connector) -> list[types.Text if name == "ragflow_retrieval": document_ids = arguments.get("document_ids", []) dataset_ids = arguments.get("dataset_ids", []) + question = arguments.get("question", "") + page = arguments.get("page", 1) + page_size = arguments.get("page_size", 10) + similarity_threshold = arguments.get("similarity_threshold", 0.2) + vector_similarity_weight = arguments.get("vector_similarity_weight", 0.3) + keyword = arguments.get("keyword", False) + top_k = arguments.get("top_k", 1024) + rerank_id = arguments.get("rerank_id") + force_refresh = arguments.get("force_refresh", False) + # If no dataset_ids provided or empty list, get all available dataset IDs if not dataset_ids: @@ -229,7 +480,15 @@ async def call_tool(name: str, arguments: dict, *, connector) -> list[types.Text return connector.retrieval( dataset_ids=dataset_ids, document_ids=document_ids, - question=arguments["question"], + question=question, + page=page, + page_size=page_size, + similarity_threshold=similarity_threshold, + vector_similarity_weight=vector_similarity_weight, + keyword=keyword, + top_k=top_k, + rerank_id=rerank_id, + force_refresh=force_refresh, ) raise ValueError(f"Tool not found: {name}")