From c642dbefca3b2422c3a968e4d816deee8b7505d1 Mon Sep 17 00:00:00 2001 From: Kevin Hu Date: Tue, 15 Jul 2025 09:36:45 +0800 Subject: [PATCH] Perf: Enhance timeout handling. (#8826) ### What problem does this PR solve? ### Type of change - [x] Performance Improvement --- api/apps/mcp_server_app.py | 20 +++- api/utils/api_utils.py | 108 ++++++++++++++++++ api/utils/mcp_server.py | 34 ------ .../general/community_reports_extractor.py | 5 +- graphrag/general/extractor.py | 2 + graphrag/general/index.py | 6 + graphrag/utils.py | 6 + rag/raptor.py | 3 + rag/svr/task_executor.py | 5 + rag/utils/redis_conn.py | 103 +++++++++-------- 10 files changed, 207 insertions(+), 85 deletions(-) delete mode 100644 api/utils/mcp_server.py diff --git a/api/apps/mcp_server_app.py b/api/apps/mcp_server_app.py index faa53328e..ef1fba79c 100644 --- a/api/apps/mcp_server_app.py +++ b/api/apps/mcp_server_app.py @@ -1,3 +1,18 @@ +# +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# from flask import Response, request from flask_login import current_user, login_required @@ -6,9 +21,10 @@ from api.db.db_models import MCPServer from api.db.services.mcp_server_service import MCPServerService from api.db.services.user_service import TenantService from api.settings import RetCode + from api.utils import get_uuid -from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request -from api.utils.mcp_server import get_mcp_tools +from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request, \ + get_mcp_tools from api.utils.web_utils import get_float, safe_json_parse from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions diff --git a/api/utils/api_utils.py b/api/utils/api_utils.py index 71e354a34..546cb2dfb 100644 --- a/api/utils/api_utils.py +++ b/api/utils/api_utils.py @@ -13,19 +13,29 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import asyncio import functools import json import logging +import queue import random +import threading import time from base64 import b64encode from copy import deepcopy from functools import wraps from hmac import HMAC from io import BytesIO +from typing import Any, Optional, Union, Callable, Coroutine, Type from urllib.parse import quote, urlencode from uuid import uuid1 +import trio + +from api.db.db_models import MCPServer +from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions + + import requests from flask import ( Response, @@ -558,3 +568,101 @@ def remap_dictionary_keys(source_data: dict, key_aliases: dict = None) -> dict: transformed_data[mapped_key] = value return transformed_data + + +def get_mcp_tools(mcp_servers: list[MCPServer], timeout: float | int = 10) -> tuple[dict, str]: + results = {} + tool_call_sessions = [] + try: + for mcp_server in mcp_servers: + server_key = mcp_server.id + + cached_tools = mcp_server.variables.get("tools", {}) + + tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables) + tool_call_sessions.append(tool_call_session) + + try: + tools = tool_call_session.get_tools(timeout) + except Exception: + tools = [] + + results[server_key] = [] + for tool in tools: + tool_dict = tool.model_dump() + cached_tool = cached_tools.get(tool_dict["name"], {}) + + tool_dict["enabled"] = cached_tool.get("enabled", True) + results[server_key].append(tool_dict) + + # PERF: blocking call to close sessions — consider moving to background thread or task queue + close_multiple_mcp_toolcall_sessions(tool_call_sessions) + return results, "" + except Exception as e: + return {}, str(e) + + +TimeoutException = Union[Type[BaseException], BaseException] +OnTimeoutCallback = Union[Callable[..., Any], Coroutine[Any, Any, Any]] +def timeout( + seconds: float |int = None, + *, + exception: Optional[TimeoutException] = None, + on_timeout: Optional[OnTimeoutCallback] = None +): + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + result_queue = queue.Queue(maxsize=1) + def target(): + try: + result = func(*args, **kwargs) + result_queue.put(result) + except Exception as e: + result_queue.put(e) + + thread = threading.Thread(target=target) + thread.daemon = True + thread.start() + + try: + result = result_queue.get(timeout=seconds) + if isinstance(result, Exception): + raise result + return result + except queue.Empty: + raise TimeoutError(f"Function '{func.__name__}' timed out after {seconds} seconds") + + @wraps(func) + async def async_wrapper(*args, **kwargs) -> Any: + if seconds is None: + return await func(*args, **kwargs) + + try: + with trio.fail_after(seconds): + return await func(*args, **kwargs) + except trio.TooSlowError: + if on_timeout is not None: + if callable(on_timeout): + result = on_timeout() + if isinstance(result, Coroutine): + return await result + return result + return on_timeout + + if exception is None: + raise TimeoutError(f"Operation timed out after {seconds} seconds") + + if isinstance(exception, BaseException): + raise exception + + if isinstance(exception, type) and issubclass(exception, BaseException): + raise exception(f"Operation timed out after {seconds} seconds") + + raise RuntimeError("Invalid exception type provided") + + if asyncio.iscoroutinefunction(func): + return async_wrapper + return wrapper + return decorator + diff --git a/api/utils/mcp_server.py b/api/utils/mcp_server.py deleted file mode 100644 index 83b168711..000000000 --- a/api/utils/mcp_server.py +++ /dev/null @@ -1,34 +0,0 @@ -from api.db.db_models import MCPServer -from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions - - -def get_mcp_tools(mcp_servers: list[MCPServer], timeout: float | int = 10) -> tuple[dict, str]: - results = {} - tool_call_sessions = [] - try: - for mcp_server in mcp_servers: - server_key = mcp_server.id - - cached_tools = mcp_server.variables.get("tools", {}) - - tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables) - tool_call_sessions.append(tool_call_session) - - try: - tools = tool_call_session.get_tools(timeout) - except Exception: - tools = [] - - results[server_key] = [] - for tool in tools: - tool_dict = tool.model_dump() - cached_tool = cached_tools.get(tool_dict["name"], {}) - - tool_dict["enabled"] = cached_tool.get("enabled", True) - results[server_key].append(tool_dict) - - # PERF: blocking call to close sessions — consider moving to background thread or task queue - close_multiple_mcp_toolcall_sessions(tool_call_sessions) - return results, "" - except Exception as e: - return {}, str(e) diff --git a/graphrag/general/community_reports_extractor.py b/graphrag/general/community_reports_extractor.py index 4d8b33bfd..a400d3035 100644 --- a/graphrag/general/community_reports_extractor.py +++ b/graphrag/general/community_reports_extractor.py @@ -12,6 +12,8 @@ from typing import Callable from dataclasses import dataclass import networkx as nx import pandas as pd + +from api.utils.api_utils import timeout from graphrag.general import leiden from graphrag.general.community_report_prompt import COMMUNITY_REPORT_PROMPT from graphrag.general.extractor import Extractor @@ -57,6 +59,7 @@ class CommunityReportsExtractor(Extractor): res_str = [] res_dict = [] over, token_count = 0, 0 + @timeout(120) async def extract_community_report(community): nonlocal res_str, res_dict, over, token_count cm_id, cm = community @@ -90,7 +93,7 @@ class CommunityReportsExtractor(Extractor): gen_conf = {"temperature": 0.3} async with chat_limiter: try: - with trio.move_on_after(120) as cancel_scope: + with trio.move_on_after(80) as cancel_scope: response = await trio.to_thread.run_sync( self._chat, text, [{"role": "user", "content": "Output:"}], gen_conf) if cancel_scope.cancelled_caught: logging.warning("extract_community_report._chat timeout, skipping...") diff --git a/graphrag/general/extractor.py b/graphrag/general/extractor.py index 1bd9f2d8d..23134425f 100644 --- a/graphrag/general/extractor.py +++ b/graphrag/general/extractor.py @@ -21,6 +21,7 @@ from typing import Callable import trio import networkx as nx +from api.utils.api_utils import timeout from graphrag.general.graph_prompt import SUMMARIZE_DESCRIPTIONS_PROMPT from graphrag.utils import get_llm_cache, set_llm_cache, handle_single_entity_extraction, \ handle_single_relationship_extraction, split_string_by_multi_markers, flat_uniq_list, chat_limiter, get_from_to, GraphChange @@ -46,6 +47,7 @@ class Extractor: self._language = language self._entity_types = entity_types or DEFAULT_ENTITY_TYPES + @timeout(60) def _chat(self, system, history, gen_conf): hist = deepcopy(history) conf = deepcopy(gen_conf) diff --git a/graphrag/general/index.py b/graphrag/general/index.py index 6e107bc87..fe54747f4 100644 --- a/graphrag/general/index.py +++ b/graphrag/general/index.py @@ -20,6 +20,7 @@ import trio from api import settings from api.utils import get_uuid +from api.utils.api_utils import timeout from graphrag.light.graph_extractor import GraphExtractor as LightKGExt from graphrag.general.graph_extractor import GraphExtractor as GeneralKGExt from graphrag.general.community_reports_extractor import CommunityReportsExtractor @@ -123,6 +124,7 @@ async def run_graphrag( return +@timeout(60*60*2) async def generate_subgraph( extractor: Extractor, tenant_id: str, @@ -194,6 +196,8 @@ async def generate_subgraph( callback(msg=f"generated subgraph for doc {doc_id} in {now - start:.2f} seconds.") return subgraph + +@timeout(60*3) async def merge_subgraph( tenant_id: str, kb_id: str, @@ -225,6 +229,7 @@ async def merge_subgraph( return new_graph +@timeout(60*60) async def resolve_entities( graph, subgraph_nodes: set[str], @@ -250,6 +255,7 @@ async def resolve_entities( callback(msg=f"Graph resolution done in {now - start:.2f}s.") +@timeout(60*30) async def extract_community( graph, tenant_id: str, diff --git a/graphrag/utils.py b/graphrag/utils.py index 81df2a24b..d414f2c41 100644 --- a/graphrag/utils.py +++ b/graphrag/utils.py @@ -157,6 +157,7 @@ def set_tags_to_cache(kb_ids, tags): k = hasher.hexdigest() REDIS_CONN.set(k, json.dumps(tags).encode("utf-8"), 600) + def tidy_graph(graph: nx.Graph, callback, check_attribute: bool = True): """ Ensure all nodes and edges in the graph have some essential attribute. @@ -190,12 +191,14 @@ def tidy_graph(graph: nx.Graph, callback, check_attribute: bool = True): if purged_edges and callback: callback(msg=f"Purged {len(purged_edges)} edges from graph due to missing essential attributes.") + def get_from_to(node1, node2): if node1 < node2: return (node1, node2) else: return (node2, node1) + def graph_merge(g1: nx.Graph, g2: nx.Graph, change: GraphChange): """Merge graph g2 into g1 in place.""" for node_name, attr in g2.nodes(data=True): @@ -228,6 +231,7 @@ def graph_merge(g1: nx.Graph, g2: nx.Graph, change: GraphChange): g1.graph["source_id"] += g2.graph.get("source_id", []) return g1 + def compute_args_hash(*args): return md5(str(args).encode()).hexdigest() @@ -378,6 +382,7 @@ async def graph_edge_to_chunk(kb_id, embd_mdl, from_ent_name, to_ent_name, meta, chunk["q_%d_vec" % len(ebd)] = ebd chunks.append(chunk) + async def does_graph_contains(tenant_id, kb_id, doc_id): # Get doc_ids of graph fields = ["source_id"] @@ -392,6 +397,7 @@ async def does_graph_contains(tenant_id, kb_id, doc_id): graph_doc_ids = set(fields2[chunk_id]["source_id"]) return doc_id in graph_doc_ids + async def get_graph_doc_ids(tenant_id, kb_id) -> list[str]: conds = { "fields": ["source_id"], diff --git a/rag/raptor.py b/rag/raptor.py index a8d912f32..20a4334d0 100644 --- a/rag/raptor.py +++ b/rag/raptor.py @@ -20,6 +20,7 @@ import numpy as np from sklearn.mixture import GaussianMixture import trio +from api.utils.api_utils import timeout from graphrag.utils import ( get_llm_cache, get_embed_cache, @@ -54,6 +55,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: set_llm_cache(self._llm_model.llm_name, system, response, history, gen_conf) return response + @timeout(2) async def _embedding_encode(self, txt): response = get_embed_cache(self._embd_model.llm_name, txt) if response is not None: @@ -83,6 +85,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: layers = [(0, len(chunks))] start, end = 0, len(chunks) + @timeout(60) async def summarize(ck_idx: list[int]): nonlocal chunks texts = [chunks[i][0] for i in ck_idx] diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index adcaa4c23..7fdc570c2 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -21,6 +21,7 @@ import sys import threading import time +from api.utils.api_utils import timeout from api.utils.log_utils import init_root_logger, get_project_base_directory from graphrag.general.index import run_graphrag from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache @@ -275,6 +276,7 @@ async def build_chunks(task, progress_callback): doc[PAGERANK_FLD] = int(task["pagerank"]) st = timer() + @timeout(60) async def upload_to_minio(document, chunk): try: d = copy.deepcopy(document) @@ -415,6 +417,7 @@ def init_kb(row, vector_size: int): return settings.docStoreConn.createIdx(idxnm, row.get("kb_id", ""), vector_size) +@timeout(60*20) async def embedding(docs, mdl, parser_config=None, callback=None): if parser_config is None: parser_config = {} @@ -461,6 +464,7 @@ async def embedding(docs, mdl, parser_config=None, callback=None): return tk_count, vector_size +@timeout(3600) async def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None): chunks = [] vctr_nm = "q_%d_vec"%vector_size @@ -502,6 +506,7 @@ async def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None): return res, tk_count +@timeout(60*60*1.5) async def do_handle_task(task): task_id = task["id"] task_from_page = task["from_page"] diff --git a/rag/utils/redis_conn.py b/rag/utils/redis_conn.py index abfb26fb7..334f438ed 100644 --- a/rag/utils/redis_conn.py +++ b/rag/utils/redis_conn.py @@ -220,40 +220,43 @@ class RedisDB: logging.exception( "RedisDB.queue_product " + str(queue) + " got exception: " + str(e) ) + self.__open__() return False def queue_consumer(self, queue_name, group_name, consumer_name, msg_id=b">") -> RedisMsg: """https://redis.io/docs/latest/commands/xreadgroup/""" - try: - group_info = self.REDIS.xinfo_groups(queue_name) - if not any(gi["name"] == group_name for gi in group_info): - self.REDIS.xgroup_create(queue_name, group_name, id="0", mkstream=True) - args = { - "groupname": group_name, - "consumername": consumer_name, - "count": 1, - "block": 5, - "streams": {queue_name: msg_id}, - } - messages = self.REDIS.xreadgroup(**args) - if not messages: - return None - stream, element_list = messages[0] - if not element_list: - return None - msg_id, payload = element_list[0] - res = RedisMsg(self.REDIS, queue_name, group_name, msg_id, payload) - return res - except Exception as e: - if str(e) == 'no such key': - pass - else: - logging.exception( - "RedisDB.queue_consumer " - + str(queue_name) - + " got exception: " - + str(e) - ) + for _ in range(3): + try: + group_info = self.REDIS.xinfo_groups(queue_name) + if not any(gi["name"] == group_name for gi in group_info): + self.REDIS.xgroup_create(queue_name, group_name, id="0", mkstream=True) + args = { + "groupname": group_name, + "consumername": consumer_name, + "count": 1, + "block": 5, + "streams": {queue_name: msg_id}, + } + messages = self.REDIS.xreadgroup(**args) + if not messages: + return None + stream, element_list = messages[0] + if not element_list: + return None + msg_id, payload = element_list[0] + res = RedisMsg(self.REDIS, queue_name, group_name, msg_id, payload) + return res + except Exception as e: + if str(e) == 'no such key': + pass + else: + logging.exception( + "RedisDB.queue_consumer " + + str(queue_name) + + " got exception: " + + str(e) + ) + self.__open__() return None def get_unacked_iterator(self, queue_names: list[str], group_name, consumer_name): @@ -294,26 +297,30 @@ class RedisDB: return [] def requeue_msg(self, queue: str, group_name: str, msg_id: str): - try: - messages = self.REDIS.xrange(queue, msg_id, msg_id) - if messages: - self.REDIS.xadd(queue, messages[0][1]) - self.REDIS.xack(queue, group_name, msg_id) - except Exception as e: - logging.warning( - "RedisDB.get_pending_msg " + str(queue) + " got exception: " + str(e) - ) + for _ in range(3): + try: + messages = self.REDIS.xrange(queue, msg_id, msg_id) + if messages: + self.REDIS.xadd(queue, messages[0][1]) + self.REDIS.xack(queue, group_name, msg_id) + except Exception as e: + logging.warning( + "RedisDB.get_pending_msg " + str(queue) + " got exception: " + str(e) + ) + self.__open__() def queue_info(self, queue, group_name) -> dict | None: - try: - groups = self.REDIS.xinfo_groups(queue) - for group in groups: - if group["name"] == group_name: - return group - except Exception as e: - logging.warning( - "RedisDB.queue_info " + str(queue) + " got exception: " + str(e) - ) + for _ in range(3): + try: + groups = self.REDIS.xinfo_groups(queue) + for group in groups: + if group["name"] == group_name: + return group + except Exception as e: + logging.warning( + "RedisDB.queue_info " + str(queue) + " got exception: " + str(e) + ) + self.__open__() return None def delete_if_equal(self, key: str, expected_value: str) -> bool: