Perf: Enhance timeout handling. (#8826)

### What problem does this PR solve?


### Type of change

- [x] Performance Improvement
This commit is contained in:
Kevin Hu
2025-07-15 09:36:45 +08:00
committed by GitHub
parent ce140f1393
commit c642dbefca
10 changed files with 207 additions and 85 deletions

View File

@ -1,3 +1,18 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from flask import Response, request from flask import Response, request
from flask_login import current_user, login_required from flask_login import current_user, login_required
@ -6,9 +21,10 @@ from api.db.db_models import MCPServer
from api.db.services.mcp_server_service import MCPServerService from api.db.services.mcp_server_service import MCPServerService
from api.db.services.user_service import TenantService from api.db.services.user_service import TenantService
from api.settings import RetCode from api.settings import RetCode
from api.utils import get_uuid from api.utils import get_uuid
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request, \
from api.utils.mcp_server import get_mcp_tools get_mcp_tools
from api.utils.web_utils import get_float, safe_json_parse from api.utils.web_utils import get_float, safe_json_parse
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions

View File

@ -13,19 +13,29 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import asyncio
import functools import functools
import json import json
import logging import logging
import queue
import random import random
import threading
import time import time
from base64 import b64encode from base64 import b64encode
from copy import deepcopy from copy import deepcopy
from functools import wraps from functools import wraps
from hmac import HMAC from hmac import HMAC
from io import BytesIO from io import BytesIO
from typing import Any, Optional, Union, Callable, Coroutine, Type
from urllib.parse import quote, urlencode from urllib.parse import quote, urlencode
from uuid import uuid1 from uuid import uuid1
import trio
from api.db.db_models import MCPServer
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
import requests import requests
from flask import ( from flask import (
Response, Response,
@ -558,3 +568,101 @@ def remap_dictionary_keys(source_data: dict, key_aliases: dict = None) -> dict:
transformed_data[mapped_key] = value transformed_data[mapped_key] = value
return transformed_data return transformed_data
def get_mcp_tools(mcp_servers: list[MCPServer], timeout: float | int = 10) -> tuple[dict, str]:
results = {}
tool_call_sessions = []
try:
for mcp_server in mcp_servers:
server_key = mcp_server.id
cached_tools = mcp_server.variables.get("tools", {})
tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables)
tool_call_sessions.append(tool_call_session)
try:
tools = tool_call_session.get_tools(timeout)
except Exception:
tools = []
results[server_key] = []
for tool in tools:
tool_dict = tool.model_dump()
cached_tool = cached_tools.get(tool_dict["name"], {})
tool_dict["enabled"] = cached_tool.get("enabled", True)
results[server_key].append(tool_dict)
# PERF: blocking call to close sessions — consider moving to background thread or task queue
close_multiple_mcp_toolcall_sessions(tool_call_sessions)
return results, ""
except Exception as e:
return {}, str(e)
TimeoutException = Union[Type[BaseException], BaseException]
OnTimeoutCallback = Union[Callable[..., Any], Coroutine[Any, Any, Any]]
def timeout(
seconds: float |int = None,
*,
exception: Optional[TimeoutException] = None,
on_timeout: Optional[OnTimeoutCallback] = None
):
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
result_queue = queue.Queue(maxsize=1)
def target():
try:
result = func(*args, **kwargs)
result_queue.put(result)
except Exception as e:
result_queue.put(e)
thread = threading.Thread(target=target)
thread.daemon = True
thread.start()
try:
result = result_queue.get(timeout=seconds)
if isinstance(result, Exception):
raise result
return result
except queue.Empty:
raise TimeoutError(f"Function '{func.__name__}' timed out after {seconds} seconds")
@wraps(func)
async def async_wrapper(*args, **kwargs) -> Any:
if seconds is None:
return await func(*args, **kwargs)
try:
with trio.fail_after(seconds):
return await func(*args, **kwargs)
except trio.TooSlowError:
if on_timeout is not None:
if callable(on_timeout):
result = on_timeout()
if isinstance(result, Coroutine):
return await result
return result
return on_timeout
if exception is None:
raise TimeoutError(f"Operation timed out after {seconds} seconds")
if isinstance(exception, BaseException):
raise exception
if isinstance(exception, type) and issubclass(exception, BaseException):
raise exception(f"Operation timed out after {seconds} seconds")
raise RuntimeError("Invalid exception type provided")
if asyncio.iscoroutinefunction(func):
return async_wrapper
return wrapper
return decorator

View File

@ -1,34 +0,0 @@
from api.db.db_models import MCPServer
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
def get_mcp_tools(mcp_servers: list[MCPServer], timeout: float | int = 10) -> tuple[dict, str]:
results = {}
tool_call_sessions = []
try:
for mcp_server in mcp_servers:
server_key = mcp_server.id
cached_tools = mcp_server.variables.get("tools", {})
tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables)
tool_call_sessions.append(tool_call_session)
try:
tools = tool_call_session.get_tools(timeout)
except Exception:
tools = []
results[server_key] = []
for tool in tools:
tool_dict = tool.model_dump()
cached_tool = cached_tools.get(tool_dict["name"], {})
tool_dict["enabled"] = cached_tool.get("enabled", True)
results[server_key].append(tool_dict)
# PERF: blocking call to close sessions — consider moving to background thread or task queue
close_multiple_mcp_toolcall_sessions(tool_call_sessions)
return results, ""
except Exception as e:
return {}, str(e)

View File

@ -12,6 +12,8 @@ from typing import Callable
from dataclasses import dataclass from dataclasses import dataclass
import networkx as nx import networkx as nx
import pandas as pd import pandas as pd
from api.utils.api_utils import timeout
from graphrag.general import leiden from graphrag.general import leiden
from graphrag.general.community_report_prompt import COMMUNITY_REPORT_PROMPT from graphrag.general.community_report_prompt import COMMUNITY_REPORT_PROMPT
from graphrag.general.extractor import Extractor from graphrag.general.extractor import Extractor
@ -57,6 +59,7 @@ class CommunityReportsExtractor(Extractor):
res_str = [] res_str = []
res_dict = [] res_dict = []
over, token_count = 0, 0 over, token_count = 0, 0
@timeout(120)
async def extract_community_report(community): async def extract_community_report(community):
nonlocal res_str, res_dict, over, token_count nonlocal res_str, res_dict, over, token_count
cm_id, cm = community cm_id, cm = community
@ -90,7 +93,7 @@ class CommunityReportsExtractor(Extractor):
gen_conf = {"temperature": 0.3} gen_conf = {"temperature": 0.3}
async with chat_limiter: async with chat_limiter:
try: try:
with trio.move_on_after(120) as cancel_scope: with trio.move_on_after(80) as cancel_scope:
response = await trio.to_thread.run_sync( self._chat, text, [{"role": "user", "content": "Output:"}], gen_conf) response = await trio.to_thread.run_sync( self._chat, text, [{"role": "user", "content": "Output:"}], gen_conf)
if cancel_scope.cancelled_caught: if cancel_scope.cancelled_caught:
logging.warning("extract_community_report._chat timeout, skipping...") logging.warning("extract_community_report._chat timeout, skipping...")

View File

@ -21,6 +21,7 @@ from typing import Callable
import trio import trio
import networkx as nx import networkx as nx
from api.utils.api_utils import timeout
from graphrag.general.graph_prompt import SUMMARIZE_DESCRIPTIONS_PROMPT from graphrag.general.graph_prompt import SUMMARIZE_DESCRIPTIONS_PROMPT
from graphrag.utils import get_llm_cache, set_llm_cache, handle_single_entity_extraction, \ from graphrag.utils import get_llm_cache, set_llm_cache, handle_single_entity_extraction, \
handle_single_relationship_extraction, split_string_by_multi_markers, flat_uniq_list, chat_limiter, get_from_to, GraphChange handle_single_relationship_extraction, split_string_by_multi_markers, flat_uniq_list, chat_limiter, get_from_to, GraphChange
@ -46,6 +47,7 @@ class Extractor:
self._language = language self._language = language
self._entity_types = entity_types or DEFAULT_ENTITY_TYPES self._entity_types = entity_types or DEFAULT_ENTITY_TYPES
@timeout(60)
def _chat(self, system, history, gen_conf): def _chat(self, system, history, gen_conf):
hist = deepcopy(history) hist = deepcopy(history)
conf = deepcopy(gen_conf) conf = deepcopy(gen_conf)

View File

@ -20,6 +20,7 @@ import trio
from api import settings from api import settings
from api.utils import get_uuid from api.utils import get_uuid
from api.utils.api_utils import timeout
from graphrag.light.graph_extractor import GraphExtractor as LightKGExt from graphrag.light.graph_extractor import GraphExtractor as LightKGExt
from graphrag.general.graph_extractor import GraphExtractor as GeneralKGExt from graphrag.general.graph_extractor import GraphExtractor as GeneralKGExt
from graphrag.general.community_reports_extractor import CommunityReportsExtractor from graphrag.general.community_reports_extractor import CommunityReportsExtractor
@ -123,6 +124,7 @@ async def run_graphrag(
return return
@timeout(60*60*2)
async def generate_subgraph( async def generate_subgraph(
extractor: Extractor, extractor: Extractor,
tenant_id: str, tenant_id: str,
@ -194,6 +196,8 @@ async def generate_subgraph(
callback(msg=f"generated subgraph for doc {doc_id} in {now - start:.2f} seconds.") callback(msg=f"generated subgraph for doc {doc_id} in {now - start:.2f} seconds.")
return subgraph return subgraph
@timeout(60*3)
async def merge_subgraph( async def merge_subgraph(
tenant_id: str, tenant_id: str,
kb_id: str, kb_id: str,
@ -225,6 +229,7 @@ async def merge_subgraph(
return new_graph return new_graph
@timeout(60*60)
async def resolve_entities( async def resolve_entities(
graph, graph,
subgraph_nodes: set[str], subgraph_nodes: set[str],
@ -250,6 +255,7 @@ async def resolve_entities(
callback(msg=f"Graph resolution done in {now - start:.2f}s.") callback(msg=f"Graph resolution done in {now - start:.2f}s.")
@timeout(60*30)
async def extract_community( async def extract_community(
graph, graph,
tenant_id: str, tenant_id: str,

View File

@ -157,6 +157,7 @@ def set_tags_to_cache(kb_ids, tags):
k = hasher.hexdigest() k = hasher.hexdigest()
REDIS_CONN.set(k, json.dumps(tags).encode("utf-8"), 600) REDIS_CONN.set(k, json.dumps(tags).encode("utf-8"), 600)
def tidy_graph(graph: nx.Graph, callback, check_attribute: bool = True): def tidy_graph(graph: nx.Graph, callback, check_attribute: bool = True):
""" """
Ensure all nodes and edges in the graph have some essential attribute. Ensure all nodes and edges in the graph have some essential attribute.
@ -190,12 +191,14 @@ def tidy_graph(graph: nx.Graph, callback, check_attribute: bool = True):
if purged_edges and callback: if purged_edges and callback:
callback(msg=f"Purged {len(purged_edges)} edges from graph due to missing essential attributes.") callback(msg=f"Purged {len(purged_edges)} edges from graph due to missing essential attributes.")
def get_from_to(node1, node2): def get_from_to(node1, node2):
if node1 < node2: if node1 < node2:
return (node1, node2) return (node1, node2)
else: else:
return (node2, node1) return (node2, node1)
def graph_merge(g1: nx.Graph, g2: nx.Graph, change: GraphChange): def graph_merge(g1: nx.Graph, g2: nx.Graph, change: GraphChange):
"""Merge graph g2 into g1 in place.""" """Merge graph g2 into g1 in place."""
for node_name, attr in g2.nodes(data=True): for node_name, attr in g2.nodes(data=True):
@ -228,6 +231,7 @@ def graph_merge(g1: nx.Graph, g2: nx.Graph, change: GraphChange):
g1.graph["source_id"] += g2.graph.get("source_id", []) g1.graph["source_id"] += g2.graph.get("source_id", [])
return g1 return g1
def compute_args_hash(*args): def compute_args_hash(*args):
return md5(str(args).encode()).hexdigest() return md5(str(args).encode()).hexdigest()
@ -378,6 +382,7 @@ async def graph_edge_to_chunk(kb_id, embd_mdl, from_ent_name, to_ent_name, meta,
chunk["q_%d_vec" % len(ebd)] = ebd chunk["q_%d_vec" % len(ebd)] = ebd
chunks.append(chunk) chunks.append(chunk)
async def does_graph_contains(tenant_id, kb_id, doc_id): async def does_graph_contains(tenant_id, kb_id, doc_id):
# Get doc_ids of graph # Get doc_ids of graph
fields = ["source_id"] fields = ["source_id"]
@ -392,6 +397,7 @@ async def does_graph_contains(tenant_id, kb_id, doc_id):
graph_doc_ids = set(fields2[chunk_id]["source_id"]) graph_doc_ids = set(fields2[chunk_id]["source_id"])
return doc_id in graph_doc_ids return doc_id in graph_doc_ids
async def get_graph_doc_ids(tenant_id, kb_id) -> list[str]: async def get_graph_doc_ids(tenant_id, kb_id) -> list[str]:
conds = { conds = {
"fields": ["source_id"], "fields": ["source_id"],

View File

@ -20,6 +20,7 @@ import numpy as np
from sklearn.mixture import GaussianMixture from sklearn.mixture import GaussianMixture
import trio import trio
from api.utils.api_utils import timeout
from graphrag.utils import ( from graphrag.utils import (
get_llm_cache, get_llm_cache,
get_embed_cache, get_embed_cache,
@ -54,6 +55,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
set_llm_cache(self._llm_model.llm_name, system, response, history, gen_conf) set_llm_cache(self._llm_model.llm_name, system, response, history, gen_conf)
return response return response
@timeout(2)
async def _embedding_encode(self, txt): async def _embedding_encode(self, txt):
response = get_embed_cache(self._embd_model.llm_name, txt) response = get_embed_cache(self._embd_model.llm_name, txt)
if response is not None: if response is not None:
@ -83,6 +85,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
layers = [(0, len(chunks))] layers = [(0, len(chunks))]
start, end = 0, len(chunks) start, end = 0, len(chunks)
@timeout(60)
async def summarize(ck_idx: list[int]): async def summarize(ck_idx: list[int]):
nonlocal chunks nonlocal chunks
texts = [chunks[i][0] for i in ck_idx] texts = [chunks[i][0] for i in ck_idx]

View File

@ -21,6 +21,7 @@ import sys
import threading import threading
import time import time
from api.utils.api_utils import timeout
from api.utils.log_utils import init_root_logger, get_project_base_directory from api.utils.log_utils import init_root_logger, get_project_base_directory
from graphrag.general.index import run_graphrag from graphrag.general.index import run_graphrag
from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache
@ -275,6 +276,7 @@ async def build_chunks(task, progress_callback):
doc[PAGERANK_FLD] = int(task["pagerank"]) doc[PAGERANK_FLD] = int(task["pagerank"])
st = timer() st = timer()
@timeout(60)
async def upload_to_minio(document, chunk): async def upload_to_minio(document, chunk):
try: try:
d = copy.deepcopy(document) d = copy.deepcopy(document)
@ -415,6 +417,7 @@ def init_kb(row, vector_size: int):
return settings.docStoreConn.createIdx(idxnm, row.get("kb_id", ""), vector_size) return settings.docStoreConn.createIdx(idxnm, row.get("kb_id", ""), vector_size)
@timeout(60*20)
async def embedding(docs, mdl, parser_config=None, callback=None): async def embedding(docs, mdl, parser_config=None, callback=None):
if parser_config is None: if parser_config is None:
parser_config = {} parser_config = {}
@ -461,6 +464,7 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
return tk_count, vector_size return tk_count, vector_size
@timeout(3600)
async def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None): async def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
chunks = [] chunks = []
vctr_nm = "q_%d_vec"%vector_size vctr_nm = "q_%d_vec"%vector_size
@ -502,6 +506,7 @@ async def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
return res, tk_count return res, tk_count
@timeout(60*60*1.5)
async def do_handle_task(task): async def do_handle_task(task):
task_id = task["id"] task_id = task["id"]
task_from_page = task["from_page"] task_from_page = task["from_page"]

View File

@ -220,10 +220,12 @@ class RedisDB:
logging.exception( logging.exception(
"RedisDB.queue_product " + str(queue) + " got exception: " + str(e) "RedisDB.queue_product " + str(queue) + " got exception: " + str(e)
) )
self.__open__()
return False return False
def queue_consumer(self, queue_name, group_name, consumer_name, msg_id=b">") -> RedisMsg: def queue_consumer(self, queue_name, group_name, consumer_name, msg_id=b">") -> RedisMsg:
"""https://redis.io/docs/latest/commands/xreadgroup/""" """https://redis.io/docs/latest/commands/xreadgroup/"""
for _ in range(3):
try: try:
group_info = self.REDIS.xinfo_groups(queue_name) group_info = self.REDIS.xinfo_groups(queue_name)
if not any(gi["name"] == group_name for gi in group_info): if not any(gi["name"] == group_name for gi in group_info):
@ -254,6 +256,7 @@ class RedisDB:
+ " got exception: " + " got exception: "
+ str(e) + str(e)
) )
self.__open__()
return None return None
def get_unacked_iterator(self, queue_names: list[str], group_name, consumer_name): def get_unacked_iterator(self, queue_names: list[str], group_name, consumer_name):
@ -294,6 +297,7 @@ class RedisDB:
return [] return []
def requeue_msg(self, queue: str, group_name: str, msg_id: str): def requeue_msg(self, queue: str, group_name: str, msg_id: str):
for _ in range(3):
try: try:
messages = self.REDIS.xrange(queue, msg_id, msg_id) messages = self.REDIS.xrange(queue, msg_id, msg_id)
if messages: if messages:
@ -303,8 +307,10 @@ class RedisDB:
logging.warning( logging.warning(
"RedisDB.get_pending_msg " + str(queue) + " got exception: " + str(e) "RedisDB.get_pending_msg " + str(queue) + " got exception: " + str(e)
) )
self.__open__()
def queue_info(self, queue, group_name) -> dict | None: def queue_info(self, queue, group_name) -> dict | None:
for _ in range(3):
try: try:
groups = self.REDIS.xinfo_groups(queue) groups = self.REDIS.xinfo_groups(queue)
for group in groups: for group in groups:
@ -314,6 +320,7 @@ class RedisDB:
logging.warning( logging.warning(
"RedisDB.queue_info " + str(queue) + " got exception: " + str(e) "RedisDB.queue_info " + str(queue) + " got exception: " + str(e)
) )
self.__open__()
return None return None
def delete_if_equal(self, key: str, expected_value: str) -> bool: def delete_if_equal(self, key: str, expected_value: str) -> bool: