mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Feat: GraphRAG handle cancel gracefully (#11061)
### What problem does this PR solve? GraghRAG handle cancel gracefully. #10997. ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -571,7 +571,7 @@ def trace_graphrag():
|
||||
|
||||
ok, task = TaskService.get_by_id(task_id)
|
||||
if not ok:
|
||||
return get_error_data_result(message="GraphRAG Task Not Found or Error Occurred")
|
||||
return get_json_result(data={})
|
||||
|
||||
return get_json_result(data=task.to_dict())
|
||||
|
||||
|
||||
18
common/exceptions.py
Normal file
18
common/exceptions.py
Normal file
@ -0,0 +1,18 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
class TaskCanceledException(Exception):
|
||||
def __init__(self, msg):
|
||||
self.msg = msg
|
||||
@ -29,6 +29,8 @@ import editdistance
|
||||
from graphrag.entity_resolution_prompt import ENTITY_RESOLUTION_PROMPT
|
||||
from rag.llm.chat_model import Base as CompletionLLM
|
||||
from graphrag.utils import perform_variable_replacements, chat_limiter, GraphChange
|
||||
from api.db.services.task_service import has_canceled
|
||||
from common.exceptions import TaskCanceledException
|
||||
|
||||
DEFAULT_RECORD_DELIMITER = "##"
|
||||
DEFAULT_ENTITY_INDEX_DELIMITER = "<|>"
|
||||
@ -67,7 +69,8 @@ class EntityResolution(Extractor):
|
||||
async def __call__(self, graph: nx.Graph,
|
||||
subgraph_nodes: set[str],
|
||||
prompt_variables: dict[str, Any] | None = None,
|
||||
callback: Callable | None = None) -> EntityResolutionResult:
|
||||
callback: Callable | None = None,
|
||||
task_id: str = "") -> EntityResolutionResult:
|
||||
"""Call method definition."""
|
||||
if prompt_variables is None:
|
||||
prompt_variables = {}
|
||||
@ -109,7 +112,7 @@ class EntityResolution(Extractor):
|
||||
try:
|
||||
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
|
||||
with trio.move_on_after(280 if enable_timeout_assertion else 1000000000) as cancel_scope:
|
||||
await self._resolve_candidate(candidate_batch, result_set, result_lock)
|
||||
await self._resolve_candidate(candidate_batch, result_set, result_lock, task_id)
|
||||
remain_candidates_to_resolve = remain_candidates_to_resolve - len(candidate_batch[1])
|
||||
callback(msg=f"Resolved {len(candidate_batch[1])} pairs, {remain_candidates_to_resolve} are remained to resolve. ")
|
||||
if cancel_scope.cancelled_caught:
|
||||
@ -136,7 +139,7 @@ class EntityResolution(Extractor):
|
||||
|
||||
async def limited_merge_nodes(graph, nodes, change):
|
||||
async with semaphore:
|
||||
await self._merge_graph_nodes(graph, nodes, change)
|
||||
await self._merge_graph_nodes(graph, nodes, change, task_id)
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
for sub_connect_graph in nx.connected_components(connect_graph):
|
||||
@ -153,7 +156,12 @@ class EntityResolution(Extractor):
|
||||
change=change,
|
||||
)
|
||||
|
||||
async def _resolve_candidate(self, candidate_resolution_i: tuple[str, list[tuple[str, str]]], resolution_result: set[str], resolution_result_lock: trio.Lock):
|
||||
async def _resolve_candidate(self, candidate_resolution_i: tuple[str, list[tuple[str, str]]], resolution_result: set[str], resolution_result_lock: trio.Lock, task_id: str = ""):
|
||||
if task_id:
|
||||
if has_canceled(task_id):
|
||||
logging.info(f"Task {task_id} cancelled during entity resolution candidate processing.")
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||
|
||||
pair_txt = [
|
||||
f'When determining whether two {candidate_resolution_i[0]}s are the same, you should only focus on critical properties and overlook noisy factors.\n']
|
||||
for index, candidate in enumerate(candidate_resolution_i[1]):
|
||||
@ -173,7 +181,7 @@ class EntityResolution(Extractor):
|
||||
try:
|
||||
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
|
||||
with trio.move_on_after(280 if enable_timeout_assertion else 1000000000) as cancel_scope:
|
||||
response = await trio.to_thread.run_sync(self._chat, text, [{"role": "user", "content": "Output:"}], {})
|
||||
response = await trio.to_thread.run_sync(self._chat, text, [{"role": "user", "content": "Output:"}], {}, task_id)
|
||||
if cancel_scope.cancelled_caught:
|
||||
logging.warning("_resolve_candidate._chat timeout, skipping...")
|
||||
return
|
||||
|
||||
@ -14,6 +14,8 @@ from dataclasses import dataclass
|
||||
import networkx as nx
|
||||
import pandas as pd
|
||||
|
||||
from api.db.services.task_service import has_canceled
|
||||
from common.exceptions import TaskCanceledException
|
||||
from common.connection_utils import timeout
|
||||
from graphrag.general import leiden
|
||||
from graphrag.general.community_report_prompt import COMMUNITY_REPORT_PROMPT
|
||||
@ -51,7 +53,7 @@ class CommunityReportsExtractor(Extractor):
|
||||
self._extraction_prompt = COMMUNITY_REPORT_PROMPT
|
||||
self._max_report_length = max_report_length or 1500
|
||||
|
||||
async def __call__(self, graph: nx.Graph, callback: Callable | None = None):
|
||||
async def __call__(self, graph: nx.Graph, callback: Callable | None = None, task_id: str = ""):
|
||||
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
|
||||
for node_degree in graph.degree:
|
||||
graph.nodes[str(node_degree[0])]["rank"] = int(node_degree[1])
|
||||
@ -64,6 +66,11 @@ class CommunityReportsExtractor(Extractor):
|
||||
@timeout(120)
|
||||
async def extract_community_report(community):
|
||||
nonlocal res_str, res_dict, over, token_count
|
||||
if task_id:
|
||||
if has_canceled(task_id):
|
||||
logging.info(f"Task {task_id} cancelled during community report extraction.")
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||
|
||||
cm_id, cm = community
|
||||
weight = cm["weight"]
|
||||
ents = cm["nodes"]
|
||||
@ -95,7 +102,10 @@ class CommunityReportsExtractor(Extractor):
|
||||
async with chat_limiter:
|
||||
try:
|
||||
with trio.move_on_after(180 if enable_timeout_assertion else 1000000000) as cancel_scope:
|
||||
response = await trio.to_thread.run_sync( self._chat, text, [{"role": "user", "content": "Output:"}], {})
|
||||
if task_id and has_canceled(task_id):
|
||||
logging.info(f"Task {task_id} cancelled before LLM call.")
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||
response = await trio.to_thread.run_sync( self._chat, text, [{"role": "user", "content": "Output:"}], {}, task_id)
|
||||
if cancel_scope.cancelled_caught:
|
||||
logging.warning("extract_community_report._chat timeout, skipping...")
|
||||
return
|
||||
@ -136,6 +146,9 @@ class CommunityReportsExtractor(Extractor):
|
||||
for level, comm in communities.items():
|
||||
logging.info(f"Level {level}: Community: {len(comm.keys())}")
|
||||
for community in comm.items():
|
||||
if task_id and has_canceled(task_id):
|
||||
logging.info(f"Task {task_id} cancelled before community processing.")
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||
nursery.start_soon(extract_community_report, community)
|
||||
if callback:
|
||||
callback(msg=f"Community reports done in {trio.current_time() - st:.2f}s, used tokens: {token_count}")
|
||||
|
||||
@ -23,7 +23,9 @@ from typing import Callable
|
||||
import networkx as nx
|
||||
import trio
|
||||
|
||||
from api.db.services.task_service import has_canceled
|
||||
from common.connection_utils import timeout
|
||||
from common.token_utils import truncate
|
||||
from graphrag.general.graph_prompt import SUMMARIZE_DESCRIPTIONS_PROMPT
|
||||
from graphrag.utils import (
|
||||
GraphChange,
|
||||
@ -38,7 +40,7 @@ from graphrag.utils import (
|
||||
)
|
||||
from rag.llm.chat_model import Base as CompletionLLM
|
||||
from rag.prompts.generator import message_fit_in
|
||||
from common.token_utils import truncate
|
||||
from common.exceptions import TaskCanceledException
|
||||
|
||||
GRAPH_FIELD_SEP = "<SEP>"
|
||||
DEFAULT_ENTITY_TYPES = ["organization", "person", "geo", "event", "category"]
|
||||
@ -60,7 +62,7 @@ class Extractor:
|
||||
self._entity_types = entity_types or DEFAULT_ENTITY_TYPES
|
||||
|
||||
@timeout(60 * 20)
|
||||
def _chat(self, system, history, gen_conf={}):
|
||||
def _chat(self, system, history, gen_conf={}, task_id=""):
|
||||
hist = deepcopy(history)
|
||||
conf = deepcopy(gen_conf)
|
||||
response = get_llm_cache(self._llm.llm_name, system, hist, conf)
|
||||
@ -69,6 +71,12 @@ class Extractor:
|
||||
_, system_msg = message_fit_in([{"role": "system", "content": system}], int(self._llm.max_length * 0.92))
|
||||
response = ""
|
||||
for attempt in range(3):
|
||||
|
||||
if task_id:
|
||||
if has_canceled(task_id):
|
||||
logging.info(f"Task {task_id} cancelled during entity resolution candidate processing.")
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||
|
||||
try:
|
||||
response = self._llm.chat(system_msg[0]["content"], hist, conf)
|
||||
response = re.sub(r"^.*</think>", "", response, flags=re.DOTALL)
|
||||
@ -99,25 +107,29 @@ class Extractor:
|
||||
maybe_edges[(if_relation["src_id"], if_relation["tgt_id"])].append(if_relation)
|
||||
return dict(maybe_nodes), dict(maybe_edges)
|
||||
|
||||
async def __call__(self, doc_id: str, chunks: list[str], callback: Callable | None = None):
|
||||
async def __call__(self, doc_id: str, chunks: list[str], callback: Callable | None = None, task_id: str = ""):
|
||||
self.callback = callback
|
||||
start_ts = trio.current_time()
|
||||
|
||||
async def extract_all(doc_id, chunks, max_concurrency=MAX_CONCURRENT_PROCESS_AND_EXTRACT_CHUNK):
|
||||
async def extract_all(doc_id, chunks, max_concurrency=MAX_CONCURRENT_PROCESS_AND_EXTRACT_CHUNK, task_id=""):
|
||||
out_results = []
|
||||
error_count = 0
|
||||
max_errors = 3
|
||||
|
||||
limiter = trio.Semaphore(max_concurrency)
|
||||
|
||||
async def worker(chunk_key_dp: tuple[str, str], idx: int, total: int):
|
||||
async def worker(chunk_key_dp: tuple[str, str], idx: int, total: int, task_id=""):
|
||||
nonlocal error_count
|
||||
async with limiter:
|
||||
|
||||
if task_id and has_canceled(task_id):
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled during entity extraction")
|
||||
|
||||
try:
|
||||
await self._process_single_content(chunk_key_dp, idx, total, out_results)
|
||||
await self._process_single_content(chunk_key_dp, idx, total, out_results, task_id)
|
||||
except Exception as e:
|
||||
error_count += 1
|
||||
error_msg = f"Error processing chunk {idx+1}/{total}: {str(e)}"
|
||||
error_msg = f"Error processing chunk {idx + 1}/{total}: {str(e)}"
|
||||
logging.warning(error_msg)
|
||||
if self.callback:
|
||||
self.callback(msg=error_msg)
|
||||
@ -127,7 +139,7 @@ class Extractor:
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
for i, ck in enumerate(chunks):
|
||||
nursery.start_soon(worker, (doc_id, ck), i, len(chunks))
|
||||
nursery.start_soon(worker, (doc_id, ck), i, len(chunks), task_id)
|
||||
|
||||
if error_count > 0:
|
||||
warning_msg = f"Completed with {error_count} errors (out of {len(chunks)} chunks processed)"
|
||||
@ -137,7 +149,13 @@ class Extractor:
|
||||
|
||||
return out_results
|
||||
|
||||
out_results = await extract_all(doc_id, chunks, max_concurrency=MAX_CONCURRENT_PROCESS_AND_EXTRACT_CHUNK)
|
||||
if task_id and has_canceled(task_id):
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled before entity extraction")
|
||||
|
||||
out_results = await extract_all(doc_id, chunks, max_concurrency=MAX_CONCURRENT_PROCESS_AND_EXTRACT_CHUNK, task_id=task_id)
|
||||
|
||||
if task_id and has_canceled(task_id):
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled after entity extraction")
|
||||
|
||||
maybe_nodes = defaultdict(list)
|
||||
maybe_edges = defaultdict(list)
|
||||
@ -154,9 +172,17 @@ class Extractor:
|
||||
start_ts = now
|
||||
logging.info("Entities merging...")
|
||||
all_entities_data = []
|
||||
|
||||
if task_id and has_canceled(task_id):
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled before nodes merging")
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
for en_nm, ents in maybe_nodes.items():
|
||||
nursery.start_soon(self._merge_nodes, en_nm, ents, all_entities_data)
|
||||
nursery.start_soon(self._merge_nodes, en_nm, ents, all_entities_data, task_id)
|
||||
|
||||
if task_id and has_canceled(task_id):
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled after nodes merging")
|
||||
|
||||
now = trio.current_time()
|
||||
if self.callback:
|
||||
self.callback(msg=f"Entities merging done, {now - start_ts:.2f}s.")
|
||||
@ -164,9 +190,17 @@ class Extractor:
|
||||
start_ts = now
|
||||
logging.info("Relationships merging...")
|
||||
all_relationships_data = []
|
||||
|
||||
if task_id and has_canceled(task_id):
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled before relationships merging")
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
for (src, tgt), rels in maybe_edges.items():
|
||||
nursery.start_soon(self._merge_edges, src, tgt, rels, all_relationships_data)
|
||||
nursery.start_soon(self._merge_edges, src, tgt, rels, all_relationships_data, task_id)
|
||||
|
||||
if task_id and has_canceled(task_id):
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled after relationships merging")
|
||||
|
||||
now = trio.current_time()
|
||||
if self.callback:
|
||||
self.callback(msg=f"Relationships merging done, {now - start_ts:.2f}s.")
|
||||
@ -181,7 +215,10 @@ class Extractor:
|
||||
|
||||
return all_entities_data, all_relationships_data
|
||||
|
||||
async def _merge_nodes(self, entity_name: str, entities: list[dict], all_relationships_data):
|
||||
async def _merge_nodes(self, entity_name: str, entities: list[dict], all_relationships_data, task_id=""):
|
||||
if task_id and has_canceled(task_id):
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled during merge nodes")
|
||||
|
||||
if not entities:
|
||||
return
|
||||
entity_type = sorted(
|
||||
@ -191,7 +228,7 @@ class Extractor:
|
||||
)[0][0]
|
||||
description = GRAPH_FIELD_SEP.join(sorted(set([dp["description"] for dp in entities])))
|
||||
already_source_ids = flat_uniq_list(entities, "source_id")
|
||||
description = await self._handle_entity_relation_summary(entity_name, description)
|
||||
description = await self._handle_entity_relation_summary(entity_name, description, task_id=task_id)
|
||||
node_data = dict(
|
||||
entity_type=entity_type,
|
||||
description=description,
|
||||
@ -200,18 +237,21 @@ class Extractor:
|
||||
node_data["entity_name"] = entity_name
|
||||
all_relationships_data.append(node_data)
|
||||
|
||||
async def _merge_edges(self, src_id: str, tgt_id: str, edges_data: list[dict], all_relationships_data=None):
|
||||
async def _merge_edges(self, src_id: str, tgt_id: str, edges_data: list[dict], all_relationships_data=None, task_id=""):
|
||||
if not edges_data:
|
||||
return
|
||||
weight = sum([edge["weight"] for edge in edges_data])
|
||||
description = GRAPH_FIELD_SEP.join(sorted(set([edge["description"] for edge in edges_data])))
|
||||
description = await self._handle_entity_relation_summary(f"{src_id} -> {tgt_id}", description)
|
||||
description = await self._handle_entity_relation_summary(f"{src_id} -> {tgt_id}", description, task_id=task_id)
|
||||
keywords = flat_uniq_list(edges_data, "keywords")
|
||||
source_id = flat_uniq_list(edges_data, "source_id")
|
||||
edge_data = dict(src_id=src_id, tgt_id=tgt_id, description=description, keywords=keywords, weight=weight, source_id=source_id)
|
||||
all_relationships_data.append(edge_data)
|
||||
|
||||
async def _merge_graph_nodes(self, graph: nx.Graph, nodes: list[str], change: GraphChange):
|
||||
async def _merge_graph_nodes(self, graph: nx.Graph, nodes: list[str], change: GraphChange, task_id=""):
|
||||
if task_id and has_canceled(task_id):
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled during merge graph nodes")
|
||||
|
||||
if len(nodes) <= 1:
|
||||
return
|
||||
change.added_updated_nodes.add(nodes[0])
|
||||
@ -220,6 +260,9 @@ class Extractor:
|
||||
node0_attrs = graph.nodes[nodes[0]]
|
||||
node0_neighbors = set(graph.neighbors(nodes[0]))
|
||||
for node1 in nodes[1:]:
|
||||
if task_id and has_canceled(task_id):
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled during merge_graph nodes")
|
||||
|
||||
# Merge two nodes, keep "entity_name", "entity_type", "page_rank" unchanged.
|
||||
node1_attrs = graph.nodes[node1]
|
||||
node0_attrs["description"] += f"{GRAPH_FIELD_SEP}{node1_attrs['description']}"
|
||||
@ -236,15 +279,18 @@ class Extractor:
|
||||
edge0_attrs["description"] += f"{GRAPH_FIELD_SEP}{edge1_attrs['description']}"
|
||||
for attr in ["keywords", "source_id"]:
|
||||
edge0_attrs[attr] = sorted(set(edge0_attrs[attr] + edge1_attrs[attr]))
|
||||
edge0_attrs["description"] = await self._handle_entity_relation_summary(f"({nodes[0]}, {neighbor})", edge0_attrs["description"])
|
||||
edge0_attrs["description"] = await self._handle_entity_relation_summary(f"({nodes[0]}, {neighbor})", edge0_attrs["description"], task_id=task_id)
|
||||
graph.add_edge(nodes[0], neighbor, **edge0_attrs)
|
||||
else:
|
||||
graph.add_edge(nodes[0], neighbor, **edge1_attrs)
|
||||
graph.remove_node(node1)
|
||||
node0_attrs["description"] = await self._handle_entity_relation_summary(nodes[0], node0_attrs["description"])
|
||||
node0_attrs["description"] = await self._handle_entity_relation_summary(nodes[0], node0_attrs["description"], task_id=task_id)
|
||||
graph.nodes[nodes[0]].update(node0_attrs)
|
||||
|
||||
async def _handle_entity_relation_summary(self, entity_or_relation_name: str, description: str) -> str:
|
||||
async def _handle_entity_relation_summary(self, entity_or_relation_name: str, description: str, task_id="") -> str:
|
||||
if task_id and has_canceled(task_id):
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled during summary handling")
|
||||
|
||||
summary_max_tokens = 512
|
||||
use_description = truncate(description, summary_max_tokens)
|
||||
description_list = use_description.split(GRAPH_FIELD_SEP)
|
||||
@ -258,6 +304,10 @@ class Extractor:
|
||||
)
|
||||
use_prompt = prompt_template.format(**context_base)
|
||||
logging.info(f"Trigger summary: {entity_or_relation_name}")
|
||||
|
||||
if task_id and has_canceled(task_id):
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled during summary handling")
|
||||
|
||||
async with chat_limiter:
|
||||
summary = await trio.to_thread.run_sync(self._chat, "", [{"role": "user", "content": use_prompt}])
|
||||
summary = await trio.to_thread.run_sync(self._chat, "", [{"role": "user", "content": use_prompt}], {}, task_id)
|
||||
return summary
|
||||
|
||||
@ -97,7 +97,7 @@ class GraphExtractor(Extractor):
|
||||
self._entity_types_key: ",".join(entity_types),
|
||||
}
|
||||
|
||||
async def _process_single_content(self, chunk_key_dp: tuple[str, str], chunk_seq: int, num_chunks: int, out_results):
|
||||
async def _process_single_content(self, chunk_key_dp: tuple[str, str], chunk_seq: int, num_chunks: int, out_results, task_id=""):
|
||||
token_count = 0
|
||||
chunk_key = chunk_key_dp[0]
|
||||
content = chunk_key_dp[1]
|
||||
@ -107,7 +107,7 @@ class GraphExtractor(Extractor):
|
||||
}
|
||||
hint_prompt = perform_variable_replacements(self._extraction_prompt, variables=variables)
|
||||
async with chat_limiter:
|
||||
response = await trio.to_thread.run_sync(lambda: self._chat(hint_prompt, [{"role": "user", "content": "Output:"}], {}))
|
||||
response = await trio.to_thread.run_sync(self._chat, hint_prompt, [{"role": "user", "content": "Output:"}], {}, task_id)
|
||||
token_count += num_tokens_from_string(hint_prompt + response)
|
||||
|
||||
results = response or ""
|
||||
|
||||
@ -21,6 +21,8 @@ import networkx as nx
|
||||
import trio
|
||||
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.task_service import has_canceled
|
||||
from common.exceptions import TaskCanceledException
|
||||
from common.misc_utils import get_uuid
|
||||
from common.connection_utils import timeout
|
||||
from graphrag.entity_resolution import EntityResolution
|
||||
@ -106,6 +108,7 @@ async def run_graphrag(
|
||||
chat_model,
|
||||
embedding_model,
|
||||
callback,
|
||||
task_id=row["id"],
|
||||
)
|
||||
if with_community:
|
||||
await graphrag_task_lock.spin_acquire()
|
||||
@ -118,6 +121,7 @@ async def run_graphrag(
|
||||
chat_model,
|
||||
embedding_model,
|
||||
callback,
|
||||
task_id=row["id"],
|
||||
)
|
||||
finally:
|
||||
graphrag_task_lock.release()
|
||||
@ -207,6 +211,10 @@ async def run_graphrag_for_kb(
|
||||
failed_docs: list[tuple[str, str]] = [] # (doc_id, error)
|
||||
|
||||
async def build_one(doc_id: str):
|
||||
if has_canceled(row["id"]):
|
||||
callback(msg=f"Task {row['id']} cancelled, stopping execution.")
|
||||
raise TaskCanceledException(f"Task {row['id']} was cancelled")
|
||||
|
||||
chunks = all_doc_chunks.get(doc_id, [])
|
||||
if not chunks:
|
||||
callback(msg=f"[GraphRAG] doc:{doc_id} has no available chunks, skip generation.")
|
||||
@ -232,6 +240,7 @@ async def run_graphrag_for_kb(
|
||||
chat_model,
|
||||
embedding_model,
|
||||
callback,
|
||||
task_id=row["id"]
|
||||
)
|
||||
if sg:
|
||||
subgraphs[doc_id] = sg
|
||||
@ -239,14 +248,24 @@ async def run_graphrag_for_kb(
|
||||
else:
|
||||
failed_docs.append((doc_id, "subgraph is empty"))
|
||||
callback(msg=f"{msg} empty")
|
||||
except TaskCanceledException as canceled:
|
||||
callback(msg=f"[GraphRAG] build_subgraph doc:{doc_id} FAILED: {canceled}")
|
||||
except Exception as e:
|
||||
failed_docs.append((doc_id, repr(e)))
|
||||
callback(msg=f"[GraphRAG] build_subgraph doc:{doc_id} FAILED: {e!r}")
|
||||
|
||||
if has_canceled(row["id"]):
|
||||
callback(msg=f"Task {row['id']} cancelled before processing documents.")
|
||||
raise TaskCanceledException(f"Task {row['id']} was cancelled")
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
for doc_id in doc_ids:
|
||||
nursery.start_soon(build_one, doc_id)
|
||||
|
||||
if has_canceled(row["id"]):
|
||||
callback(msg=f"Task {row['id']} cancelled after document processing.")
|
||||
raise TaskCanceledException(f"Task {row['id']} was cancelled")
|
||||
|
||||
ok_docs = [d for d in doc_ids if d in subgraphs]
|
||||
if not ok_docs:
|
||||
callback(msg=f"[GraphRAG] kb:{kb_id} no subgraphs generated successfully, end.")
|
||||
@ -257,6 +276,10 @@ async def run_graphrag_for_kb(
|
||||
await kb_lock.spin_acquire()
|
||||
callback(msg=f"[GraphRAG] kb:{kb_id} merge lock acquired")
|
||||
|
||||
if has_canceled(row["id"]):
|
||||
callback(msg=f"Task {row['id']} cancelled before merging subgraphs.")
|
||||
raise TaskCanceledException(f"Task {row['id']} was cancelled")
|
||||
|
||||
try:
|
||||
union_nodes: set = set()
|
||||
final_graph = None
|
||||
@ -288,6 +311,10 @@ async def run_graphrag_for_kb(
|
||||
callback(msg=f"[GraphRAG] KB merge done in {now - start:.2f}s. ok={len(ok_docs)} / total={len(doc_ids)}")
|
||||
return {"ok_docs": ok_docs, "failed_docs": failed_docs, "total_docs": len(doc_ids), "total_chunks": total_chunks, "seconds": now - start}
|
||||
|
||||
if has_canceled(row["id"]):
|
||||
callback(msg=f"Task {row['id']} cancelled before resolution/community extraction.")
|
||||
raise TaskCanceledException(f"Task {row['id']} was cancelled")
|
||||
|
||||
await kb_lock.spin_acquire()
|
||||
callback(msg=f"[GraphRAG] kb:{kb_id} post-merge lock acquired for resolution/community")
|
||||
|
||||
@ -306,6 +333,7 @@ async def run_graphrag_for_kb(
|
||||
chat_model,
|
||||
embedding_model,
|
||||
callback,
|
||||
task_id=row["id"],
|
||||
)
|
||||
|
||||
if with_community:
|
||||
@ -317,6 +345,7 @@ async def run_graphrag_for_kb(
|
||||
chat_model,
|
||||
embedding_model,
|
||||
callback,
|
||||
task_id=row["id"],
|
||||
)
|
||||
finally:
|
||||
kb_lock.release()
|
||||
@ -343,7 +372,12 @@ async def generate_subgraph(
|
||||
llm_bdl,
|
||||
embed_bdl,
|
||||
callback,
|
||||
task_id: str = "",
|
||||
):
|
||||
if task_id and has_canceled(task_id):
|
||||
callback(msg=f"Task {task_id} cancelled during subgraph generation for doc {doc_id}.")
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||
|
||||
contains = await does_graph_contains(tenant_id, kb_id, doc_id)
|
||||
if contains:
|
||||
callback(msg=f"Graph already contains {doc_id}")
|
||||
@ -354,15 +388,24 @@ async def generate_subgraph(
|
||||
language=language,
|
||||
entity_types=entity_types,
|
||||
)
|
||||
ents, rels = await ext(doc_id, chunks, callback)
|
||||
ents, rels = await ext(doc_id, chunks, callback, task_id=task_id)
|
||||
subgraph = nx.Graph()
|
||||
|
||||
for ent in ents:
|
||||
if task_id and has_canceled(task_id):
|
||||
callback(msg=f"Task {task_id} cancelled during entity processing for doc {doc_id}.")
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||
|
||||
assert "description" in ent, f"entity {ent} does not have description"
|
||||
ent["source_id"] = [doc_id]
|
||||
subgraph.add_node(ent["entity_name"], **ent)
|
||||
|
||||
ignored_rels = 0
|
||||
for rel in rels:
|
||||
if task_id and has_canceled(task_id):
|
||||
callback(msg=f"Task {task_id} cancelled during relationship processing for doc {doc_id}.")
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||
|
||||
assert "description" in rel, f"relation {rel} does not have description"
|
||||
if not subgraph.has_node(rel["src_id"]) or not subgraph.has_node(rel["tgt_id"]):
|
||||
ignored_rels += 1
|
||||
@ -434,17 +477,27 @@ async def resolve_entities(
|
||||
llm_bdl,
|
||||
embed_bdl,
|
||||
callback,
|
||||
task_id: str = "",
|
||||
):
|
||||
# Check if task has been canceled before resolution
|
||||
if task_id and has_canceled(task_id):
|
||||
callback(msg=f"Task {task_id} cancelled during entity resolution.")
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||
|
||||
start = trio.current_time()
|
||||
er = EntityResolution(
|
||||
llm_bdl,
|
||||
)
|
||||
reso = await er(graph, subgraph_nodes, callback=callback)
|
||||
reso = await er(graph, subgraph_nodes, callback=callback, task_id=task_id)
|
||||
graph = reso.graph
|
||||
change = reso.change
|
||||
callback(msg=f"Graph resolution removed {len(change.removed_nodes)} nodes and {len(change.removed_edges)} edges.")
|
||||
callback(msg="Graph resolution updated pagerank.")
|
||||
|
||||
if task_id and has_canceled(task_id):
|
||||
callback(msg=f"Task {task_id} cancelled after entity resolution.")
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||
|
||||
await set_graph(tenant_id, kb_id, embed_bdl, graph, change, callback)
|
||||
now = trio.current_time()
|
||||
callback(msg=f"Graph resolution done in {now - start:.2f}s.")
|
||||
@ -459,12 +512,22 @@ async def extract_community(
|
||||
llm_bdl,
|
||||
embed_bdl,
|
||||
callback,
|
||||
task_id: str = "",
|
||||
):
|
||||
if task_id and has_canceled(task_id):
|
||||
callback(msg=f"Task {task_id} cancelled before community extraction.")
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||
|
||||
start = trio.current_time()
|
||||
ext = CommunityReportsExtractor(
|
||||
llm_bdl,
|
||||
)
|
||||
cr = await ext(graph, callback=callback)
|
||||
cr = await ext(graph, callback=callback, task_id=task_id)
|
||||
|
||||
if task_id and has_canceled(task_id):
|
||||
callback(msg=f"Task {task_id} cancelled during community extraction.")
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||
|
||||
community_structure = cr.structured_output
|
||||
community_reports = cr.output
|
||||
doc_ids = graph.graph["source_id"]
|
||||
@ -472,6 +535,10 @@ async def extract_community(
|
||||
now = trio.current_time()
|
||||
callback(msg=f"Graph extracted {len(cr.structured_output)} communities in {now - start:.2f}s.")
|
||||
start = now
|
||||
if task_id and has_canceled(task_id):
|
||||
callback(msg=f"Task {task_id} cancelled during community indexing.")
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||
|
||||
chunks = []
|
||||
for stru, rep in zip(community_structure, community_reports):
|
||||
obj = {
|
||||
@ -509,6 +576,10 @@ async def extract_community(
|
||||
error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
|
||||
raise Exception(error_message)
|
||||
|
||||
if task_id and has_canceled(task_id):
|
||||
callback(msg=f"Task {task_id} cancelled after community indexing.")
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||
|
||||
now = trio.current_time()
|
||||
callback(msg=f"Graph indexed {len(cr.structured_output)} communities in {now - start:.2f}s.")
|
||||
return community_structure, community_reports
|
||||
|
||||
@ -71,7 +71,7 @@ class GraphExtractor(Extractor):
|
||||
self._left_token_count = llm_invoker.max_length - num_tokens_from_string(self._entity_extract_prompt.format(**self._context_base, input_text=""))
|
||||
self._left_token_count = max(llm_invoker.max_length * 0.6, self._left_token_count)
|
||||
|
||||
async def _process_single_content(self, chunk_key_dp: tuple[str, str], chunk_seq: int, num_chunks: int, out_results):
|
||||
async def _process_single_content(self, chunk_key_dp: tuple[str, str], chunk_seq: int, num_chunks: int, out_results, task_id=""):
|
||||
token_count = 0
|
||||
chunk_key = chunk_key_dp[0]
|
||||
content = chunk_key_dp[1]
|
||||
@ -86,13 +86,13 @@ class GraphExtractor(Extractor):
|
||||
if self.callback:
|
||||
self.callback(msg=f"Start processing for {chunk_key}: {content[:25]}...")
|
||||
async with chat_limiter:
|
||||
final_result = await trio.to_thread.run_sync(self._chat, "", [{"role": "user", "content": hint_prompt}], gen_conf)
|
||||
final_result = await trio.to_thread.run_sync(self._chat, "", [{"role": "user", "content": hint_prompt}], gen_conf, task_id)
|
||||
token_count += num_tokens_from_string(hint_prompt + final_result)
|
||||
history = pack_user_ass_to_openai_messages(hint_prompt, final_result, self._continue_prompt)
|
||||
for now_glean_index in range(self._max_gleanings):
|
||||
async with chat_limiter:
|
||||
# glean_result = await trio.to_thread.run_sync(lambda: self._chat(hint_prompt, history, gen_conf))
|
||||
glean_result = await trio.to_thread.run_sync(self._chat, "", history, gen_conf)
|
||||
glean_result = await trio.to_thread.run_sync(self._chat, "", history, gen_conf, task_id)
|
||||
history.extend([{"role": "assistant", "content": glean_result}])
|
||||
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + hint_prompt + self._continue_prompt)
|
||||
final_result += glean_result
|
||||
@ -101,7 +101,7 @@ class GraphExtractor(Extractor):
|
||||
|
||||
history.extend([{"role": "user", "content": self._if_loop_prompt}])
|
||||
async with chat_limiter:
|
||||
if_loop_result = await trio.to_thread.run_sync(self._chat, "", history, gen_conf)
|
||||
if_loop_result = await trio.to_thread.run_sync(self._chat, "", history, gen_conf, task_id)
|
||||
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + if_loop_result + self._if_loop_prompt)
|
||||
if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
|
||||
if if_loop_result != "yes":
|
||||
|
||||
@ -41,7 +41,7 @@ class Pipeline(Graph):
|
||||
self._doc_id = None
|
||||
|
||||
def callback(self, component_name: str, progress: float | int | None = None, message: str = "") -> None:
|
||||
from rag.svr.task_executor import TaskCanceledException
|
||||
from common.exceptions import TaskCanceledException
|
||||
log_key = f"{self._flow_id}-{self.task_id}-logs"
|
||||
timestamp = timer()
|
||||
if has_canceled(self.task_id):
|
||||
|
||||
@ -65,6 +65,7 @@ from common.token_utils import num_tokens_from_string, truncate
|
||||
from rag.utils.redis_conn import REDIS_CONN, RedisDistributedLock
|
||||
from graphrag.utils import chat_limiter
|
||||
from common.signal_utils import start_tracemalloc_and_snapshot, stop_tracemalloc
|
||||
from common.exceptions import TaskCanceledException
|
||||
from common import settings
|
||||
from common.constants import PAGERANK_FLD, TAG_FLD, SVR_CONSUMER_GROUP_NAME
|
||||
|
||||
@ -127,9 +128,7 @@ def signal_handler(sig, frame):
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
class TaskCanceledException(Exception):
|
||||
def __init__(self, msg):
|
||||
self.msg = msg
|
||||
|
||||
|
||||
|
||||
def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing..."):
|
||||
|
||||
Reference in New Issue
Block a user