From 0cd8024c34c9d252e5e387ae74baf3191cb03154 Mon Sep 17 00:00:00 2001 From: Yongteng Lei Date: Thu, 6 Nov 2025 17:18:03 +0800 Subject: [PATCH] Feat: RAPTOR handle cancel gracefully (#11074) ### What problem does this PR solve? RAPTOR handle cancel gracefully. ### Type of change - [x] New Feature (non-breaking change which adds functionality) --- download_deps.py | 8 -------- rag/raptor.py | 41 +++++++++++++++++++++++++++++++++++++--- rag/svr/task_executor.py | 4 +++- 3 files changed, 41 insertions(+), 12 deletions(-) diff --git a/download_deps.py b/download_deps.py index 0d11d451c..10e08bdda 100644 --- a/download_deps.py +++ b/download_deps.py @@ -4,7 +4,6 @@ # /// script # requires-python = ">=3.10" # dependencies = [ -# "huggingface-hub", # "nltk", # ] # /// @@ -40,13 +39,6 @@ def get_urls(use_china_mirrors=False) -> list[Union[str, list[str]]]: ] -repos = [ - "InfiniFlow/text_concat_xgb_v1.0", - "InfiniFlow/deepdoc", - "InfiniFlow/huqie", -] - - if __name__ == "__main__": parser = argparse.ArgumentParser(description="Download dependencies with optional China mirror support") parser.add_argument("--china-mirrors", action="store_true", help="Use China-accessible mirrors for downloads") diff --git a/rag/raptor.py b/rag/raptor.py index 22f8ce397..6c7b5f2f5 100644 --- a/rag/raptor.py +++ b/rag/raptor.py @@ -20,7 +20,9 @@ import numpy as np from sklearn.mixture import GaussianMixture import trio +from api.db.services.task_service import has_canceled from common.connection_utils import timeout +from common.exceptions import TaskCanceledException from graphrag.utils import ( get_llm_cache, get_embed_cache, @@ -75,18 +77,24 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: await trio.to_thread.run_sync(lambda: set_embed_cache(self._embd_model.llm_name, txt, embds)) return embds - def _get_optimal_clusters(self, embeddings: np.ndarray, random_state: int): + def _get_optimal_clusters(self, embeddings: np.ndarray, random_state: int, task_id: str = ""): max_clusters = min(self._max_cluster, len(embeddings)) n_clusters = np.arange(1, max_clusters) bics = [] for n in n_clusters: + + if task_id: + if has_canceled(task_id): + logging.info(f"Task {task_id} cancelled during get optimal clusters.") + raise TaskCanceledException(f"Task {task_id} was cancelled") + gm = GaussianMixture(n_components=n, random_state=random_state) gm.fit(embeddings) bics.append(gm.bic(embeddings)) optimal_clusters = n_clusters[np.argmin(bics)] return optimal_clusters - async def __call__(self, chunks, random_state, callback=None): + async def __call__(self, chunks, random_state, callback=None, task_id: str = ""): if len(chunks) <= 1: return [] chunks = [(s, a) for s, a in chunks if s and len(a) > 0] @@ -96,6 +104,12 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: @timeout(60*20) async def summarize(ck_idx: list[int]): nonlocal chunks + + if task_id: + if has_canceled(task_id): + logging.info(f"Task {task_id} cancelled during RAPTOR summarization.") + raise TaskCanceledException(f"Task {task_id} was cancelled") + texts = [chunks[i][0] for i in ck_idx] len_per_chunk = int( (self._llm_model.max_length - self._max_token) / len(texts) @@ -104,6 +118,11 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: [truncate(t, max(1, len_per_chunk)) for t in texts] ) async with chat_limiter: + + if task_id and has_canceled(task_id): + logging.info(f"Task {task_id} cancelled before RAPTOR LLM call.") + raise TaskCanceledException(f"Task {task_id} was cancelled") + cnt = await self._chat( "You're a helpful assistant.", [ @@ -122,11 +141,22 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: cnt, ) logging.debug(f"SUM: {cnt}") + + if task_id and has_canceled(task_id): + logging.info(f"Task {task_id} cancelled before RAPTOR embedding.") + raise TaskCanceledException(f"Task {task_id} was cancelled") + embds = await self._embedding_encode(cnt) chunks.append((cnt, embds)) labels = [] while end - start > 1: + + if task_id: + if has_canceled(task_id): + logging.info(f"Task {task_id} cancelled during RAPTOR layer processing.") + raise TaskCanceledException(f"Task {task_id} was cancelled") + embeddings = [embd for _, embd in chunks[start:end]] if len(embeddings) == 2: await summarize([start, start + 1]) @@ -148,7 +178,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: n_components=min(12, len(embeddings) - 2), metric="cosine", ).fit_transform(embeddings) - n_clusters = self._get_optimal_clusters(reduced_embeddings, random_state) + n_clusters = self._get_optimal_clusters(reduced_embeddings, random_state, task_id=task_id) if n_clusters == 1: lbls = [0 for _ in range(len(reduced_embeddings))] else: @@ -162,6 +192,11 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: for c in range(n_clusters): ck_idx = [i + start for i in range(len(lbls)) if lbls[i] == c] assert len(ck_idx) > 0 + + if task_id and has_canceled(task_id): + logging.info(f"Task {task_id} cancelled before RAPTOR cluster processing.") + raise TaskCanceledException(f"Task {task_id} was cancelled") + nursery.start_soon(summarize, ck_idx) assert len(chunks) - end == n_clusters, "{} vs. {}".format( diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index d07a44ea4..9b3a68ce0 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -659,7 +659,7 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si raptor_config["threshold"], ) original_length = len(chunks) - chunks = await raptor(chunks, kb_parser_config["raptor"]["random_seed"], callback) + chunks = await raptor(chunks, kb_parser_config["raptor"]["random_seed"], callback, row["id"]) doc = { "doc_id": fake_doc_id, "kb_id": [str(row["kb_id"])], @@ -814,6 +814,8 @@ async def do_handle_task(task): callback=progress_callback, doc_ids=task.get("doc_ids", []), ) + if fake_doc_ids := task.get("doc_ids", []): + task_doc_id = fake_doc_ids[0] # use the first document ID to represent this task for logging purposes # Either using graphrag or Standard chunking methods elif task_type == "graphrag": ok, kb = KnowledgebaseService.get_by_id(task_dataset_id)