mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-01-31 23:55:06 +08:00
Feat: GraphRAG handle cancel gracefully (#11061)
### What problem does this PR solve? GraghRAG handle cancel gracefully. #10997. ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -29,6 +29,8 @@ import editdistance
|
||||
from graphrag.entity_resolution_prompt import ENTITY_RESOLUTION_PROMPT
|
||||
from rag.llm.chat_model import Base as CompletionLLM
|
||||
from graphrag.utils import perform_variable_replacements, chat_limiter, GraphChange
|
||||
from api.db.services.task_service import has_canceled
|
||||
from common.exceptions import TaskCanceledException
|
||||
|
||||
DEFAULT_RECORD_DELIMITER = "##"
|
||||
DEFAULT_ENTITY_INDEX_DELIMITER = "<|>"
|
||||
@ -67,7 +69,8 @@ class EntityResolution(Extractor):
|
||||
async def __call__(self, graph: nx.Graph,
|
||||
subgraph_nodes: set[str],
|
||||
prompt_variables: dict[str, Any] | None = None,
|
||||
callback: Callable | None = None) -> EntityResolutionResult:
|
||||
callback: Callable | None = None,
|
||||
task_id: str = "") -> EntityResolutionResult:
|
||||
"""Call method definition."""
|
||||
if prompt_variables is None:
|
||||
prompt_variables = {}
|
||||
@ -109,7 +112,7 @@ class EntityResolution(Extractor):
|
||||
try:
|
||||
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
|
||||
with trio.move_on_after(280 if enable_timeout_assertion else 1000000000) as cancel_scope:
|
||||
await self._resolve_candidate(candidate_batch, result_set, result_lock)
|
||||
await self._resolve_candidate(candidate_batch, result_set, result_lock, task_id)
|
||||
remain_candidates_to_resolve = remain_candidates_to_resolve - len(candidate_batch[1])
|
||||
callback(msg=f"Resolved {len(candidate_batch[1])} pairs, {remain_candidates_to_resolve} are remained to resolve. ")
|
||||
if cancel_scope.cancelled_caught:
|
||||
@ -136,7 +139,7 @@ class EntityResolution(Extractor):
|
||||
|
||||
async def limited_merge_nodes(graph, nodes, change):
|
||||
async with semaphore:
|
||||
await self._merge_graph_nodes(graph, nodes, change)
|
||||
await self._merge_graph_nodes(graph, nodes, change, task_id)
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
for sub_connect_graph in nx.connected_components(connect_graph):
|
||||
@ -153,7 +156,12 @@ class EntityResolution(Extractor):
|
||||
change=change,
|
||||
)
|
||||
|
||||
async def _resolve_candidate(self, candidate_resolution_i: tuple[str, list[tuple[str, str]]], resolution_result: set[str], resolution_result_lock: trio.Lock):
|
||||
async def _resolve_candidate(self, candidate_resolution_i: tuple[str, list[tuple[str, str]]], resolution_result: set[str], resolution_result_lock: trio.Lock, task_id: str = ""):
|
||||
if task_id:
|
||||
if has_canceled(task_id):
|
||||
logging.info(f"Task {task_id} cancelled during entity resolution candidate processing.")
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||
|
||||
pair_txt = [
|
||||
f'When determining whether two {candidate_resolution_i[0]}s are the same, you should only focus on critical properties and overlook noisy factors.\n']
|
||||
for index, candidate in enumerate(candidate_resolution_i[1]):
|
||||
@ -173,7 +181,7 @@ class EntityResolution(Extractor):
|
||||
try:
|
||||
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
|
||||
with trio.move_on_after(280 if enable_timeout_assertion else 1000000000) as cancel_scope:
|
||||
response = await trio.to_thread.run_sync(self._chat, text, [{"role": "user", "content": "Output:"}], {})
|
||||
response = await trio.to_thread.run_sync(self._chat, text, [{"role": "user", "content": "Output:"}], {}, task_id)
|
||||
if cancel_scope.cancelled_caught:
|
||||
logging.warning("_resolve_candidate._chat timeout, skipping...")
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user