Feat: GraphRAG handle cancel gracefully (#11061)

### What problem does this PR solve?

 GraghRAG handle cancel gracefully. #10997.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Yongteng Lei
2025-11-06 16:12:20 +08:00
committed by GitHub
parent 66c01c7274
commit 23b81eae77
10 changed files with 206 additions and 47 deletions

View File

@ -38,7 +38,7 @@ from api.utils.api_utils import get_json_result
from rag.nlp import search
from api.constants import DATASET_NAME_LIMIT
from rag.utils.redis_conn import REDIS_CONN
from rag.utils.doc_store_conn import OrderByExpr
from rag.utils.doc_store_conn import OrderByExpr
from common.constants import RetCode, PipelineTaskType, StatusEnum, VALID_TASK_STATUS, FileSource, LLMType, PAGERANK_FLD
from common import settings
@ -52,7 +52,7 @@ def create():
tenant_id = current_user.id,
parser_id = req.pop("parser_id", None),
**req
)
)
try:
if not KnowledgebaseService.save(**req):
@ -571,7 +571,7 @@ def trace_graphrag():
ok, task = TaskService.get_by_id(task_id)
if not ok:
return get_error_data_result(message="GraphRAG Task Not Found or Error Occurred")
return get_json_result(data={})
return get_json_result(data=task.to_dict())
@ -780,14 +780,14 @@ def check_embedding():
def _to_1d(x):
a = np.asarray(x, dtype=np.float32)
return a.reshape(-1)
return a.reshape(-1)
def _cos_sim(a, b, eps=1e-12):
a = _to_1d(a)
b = _to_1d(b)
na = np.linalg.norm(a)
nb = np.linalg.norm(b)
if na < eps or nb < eps:
if na < eps or nb < eps:
return 0.0
return float(np.dot(a, b) / (na * nb))
@ -825,7 +825,7 @@ def check_embedding():
indexNames=index_nm, knowledgebaseIds=[kb_id]
)
ids = docStoreConn.getChunkIds(res1)
if not ids:
if not ids:
continue
cid = ids[0]
@ -869,7 +869,7 @@ def check_embedding():
continue
try:
qv, _ = emb_mdl.encode_queries(txt)
qv, _ = emb_mdl.encode_queries(txt)
sim = _cos_sim(qv, ck["vector"])
except Exception:
return get_error_data_result(message="embedding failure")

18
common/exceptions.py Normal file
View File

@ -0,0 +1,18 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class TaskCanceledException(Exception):
def __init__(self, msg):
self.msg = msg

View File

