Feat: GraphRAG handle cancel gracefully (#11061)

### What problem does this PR solve?

 GraghRAG handle cancel gracefully. #10997.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Yongteng Lei
2025-11-06 16:12:20 +08:00
committed by GitHub
parent 66c01c7274
commit 23b81eae77
10 changed files with 206 additions and 47 deletions

View File

@ -21,6 +21,8 @@ import networkx as nx
import trio
from api.db.services.document_service import DocumentService
from api.db.services.task_service import has_canceled
from common.exceptions import TaskCanceledException
from common.misc_utils import get_uuid
from common.connection_utils import timeout
from graphrag.entity_resolution import EntityResolution
@ -106,6 +108,7 @@ async def run_graphrag(
chat_model,
embedding_model,
callback,
task_id=row["id"],
)
if with_community:
await graphrag_task_lock.spin_acquire()
@ -118,6 +121,7 @@ async def run_graphrag(
chat_model,
embedding_model,
callback,
task_id=row["id"],
)
finally:
graphrag_task_lock.release()
@ -207,6 +211,10 @@ async def run_graphrag_for_kb(
failed_docs: list[tuple[str, str]] = [] # (doc_id, error)
async def build_one(doc_id: str):
if has_canceled(row["id"]):
callback(msg=f"Task {row['id']} cancelled, stopping execution.")
raise TaskCanceledException(f"Task {row['id']} was cancelled")
chunks = all_doc_chunks.get(doc_id, [])
if not chunks:
callback(msg=f"[GraphRAG] doc:{doc_id} has no available chunks, skip generation.")
@ -232,6 +240,7 @@ async def run_graphrag_for_kb(
chat_model,
embedding_model,
callback,
task_id=row["id"]
)
if sg:
subgraphs[doc_id] = sg
@ -239,14 +248,24 @@ async def run_graphrag_for_kb(
else:
failed_docs.append((doc_id, "subgraph is empty"))
callback(msg=f"{msg} empty")
except TaskCanceledException as canceled:
callback(msg=f"[GraphRAG] build_subgraph doc:{doc_id} FAILED: {canceled}")
except Exception as e:
failed_docs.append((doc_id, repr(e)))
callback(msg=f"[GraphRAG] build_subgraph doc:{doc_id} FAILED: {e!r}")
if has_canceled(row["id"]):
callback(msg=f"Task {row['id']} cancelled before processing documents.")
raise TaskCanceledException(f"Task {row['id']} was cancelled")
async with trio.open_nursery() as nursery:
for doc_id in doc_ids:
nursery.start_soon(build_one, doc_id)
if has_canceled(row["id"]):
callback(msg=f"Task {row['id']} cancelled after document processing.")
raise TaskCanceledException(f"Task {row['id']} was cancelled")
ok_docs = [d for d in doc_ids if d in subgraphs]
if not ok_docs:
callback(msg=f"[GraphRAG] kb:{kb_id} no subgraphs generated successfully, end.")
@ -257,6 +276,10 @@ async def run_graphrag_for_kb(
await kb_lock.spin_acquire()
callback(msg=f"[GraphRAG] kb:{kb_id} merge lock acquired")
if has_canceled(row["id"]):
callback(msg=f"Task {row['id']} cancelled before merging subgraphs.")
raise TaskCanceledException(f"Task {row['id']} was cancelled")
try:
union_nodes: set = set()
final_graph = None
@ -288,6 +311,10 @@ async def run_graphrag_for_kb(
callback(msg=f"[GraphRAG] KB merge done in {now - start:.2f}s. ok={len(ok_docs)} / total={len(doc_ids)}")
return {"ok_docs": ok_docs, "failed_docs": failed_docs, "total_docs": len(doc_ids), "total_chunks": total_chunks, "seconds": now - start}
if has_canceled(row["id"]):
callback(msg=f"Task {row['id']} cancelled before resolution/community extraction.")
raise TaskCanceledException(f"Task {row['id']} was cancelled")
await kb_lock.spin_acquire()
callback(msg=f"[GraphRAG] kb:{kb_id} post-merge lock acquired for resolution/community")
@ -306,6 +333,7 @@ async def run_graphrag_for_kb(
chat_model,
embedding_model,
callback,
task_id=row["id"],
)
if with_community:
@ -317,6 +345,7 @@ async def run_graphrag_for_kb(
chat_model,
embedding_model,
callback,
task_id=row["id"],
)
finally:
kb_lock.release()
@ -343,7 +372,12 @@ async def generate_subgraph(
llm_bdl,
embed_bdl,
callback,
task_id: str = "",
):
if task_id and has_canceled(task_id):
callback(msg=f"Task {task_id} cancelled during subgraph generation for doc {doc_id}.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
contains = await does_graph_contains(tenant_id, kb_id, doc_id)
if contains:
callback(msg=f"Graph already contains {doc_id}")
@ -354,15 +388,24 @@ async def generate_subgraph(
language=language,
entity_types=entity_types,
)
ents, rels = await ext(doc_id, chunks, callback)
ents, rels = await ext(doc_id, chunks, callback, task_id=task_id)
subgraph = nx.Graph()
for ent in ents:
if task_id and has_canceled(task_id):
callback(msg=f"Task {task_id} cancelled during entity processing for doc {doc_id}.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
assert "description" in ent, f"entity {ent} does not have description"
ent["source_id"] = [doc_id]
subgraph.add_node(ent["entity_name"], **ent)
ignored_rels = 0
for rel in rels:
if task_id and has_canceled(task_id):
callback(msg=f"Task {task_id} cancelled during relationship processing for doc {doc_id}.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
assert "description" in rel, f"relation {rel} does not have description"
if not subgraph.has_node(rel["src_id"]) or not subgraph.has_node(rel["tgt_id"]):
ignored_rels += 1
@ -434,17 +477,27 @@ async def resolve_entities(
llm_bdl,
embed_bdl,
callback,
task_id: str = "",
):
# Check if task has been canceled before resolution
if task_id and has_canceled(task_id):
callback(msg=f"Task {task_id} cancelled during entity resolution.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
start = trio.current_time()
er = EntityResolution(
llm_bdl,
)
reso = await er(graph, subgraph_nodes, callback=callback)
reso = await er(graph, subgraph_nodes, callback=callback, task_id=task_id)
graph = reso.graph
change = reso.change
callback(msg=f"Graph resolution removed {len(change.removed_nodes)} nodes and {len(change.removed_edges)} edges.")
callback(msg="Graph resolution updated pagerank.")
if task_id and has_canceled(task_id):
callback(msg=f"Task {task_id} cancelled after entity resolution.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
await set_graph(tenant_id, kb_id, embed_bdl, graph, change, callback)
now = trio.current_time()
callback(msg=f"Graph resolution done in {now - start:.2f}s.")
@ -459,12 +512,22 @@ async def extract_community(
llm_bdl,
embed_bdl,
callback,
task_id: str = "",
):
if task_id and has_canceled(task_id):
callback(msg=f"Task {task_id} cancelled before community extraction.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
start = trio.current_time()
ext = CommunityReportsExtractor(
llm_bdl,
)
cr = await ext(graph, callback=callback)
cr = await ext(graph, callback=callback, task_id=task_id)
if task_id and has_canceled(task_id):
callback(msg=f"Task {task_id} cancelled during community extraction.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
community_structure = cr.structured_output
community_reports = cr.output
doc_ids = graph.graph["source_id"]
@ -472,6 +535,10 @@ async def extract_community(
now = trio.current_time()
callback(msg=f"Graph extracted {len(cr.structured_output)} communities in {now - start:.2f}s.")
start = now
if task_id and has_canceled(task_id):
callback(msg=f"Task {task_id} cancelled during community indexing.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
chunks = []
for stru, rep in zip(community_structure, community_reports):
obj = {
@ -509,6 +576,10 @@ async def extract_community(
error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
raise Exception(error_message)
if task_id and has_canceled(task_id):
callback(msg=f"Task {task_id} cancelled after community indexing.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
now = trio.current_time()
callback(msg=f"Graph indexed {len(cr.structured_output)} communities in {now - start:.2f}s.")
return community_structure, community_reports