Perf: Enhance timeout handling. (#8826)

### What problem does this PR solve?


### Type of change

- [x] Performance Improvement
This commit is contained in:
Kevin Hu
2025-07-15 09:36:45 +08:00
committed by GitHub
parent ce140f1393
commit c642dbefca
10 changed files with 207 additions and 85 deletions

View File

@ -12,6 +12,8 @@ from typing import Callable
from dataclasses import dataclass
import networkx as nx
import pandas as pd
from api.utils.api_utils import timeout
from graphrag.general import leiden
from graphrag.general.community_report_prompt import COMMUNITY_REPORT_PROMPT
from graphrag.general.extractor import Extractor
@ -57,6 +59,7 @@ class CommunityReportsExtractor(Extractor):
res_str = []
res_dict = []
over, token_count = 0, 0
@timeout(120)
async def extract_community_report(community):
nonlocal res_str, res_dict, over, token_count
cm_id, cm = community
@ -90,7 +93,7 @@ class CommunityReportsExtractor(Extractor):
gen_conf = {"temperature": 0.3}
async with chat_limiter:
try:
with trio.move_on_after(120) as cancel_scope:
with trio.move_on_after(80) as cancel_scope:
response = await trio.to_thread.run_sync( self._chat, text, [{"role": "user", "content": "Output:"}], gen_conf)
if cancel_scope.cancelled_caught:
logging.warning("extract_community_report._chat timeout, skipping...")

View File

@ -21,6 +21,7 @@ from typing import Callable
import trio
import networkx as nx
from api.utils.api_utils import timeout
from graphrag.general.graph_prompt import SUMMARIZE_DESCRIPTIONS_PROMPT
from graphrag.utils import get_llm_cache, set_llm_cache, handle_single_entity_extraction, \
handle_single_relationship_extraction, split_string_by_multi_markers, flat_uniq_list, chat_limiter, get_from_to, GraphChange
@ -46,6 +47,7 @@ class Extractor:
self._language = language
self._entity_types = entity_types or DEFAULT_ENTITY_TYPES
@timeout(60)
def _chat(self, system, history, gen_conf):
hist = deepcopy(history)
conf = deepcopy(gen_conf)

View File

@ -20,6 +20,7 @@ import trio
from api import settings
from api.utils import get_uuid
from api.utils.api_utils import timeout
from graphrag.light.graph_extractor import GraphExtractor as LightKGExt
from graphrag.general.graph_extractor import GraphExtractor as GeneralKGExt
from graphrag.general.community_reports_extractor import CommunityReportsExtractor
@ -123,6 +124,7 @@ async def run_graphrag(
return
@timeout(60*60*2)
async def generate_subgraph(
extractor: Extractor,
tenant_id: str,
@ -194,6 +196,8 @@ async def generate_subgraph(
callback(msg=f"generated subgraph for doc {doc_id} in {now - start:.2f} seconds.")
return subgraph
@timeout(60*3)
async def merge_subgraph(
tenant_id: str,
kb_id: str,
@ -225,6 +229,7 @@ async def merge_subgraph(
return new_graph
@timeout(60*60)
async def resolve_entities(
graph,
subgraph_nodes: set[str],
@ -250,6 +255,7 @@ async def resolve_entities(
callback(msg=f"Graph resolution done in {now - start:.2f}s.")
@timeout(60*30)
async def extract_community(
graph,
tenant_id: str,

View File

@ -157,6 +157,7 @@ def set_tags_to_cache(kb_ids, tags):
k = hasher.hexdigest()
REDIS_CONN.set(k, json.dumps(tags).encode("utf-8"), 600)
def tidy_graph(graph: nx.Graph, callback, check_attribute: bool = True):
"""
Ensure all nodes and edges in the graph have some essential attribute.
@ -190,12 +191,14 @@ def tidy_graph(graph: nx.Graph, callback, check_attribute: bool = True):
if purged_edges and callback:
callback(msg=f"Purged {len(purged_edges)} edges from graph due to missing essential attributes.")
def get_from_to(node1, node2):
if node1 < node2:
return (node1, node2)
else:
return (node2, node1)
def graph_merge(g1: nx.Graph, g2: nx.Graph, change: GraphChange):
"""Merge graph g2 into g1 in place."""
for node_name, attr in g2.nodes(data=True):
@ -228,6 +231,7 @@ def graph_merge(g1: nx.Graph, g2: nx.Graph, change: GraphChange):
g1.graph["source_id"] += g2.graph.get("source_id", [])
return g1
def compute_args_hash(*args):
return md5(str(args).encode()).hexdigest()
@ -378,6 +382,7 @@ async def graph_edge_to_chunk(kb_id, embd_mdl, from_ent_name, to_ent_name, meta,
chunk["q_%d_vec" % len(ebd)] = ebd
chunks.append(chunk)
async def does_graph_contains(tenant_id, kb_id, doc_id):
# Get doc_ids of graph
fields = ["source_id"]
@ -392,6 +397,7 @@ async def does_graph_contains(tenant_id, kb_id, doc_id):
graph_doc_ids = set(fields2[chunk_id]["source_id"])
return doc_id in graph_doc_ids
async def get_graph_doc_ids(tenant_id, kb_id) -> list[str]:
conds = {
"fields": ["source_id"],