@ -29,6 +29,8 @@ import editdistance
from graphrag.entity_resolution_prompt import ENTITY_RESOLUTION_PROMPT
from rag.llm.chat_model import Base as CompletionLLM
from graphrag.utils import perform_variable_replacements, chat_limiter, GraphChange
from api.db.services.task_service import has_canceled
from common.exceptions import TaskCanceledException
DEFAULT_RECORD_DELIMITER = "##"
DEFAULT_ENTITY_INDEX_DELIMITER = "<|>"
@ -67,7 +69,8 @@ class EntityResolution(Extractor):
async def __call__(self, graph: nx.Graph,
subgraph_nodes: set[str],
prompt_variables: dict[str, Any] | None = None,
callback: Callable | None = None) -> EntityResolutionResult:
callback: Callable | None = None,
task_id: str = "") -> EntityResolutionResult:
"""Call method definition."""
if prompt_variables is None:
prompt_variables = {}
@ -109,7 +112,7 @@ class EntityResolution(Extractor):
try:
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
with trio.move_on_after(280 if enable_timeout_assertion else 1000000000) as cancel_scope:
await self._resolve_candidate(candidate_batch, result_set, result_lock)
await self._resolve_candidate(candidate_batch, result_set, result_lock, task_id)
remain_candidates_to_resolve = remain_candidates_to_resolve - len(candidate_batch[1])
callback(msg=f"Resolved {len(candidate_batch[1])} pairs, {remain_candidates_to_resolve} are remained to resolve. ")
if cancel_scope.cancelled_caught:
@ -136,7 +139,7 @@ class EntityResolution(Extractor):
async def limited_merge_nodes(graph, nodes, change):
async with semaphore:
await self._merge_graph_nodes(graph, nodes, change)
await self._merge_graph_nodes(graph, nodes, change, task_id)
async with trio.open_nursery() as nursery:
for sub_connect_graph in nx.connected_components(connect_graph):
@ -153,7 +156,12 @@ class EntityResolution(Extractor):
change=change,
)
async def _resolve_candidate(self, candidate_resolution_i: tuple[str, list[tuple[str, str]]], resolution_result: set[str], resolution_result_lock: trio.Lock):
async def _resolve_candidate(self, candidate_resolution_i: tuple[str, list[tuple[str, str]]], resolution_result: set[str], resolution_result_lock: trio.Lock, task_id: str = ""):
if task_id:
if has_canceled(task_id):
logging.info(f"Task {task_id} cancelled during entity resolution candidate processing.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
pair_txt = [
f'When determining whether two {candidate_resolution_i[0]}s are the same, you should only focus on critical properties and overlook noisy factors.\n']
for index, candidate in enumerate(candidate_resolution_i[1]):
@ -173,7 +181,7 @@ class EntityResolution(Extractor):
try:
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
with trio.move_on_after(280 if enable_timeout_assertion else 1000000000) as cancel_scope:
response = await trio.to_thread.run_sync(self._chat, text, [{"role": "user", "content": "Output:"}], {})
response = await trio.to_thread.run_sync(self._chat, text, [{"role": "user", "content": "Output:"}], {}, task_id)
if cancel_scope.cancelled_caught:
logging.warning("_resolve_candidate._chat timeout, skipping...")
return

View File

@ -14,6 +14,8 @@ from dataclasses import dataclass
import networkx as nx
import pandas as pd
from api.db.services.task_service import has_canceled
from common.exceptions import TaskCanceledException
from common.connection_utils import timeout
from graphrag.general import leiden
from graphrag.general.community_report_prompt import COMMUNITY_REPORT_PROMPT
@ -51,7 +53,7 @@ class CommunityReportsExtractor(Extractor):
self._extraction_prompt = COMMUNITY_REPORT_PROMPT
self._max_report_length = max_report_length or 1500
async def __call__(self, graph: nx.Graph, callback: Callable | None = None):
async def __call__(self, graph: nx.Graph, callback: Callable | None = None, task_id: str = ""):
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
for node_degree in graph.degree:
graph.nodes[str(node_degree[0])]["rank"] = int(node_degree[1])
@ -64,6 +66,11 @@ class CommunityReportsExtractor(Extractor):
@timeout(120)
async def extract_community_report(community):
nonlocal res_str, res_dict, over, token_count
if task_id:
if has_canceled(task_id):
logging.info(f"Task {task_id} cancelled during community report extraction.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
cm_id, cm = community
weight = cm["weight"]
ents = cm["nodes"]
@ -95,7 +102,10 @@ class CommunityReportsExtractor(Extractor):
async with chat_limiter:
try:
with trio.move_on_after(180 if enable_timeout_assertion else 1000000000) as cancel_scope:
response = await trio.to_thread.run_sync( self._chat, text, [{"role": "user", "content": "Output:"}], {})
if task_id and has_canceled(task_id):
logging.info(f"Task {task_id} cancelled before LLM call.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
response = await trio.to_thread.run_sync( self._chat, text, [{"role": "user", "content": "Output:"}], {}, task_id)
if cancel_scope.cancelled_caught:
logging.warning("extract_community_report._chat timeout, skipping...")
return
@ -136,6 +146,9 @@ class CommunityReportsExtractor(Extractor):
for level, comm in communities.items():
logging.info(f"Level {level}: Community: {len(comm.keys())}")
for community in comm.items():
if task_id and has_canceled(task_id):
logging.info(f"Task {task_id} cancelled before community processing.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
nursery.start_soon(extract_community_report, community)
if callback:
callback(msg=f"Community reports done in {trio.current_time() - st:.2f}s, used tokens: {token_count}")

View File

@ -23,7 +23,9 @@ from typing import Callable
import networkx as nx
import trio
from api.db.services.task_service import has_canceled
from common.connection_utils import timeout
from common.token_utils import truncate
from graphrag.general.graph_prompt import SUMMARIZE_DESCRIPTIONS_PROMPT
from graphrag.utils import (
GraphChange,
@ -38,7 +40,7 @@ from graphrag.utils import (
)
from rag.llm.chat_model import Base as CompletionLLM
from rag.prompts.generator import message_fit_in
from common.token_utils import truncate
from common.exceptions import TaskCanceledException
GRAPH_FIELD_SEP = "<SEP>"
DEFAULT_ENTITY_TYPES = ["organization", "person", "geo", "event", "category"]
@ -60,7 +62,7 @@ class Extractor:
self._entity_types = entity_types or DEFAULT_ENTITY_TYPES
@timeout(60 * 20)
def _chat(self, system, history, gen_conf={}):
def _chat(self, system, history, gen_conf={}, task_id=""):
hist = deepcopy(history)
conf = deepcopy(gen_conf)
response = get_llm_cache(self._llm.llm_name, system, hist, conf)
@ -69,6 +71,12 @@ class Extractor:
_, system_msg = message_fit_in([{"role": "system", "content": system}], int(self._llm.max_length * 0.92))
response = ""
for attempt in range(3):
if task_id:
if has_canceled(task_id):
logging.info(f"Task {task_id} cancelled during entity resolution candidate processing.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
try:
response = self._llm.chat(system_msg[0]["content"], hist, conf)
response = re.sub(r"^.*</think>", "", response, flags=re.DOTALL)
@ -99,25 +107,29 @@ class Extractor:
maybe_edges[(if_relation["src_id"], if_relation["tgt_id"])].append(if_relation)
return dict(maybe_nodes), dict(maybe_edges)
async def __call__(self, doc_id: str, chunks: list[str], callback: Callable | None = None):
async def __call__(self, doc_id: str, chunks: list[str], callback: Callable | None = None, task_id: str = ""):
self.callback = callback
start_ts = trio.current_time()
async def extract_all(doc_id, chunks, max_concurrency=MAX_CONCURRENT_PROCESS_AND_EXTRACT_CHUNK):
async def extract_all(doc_id, chunks, max_concurrency=MAX_CONCURRENT_PROCESS_AND_EXTRACT_CHUNK, task_id=""):
out_results = []
error_count = 0
max_errors = 3
limiter = trio.Semaphore(max_concurrency)
async def worker(chunk_key_dp: tuple[str, str], idx: int, total: int):
async def worker(chunk_key_dp: tuple[str, str], idx: int, total: int, task_id=""):
nonlocal error_count
async with limiter:
if task_id and has_canceled(task_id):
raise TaskCanceledException(f"Task {task_id} was cancelled during entity extraction")
try:
await self._process_single_content(chunk_key_dp, idx, total, out_results)
await self._process_single_content(chunk_key_dp, idx, total, out_results, task_id)
except Exception as e:
error_count += 1
error_msg = f"Error processing chunk {idx+1}/{total}: {str(e)}"
error_msg = f"Error processing chunk {idx + 1}/{total}: {str(e)}"
logging.warning(error_msg)
if self.callback:
self.callback(msg=error_msg)
@ -127,7 +139,7 @@ class Extractor:
async with trio.open_nursery() as nursery:
for i, ck in enumerate(chunks):
nursery.start_soon(worker, (doc_id, ck), i, len(chunks))
nursery.start_soon(worker, (doc_id, ck), i, len(chunks), task_id)
if error_count > 0:
warning_msg = f"Completed with {error_count} errors (out of {len(chunks)} chunks processed)"
@ -137,7 +149,13 @@ class Extractor:
return out_results
out_results = await extract_all(doc_id, chunks, max_concurrency=MAX_CONCURRENT_PROCESS_AND_EXTRACT_CHUNK)
if task_id and has_canceled(task_id):
raise TaskCanceledException(f"Task {task_id} was cancelled before entity extraction")
out_results = await extract_all(doc_id, chunks, max_concurrency=MAX_CONCURRENT_PROCESS_AND_EXTRACT_CHUNK, task_id=task_id)
if task_id and has_canceled(task_id):
raise TaskCanceledException(f"Task {task_id} was cancelled after entity extraction")
maybe_nodes = defaultdict(list)
maybe_edges = defaultdict(list)
@ -154,9 +172,17 @@ class Extractor:
start_ts = now
logging.info("Entities merging...")
all_entities_data = []
if task_id and has_canceled(task_id):
raise TaskCanceledException(f"Task {task_id} was cancelled before nodes merging")
async with trio.open_nursery() as nursery:
for en_nm, ents in maybe_nodes.items():
nursery.start_soon(self._merge_nodes, en_nm, ents, all_entities_data)
nursery.start_soon(self._merge_nodes, en_nm, ents, all_entities_data, task_id)
if task_id and has_canceled(task_id):
raise TaskCanceledException(f"Task {task_id} was cancelled after nodes merging")
now = trio.current_time()
if self.callback:
self.callback(msg=f"Entities merging done, {now - start_ts:.2f}s.")
@ -164,9 +190,17 @@ class Extractor:
start_ts = now
logging.info("Relationships merging...")
all_relationships_data = []
if task_id and has_canceled(task_id):
raise TaskCanceledException(f"Task {task_id} was cancelled before relationships merging")
async with trio.open_nursery() as nursery:
for (src, tgt), rels in maybe_edges.items():
nursery.start_soon(self._merge_edges, src, tgt, rels, all_relationships_data)
nursery.start_soon(self._merge_edges, src, tgt, rels, all_relationships_data, task_id)
if task_id and has_canceled(task_id):
raise TaskCanceledException(f"Task {task_id} was cancelled after relationships merging")
now = trio.current_time()
if self.callback:
self.callback(msg=f"Relationships merging done, {now - start_ts:.2f}s.")
@ -181,7 +215,10 @@ class Extractor:
return all_entities_data, all_relationships_data
async def _merge_nodes(self, entity_name: str, entities: list[dict], all_relationships_data):
async def _merge_nodes(self, entity_name: str, entities: list[dict], all_relationships_data, task_id=""):
if task_id and has_canceled(task_id):
raise TaskCanceledException(f"Task {task_id} was cancelled during merge nodes")
if not entities:
return
entity_type = sorted(
@ -191,7 +228,7 @@ class Extractor:
)[0][0]
description = GRAPH_FIELD_SEP.join(sorted(set([dp["description"] for dp in entities])))
already_source_ids = flat_uniq_list(entities, "source_id")
description = await self._handle_entity_relation_summary(entity_name, description)
description = await self._handle_entity_relation_summary(entity_name, description, task_id=task_id)
node_data = dict(
entity_type=entity_type,
description=description,
@ -200,18 +237,21 @@ class Extractor:
node_data["entity_name"] = entity_name
all_relationships_data.append(node_data)
async def _merge_edges(self, src_id: str, tgt_id: str, edges_data: list[dict], all_relationships_data=None):
async def _merge_edges(self, src_id: str, tgt_id: str, edges_data: list[dict], all_relationships_data=None, task_id=""):
if not edges_data:
return
weight = sum([edge["weight"] for edge in edges_data])
description = GRAPH_FIELD_SEP.join(sorted(set([edge["description"] for edge in edges_data])))
description = await self._handle_entity_relation_summary(f"{src_id} -> {tgt_id}", description)
description = await self._handle_entity_relation_summary(f"{src_id} -> {tgt_id}", description, task_id=task_id)
keywords = flat_uniq_list(edges_data, "keywords")
source_id = flat_uniq_list(edges_data, "source_id")
edge_data = dict(src_id=src_id, tgt_id=tgt_id, description=description, keywords=keywords, weight=weight, source_id=source_id)
all_relationships_data.append(edge_data)
async def _merge_graph_nodes(self, graph: nx.Graph, nodes: list[str], change: GraphChange):
async def _merge_graph_nodes(self, graph: nx.Graph, nodes: list[str], change: GraphChange, task_id=""):
if task_id and has_canceled(task_id):
raise TaskCanceledException(f"Task {task_id} was cancelled during merge graph nodes")
if len(nodes) <= 1:
return
change.added_updated_nodes.add(nodes[0])
@ -220,6 +260,9 @@ class Extractor:
node0_attrs = graph.nodes[nodes[0]]
node0_neighbors = set(graph.neighbors(nodes[0]))
for node1 in nodes[1:]:
if task_id and has_canceled(task_id):
raise TaskCanceledException(f"Task {task_id} was cancelled during merge_graph nodes")
# Merge two nodes, keep "entity_name", "entity_type", "page_rank" unchanged.
node1_attrs = graph.nodes[node1]
node0_attrs["description"] += f"{GRAPH_FIELD_SEP}{node1_attrs['description']}"
@ -236,15 +279,18 @@ class Extractor:
edge0_attrs["description"] += f"{GRAPH_FIELD_SEP}{edge1_attrs['description']}"
for attr in ["keywords", "source_id"]:
edge0_attrs[attr] = sorted(set(edge0_attrs[attr] + edge1_attrs[attr]))
edge0_attrs["description"] = await self._handle_entity_relation_summary(f"({nodes[0]}, {neighbor})", edge0_attrs["description"])
edge0_attrs["description"] = await self._handle_entity_relation_summary(f"({nodes[0]}, {neighbor})", edge0_attrs["description"], task_id=task_id)
graph.add_edge(nodes[0], neighbor, **edge0_attrs)
else:
graph.add_edge(nodes[0], neighbor, **edge1_attrs)
graph.remove_node(node1)
node0_attrs["description"] = await self._handle_entity_relation_summary(nodes[0], node0_attrs["description"])
node0_attrs["description"] = await self._handle_entity_relation_summary(nodes[0], node0_attrs["description"], task_id=task_id)
graph.nodes[nodes[0]].update(node0_attrs)
async def _handle_entity_relation_summary(self, entity_or_relation_name: str, description: str) -> str:
async def _handle_entity_relation_summary(self, entity_or_relation_name: str, description: str, task_id="") -> str:
if task_id and has_canceled(task_id):
raise TaskCanceledException(f"Task {task_id} was cancelled during summary handling")
summary_max_tokens = 512
use_description = truncate(description, summary_max_tokens)
description_list = use_description.split(GRAPH_FIELD_SEP)
@ -258,6 +304,10 @@ class Extractor:
)
use_prompt = prompt_template.format(**context_base)
logging.info(f"Trigger summary: {entity_or_relation_name}")
if task_id and has_canceled(task_id):
raise TaskCanceledException(f"Task {task_id} was cancelled during summary handling")
async with chat_limiter:
summary = await trio.to_thread.run_sync(self._chat, "", [{"role": "user", "content": use_prompt}])
summary = await trio.to_thread.run_sync(self._chat, "", [{"role": "user", "content": use_prompt}], {}, task_id)
return summary

View File

@ -97,7 +97,7 @@ class GraphExtractor(Extractor):
self._entity_types_key: ",".join(entity_types),
}
async def _process_single_content(self, chunk_key_dp: tuple[str, str], chunk_seq: int, num_chunks: int, out_results):
async def _process_single_content(self, chunk_key_dp: tuple[str, str], chunk_seq: int, num_chunks: int, out_results, task_id=""):
token_count = 0
chunk_key = chunk_key_dp[0]
content = chunk_key_dp[1]
@ -107,7 +107,7 @@ class GraphExtractor(Extractor):
}
hint_prompt = perform_variable_replacements(self._extraction_prompt, variables=variables)
async with chat_limiter:
response = await trio.to_thread.run_sync(lambda: self._chat(hint_prompt, [{"role": "user", "content": "Output:"}], {}))
response = await trio.to_thread.run_sync(self._chat, hint_prompt, [{"role": "user", "content": "Output:"}], {}, task_id)
token_count += num_tokens_from_string(hint_prompt + response)
results = response or ""

View File

@ -21,6 +21,8 @@ import networkx as nx
import trio
from api.db.services.document_service import DocumentService
from api.db.services.task_service import has_canceled
from common.exceptions import TaskCanceledException
from common.misc_utils import get_uuid
from common.connection_utils import timeout
from graphrag.entity_resolution import EntityResolution
@ -106,6 +108,7 @@ async def run_graphrag(
chat_model,
embedding_model,
callback,
task_id=row["id"],
)
if with_community:
await graphrag_task_lock.spin_acquire()
@ -118,6 +121,7 @@ async def run_graphrag(
chat_model,
embedding_model,
callback,
task_id=row["id"],
)
finally:
graphrag_task_lock.release()
@ -207,6 +211,10 @@ async def run_graphrag_for_kb(
failed_docs: list[tuple[str, str]] = [] # (doc_id, error)
async def build_one(doc_id: str):
if has_canceled(row["id"]):
callback(msg=f"Task {row['id']} cancelled, stopping execution.")
raise TaskCanceledException(f"Task {row['id']} was cancelled")
chunks = all_doc_chunks.get(doc_id, [])
if not chunks:
callback(msg=f"[GraphRAG] doc:{doc_id} has no available chunks, skip generation.")
@ -232,6 +240,7 @@ async def run_graphrag_for_kb(
chat_model,
embedding_model,
callback,
task_id=row["id"]
)
if sg:
subgraphs[doc_id] = sg
@ -239,14 +248,24 @@ async def run_graphrag_for_kb(
else:
failed_docs.append((doc_id, "subgraph is empty"))
callback(msg=f"{msg} empty")
except TaskCanceledException as canceled:
callback(msg=f"[GraphRAG] build_subgraph doc:{doc_id} FAILED: {canceled}")
except Exception as e:
failed_docs.append((doc_id, repr(e)))
callback(msg=f"[GraphRAG] build_subgraph doc:{doc_id} FAILED: {e!r}")
if has_canceled(row["id"]):
callback(msg=f"Task {row['id']} cancelled before processing documents.")
raise TaskCanceledException(f"Task {row['id']} was cancelled")
async with trio.open_nursery() as nursery:
for doc_id in doc_ids:
nursery.start_soon(build_one, doc_id)
if has_canceled(row["id"]):
callback(msg=f"Task {row['id']} cancelled after document processing.")
raise TaskCanceledException(f"Task {row['id']} was cancelled")
ok_docs = [d for d in doc_ids if d in subgraphs]
if not ok_docs:
callback(msg=f"[GraphRAG] kb:{kb_id} no subgraphs generated successfully, end.")
@ -257,6 +276,10 @@ async def run_graphrag_for_kb(
await kb_lock.spin_acquire()
callback(msg=f"[GraphRAG] kb:{kb_id} merge lock acquired")
if has_canceled(row["id"]):
callback(msg=f"Task {row['id']} cancelled before merging subgraphs.")
raise TaskCanceledException(f"Task {row['id']} was cancelled")
try:
union_nodes: set = set()
final_graph = None
@ -288,6 +311,10 @@ async def run_graphrag_for_kb(
callback(msg=f"[GraphRAG] KB merge done in {now - start:.2f}s. ok={len(ok_docs)} / total={len(doc_ids)}")
return {"ok_docs": ok_docs, "failed_docs": failed_docs, "total_docs": len(doc_ids), "total_chunks": total_chunks, "seconds": now - start}
if has_canceled(row["id"]):
callback(msg=f"Task {row['id']} cancelled before resolution/community extraction.")
raise TaskCanceledException(f"Task {row['id']} was cancelled")
await kb_lock.spin_acquire()
callback(msg=f"[GraphRAG] kb:{kb_id} post-merge lock acquired for resolution/community")
@ -306,6 +333,7 @@ async def run_graphrag_for_kb(
chat_model,
embedding_model,
callback,
task_id=row["id"],
)
if with_community:
@ -317,6 +345,7 @@ async def run_graphrag_for_kb(
chat_model,
embedding_model,
callback,
task_id=row["id"],
)
finally:
kb_lock.release()
@ -343,7 +372,12 @@ async def generate_subgraph(
llm_bdl,
embed_bdl,
callback,
task_id: str = "",
):
if task_id and has_canceled(task_id):
callback(msg=f"Task {task_id} cancelled during subgraph generation for doc {doc_id}.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
contains = await does_graph_contains(tenant_id, kb_id, doc_id)
if contains:
callback(msg=f"Graph already contains {doc_id}")
@ -354,15 +388,24 @@ async def generate_subgraph(
language=language,
entity_types=entity_types,
)
ents, rels = await ext(doc_id, chunks, callback)
ents, rels = await ext(doc_id, chunks, callback, task_id=task_id)
subgraph = nx.Graph()
for ent in ents:
if task_id and has_canceled(task_id):
callback(msg=f"Task {task_id} cancelled during entity processing for doc {doc_id}.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
assert "description" in ent, f"entity {ent} does not have description"
ent["source_id"] = [doc_id]
subgraph.add_node(ent["entity_name"], **ent)
ignored_rels = 0
for rel in rels:
if task_id and has_canceled(task_id):
callback(msg=f"Task {task_id} cancelled during relationship processing for doc {doc_id}.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
assert "description" in rel, f"relation {rel} does not have description"
if not subgraph.has_node(rel["src_id"]) or not subgraph.has_node(rel["tgt_id"]):
ignored_rels += 1
@ -434,17 +477,27 @@ async def resolve_entities(
llm_bdl,
embed_bdl,
callback,
task_id: str = "",
):
# Check if task has been canceled before resolution
if task_id and has_canceled(task_id):
callback(msg=f"Task {task_id} cancelled during entity resolution.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
start = trio.current_time()
er = EntityResolution(
llm_bdl,
)
reso = await er(graph, subgraph_nodes, callback=callback)
reso = await er(graph, subgraph_nodes, callback=callback, task_id=task_id)
graph = reso.graph
change = reso.change
callback(msg=f"Graph resolution removed {len(change.removed_nodes)} nodes and {len(change.removed_edges)} edges.")
callback(msg="Graph resolution updated pagerank.")
if task_id and has_canceled(task_id):
callback(msg=f"Task {task_id} cancelled after entity resolution.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
await set_graph(tenant_id, kb_id, embed_bdl, graph, change, callback)
now = trio.current_time()
callback(msg=f"Graph resolution done in {now - start:.2f}s.")
@ -459,12 +512,22 @@ async def extract_community(
llm_bdl,
embed_bdl,
callback,
task_id: str = "",
):
if task_id and has_canceled(task_id):
callback(msg=f"Task {task_id} cancelled before community extraction.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
start = trio.current_time()
ext = CommunityReportsExtractor(
llm_bdl,
)
cr = await ext(graph, callback=callback)
cr = await ext(graph, callback=callback, task_id=task_id)
if task_id and has_canceled(task_id):
callback(msg=f"Task {task_id} cancelled during community extraction.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
community_structure = cr.structured_output
community_reports = cr.output
doc_ids = graph.graph["source_id"]
@ -472,6 +535,10 @@ async def extract_community(
now = trio.current_time()
callback(msg=f"Graph extracted {len(cr.structured_output)} communities in {now - start:.2f}s.")
start = now
if task_id and has_canceled(task_id):
callback(msg=f"Task {task_id} cancelled during community indexing.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
chunks = []
for stru, rep in zip(community_structure, community_reports):
obj = {
@ -509,6 +576,10 @@ async def extract_community(
error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
raise Exception(error_message)
if task_id and has_canceled(task_id):
callback(msg=f"Task {task_id} cancelled after community indexing.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
now = trio.current_time()
callback(msg=f"Graph indexed {len(cr.structured_output)} communities in {now - start:.2f}s.")
return community_structure, community_reports

View File

@ -71,7 +71,7 @@ class GraphExtractor(Extractor):
self._left_token_count = llm_invoker.max_length - num_tokens_from_string(self._entity_extract_prompt.format(**self._context_base, input_text=""))
self._left_token_count = max(llm_invoker.max_length * 0.6, self._left_token_count)
async def _process_single_content(self, chunk_key_dp: tuple[str, str], chunk_seq: int, num_chunks: int, out_results):
async def _process_single_content(self, chunk_key_dp: tuple[str, str], chunk_seq: int, num_chunks: int, out_results, task_id=""):
token_count = 0
chunk_key = chunk_key_dp[0]
content = chunk_key_dp[1]
@ -86,13 +86,13 @@ class GraphExtractor(Extractor):
if self.callback:
self.callback(msg=f"Start processing for {chunk_key}: {content[:25]}...")
async with chat_limiter:
final_result = await trio.to_thread.run_sync(self._chat, "", [{"role": "user", "content": hint_prompt}], gen_conf)
final_result = await trio.to_thread.run_sync(self._chat, "", [{"role": "user", "content": hint_prompt}], gen_conf, task_id)
token_count += num_tokens_from_string(hint_prompt + final_result)
history = pack_user_ass_to_openai_messages(hint_prompt, final_result, self._continue_prompt)
for now_glean_index in range(self._max_gleanings):
async with chat_limiter:
# glean_result = await trio.to_thread.run_sync(lambda: self._chat(hint_prompt, history, gen_conf))
glean_result = await trio.to_thread.run_sync(self._chat, "", history, gen_conf)
glean_result = await trio.to_thread.run_sync(self._chat, "", history, gen_conf, task_id)
history.extend([{"role": "assistant", "content": glean_result}])
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + hint_prompt + self._continue_prompt)
final_result += glean_result
@ -101,7 +101,7 @@ class GraphExtractor(Extractor):
history.extend([{"role": "user", "content": self._if_loop_prompt}])
async with chat_limiter:
if_loop_result = await trio.to_thread.run_sync(self._chat, "", history, gen_conf)
if_loop_result = await trio.to_thread.run_sync(self._chat, "", history, gen_conf, task_id)
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + if_loop_result + self._if_loop_prompt)
if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
if if_loop_result != "yes":

View File

@ -41,7 +41,7 @@ class Pipeline(Graph):
self._doc_id = None
def callback(self, component_name: str, progress: float | int | None = None, message: str = "") -> None:
from rag.svr.task_executor import TaskCanceledException
from common.exceptions import TaskCanceledException
log_key = f"{self._flow_id}-{self.task_id}-logs"
timestamp = timer()
if has_canceled(self.task_id):

View File

@ -65,6 +65,7 @@ from common.token_utils import num_tokens_from_string, truncate
from rag.utils.redis_conn import REDIS_CONN, RedisDistributedLock
from graphrag.utils import chat_limiter
from common.signal_utils import start_tracemalloc_and_snapshot, stop_tracemalloc
from common.exceptions import TaskCanceledException
from common import settings
from common.constants import PAGERANK_FLD, TAG_FLD, SVR_CONSUMER_GROUP_NAME
@ -127,9 +128,7 @@ def signal_handler(sig, frame):
sys.exit(0)
class TaskCanceledException(Exception):
def __init__(self, msg):
self.msg = msg
def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing..."):