mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Perf: Enhance timeout handling. (#8826)
### What problem does this PR solve? ### Type of change - [x] Performance Improvement
This commit is contained in:
@ -1,3 +1,18 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from flask import Response, request
|
||||
from flask_login import current_user, login_required
|
||||
|
||||
@ -6,9 +21,10 @@ from api.db.db_models import MCPServer
|
||||
from api.db.services.mcp_server_service import MCPServerService
|
||||
from api.db.services.user_service import TenantService
|
||||
from api.settings import RetCode
|
||||
|
||||
from api.utils import get_uuid
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request
|
||||
from api.utils.mcp_server import get_mcp_tools
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request, \
|
||||
get_mcp_tools
|
||||
from api.utils.web_utils import get_float, safe_json_parse
|
||||
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
|
||||
|
||||
|
||||
@ -13,19 +13,29 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
import functools
|
||||
import json
|
||||
import logging
|
||||
import queue
|
||||
import random
|
||||
import threading
|
||||
import time
|
||||
from base64 import b64encode
|
||||
from copy import deepcopy
|
||||
from functools import wraps
|
||||
from hmac import HMAC
|
||||
from io import BytesIO
|
||||
from typing import Any, Optional, Union, Callable, Coroutine, Type
|
||||
from urllib.parse import quote, urlencode
|
||||
from uuid import uuid1
|
||||
|
||||
import trio
|
||||
|
||||
from api.db.db_models import MCPServer
|
||||
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
|
||||
|
||||
|
||||
import requests
|
||||
from flask import (
|
||||
Response,
|
||||
@ -558,3 +568,101 @@ def remap_dictionary_keys(source_data: dict, key_aliases: dict = None) -> dict:
|
||||
transformed_data[mapped_key] = value
|
||||
|
||||
return transformed_data
|
||||
|
||||
|
||||
def get_mcp_tools(mcp_servers: list[MCPServer], timeout: float | int = 10) -> tuple[dict, str]:
|
||||
results = {}
|
||||
tool_call_sessions = []
|
||||
try:
|
||||
for mcp_server in mcp_servers:
|
||||
server_key = mcp_server.id
|
||||
|
||||
cached_tools = mcp_server.variables.get("tools", {})
|
||||
|
||||
tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables)
|
||||
tool_call_sessions.append(tool_call_session)
|
||||
|
||||
try:
|
||||
tools = tool_call_session.get_tools(timeout)
|
||||
except Exception:
|
||||
tools = []
|
||||
|
||||
results[server_key] = []
|
||||
for tool in tools:
|
||||
tool_dict = tool.model_dump()
|
||||
cached_tool = cached_tools.get(tool_dict["name"], {})
|
||||
|
||||
tool_dict["enabled"] = cached_tool.get("enabled", True)
|
||||
results[server_key].append(tool_dict)
|
||||
|
||||
# PERF: blocking call to close sessions — consider moving to background thread or task queue
|
||||
close_multiple_mcp_toolcall_sessions(tool_call_sessions)
|
||||
return results, ""
|
||||
except Exception as e:
|
||||
return {}, str(e)
|
||||
|
||||
|
||||
TimeoutException = Union[Type[BaseException], BaseException]
|
||||
OnTimeoutCallback = Union[Callable[..., Any], Coroutine[Any, Any, Any]]
|
||||
def timeout(
|
||||
seconds: float |int = None,
|
||||
*,
|
||||
exception: Optional[TimeoutException] = None,
|
||||
on_timeout: Optional[OnTimeoutCallback] = None
|
||||
):
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
result_queue = queue.Queue(maxsize=1)
|
||||
def target():
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
result_queue.put(result)
|
||||
except Exception as e:
|
||||
result_queue.put(e)
|
||||
|
||||
thread = threading.Thread(target=target)
|
||||
thread.daemon = True
|
||||
thread.start()
|
||||
|
||||
try:
|
||||
result = result_queue.get(timeout=seconds)
|
||||
if isinstance(result, Exception):
|
||||
raise result
|
||||
return result
|
||||
except queue.Empty:
|
||||
raise TimeoutError(f"Function '{func.__name__}' timed out after {seconds} seconds")
|
||||
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs) -> Any:
|
||||
if seconds is None:
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
try:
|
||||
with trio.fail_after(seconds):
|
||||
return await func(*args, **kwargs)
|
||||
except trio.TooSlowError:
|
||||
if on_timeout is not None:
|
||||
if callable(on_timeout):
|
||||
result = on_timeout()
|
||||
if isinstance(result, Coroutine):
|
||||
return await result
|
||||
return result
|
||||
return on_timeout
|
||||
|
||||
if exception is None:
|
||||
raise TimeoutError(f"Operation timed out after {seconds} seconds")
|
||||
|
||||
if isinstance(exception, BaseException):
|
||||
raise exception
|
||||
|
||||
if isinstance(exception, type) and issubclass(exception, BaseException):
|
||||
raise exception(f"Operation timed out after {seconds} seconds")
|
||||
|
||||
raise RuntimeError("Invalid exception type provided")
|
||||
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
return async_wrapper
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
@ -1,34 +0,0 @@
|
||||
from api.db.db_models import MCPServer
|
||||
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
|
||||
|
||||
|
||||
def get_mcp_tools(mcp_servers: list[MCPServer], timeout: float | int = 10) -> tuple[dict, str]:
|
||||
results = {}
|
||||
tool_call_sessions = []
|
||||
try:
|
||||
for mcp_server in mcp_servers:
|
||||
server_key = mcp_server.id
|
||||
|
||||
cached_tools = mcp_server.variables.get("tools", {})
|
||||
|
||||
tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables)
|
||||
tool_call_sessions.append(tool_call_session)
|
||||
|
||||
try:
|
||||
tools = tool_call_session.get_tools(timeout)
|
||||
except Exception:
|
||||
tools = []
|
||||
|
||||
results[server_key] = []
|
||||
for tool in tools:
|
||||
tool_dict = tool.model_dump()
|
||||
cached_tool = cached_tools.get(tool_dict["name"], {})
|
||||
|
||||
tool_dict["enabled"] = cached_tool.get("enabled", True)
|
||||
results[server_key].append(tool_dict)
|
||||
|
||||
# PERF: blocking call to close sessions — consider moving to background thread or task queue
|
||||
close_multiple_mcp_toolcall_sessions(tool_call_sessions)
|
||||
return results, ""
|
||||
except Exception as e:
|
||||
return {}, str(e)
|
||||
@ -12,6 +12,8 @@ from typing import Callable
|
||||
from dataclasses import dataclass
|
||||
import networkx as nx
|
||||
import pandas as pd
|
||||
|
||||
from api.utils.api_utils import timeout
|
||||
from graphrag.general import leiden
|
||||
from graphrag.general.community_report_prompt import COMMUNITY_REPORT_PROMPT
|
||||
from graphrag.general.extractor import Extractor
|
||||
@ -57,6 +59,7 @@ class CommunityReportsExtractor(Extractor):
|
||||
res_str = []
|
||||
res_dict = []
|
||||
over, token_count = 0, 0
|
||||
@timeout(120)
|
||||
async def extract_community_report(community):
|
||||
nonlocal res_str, res_dict, over, token_count
|
||||
cm_id, cm = community
|
||||
@ -90,7 +93,7 @@ class CommunityReportsExtractor(Extractor):
|
||||
gen_conf = {"temperature": 0.3}
|
||||
async with chat_limiter:
|
||||
try:
|
||||
with trio.move_on_after(120) as cancel_scope:
|
||||
with trio.move_on_after(80) as cancel_scope:
|
||||
response = await trio.to_thread.run_sync( self._chat, text, [{"role": "user", "content": "Output:"}], gen_conf)
|
||||
if cancel_scope.cancelled_caught:
|
||||
logging.warning("extract_community_report._chat timeout, skipping...")
|
||||
|
||||
@ -21,6 +21,7 @@ from typing import Callable
|
||||
import trio
|
||||
import networkx as nx
|
||||
|
||||
from api.utils.api_utils import timeout
|
||||
from graphrag.general.graph_prompt import SUMMARIZE_DESCRIPTIONS_PROMPT
|
||||
from graphrag.utils import get_llm_cache, set_llm_cache, handle_single_entity_extraction, \
|
||||
handle_single_relationship_extraction, split_string_by_multi_markers, flat_uniq_list, chat_limiter, get_from_to, GraphChange
|
||||
@ -46,6 +47,7 @@ class Extractor:
|
||||
self._language = language
|
||||
self._entity_types = entity_types or DEFAULT_ENTITY_TYPES
|
||||
|
||||
@timeout(60)
|
||||
def _chat(self, system, history, gen_conf):
|
||||
hist = deepcopy(history)
|
||||
conf = deepcopy(gen_conf)
|
||||
|
||||
@ -20,6 +20,7 @@ import trio
|
||||
|
||||
from api import settings
|
||||
from api.utils import get_uuid
|
||||
from api.utils.api_utils import timeout
|
||||
from graphrag.light.graph_extractor import GraphExtractor as LightKGExt
|
||||
from graphrag.general.graph_extractor import GraphExtractor as GeneralKGExt
|
||||
from graphrag.general.community_reports_extractor import CommunityReportsExtractor
|
||||
@ -123,6 +124,7 @@ async def run_graphrag(
|
||||
return
|
||||
|
||||
|
||||
@timeout(60*60*2)
|
||||
async def generate_subgraph(
|
||||
extractor: Extractor,
|
||||
tenant_id: str,
|
||||
@ -194,6 +196,8 @@ async def generate_subgraph(
|
||||
callback(msg=f"generated subgraph for doc {doc_id} in {now - start:.2f} seconds.")
|
||||
return subgraph
|
||||
|
||||
|
||||
@timeout(60*3)
|
||||
async def merge_subgraph(
|
||||
tenant_id: str,
|
||||
kb_id: str,
|
||||
@ -225,6 +229,7 @@ async def merge_subgraph(
|
||||
return new_graph
|
||||
|
||||
|
||||
@timeout(60*60)
|
||||
async def resolve_entities(
|
||||
graph,
|
||||
subgraph_nodes: set[str],
|
||||
@ -250,6 +255,7 @@ async def resolve_entities(
|
||||
callback(msg=f"Graph resolution done in {now - start:.2f}s.")
|
||||
|
||||
|
||||
@timeout(60*30)
|
||||
async def extract_community(
|
||||
graph,
|
||||
tenant_id: str,
|
||||
|
||||
@ -157,6 +157,7 @@ def set_tags_to_cache(kb_ids, tags):
|
||||
k = hasher.hexdigest()
|
||||
REDIS_CONN.set(k, json.dumps(tags).encode("utf-8"), 600)
|
||||
|
||||
|
||||
def tidy_graph(graph: nx.Graph, callback, check_attribute: bool = True):
|
||||
"""
|
||||
Ensure all nodes and edges in the graph have some essential attribute.
|
||||
@ -190,12 +191,14 @@ def tidy_graph(graph: nx.Graph, callback, check_attribute: bool = True):
|
||||
if purged_edges and callback:
|
||||
callback(msg=f"Purged {len(purged_edges)} edges from graph due to missing essential attributes.")
|
||||
|
||||
|
||||
def get_from_to(node1, node2):
|
||||
if node1 < node2:
|
||||
return (node1, node2)
|
||||
else:
|
||||
return (node2, node1)
|
||||
|
||||
|
||||
def graph_merge(g1: nx.Graph, g2: nx.Graph, change: GraphChange):
|
||||
"""Merge graph g2 into g1 in place."""
|
||||
for node_name, attr in g2.nodes(data=True):
|
||||
@ -228,6 +231,7 @@ def graph_merge(g1: nx.Graph, g2: nx.Graph, change: GraphChange):
|
||||
g1.graph["source_id"] += g2.graph.get("source_id", [])
|
||||
return g1
|
||||
|
||||
|
||||
def compute_args_hash(*args):
|
||||
return md5(str(args).encode()).hexdigest()
|
||||
|
||||
@ -378,6 +382,7 @@ async def graph_edge_to_chunk(kb_id, embd_mdl, from_ent_name, to_ent_name, meta,
|
||||
chunk["q_%d_vec" % len(ebd)] = ebd
|
||||
chunks.append(chunk)
|
||||
|
||||
|
||||
async def does_graph_contains(tenant_id, kb_id, doc_id):
|
||||
# Get doc_ids of graph
|
||||
fields = ["source_id"]
|
||||
@ -392,6 +397,7 @@ async def does_graph_contains(tenant_id, kb_id, doc_id):
|
||||
graph_doc_ids = set(fields2[chunk_id]["source_id"])
|
||||
return doc_id in graph_doc_ids
|
||||
|
||||
|
||||
async def get_graph_doc_ids(tenant_id, kb_id) -> list[str]:
|
||||
conds = {
|
||||
"fields": ["source_id"],
|
||||
|
||||
@ -20,6 +20,7 @@ import numpy as np
|
||||
from sklearn.mixture import GaussianMixture
|
||||
import trio
|
||||
|
||||
from api.utils.api_utils import timeout
|
||||
from graphrag.utils import (
|
||||
get_llm_cache,
|
||||
get_embed_cache,
|
||||
@ -54,6 +55,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
||||
set_llm_cache(self._llm_model.llm_name, system, response, history, gen_conf)
|
||||
return response
|
||||
|
||||
@timeout(2)
|
||||
async def _embedding_encode(self, txt):
|
||||
response = get_embed_cache(self._embd_model.llm_name, txt)
|
||||
if response is not None:
|
||||
@ -83,6 +85,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
||||
layers = [(0, len(chunks))]
|
||||
start, end = 0, len(chunks)
|
||||
|
||||
@timeout(60)
|
||||
async def summarize(ck_idx: list[int]):
|
||||
nonlocal chunks
|
||||
texts = [chunks[i][0] for i in ck_idx]
|
||||
|
||||
@ -21,6 +21,7 @@ import sys
|
||||
import threading
|
||||
import time
|
||||
|
||||
from api.utils.api_utils import timeout
|
||||
from api.utils.log_utils import init_root_logger, get_project_base_directory
|
||||
from graphrag.general.index import run_graphrag
|
||||
from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache
|
||||
@ -275,6 +276,7 @@ async def build_chunks(task, progress_callback):
|
||||
doc[PAGERANK_FLD] = int(task["pagerank"])
|
||||
st = timer()
|
||||
|
||||
@timeout(60)
|
||||
async def upload_to_minio(document, chunk):
|
||||
try:
|
||||
d = copy.deepcopy(document)
|
||||
@ -415,6 +417,7 @@ def init_kb(row, vector_size: int):
|
||||
return settings.docStoreConn.createIdx(idxnm, row.get("kb_id", ""), vector_size)
|
||||
|
||||
|
||||
@timeout(60*20)
|
||||
async def embedding(docs, mdl, parser_config=None, callback=None):
|
||||
if parser_config is None:
|
||||
parser_config = {}
|
||||
@ -461,6 +464,7 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
|
||||
return tk_count, vector_size
|
||||
|
||||
|
||||
@timeout(3600)
|
||||
async def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
|
||||
chunks = []
|
||||
vctr_nm = "q_%d_vec"%vector_size
|
||||
@ -502,6 +506,7 @@ async def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
|
||||
return res, tk_count
|
||||
|
||||
|
||||
@timeout(60*60*1.5)
|
||||
async def do_handle_task(task):
|
||||
task_id = task["id"]
|
||||
task_from_page = task["from_page"]
|
||||
|
||||
@ -220,10 +220,12 @@ class RedisDB:
|
||||
logging.exception(
|
||||
"RedisDB.queue_product " + str(queue) + " got exception: " + str(e)
|
||||
)
|
||||
self.__open__()
|
||||
return False
|
||||
|
||||
def queue_consumer(self, queue_name, group_name, consumer_name, msg_id=b">") -> RedisMsg:
|
||||
"""https://redis.io/docs/latest/commands/xreadgroup/"""
|
||||
for _ in range(3):
|
||||
try:
|
||||
group_info = self.REDIS.xinfo_groups(queue_name)
|
||||
if not any(gi["name"] == group_name for gi in group_info):
|
||||
@ -254,6 +256,7 @@ class RedisDB:
|
||||
+ " got exception: "
|
||||
+ str(e)
|
||||
)
|
||||
self.__open__()
|
||||
return None
|
||||
|
||||
def get_unacked_iterator(self, queue_names: list[str], group_name, consumer_name):
|
||||
@ -294,6 +297,7 @@ class RedisDB:
|
||||
return []
|
||||
|
||||
def requeue_msg(self, queue: str, group_name: str, msg_id: str):
|
||||
for _ in range(3):
|
||||
try:
|
||||
messages = self.REDIS.xrange(queue, msg_id, msg_id)
|
||||
if messages:
|
||||
@ -303,8 +307,10 @@ class RedisDB:
|
||||
logging.warning(
|
||||
"RedisDB.get_pending_msg " + str(queue) + " got exception: " + str(e)
|
||||
)
|
||||
self.__open__()
|
||||
|
||||
def queue_info(self, queue, group_name) -> dict | None:
|
||||
for _ in range(3):
|
||||
try:
|
||||
groups = self.REDIS.xinfo_groups(queue)
|
||||
for group in groups:
|
||||
@ -314,6 +320,7 @@ class RedisDB:
|
||||
logging.warning(
|
||||
"RedisDB.queue_info " + str(queue) + " got exception: " + str(e)
|
||||
)
|
||||
self.__open__()
|
||||
return None
|
||||
|
||||
def delete_if_equal(self, key: str, expected_value: str) -> bool:
|
||||
|
||||
Reference in New Issue
Block a user