Feat: GraphRAG handle cancel gracefully (#11061)

### What problem does this PR solve?

 GraghRAG handle cancel gracefully. #10997.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Yongteng Lei
2025-11-06 16:12:20 +08:00
committed by GitHub
parent 66c01c7274
commit 23b81eae77
10 changed files with 206 additions and 47 deletions

View File

@ -23,7 +23,9 @@ from typing import Callable
import networkx as nx
import trio
from api.db.services.task_service import has_canceled
from common.connection_utils import timeout
from common.token_utils import truncate
from graphrag.general.graph_prompt import SUMMARIZE_DESCRIPTIONS_PROMPT
from graphrag.utils import (
GraphChange,
@ -38,7 +40,7 @@ from graphrag.utils import (
)
from rag.llm.chat_model import Base as CompletionLLM
from rag.prompts.generator import message_fit_in
from common.token_utils import truncate
from common.exceptions import TaskCanceledException
GRAPH_FIELD_SEP = "<SEP>"
DEFAULT_ENTITY_TYPES = ["organization", "person", "geo", "event", "category"]
@ -60,7 +62,7 @@ class Extractor:
self._entity_types = entity_types or DEFAULT_ENTITY_TYPES
@timeout(60 * 20)
def _chat(self, system, history, gen_conf={}):
def _chat(self, system, history, gen_conf={}, task_id=""):
hist = deepcopy(history)
conf = deepcopy(gen_conf)
response = get_llm_cache(self._llm.llm_name, system, hist, conf)
@ -69,6 +71,12 @@ class Extractor:
_, system_msg = message_fit_in([{"role": "system", "content": system}], int(self._llm.max_length * 0.92))
response = ""
for attempt in range(3):
if task_id:
if has_canceled(task_id):
logging.info(f"Task {task_id} cancelled during entity resolution candidate processing.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
try:
response = self._llm.chat(system_msg[0]["content"], hist, conf)
response = re.sub(r"^.*</think>", "", response, flags=re.DOTALL)
@ -99,25 +107,29 @@ class Extractor:
maybe_edges[(if_relation["src_id"], if_relation["tgt_id"])].append(if_relation)
return dict(maybe_nodes), dict(maybe_edges)
async def __call__(self, doc_id: str, chunks: list[str], callback: Callable | None = None):
async def __call__(self, doc_id: str, chunks: list[str], callback: Callable | None = None, task_id: str = ""):
self.callback = callback
start_ts = trio.current_time()
async def extract_all(doc_id, chunks, max_concurrency=MAX_CONCURRENT_PROCESS_AND_EXTRACT_CHUNK):
async def extract_all(doc_id, chunks, max_concurrency=MAX_CONCURRENT_PROCESS_AND_EXTRACT_CHUNK, task_id=""):
out_results = []
error_count = 0
max_errors = 3
limiter = trio.Semaphore(max_concurrency)
async def worker(chunk_key_dp: tuple[str, str], idx: int, total: int):
async def worker(chunk_key_dp: tuple[str, str], idx: int, total: int, task_id=""):
nonlocal error_count
async with limiter:
if task_id and has_canceled(task_id):
raise TaskCanceledException(f"Task {task_id} was cancelled during entity extraction")
try:
await self._process_single_content(chunk_key_dp, idx, total, out_results)
await self._process_single_content(chunk_key_dp, idx, total, out_results, task_id)
except Exception as e:
error_count += 1
error_msg = f"Error processing chunk {idx+1}/{total}: {str(e)}"
error_msg = f"Error processing chunk {idx + 1}/{total}: {str(e)}"
logging.warning(error_msg)
if self.callback:
self.callback(msg=error_msg)
@ -127,7 +139,7 @@ class Extractor:
async with trio.open_nursery() as nursery:
for i, ck in enumerate(chunks):
nursery.start_soon(worker, (doc_id, ck), i, len(chunks))
nursery.start_soon(worker, (doc_id, ck), i, len(chunks), task_id)
if error_count > 0:
warning_msg = f"Completed with {error_count} errors (out of {len(chunks)} chunks processed)"
@ -137,7 +149,13 @@ class Extractor:
return out_results
out_results = await extract_all(doc_id, chunks, max_concurrency=MAX_CONCURRENT_PROCESS_AND_EXTRACT_CHUNK)
if task_id and has_canceled(task_id):
raise TaskCanceledException(f"Task {task_id} was cancelled before entity extraction")
out_results = await extract_all(doc_id, chunks, max_concurrency=MAX_CONCURRENT_PROCESS_AND_EXTRACT_CHUNK, task_id=task_id)
if task_id and has_canceled(task_id):
raise TaskCanceledException(f"Task {task_id} was cancelled after entity extraction")
maybe_nodes = defaultdict(list)
maybe_edges = defaultdict(list)
@ -154,9 +172,17 @@ class Extractor:
start_ts = now
logging.info("Entities merging...")
all_entities_data = []
if task_id and has_canceled(task_id):
raise TaskCanceledException(f"Task {task_id} was cancelled before nodes merging")
async with trio.open_nursery() as nursery:
for en_nm, ents in maybe_nodes.items():
nursery.start_soon(self._merge_nodes, en_nm, ents, all_entities_data)
nursery.start_soon(self._merge_nodes, en_nm, ents, all_entities_data, task_id)
if task_id and has_canceled(task_id):
raise TaskCanceledException(f"Task {task_id} was cancelled after nodes merging")
now = trio.current_time()
if self.callback:
self.callback(msg=f"Entities merging done, {now - start_ts:.2f}s.")
@ -164,9 +190,17 @@ class Extractor:
start_ts = now
logging.info("Relationships merging...")
all_relationships_data = []
if task_id and has_canceled(task_id):
raise TaskCanceledException(f"Task {task_id} was cancelled before relationships merging")
async with trio.open_nursery() as nursery:
for (src, tgt), rels in maybe_edges.items():
nursery.start_soon(self._merge_edges, src, tgt, rels, all_relationships_data)
nursery.start_soon(self._merge_edges, src, tgt, rels, all_relationships_data, task_id)
if task_id and has_canceled(task_id):
raise TaskCanceledException(f"Task {task_id} was cancelled after relationships merging")
now = trio.current_time()
if self.callback:
self.callback(msg=f"Relationships merging done, {now - start_ts:.2f}s.")
@ -181,7 +215,10 @@ class Extractor:
return all_entities_data, all_relationships_data
async def _merge_nodes(self, entity_name: str, entities: list[dict], all_relationships_data):
async def _merge_nodes(self, entity_name: str, entities: list[dict], all_relationships_data, task_id=""):
if task_id and has_canceled(task_id):
raise TaskCanceledException(f"Task {task_id} was cancelled during merge nodes")
if not entities:
return
entity_type = sorted(
@ -191,7 +228,7 @@ class Extractor:
)[0][0]
description = GRAPH_FIELD_SEP.join(sorted(set([dp["description"] for dp in entities])))
already_source_ids = flat_uniq_list(entities, "source_id")
description = await self._handle_entity_relation_summary(entity_name, description)
description = await self._handle_entity_relation_summary(entity_name, description, task_id=task_id)
node_data = dict(
entity_type=entity_type,
description=description,
@ -200,18 +237,21 @@ class Extractor:
node_data["entity_name"] = entity_name
all_relationships_data.append(node_data)
async def _merge_edges(self, src_id: str, tgt_id: str, edges_data: list[dict], all_relationships_data=None):
async def _merge_edges(self, src_id: str, tgt_id: str, edges_data: list[dict], all_relationships_data=None, task_id=""):
if not edges_data:
return
weight = sum([edge["weight"] for edge in edges_data])
description = GRAPH_FIELD_SEP.join(sorted(set([edge["description"] for edge in edges_data])))
description = await self._handle_entity_relation_summary(f"{src_id} -> {tgt_id}", description)
description = await self._handle_entity_relation_summary(f"{src_id} -> {tgt_id}", description, task_id=task_id)
keywords = flat_uniq_list(edges_data, "keywords")
source_id = flat_uniq_list(edges_data, "source_id")
edge_data = dict(src_id=src_id, tgt_id=tgt_id, description=description, keywords=keywords, weight=weight, source_id=source_id)
all_relationships_data.append(edge_data)
async def _merge_graph_nodes(self, graph: nx.Graph, nodes: list[str], change: GraphChange):
async def _merge_graph_nodes(self, graph: nx.Graph, nodes: list[str], change: GraphChange, task_id=""):
if task_id and has_canceled(task_id):
raise TaskCanceledException(f"Task {task_id} was cancelled during merge graph nodes")
if len(nodes) <= 1:
return
change.added_updated_nodes.add(nodes[0])
@ -220,6 +260,9 @@ class Extractor:
node0_attrs = graph.nodes[nodes[0]]
node0_neighbors = set(graph.neighbors(nodes[0]))
for node1 in nodes[1:]:
if task_id and has_canceled(task_id):
raise TaskCanceledException(f"Task {task_id} was cancelled during merge_graph nodes")
# Merge two nodes, keep "entity_name", "entity_type", "page_rank" unchanged.
node1_attrs = graph.nodes[node1]
node0_attrs["description"] += f"{GRAPH_FIELD_SEP}{node1_attrs['description']}"
@ -236,15 +279,18 @@ class Extractor:
edge0_attrs["description"] += f"{GRAPH_FIELD_SEP}{edge1_attrs['description']}"
for attr in ["keywords", "source_id"]:
edge0_attrs[attr] = sorted(set(edge0_attrs[attr] + edge1_attrs[attr]))
edge0_attrs["description"] = await self._handle_entity_relation_summary(f"({nodes[0]}, {neighbor})", edge0_attrs["description"])
edge0_attrs["description"] = await self._handle_entity_relation_summary(f"({nodes[0]}, {neighbor})", edge0_attrs["description"], task_id=task_id)
graph.add_edge(nodes[0], neighbor, **edge0_attrs)
else:
graph.add_edge(nodes[0], neighbor, **edge1_attrs)
graph.remove_node(node1)
node0_attrs["description"] = await self._handle_entity_relation_summary(nodes[0], node0_attrs["description"])
node0_attrs["description"] = await self._handle_entity_relation_summary(nodes[0], node0_attrs["description"], task_id=task_id)
graph.nodes[nodes[0]].update(node0_attrs)
async def _handle_entity_relation_summary(self, entity_or_relation_name: str, description: str) -> str:
async def _handle_entity_relation_summary(self, entity_or_relation_name: str, description: str, task_id="") -> str:
if task_id and has_canceled(task_id):
raise TaskCanceledException(f"Task {task_id} was cancelled during summary handling")
summary_max_tokens = 512
use_description = truncate(description, summary_max_tokens)
description_list = use_description.split(GRAPH_FIELD_SEP)
@ -258,6 +304,10 @@ class Extractor:
)
use_prompt = prompt_template.format(**context_base)
logging.info(f"Trigger summary: {entity_or_relation_name}")
if task_id and has_canceled(task_id):
raise TaskCanceledException(f"Task {task_id} was cancelled during summary handling")
async with chat_limiter:
summary = await trio.to_thread.run_sync(self._chat, "", [{"role": "user", "content": use_prompt}])
summary = await trio.to_thread.run_sync(self._chat, "", [{"role": "user", "content": use_prompt}], {}, task_id)
return summary