Feat: GraphRAG handle cancel gracefully (#11061)

### What problem does this PR solve?

 GraghRAG handle cancel gracefully. #10997.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Yongteng Lei
2025-11-06 16:12:20 +08:00
committed by GitHub
parent 66c01c7274
commit 23b81eae77
10 changed files with 206 additions and 47 deletions

View File

@ -29,6 +29,8 @@ import editdistance
from graphrag.entity_resolution_prompt import ENTITY_RESOLUTION_PROMPT
from rag.llm.chat_model import Base as CompletionLLM
from graphrag.utils import perform_variable_replacements, chat_limiter, GraphChange
from api.db.services.task_service import has_canceled
from common.exceptions import TaskCanceledException
DEFAULT_RECORD_DELIMITER = "##"
DEFAULT_ENTITY_INDEX_DELIMITER = "<|>"
@ -67,7 +69,8 @@ class EntityResolution(Extractor):
async def __call__(self, graph: nx.Graph,
subgraph_nodes: set[str],
prompt_variables: dict[str, Any] | None = None,
callback: Callable | None = None) -> EntityResolutionResult:
callback: Callable | None = None,
task_id: str = "") -> EntityResolutionResult:
"""Call method definition."""
if prompt_variables is None:
prompt_variables = {}
@ -109,7 +112,7 @@ class EntityResolution(Extractor):
try:
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
with trio.move_on_after(280 if enable_timeout_assertion else 1000000000) as cancel_scope:
await self._resolve_candidate(candidate_batch, result_set, result_lock)
await self._resolve_candidate(candidate_batch, result_set, result_lock, task_id)
remain_candidates_to_resolve = remain_candidates_to_resolve - len(candidate_batch[1])
callback(msg=f"Resolved {len(candidate_batch[1])} pairs, {remain_candidates_to_resolve} are remained to resolve. ")
if cancel_scope.cancelled_caught:
@ -136,7 +139,7 @@ class EntityResolution(Extractor):
async def limited_merge_nodes(graph, nodes, change):
async with semaphore:
await self._merge_graph_nodes(graph, nodes, change)
await self._merge_graph_nodes(graph, nodes, change, task_id)
async with trio.open_nursery() as nursery:
for sub_connect_graph in nx.connected_components(connect_graph):
@ -153,7 +156,12 @@ class EntityResolution(Extractor):
change=change,
)
async def _resolve_candidate(self, candidate_resolution_i: tuple[str, list[tuple[str, str]]], resolution_result: set[str], resolution_result_lock: trio.Lock):
async def _resolve_candidate(self, candidate_resolution_i: tuple[str, list[tuple[str, str]]], resolution_result: set[str], resolution_result_lock: trio.Lock, task_id: str = ""):
if task_id:
if has_canceled(task_id):
logging.info(f"Task {task_id} cancelled during entity resolution candidate processing.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
pair_txt = [
f'When determining whether two {candidate_resolution_i[0]}s are the same, you should only focus on critical properties and overlook noisy factors.\n']
for index, candidate in enumerate(candidate_resolution_i[1]):
@ -173,7 +181,7 @@ class EntityResolution(Extractor):
try:
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
with trio.move_on_after(280 if enable_timeout_assertion else 1000000000) as cancel_scope:
response = await trio.to_thread.run_sync(self._chat, text, [{"role": "user", "content": "Output:"}], {})
response = await trio.to_thread.run_sync(self._chat, text, [{"role": "user", "content": "Output:"}], {}, task_id)
if cancel_scope.cancelled_caught:
logging.warning("_resolve_candidate._chat timeout, skipping...")
return