mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Feat: GraphRAG handle cancel gracefully (#11061)
### What problem does this PR solve? GraghRAG handle cancel gracefully. #10997. ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -14,6 +14,8 @@ from dataclasses import dataclass
|
||||
import networkx as nx
|
||||
import pandas as pd
|
||||
|
||||
from api.db.services.task_service import has_canceled
|
||||
from common.exceptions import TaskCanceledException
|
||||
from common.connection_utils import timeout
|
||||
from graphrag.general import leiden
|
||||
from graphrag.general.community_report_prompt import COMMUNITY_REPORT_PROMPT
|
||||
@ -51,7 +53,7 @@ class CommunityReportsExtractor(Extractor):
|
||||
self._extraction_prompt = COMMUNITY_REPORT_PROMPT
|
||||
self._max_report_length = max_report_length or 1500
|
||||
|
||||
async def __call__(self, graph: nx.Graph, callback: Callable | None = None):
|
||||
async def __call__(self, graph: nx.Graph, callback: Callable | None = None, task_id: str = ""):
|
||||
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
|
||||
for node_degree in graph.degree:
|
||||
graph.nodes[str(node_degree[0])]["rank"] = int(node_degree[1])
|
||||
@ -64,6 +66,11 @@ class CommunityReportsExtractor(Extractor):
|
||||
@timeout(120)
|
||||
async def extract_community_report(community):
|
||||
nonlocal res_str, res_dict, over, token_count
|
||||
if task_id:
|
||||
if has_canceled(task_id):
|
||||
logging.info(f"Task {task_id} cancelled during community report extraction.")
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||
|
||||
cm_id, cm = community
|
||||
weight = cm["weight"]
|
||||
ents = cm["nodes"]
|
||||
@ -95,7 +102,10 @@ class CommunityReportsExtractor(Extractor):
|
||||
async with chat_limiter:
|
||||
try:
|
||||
with trio.move_on_after(180 if enable_timeout_assertion else 1000000000) as cancel_scope:
|
||||
response = await trio.to_thread.run_sync( self._chat, text, [{"role": "user", "content": "Output:"}], {})
|
||||
if task_id and has_canceled(task_id):
|
||||
logging.info(f"Task {task_id} cancelled before LLM call.")
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||
response = await trio.to_thread.run_sync( self._chat, text, [{"role": "user", "content": "Output:"}], {}, task_id)
|
||||
if cancel_scope.cancelled_caught:
|
||||
logging.warning("extract_community_report._chat timeout, skipping...")
|
||||
return
|
||||
@ -136,6 +146,9 @@ class CommunityReportsExtractor(Extractor):
|
||||
for level, comm in communities.items():
|
||||
logging.info(f"Level {level}: Community: {len(comm.keys())}")
|
||||
for community in comm.items():
|
||||
if task_id and has_canceled(task_id):
|
||||
logging.info(f"Task {task_id} cancelled before community processing.")
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||
nursery.start_soon(extract_community_report, community)
|
||||
if callback:
|
||||
callback(msg=f"Community reports done in {trio.current_time() - st:.2f}s, used tokens: {token_count}")
|
||||
|
||||
Reference in New Issue
Block a user