From 23b81eae77da3152653d6d9583b39439ba5cb264 Mon Sep 17 00:00:00 2001 From: Yongteng Lei Date: Thu, 6 Nov 2025 16:12:20 +0800 Subject: [PATCH] Feat: GraphRAG handle cancel gracefully (#11061) ### What problem does this PR solve? GraghRAG handle cancel gracefully. #10997. ### Type of change - [x] New Feature (non-breaking change which adds functionality) --- api/apps/kb_app.py | 14 +-- common/exceptions.py | 18 ++++ graphrag/entity_resolution.py | 18 ++-- .../general/community_reports_extractor.py | 17 +++- graphrag/general/extractor.py | 90 ++++++++++++++----- graphrag/general/graph_extractor.py | 4 +- graphrag/general/index.py | 77 +++++++++++++++- graphrag/light/graph_extractor.py | 8 +- rag/flow/pipeline.py | 2 +- rag/svr/task_executor.py | 5 +- 10 files changed, 206 insertions(+), 47 deletions(-) create mode 100644 common/exceptions.py diff --git a/api/apps/kb_app.py b/api/apps/kb_app.py index 36b6bf78f..74e90db21 100644 --- a/api/apps/kb_app.py +++ b/api/apps/kb_app.py @@ -38,7 +38,7 @@ from api.utils.api_utils import get_json_result from rag.nlp import search from api.constants import DATASET_NAME_LIMIT from rag.utils.redis_conn import REDIS_CONN -from rag.utils.doc_store_conn import OrderByExpr +from rag.utils.doc_store_conn import OrderByExpr from common.constants import RetCode, PipelineTaskType, StatusEnum, VALID_TASK_STATUS, FileSource, LLMType, PAGERANK_FLD from common import settings @@ -52,7 +52,7 @@ def create(): tenant_id = current_user.id, parser_id = req.pop("parser_id", None), **req - ) + ) try: if not KnowledgebaseService.save(**req): @@ -571,7 +571,7 @@ def trace_graphrag(): ok, task = TaskService.get_by_id(task_id) if not ok: - return get_error_data_result(message="GraphRAG Task Not Found or Error Occurred") + return get_json_result(data={}) return get_json_result(data=task.to_dict()) @@ -780,14 +780,14 @@ def check_embedding(): def _to_1d(x): a = np.asarray(x, dtype=np.float32) - return a.reshape(-1) + return a.reshape(-1) def _cos_sim(a, b, eps=1e-12): a = _to_1d(a) b = _to_1d(b) na = np.linalg.norm(a) nb = np.linalg.norm(b) - if na < eps or nb < eps: + if na < eps or nb < eps: return 0.0 return float(np.dot(a, b) / (na * nb)) @@ -825,7 +825,7 @@ def check_embedding(): indexNames=index_nm, knowledgebaseIds=[kb_id] ) ids = docStoreConn.getChunkIds(res1) - if not ids: + if not ids: continue cid = ids[0] @@ -869,7 +869,7 @@ def check_embedding(): continue try: - qv, _ = emb_mdl.encode_queries(txt) + qv, _ = emb_mdl.encode_queries(txt) sim = _cos_sim(qv, ck["vector"]) except Exception: return get_error_data_result(message="embedding failure") diff --git a/common/exceptions.py b/common/exceptions.py new file mode 100644 index 000000000..c0caac484 --- /dev/null +++ b/common/exceptions.py @@ -0,0 +1,18 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +class TaskCanceledException(Exception): + def __init__(self, msg): + self.msg = msg diff --git a/graphrag/entity_resolution.py b/graphrag/entity_resolution.py index 11bacd15e..7ffc52538 100644 --- a/graphrag/entity_resolution.py +++ b/graphrag/entity_resolution.py @@ -29,6 +29,8 @@ import editdistance from graphrag.entity_resolution_prompt import ENTITY_RESOLUTION_PROMPT from rag.llm.chat_model import Base as CompletionLLM from graphrag.utils import perform_variable_replacements, chat_limiter, GraphChange +from api.db.services.task_service import has_canceled +from common.exceptions import TaskCanceledException DEFAULT_RECORD_DELIMITER = "##" DEFAULT_ENTITY_INDEX_DELIMITER = "<|>" @@ -67,7 +69,8 @@ class EntityResolution(Extractor): async def __call__(self, graph: nx.Graph, subgraph_nodes: set[str], prompt_variables: dict[str, Any] | None = None, - callback: Callable | None = None) -> EntityResolutionResult: + callback: Callable | None = None, + task_id: str = "") -> EntityResolutionResult: """Call method definition.""" if prompt_variables is None: prompt_variables = {} @@ -109,7 +112,7 @@ class EntityResolution(Extractor): try: enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION") with trio.move_on_after(280 if enable_timeout_assertion else 1000000000) as cancel_scope: - await self._resolve_candidate(candidate_batch, result_set, result_lock) + await self._resolve_candidate(candidate_batch, result_set, result_lock, task_id) remain_candidates_to_resolve = remain_candidates_to_resolve - len(candidate_batch[1]) callback(msg=f"Resolved {len(candidate_batch[1])} pairs, {remain_candidates_to_resolve} are remained to resolve. ") if cancel_scope.cancelled_caught: @@ -136,7 +139,7 @@ class EntityResolution(Extractor): async def limited_merge_nodes(graph, nodes, change): async with semaphore: - await self._merge_graph_nodes(graph, nodes, change) + await self._merge_graph_nodes(graph, nodes, change, task_id) async with trio.open_nursery() as nursery: for sub_connect_graph in nx.connected_components(connect_graph): @@ -153,7 +156,12 @@ class EntityResolution(Extractor): change=change, ) - async def _resolve_candidate(self, candidate_resolution_i: tuple[str, list[tuple[str, str]]], resolution_result: set[str], resolution_result_lock: trio.Lock): + async def _resolve_candidate(self, candidate_resolution_i: tuple[str, list[tuple[str, str]]], resolution_result: set[str], resolution_result_lock: trio.Lock, task_id: str = ""): + if task_id: + if has_canceled(task_id): + logging.info(f"Task {task_id} cancelled during entity resolution candidate processing.") + raise TaskCanceledException(f"Task {task_id} was cancelled") + pair_txt = [ f'When determining whether two {candidate_resolution_i[0]}s are the same, you should only focus on critical properties and overlook noisy factors.\n'] for index, candidate in enumerate(candidate_resolution_i[1]): @@ -173,7 +181,7 @@ class EntityResolution(Extractor): try: enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION") with trio.move_on_after(280 if enable_timeout_assertion else 1000000000) as cancel_scope: - response = await trio.to_thread.run_sync(self._chat, text, [{"role": "user", "content": "Output:"}], {}) + response = await trio.to_thread.run_sync(self._chat, text, [{"role": "user", "content": "Output:"}], {}, task_id) if cancel_scope.cancelled_caught: logging.warning("_resolve_candidate._chat timeout, skipping...") return diff --git a/graphrag/general/community_reports_extractor.py b/graphrag/general/community_reports_extractor.py index 6c49c0a73..09634fb4d 100644 --- a/graphrag/general/community_reports_extractor.py +++ b/graphrag/general/community_reports_extractor.py @@ -14,6 +14,8 @@ from dataclasses import dataclass import networkx as nx import pandas as pd +from api.db.services.task_service import has_canceled +from common.exceptions import TaskCanceledException from common.connection_utils import timeout from graphrag.general import leiden from graphrag.general.community_report_prompt import COMMUNITY_REPORT_PROMPT @@ -51,7 +53,7 @@ class CommunityReportsExtractor(Extractor): self._extraction_prompt = COMMUNITY_REPORT_PROMPT self._max_report_length = max_report_length or 1500 - async def __call__(self, graph: nx.Graph, callback: Callable | None = None): + async def __call__(self, graph: nx.Graph, callback: Callable | None = None, task_id: str = ""): enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION") for node_degree in graph.degree: graph.nodes[str(node_degree[0])]["rank"] = int(node_degree[1]) @@ -64,6 +66,11 @@ class CommunityReportsExtractor(Extractor): @timeout(120) async def extract_community_report(community): nonlocal res_str, res_dict, over, token_count + if task_id: + if has_canceled(task_id): + logging.info(f"Task {task_id} cancelled during community report extraction.") + raise TaskCanceledException(f"Task {task_id} was cancelled") + cm_id, cm = community weight = cm["weight"] ents = cm["nodes"] @@ -95,7 +102,10 @@ class CommunityReportsExtractor(Extractor): async with chat_limiter: try: with trio.move_on_after(180 if enable_timeout_assertion else 1000000000) as cancel_scope: - response = await trio.to_thread.run_sync( self._chat, text, [{"role": "user", "content": "Output:"}], {}) + if task_id and has_canceled(task_id): + logging.info(f"Task {task_id} cancelled before LLM call.") + raise TaskCanceledException(f"Task {task_id} was cancelled") + response = await trio.to_thread.run_sync( self._chat, text, [{"role": "user", "content": "Output:"}], {}, task_id) if cancel_scope.cancelled_caught: logging.warning("extract_community_report._chat timeout, skipping...") return @@ -136,6 +146,9 @@ class CommunityReportsExtractor(Extractor): for level, comm in communities.items(): logging.info(f"Level {level}: Community: {len(comm.keys())}") for community in comm.items(): + if task_id and has_canceled(task_id): + logging.info(f"Task {task_id} cancelled before community processing.") + raise TaskCanceledException(f"Task {task_id} was cancelled") nursery.start_soon(extract_community_report, community) if callback: callback(msg=f"Community reports done in {trio.current_time() - st:.2f}s, used tokens: {token_count}") diff --git a/graphrag/general/extractor.py b/graphrag/general/extractor.py index 9b18d694f..1df38ed1c 100644 --- a/graphrag/general/extractor.py +++ b/graphrag/general/extractor.py @@ -23,7 +23,9 @@ from typing import Callable import networkx as nx import trio +from api.db.services.task_service import has_canceled from common.connection_utils import timeout +from common.token_utils import truncate from graphrag.general.graph_prompt import SUMMARIZE_DESCRIPTIONS_PROMPT from graphrag.utils import ( GraphChange, @@ -38,7 +40,7 @@ from graphrag.utils import ( ) from rag.llm.chat_model import Base as CompletionLLM from rag.prompts.generator import message_fit_in -from common.token_utils import truncate +from common.exceptions import TaskCanceledException GRAPH_FIELD_SEP = "" DEFAULT_ENTITY_TYPES = ["organization", "person", "geo", "event", "category"] @@ -60,7 +62,7 @@ class Extractor: self._entity_types = entity_types or DEFAULT_ENTITY_TYPES @timeout(60 * 20) - def _chat(self, system, history, gen_conf={}): + def _chat(self, system, history, gen_conf={}, task_id=""): hist = deepcopy(history) conf = deepcopy(gen_conf) response = get_llm_cache(self._llm.llm_name, system, hist, conf) @@ -69,6 +71,12 @@ class Extractor: _, system_msg = message_fit_in([{"role": "system", "content": system}], int(self._llm.max_length * 0.92)) response = "" for attempt in range(3): + + if task_id: + if has_canceled(task_id): + logging.info(f"Task {task_id} cancelled during entity resolution candidate processing.") + raise TaskCanceledException(f"Task {task_id} was cancelled") + try: response = self._llm.chat(system_msg[0]["content"], hist, conf) response = re.sub(r"^.*", "", response, flags=re.DOTALL) @@ -99,25 +107,29 @@ class Extractor: maybe_edges[(if_relation["src_id"], if_relation["tgt_id"])].append(if_relation) return dict(maybe_nodes), dict(maybe_edges) - async def __call__(self, doc_id: str, chunks: list[str], callback: Callable | None = None): + async def __call__(self, doc_id: str, chunks: list[str], callback: Callable | None = None, task_id: str = ""): self.callback = callback start_ts = trio.current_time() - async def extract_all(doc_id, chunks, max_concurrency=MAX_CONCURRENT_PROCESS_AND_EXTRACT_CHUNK): + async def extract_all(doc_id, chunks, max_concurrency=MAX_CONCURRENT_PROCESS_AND_EXTRACT_CHUNK, task_id=""): out_results = [] error_count = 0 max_errors = 3 limiter = trio.Semaphore(max_concurrency) - async def worker(chunk_key_dp: tuple[str, str], idx: int, total: int): + async def worker(chunk_key_dp: tuple[str, str], idx: int, total: int, task_id=""): nonlocal error_count async with limiter: + + if task_id and has_canceled(task_id): + raise TaskCanceledException(f"Task {task_id} was cancelled during entity extraction") + try: - await self._process_single_content(chunk_key_dp, idx, total, out_results) + await self._process_single_content(chunk_key_dp, idx, total, out_results, task_id) except Exception as e: error_count += 1 - error_msg = f"Error processing chunk {idx+1}/{total}: {str(e)}" + error_msg = f"Error processing chunk {idx + 1}/{total}: {str(e)}" logging.warning(error_msg) if self.callback: self.callback(msg=error_msg) @@ -127,7 +139,7 @@ class Extractor: async with trio.open_nursery() as nursery: for i, ck in enumerate(chunks): - nursery.start_soon(worker, (doc_id, ck), i, len(chunks)) + nursery.start_soon(worker, (doc_id, ck), i, len(chunks), task_id) if error_count > 0: warning_msg = f"Completed with {error_count} errors (out of {len(chunks)} chunks processed)" @@ -137,7 +149,13 @@ class Extractor: return out_results - out_results = await extract_all(doc_id, chunks, max_concurrency=MAX_CONCURRENT_PROCESS_AND_EXTRACT_CHUNK) + if task_id and has_canceled(task_id): + raise TaskCanceledException(f"Task {task_id} was cancelled before entity extraction") + + out_results = await extract_all(doc_id, chunks, max_concurrency=MAX_CONCURRENT_PROCESS_AND_EXTRACT_CHUNK, task_id=task_id) + + if task_id and has_canceled(task_id): + raise TaskCanceledException(f"Task {task_id} was cancelled after entity extraction") maybe_nodes = defaultdict(list) maybe_edges = defaultdict(list) @@ -154,9 +172,17 @@ class Extractor: start_ts = now logging.info("Entities merging...") all_entities_data = [] + + if task_id and has_canceled(task_id): + raise TaskCanceledException(f"Task {task_id} was cancelled before nodes merging") + async with trio.open_nursery() as nursery: for en_nm, ents in maybe_nodes.items(): - nursery.start_soon(self._merge_nodes, en_nm, ents, all_entities_data) + nursery.start_soon(self._merge_nodes, en_nm, ents, all_entities_data, task_id) + + if task_id and has_canceled(task_id): + raise TaskCanceledException(f"Task {task_id} was cancelled after nodes merging") + now = trio.current_time() if self.callback: self.callback(msg=f"Entities merging done, {now - start_ts:.2f}s.") @@ -164,9 +190,17 @@ class Extractor: start_ts = now logging.info("Relationships merging...") all_relationships_data = [] + + if task_id and has_canceled(task_id): + raise TaskCanceledException(f"Task {task_id} was cancelled before relationships merging") + async with trio.open_nursery() as nursery: for (src, tgt), rels in maybe_edges.items(): - nursery.start_soon(self._merge_edges, src, tgt, rels, all_relationships_data) + nursery.start_soon(self._merge_edges, src, tgt, rels, all_relationships_data, task_id) + + if task_id and has_canceled(task_id): + raise TaskCanceledException(f"Task {task_id} was cancelled after relationships merging") + now = trio.current_time() if self.callback: self.callback(msg=f"Relationships merging done, {now - start_ts:.2f}s.") @@ -181,7 +215,10 @@ class Extractor: return all_entities_data, all_relationships_data - async def _merge_nodes(self, entity_name: str, entities: list[dict], all_relationships_data): + async def _merge_nodes(self, entity_name: str, entities: list[dict], all_relationships_data, task_id=""): + if task_id and has_canceled(task_id): + raise TaskCanceledException(f"Task {task_id} was cancelled during merge nodes") + if not entities: return entity_type = sorted( @@ -191,7 +228,7 @@ class Extractor: )[0][0] description = GRAPH_FIELD_SEP.join(sorted(set([dp["description"] for dp in entities]))) already_source_ids = flat_uniq_list(entities, "source_id") - description = await self._handle_entity_relation_summary(entity_name, description) + description = await self._handle_entity_relation_summary(entity_name, description, task_id=task_id) node_data = dict( entity_type=entity_type, description=description, @@ -200,18 +237,21 @@ class Extractor: node_data["entity_name"] = entity_name all_relationships_data.append(node_data) - async def _merge_edges(self, src_id: str, tgt_id: str, edges_data: list[dict], all_relationships_data=None): + async def _merge_edges(self, src_id: str, tgt_id: str, edges_data: list[dict], all_relationships_data=None, task_id=""): if not edges_data: return weight = sum([edge["weight"] for edge in edges_data]) description = GRAPH_FIELD_SEP.join(sorted(set([edge["description"] for edge in edges_data]))) - description = await self._handle_entity_relation_summary(f"{src_id} -> {tgt_id}", description) + description = await self._handle_entity_relation_summary(f"{src_id} -> {tgt_id}", description, task_id=task_id) keywords = flat_uniq_list(edges_data, "keywords") source_id = flat_uniq_list(edges_data, "source_id") edge_data = dict(src_id=src_id, tgt_id=tgt_id, description=description, keywords=keywords, weight=weight, source_id=source_id) all_relationships_data.append(edge_data) - async def _merge_graph_nodes(self, graph: nx.Graph, nodes: list[str], change: GraphChange): + async def _merge_graph_nodes(self, graph: nx.Graph, nodes: list[str], change: GraphChange, task_id=""): + if task_id and has_canceled(task_id): + raise TaskCanceledException(f"Task {task_id} was cancelled during merge graph nodes") + if len(nodes) <= 1: return change.added_updated_nodes.add(nodes[0]) @@ -220,6 +260,9 @@ class Extractor: node0_attrs = graph.nodes[nodes[0]] node0_neighbors = set(graph.neighbors(nodes[0])) for node1 in nodes[1:]: + if task_id and has_canceled(task_id): + raise TaskCanceledException(f"Task {task_id} was cancelled during merge_graph nodes") + # Merge two nodes, keep "entity_name", "entity_type", "page_rank" unchanged. node1_attrs = graph.nodes[node1] node0_attrs["description"] += f"{GRAPH_FIELD_SEP}{node1_attrs['description']}" @@ -236,15 +279,18 @@ class Extractor: edge0_attrs["description"] += f"{GRAPH_FIELD_SEP}{edge1_attrs['description']}" for attr in ["keywords", "source_id"]: edge0_attrs[attr] = sorted(set(edge0_attrs[attr] + edge1_attrs[attr])) - edge0_attrs["description"] = await self._handle_entity_relation_summary(f"({nodes[0]}, {neighbor})", edge0_attrs["description"]) + edge0_attrs["description"] = await self._handle_entity_relation_summary(f"({nodes[0]}, {neighbor})", edge0_attrs["description"], task_id=task_id) graph.add_edge(nodes[0], neighbor, **edge0_attrs) else: graph.add_edge(nodes[0], neighbor, **edge1_attrs) graph.remove_node(node1) - node0_attrs["description"] = await self._handle_entity_relation_summary(nodes[0], node0_attrs["description"]) + node0_attrs["description"] = await self._handle_entity_relation_summary(nodes[0], node0_attrs["description"], task_id=task_id) graph.nodes[nodes[0]].update(node0_attrs) - async def _handle_entity_relation_summary(self, entity_or_relation_name: str, description: str) -> str: + async def _handle_entity_relation_summary(self, entity_or_relation_name: str, description: str, task_id="") -> str: + if task_id and has_canceled(task_id): + raise TaskCanceledException(f"Task {task_id} was cancelled during summary handling") + summary_max_tokens = 512 use_description = truncate(description, summary_max_tokens) description_list = use_description.split(GRAPH_FIELD_SEP) @@ -258,6 +304,10 @@ class Extractor: ) use_prompt = prompt_template.format(**context_base) logging.info(f"Trigger summary: {entity_or_relation_name}") + + if task_id and has_canceled(task_id): + raise TaskCanceledException(f"Task {task_id} was cancelled during summary handling") + async with chat_limiter: - summary = await trio.to_thread.run_sync(self._chat, "", [{"role": "user", "content": use_prompt}]) + summary = await trio.to_thread.run_sync(self._chat, "", [{"role": "user", "content": use_prompt}], {}, task_id) return summary diff --git a/graphrag/general/graph_extractor.py b/graphrag/general/graph_extractor.py index 59ebeeddf..d156fcb2e 100644 --- a/graphrag/general/graph_extractor.py +++ b/graphrag/general/graph_extractor.py @@ -97,7 +97,7 @@ class GraphExtractor(Extractor): self._entity_types_key: ",".join(entity_types), } - async def _process_single_content(self, chunk_key_dp: tuple[str, str], chunk_seq: int, num_chunks: int, out_results): + async def _process_single_content(self, chunk_key_dp: tuple[str, str], chunk_seq: int, num_chunks: int, out_results, task_id=""): token_count = 0 chunk_key = chunk_key_dp[0] content = chunk_key_dp[1] @@ -107,7 +107,7 @@ class GraphExtractor(Extractor): } hint_prompt = perform_variable_replacements(self._extraction_prompt, variables=variables) async with chat_limiter: - response = await trio.to_thread.run_sync(lambda: self._chat(hint_prompt, [{"role": "user", "content": "Output:"}], {})) + response = await trio.to_thread.run_sync(self._chat, hint_prompt, [{"role": "user", "content": "Output:"}], {}, task_id) token_count += num_tokens_from_string(hint_prompt + response) results = response or "" diff --git a/graphrag/general/index.py b/graphrag/general/index.py index 94a28252c..12b39400e 100644 --- a/graphrag/general/index.py +++ b/graphrag/general/index.py @@ -21,6 +21,8 @@ import networkx as nx import trio from api.db.services.document_service import DocumentService +from api.db.services.task_service import has_canceled +from common.exceptions import TaskCanceledException from common.misc_utils import get_uuid from common.connection_utils import timeout from graphrag.entity_resolution import EntityResolution @@ -106,6 +108,7 @@ async def run_graphrag( chat_model, embedding_model, callback, + task_id=row["id"], ) if with_community: await graphrag_task_lock.spin_acquire() @@ -118,6 +121,7 @@ async def run_graphrag( chat_model, embedding_model, callback, + task_id=row["id"], ) finally: graphrag_task_lock.release() @@ -207,6 +211,10 @@ async def run_graphrag_for_kb( failed_docs: list[tuple[str, str]] = [] # (doc_id, error) async def build_one(doc_id: str): + if has_canceled(row["id"]): + callback(msg=f"Task {row['id']} cancelled, stopping execution.") + raise TaskCanceledException(f"Task {row['id']} was cancelled") + chunks = all_doc_chunks.get(doc_id, []) if not chunks: callback(msg=f"[GraphRAG] doc:{doc_id} has no available chunks, skip generation.") @@ -232,6 +240,7 @@ async def run_graphrag_for_kb( chat_model, embedding_model, callback, + task_id=row["id"] ) if sg: subgraphs[doc_id] = sg @@ -239,14 +248,24 @@ async def run_graphrag_for_kb( else: failed_docs.append((doc_id, "subgraph is empty")) callback(msg=f"{msg} empty") + except TaskCanceledException as canceled: + callback(msg=f"[GraphRAG] build_subgraph doc:{doc_id} FAILED: {canceled}") except Exception as e: failed_docs.append((doc_id, repr(e))) callback(msg=f"[GraphRAG] build_subgraph doc:{doc_id} FAILED: {e!r}") + if has_canceled(row["id"]): + callback(msg=f"Task {row['id']} cancelled before processing documents.") + raise TaskCanceledException(f"Task {row['id']} was cancelled") + async with trio.open_nursery() as nursery: for doc_id in doc_ids: nursery.start_soon(build_one, doc_id) + if has_canceled(row["id"]): + callback(msg=f"Task {row['id']} cancelled after document processing.") + raise TaskCanceledException(f"Task {row['id']} was cancelled") + ok_docs = [d for d in doc_ids if d in subgraphs] if not ok_docs: callback(msg=f"[GraphRAG] kb:{kb_id} no subgraphs generated successfully, end.") @@ -257,6 +276,10 @@ async def run_graphrag_for_kb( await kb_lock.spin_acquire() callback(msg=f"[GraphRAG] kb:{kb_id} merge lock acquired") + if has_canceled(row["id"]): + callback(msg=f"Task {row['id']} cancelled before merging subgraphs.") + raise TaskCanceledException(f"Task {row['id']} was cancelled") + try: union_nodes: set = set() final_graph = None @@ -288,6 +311,10 @@ async def run_graphrag_for_kb( callback(msg=f"[GraphRAG] KB merge done in {now - start:.2f}s. ok={len(ok_docs)} / total={len(doc_ids)}") return {"ok_docs": ok_docs, "failed_docs": failed_docs, "total_docs": len(doc_ids), "total_chunks": total_chunks, "seconds": now - start} + if has_canceled(row["id"]): + callback(msg=f"Task {row['id']} cancelled before resolution/community extraction.") + raise TaskCanceledException(f"Task {row['id']} was cancelled") + await kb_lock.spin_acquire() callback(msg=f"[GraphRAG] kb:{kb_id} post-merge lock acquired for resolution/community") @@ -306,6 +333,7 @@ async def run_graphrag_for_kb( chat_model, embedding_model, callback, + task_id=row["id"], ) if with_community: @@ -317,6 +345,7 @@ async def run_graphrag_for_kb( chat_model, embedding_model, callback, + task_id=row["id"], ) finally: kb_lock.release() @@ -343,7 +372,12 @@ async def generate_subgraph( llm_bdl, embed_bdl, callback, + task_id: str = "", ): + if task_id and has_canceled(task_id): + callback(msg=f"Task {task_id} cancelled during subgraph generation for doc {doc_id}.") + raise TaskCanceledException(f"Task {task_id} was cancelled") + contains = await does_graph_contains(tenant_id, kb_id, doc_id) if contains: callback(msg=f"Graph already contains {doc_id}") @@ -354,15 +388,24 @@ async def generate_subgraph( language=language, entity_types=entity_types, ) - ents, rels = await ext(doc_id, chunks, callback) + ents, rels = await ext(doc_id, chunks, callback, task_id=task_id) subgraph = nx.Graph() + for ent in ents: + if task_id and has_canceled(task_id): + callback(msg=f"Task {task_id} cancelled during entity processing for doc {doc_id}.") + raise TaskCanceledException(f"Task {task_id} was cancelled") + assert "description" in ent, f"entity {ent} does not have description" ent["source_id"] = [doc_id] subgraph.add_node(ent["entity_name"], **ent) ignored_rels = 0 for rel in rels: + if task_id and has_canceled(task_id): + callback(msg=f"Task {task_id} cancelled during relationship processing for doc {doc_id}.") + raise TaskCanceledException(f"Task {task_id} was cancelled") + assert "description" in rel, f"relation {rel} does not have description" if not subgraph.has_node(rel["src_id"]) or not subgraph.has_node(rel["tgt_id"]): ignored_rels += 1 @@ -434,17 +477,27 @@ async def resolve_entities( llm_bdl, embed_bdl, callback, + task_id: str = "", ): + # Check if task has been canceled before resolution + if task_id and has_canceled(task_id): + callback(msg=f"Task {task_id} cancelled during entity resolution.") + raise TaskCanceledException(f"Task {task_id} was cancelled") + start = trio.current_time() er = EntityResolution( llm_bdl, ) - reso = await er(graph, subgraph_nodes, callback=callback) + reso = await er(graph, subgraph_nodes, callback=callback, task_id=task_id) graph = reso.graph change = reso.change callback(msg=f"Graph resolution removed {len(change.removed_nodes)} nodes and {len(change.removed_edges)} edges.") callback(msg="Graph resolution updated pagerank.") + if task_id and has_canceled(task_id): + callback(msg=f"Task {task_id} cancelled after entity resolution.") + raise TaskCanceledException(f"Task {task_id} was cancelled") + await set_graph(tenant_id, kb_id, embed_bdl, graph, change, callback) now = trio.current_time() callback(msg=f"Graph resolution done in {now - start:.2f}s.") @@ -459,12 +512,22 @@ async def extract_community( llm_bdl, embed_bdl, callback, + task_id: str = "", ): + if task_id and has_canceled(task_id): + callback(msg=f"Task {task_id} cancelled before community extraction.") + raise TaskCanceledException(f"Task {task_id} was cancelled") + start = trio.current_time() ext = CommunityReportsExtractor( llm_bdl, ) - cr = await ext(graph, callback=callback) + cr = await ext(graph, callback=callback, task_id=task_id) + + if task_id and has_canceled(task_id): + callback(msg=f"Task {task_id} cancelled during community extraction.") + raise TaskCanceledException(f"Task {task_id} was cancelled") + community_structure = cr.structured_output community_reports = cr.output doc_ids = graph.graph["source_id"] @@ -472,6 +535,10 @@ async def extract_community( now = trio.current_time() callback(msg=f"Graph extracted {len(cr.structured_output)} communities in {now - start:.2f}s.") start = now + if task_id and has_canceled(task_id): + callback(msg=f"Task {task_id} cancelled during community indexing.") + raise TaskCanceledException(f"Task {task_id} was cancelled") + chunks = [] for stru, rep in zip(community_structure, community_reports): obj = { @@ -509,6 +576,10 @@ async def extract_community( error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!" raise Exception(error_message) + if task_id and has_canceled(task_id): + callback(msg=f"Task {task_id} cancelled after community indexing.") + raise TaskCanceledException(f"Task {task_id} was cancelled") + now = trio.current_time() callback(msg=f"Graph indexed {len(cr.structured_output)} communities in {now - start:.2f}s.") return community_structure, community_reports diff --git a/graphrag/light/graph_extractor.py b/graphrag/light/graph_extractor.py index c2827c00f..e698c2b9f 100644 --- a/graphrag/light/graph_extractor.py +++ b/graphrag/light/graph_extractor.py @@ -71,7 +71,7 @@ class GraphExtractor(Extractor): self._left_token_count = llm_invoker.max_length - num_tokens_from_string(self._entity_extract_prompt.format(**self._context_base, input_text="")) self._left_token_count = max(llm_invoker.max_length * 0.6, self._left_token_count) - async def _process_single_content(self, chunk_key_dp: tuple[str, str], chunk_seq: int, num_chunks: int, out_results): + async def _process_single_content(self, chunk_key_dp: tuple[str, str], chunk_seq: int, num_chunks: int, out_results, task_id=""): token_count = 0 chunk_key = chunk_key_dp[0] content = chunk_key_dp[1] @@ -86,13 +86,13 @@ class GraphExtractor(Extractor): if self.callback: self.callback(msg=f"Start processing for {chunk_key}: {content[:25]}...") async with chat_limiter: - final_result = await trio.to_thread.run_sync(self._chat, "", [{"role": "user", "content": hint_prompt}], gen_conf) + final_result = await trio.to_thread.run_sync(self._chat, "", [{"role": "user", "content": hint_prompt}], gen_conf, task_id) token_count += num_tokens_from_string(hint_prompt + final_result) history = pack_user_ass_to_openai_messages(hint_prompt, final_result, self._continue_prompt) for now_glean_index in range(self._max_gleanings): async with chat_limiter: # glean_result = await trio.to_thread.run_sync(lambda: self._chat(hint_prompt, history, gen_conf)) - glean_result = await trio.to_thread.run_sync(self._chat, "", history, gen_conf) + glean_result = await trio.to_thread.run_sync(self._chat, "", history, gen_conf, task_id) history.extend([{"role": "assistant", "content": glean_result}]) token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + hint_prompt + self._continue_prompt) final_result += glean_result @@ -101,7 +101,7 @@ class GraphExtractor(Extractor): history.extend([{"role": "user", "content": self._if_loop_prompt}]) async with chat_limiter: - if_loop_result = await trio.to_thread.run_sync(self._chat, "", history, gen_conf) + if_loop_result = await trio.to_thread.run_sync(self._chat, "", history, gen_conf, task_id) token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + if_loop_result + self._if_loop_prompt) if_loop_result = if_loop_result.strip().strip('"').strip("'").lower() if if_loop_result != "yes": diff --git a/rag/flow/pipeline.py b/rag/flow/pipeline.py index a1db65bbc..b44c77bd4 100644 --- a/rag/flow/pipeline.py +++ b/rag/flow/pipeline.py @@ -41,7 +41,7 @@ class Pipeline(Graph): self._doc_id = None def callback(self, component_name: str, progress: float | int | None = None, message: str = "") -> None: - from rag.svr.task_executor import TaskCanceledException + from common.exceptions import TaskCanceledException log_key = f"{self._flow_id}-{self.task_id}-logs" timestamp = timer() if has_canceled(self.task_id): diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index 0f7a4e319..d07a44ea4 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -65,6 +65,7 @@ from common.token_utils import num_tokens_from_string, truncate from rag.utils.redis_conn import REDIS_CONN, RedisDistributedLock from graphrag.utils import chat_limiter from common.signal_utils import start_tracemalloc_and_snapshot, stop_tracemalloc +from common.exceptions import TaskCanceledException from common import settings from common.constants import PAGERANK_FLD, TAG_FLD, SVR_CONSUMER_GROUP_NAME @@ -127,9 +128,7 @@ def signal_handler(sig, frame): sys.exit(0) -class TaskCanceledException(Exception): - def __init__(self, msg): - self.msg = msg + def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing..."):