mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-23 23:16:58 +08:00
Feat: GraphRAG handle cancel gracefully (#11061)
### What problem does this PR solve? GraghRAG handle cancel gracefully. #10997. ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -21,6 +21,8 @@ import networkx as nx
|
||||
import trio
|
||||
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.task_service import has_canceled
|
||||
from common.exceptions import TaskCanceledException
|
||||
from common.misc_utils import get_uuid
|
||||
from common.connection_utils import timeout
|
||||
from graphrag.entity_resolution import EntityResolution
|
||||
@ -106,6 +108,7 @@ async def run_graphrag(
|
||||
chat_model,
|
||||
embedding_model,
|
||||
callback,
|
||||
task_id=row["id"],
|
||||
)
|
||||
if with_community:
|
||||
await graphrag_task_lock.spin_acquire()
|
||||
@ -118,6 +121,7 @@ async def run_graphrag(
|
||||
chat_model,
|
||||
embedding_model,
|
||||
callback,
|
||||
task_id=row["id"],
|
||||
)
|
||||
finally:
|
||||
graphrag_task_lock.release()
|
||||
@ -207,6 +211,10 @@ async def run_graphrag_for_kb(
|
||||
failed_docs: list[tuple[str, str]] = [] # (doc_id, error)
|
||||
|
||||
async def build_one(doc_id: str):
|
||||
if has_canceled(row["id"]):
|
||||
callback(msg=f"Task {row['id']} cancelled, stopping execution.")
|
||||
raise TaskCanceledException(f"Task {row['id']} was cancelled")
|
||||
|
||||
chunks = all_doc_chunks.get(doc_id, [])
|
||||
if not chunks:
|
||||
callback(msg=f"[GraphRAG] doc:{doc_id} has no available chunks, skip generation.")
|
||||
@ -232,6 +240,7 @@ async def run_graphrag_for_kb(
|
||||
chat_model,
|
||||
embedding_model,
|
||||
callback,
|
||||
task_id=row["id"]
|
||||
)
|
||||
if sg:
|
||||
subgraphs[doc_id] = sg
|
||||
@ -239,14 +248,24 @@ async def run_graphrag_for_kb(
|
||||
else:
|
||||
failed_docs.append((doc_id, "subgraph is empty"))
|
||||
callback(msg=f"{msg} empty")
|
||||
except TaskCanceledException as canceled:
|
||||
callback(msg=f"[GraphRAG] build_subgraph doc:{doc_id} FAILED: {canceled}")
|
||||
except Exception as e:
|
||||
failed_docs.append((doc_id, repr(e)))
|
||||
callback(msg=f"[GraphRAG] build_subgraph doc:{doc_id} FAILED: {e!r}")
|
||||
|
||||
if has_canceled(row["id"]):
|
||||
callback(msg=f"Task {row['id']} cancelled before processing documents.")
|
||||
raise TaskCanceledException(f"Task {row['id']} was cancelled")
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
for doc_id in doc_ids:
|
||||
nursery.start_soon(build_one, doc_id)
|
||||
|
||||
if has_canceled(row["id"]):
|
||||
callback(msg=f"Task {row['id']} cancelled after document processing.")
|
||||
raise TaskCanceledException(f"Task {row['id']} was cancelled")
|
||||
|
||||
ok_docs = [d for d in doc_ids if d in subgraphs]
|
||||
if not ok_docs:
|
||||
callback(msg=f"[GraphRAG] kb:{kb_id} no subgraphs generated successfully, end.")
|
||||
@ -257,6 +276,10 @@ async def run_graphrag_for_kb(
|
||||
await kb_lock.spin_acquire()
|
||||
callback(msg=f"[GraphRAG] kb:{kb_id} merge lock acquired")
|
||||
|
||||
if has_canceled(row["id"]):
|
||||
callback(msg=f"Task {row['id']} cancelled before merging subgraphs.")
|
||||
raise TaskCanceledException(f"Task {row['id']} was cancelled")
|
||||
|
||||
try:
|
||||
union_nodes: set = set()
|
||||
final_graph = None
|
||||
@ -288,6 +311,10 @@ async def run_graphrag_for_kb(
|
||||
callback(msg=f"[GraphRAG] KB merge done in {now - start:.2f}s. ok={len(ok_docs)} / total={len(doc_ids)}")
|
||||
return {"ok_docs": ok_docs, "failed_docs": failed_docs, "total_docs": len(doc_ids), "total_chunks": total_chunks, "seconds": now - start}
|
||||
|
||||
if has_canceled(row["id"]):
|
||||
callback(msg=f"Task {row['id']} cancelled before resolution/community extraction.")
|
||||
raise TaskCanceledException(f"Task {row['id']} was cancelled")
|
||||
|
||||
await kb_lock.spin_acquire()
|
||||
callback(msg=f"[GraphRAG] kb:{kb_id} post-merge lock acquired for resolution/community")
|
||||
|
||||
@ -306,6 +333,7 @@ async def run_graphrag_for_kb(
|
||||
chat_model,
|
||||
embedding_model,
|
||||
callback,
|
||||
task_id=row["id"],
|
||||
)
|
||||
|
||||
if with_community:
|
||||
@ -317,6 +345,7 @@ async def run_graphrag_for_kb(
|
||||
chat_model,
|
||||
embedding_model,
|
||||
callback,
|
||||
task_id=row["id"],
|
||||
)
|
||||
finally:
|
||||
kb_lock.release()
|
||||
@ -343,7 +372,12 @@ async def generate_subgraph(
|
||||
llm_bdl,
|
||||
embed_bdl,
|
||||
callback,
|
||||
task_id: str = "",
|
||||
):
|
||||
if task_id and has_canceled(task_id):
|
||||
callback(msg=f"Task {task_id} cancelled during subgraph generation for doc {doc_id}.")
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||
|
||||
contains = await does_graph_contains(tenant_id, kb_id, doc_id)
|
||||
if contains:
|
||||
callback(msg=f"Graph already contains {doc_id}")
|
||||
@ -354,15 +388,24 @@ async def generate_subgraph(
|
||||
language=language,
|
||||
entity_types=entity_types,
|
||||
)
|
||||
ents, rels = await ext(doc_id, chunks, callback)
|
||||
ents, rels = await ext(doc_id, chunks, callback, task_id=task_id)
|
||||
subgraph = nx.Graph()
|
||||
|
||||
for ent in ents:
|
||||
if task_id and has_canceled(task_id):
|
||||
callback(msg=f"Task {task_id} cancelled during entity processing for doc {doc_id}.")
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||
|
||||
assert "description" in ent, f"entity {ent} does not have description"
|
||||
ent["source_id"] = [doc_id]
|
||||
subgraph.add_node(ent["entity_name"], **ent)
|
||||
|
||||
ignored_rels = 0
|
||||
for rel in rels:
|
||||
if task_id and has_canceled(task_id):
|
||||
callback(msg=f"Task {task_id} cancelled during relationship processing for doc {doc_id}.")
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||
|
||||
assert "description" in rel, f"relation {rel} does not have description"
|
||||
if not subgraph.has_node(rel["src_id"]) or not subgraph.has_node(rel["tgt_id"]):
|
||||
ignored_rels += 1
|
||||
@ -434,17 +477,27 @@ async def resolve_entities(
|
||||
llm_bdl,
|
||||
embed_bdl,
|
||||
callback,
|
||||
task_id: str = "",
|
||||
):
|
||||
# Check if task has been canceled before resolution
|
||||
if task_id and has_canceled(task_id):
|
||||
callback(msg=f"Task {task_id} cancelled during entity resolution.")
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||
|
||||
start = trio.current_time()
|
||||
er = EntityResolution(
|
||||
llm_bdl,
|
||||
)
|
||||
reso = await er(graph, subgraph_nodes, callback=callback)
|
||||
reso = await er(graph, subgraph_nodes, callback=callback, task_id=task_id)
|
||||
graph = reso.graph
|
||||
change = reso.change
|
||||
callback(msg=f"Graph resolution removed {len(change.removed_nodes)} nodes and {len(change.removed_edges)} edges.")
|
||||
callback(msg="Graph resolution updated pagerank.")
|
||||
|
||||
if task_id and has_canceled(task_id):
|
||||
callback(msg=f"Task {task_id} cancelled after entity resolution.")
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||
|
||||
await set_graph(tenant_id, kb_id, embed_bdl, graph, change, callback)
|
||||
now = trio.current_time()
|
||||
callback(msg=f"Graph resolution done in {now - start:.2f}s.")
|
||||
@ -459,12 +512,22 @@ async def extract_community(
|
||||
llm_bdl,
|
||||
embed_bdl,
|
||||
callback,
|
||||
task_id: str = "",
|
||||
):
|
||||
if task_id and has_canceled(task_id):
|
||||
callback(msg=f"Task {task_id} cancelled before community extraction.")
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||
|
||||
start = trio.current_time()
|
||||
ext = CommunityReportsExtractor(
|
||||
llm_bdl,
|
||||
)
|
||||
cr = await ext(graph, callback=callback)
|
||||
cr = await ext(graph, callback=callback, task_id=task_id)
|
||||
|
||||
if task_id and has_canceled(task_id):
|
||||
callback(msg=f"Task {task_id} cancelled during community extraction.")
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||
|
||||
community_structure = cr.structured_output
|
||||
community_reports = cr.output
|
||||
doc_ids = graph.graph["source_id"]
|
||||
@ -472,6 +535,10 @@ async def extract_community(
|
||||
now = trio.current_time()
|
||||
callback(msg=f"Graph extracted {len(cr.structured_output)} communities in {now - start:.2f}s.")
|
||||
start = now
|
||||
if task_id and has_canceled(task_id):
|
||||
callback(msg=f"Task {task_id} cancelled during community indexing.")
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||
|
||||
chunks = []
|
||||
for stru, rep in zip(community_structure, community_reports):
|
||||
obj = {
|
||||
@ -509,6 +576,10 @@ async def extract_community(
|
||||
error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
|
||||
raise Exception(error_message)
|
||||
|
||||
if task_id and has_canceled(task_id):
|
||||
callback(msg=f"Task {task_id} cancelled after community indexing.")
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||
|
||||
now = trio.current_time()
|
||||
callback(msg=f"Graph indexed {len(cr.structured_output)} communities in {now - start:.2f}s.")
|
||||
return community_structure, community_reports
|
||||
|
||||
Reference in New Issue
Block a user