mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 12:32:30 +08:00
Perf: Enhance timeout handling. (#8826)
### What problem does this PR solve? ### Type of change - [x] Performance Improvement
This commit is contained in:
@ -12,6 +12,8 @@ from typing import Callable
|
||||
from dataclasses import dataclass
|
||||
import networkx as nx
|
||||
import pandas as pd
|
||||
|
||||
from api.utils.api_utils import timeout
|
||||
from graphrag.general import leiden
|
||||
from graphrag.general.community_report_prompt import COMMUNITY_REPORT_PROMPT
|
||||
from graphrag.general.extractor import Extractor
|
||||
@ -57,6 +59,7 @@ class CommunityReportsExtractor(Extractor):
|
||||
res_str = []
|
||||
res_dict = []
|
||||
over, token_count = 0, 0
|
||||
@timeout(120)
|
||||
async def extract_community_report(community):
|
||||
nonlocal res_str, res_dict, over, token_count
|
||||
cm_id, cm = community
|
||||
@ -90,7 +93,7 @@ class CommunityReportsExtractor(Extractor):
|
||||
gen_conf = {"temperature": 0.3}
|
||||
async with chat_limiter:
|
||||
try:
|
||||
with trio.move_on_after(120) as cancel_scope:
|
||||
with trio.move_on_after(80) as cancel_scope:
|
||||
response = await trio.to_thread.run_sync( self._chat, text, [{"role": "user", "content": "Output:"}], gen_conf)
|
||||
if cancel_scope.cancelled_caught:
|
||||
logging.warning("extract_community_report._chat timeout, skipping...")
|
||||
|
||||
@ -21,6 +21,7 @@ from typing import Callable
|
||||
import trio
|
||||
import networkx as nx
|
||||
|
||||
from api.utils.api_utils import timeout
|
||||
from graphrag.general.graph_prompt import SUMMARIZE_DESCRIPTIONS_PROMPT
|
||||
from graphrag.utils import get_llm_cache, set_llm_cache, handle_single_entity_extraction, \
|
||||
handle_single_relationship_extraction, split_string_by_multi_markers, flat_uniq_list, chat_limiter, get_from_to, GraphChange
|
||||
@ -46,6 +47,7 @@ class Extractor:
|
||||
self._language = language
|
||||
self._entity_types = entity_types or DEFAULT_ENTITY_TYPES
|
||||
|
||||
@timeout(60)
|
||||
def _chat(self, system, history, gen_conf):
|
||||
hist = deepcopy(history)
|
||||
conf = deepcopy(gen_conf)
|
||||
|
||||
@ -20,6 +20,7 @@ import trio
|
||||
|
||||
from api import settings
|
||||
from api.utils import get_uuid
|
||||
from api.utils.api_utils import timeout
|
||||
from graphrag.light.graph_extractor import GraphExtractor as LightKGExt
|
||||
from graphrag.general.graph_extractor import GraphExtractor as GeneralKGExt
|
||||
from graphrag.general.community_reports_extractor import CommunityReportsExtractor
|
||||
@ -123,6 +124,7 @@ async def run_graphrag(
|
||||
return
|
||||
|
||||
|
||||
@timeout(60*60*2)
|
||||
async def generate_subgraph(
|
||||
extractor: Extractor,
|
||||
tenant_id: str,
|
||||
@ -194,6 +196,8 @@ async def generate_subgraph(
|
||||
callback(msg=f"generated subgraph for doc {doc_id} in {now - start:.2f} seconds.")
|
||||
return subgraph
|
||||
|
||||
|
||||
@timeout(60*3)
|
||||
async def merge_subgraph(
|
||||
tenant_id: str,
|
||||
kb_id: str,
|
||||
@ -225,6 +229,7 @@ async def merge_subgraph(
|
||||
return new_graph
|
||||
|
||||
|
||||
@timeout(60*60)
|
||||
async def resolve_entities(
|
||||
graph,
|
||||
subgraph_nodes: set[str],
|
||||
@ -250,6 +255,7 @@ async def resolve_entities(
|
||||
callback(msg=f"Graph resolution done in {now - start:.2f}s.")
|
||||
|
||||
|
||||
@timeout(60*30)
|
||||
async def extract_community(
|
||||
graph,
|
||||
tenant_id: str,
|
||||
|
||||
@ -157,6 +157,7 @@ def set_tags_to_cache(kb_ids, tags):
|
||||
k = hasher.hexdigest()
|
||||
REDIS_CONN.set(k, json.dumps(tags).encode("utf-8"), 600)
|
||||
|
||||
|
||||
def tidy_graph(graph: nx.Graph, callback, check_attribute: bool = True):
|
||||
"""
|
||||
Ensure all nodes and edges in the graph have some essential attribute.
|
||||
@ -190,12 +191,14 @@ def tidy_graph(graph: nx.Graph, callback, check_attribute: bool = True):
|
||||
if purged_edges and callback:
|
||||
callback(msg=f"Purged {len(purged_edges)} edges from graph due to missing essential attributes.")
|
||||
|
||||
|
||||
def get_from_to(node1, node2):
|
||||
if node1 < node2:
|
||||
return (node1, node2)
|
||||
else:
|
||||
return (node2, node1)
|
||||
|
||||
|
||||
def graph_merge(g1: nx.Graph, g2: nx.Graph, change: GraphChange):
|
||||
"""Merge graph g2 into g1 in place."""
|
||||
for node_name, attr in g2.nodes(data=True):
|
||||
@ -228,6 +231,7 @@ def graph_merge(g1: nx.Graph, g2: nx.Graph, change: GraphChange):
|
||||
g1.graph["source_id"] += g2.graph.get("source_id", [])
|
||||
return g1
|
||||
|
||||
|
||||
def compute_args_hash(*args):
|
||||
return md5(str(args).encode()).hexdigest()
|
||||
|
||||
@ -378,6 +382,7 @@ async def graph_edge_to_chunk(kb_id, embd_mdl, from_ent_name, to_ent_name, meta,
|
||||
chunk["q_%d_vec" % len(ebd)] = ebd
|
||||
chunks.append(chunk)
|
||||
|
||||
|
||||
async def does_graph_contains(tenant_id, kb_id, doc_id):
|
||||
# Get doc_ids of graph
|
||||
fields = ["source_id"]
|
||||
@ -392,6 +397,7 @@ async def does_graph_contains(tenant_id, kb_id, doc_id):
|
||||
graph_doc_ids = set(fields2[chunk_id]["source_id"])
|
||||
return doc_id in graph_doc_ids
|
||||
|
||||
|
||||
async def get_graph_doc_ids(tenant_id, kb_id) -> list[str]:
|
||||
conds = {
|
||||
"fields": ["source_id"],
|
||||
|
||||
Reference in New Issue
Block a user