From 0782a7d3c6a1fadcfb5f23edd4098af54798403e Mon Sep 17 00:00:00 2001 From: Stephen Hu <812791840@qq.com> Date: Mon, 26 Jan 2026 11:34:54 +0800 Subject: [PATCH] Refactor: improve task cancellation checks in RAPTOR (#12813) ### What problem does this PR solve? Introduced a helper method _check_task_canceled to centralize and simplify task cancellation checks throughout RecursiveAbstractiveProcessing4TreeOrganizedRetrieval. This reduces code duplication and improves maintainability. ### Type of change - [x] Refactoring --- rag/raptor.py | 33 ++++++++++++--------------------- 1 file changed, 12 insertions(+), 21 deletions(-) diff --git a/rag/raptor.py b/rag/raptor.py index 867911d22..ce9e5345e 100644 --- a/rag/raptor.py +++ b/rag/raptor.py @@ -54,6 +54,12 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: self._max_token = max_token self._max_errors = max(1, max_errors) self._error_count = 0 + + def _check_task_canceled(self, task_id: str, message: str = ""): + if task_id and has_canceled(task_id): + log_msg = f"Task {task_id} cancelled during RAPTOR {message}." + logging.info(log_msg) + raise TaskCanceledException(f"Task {task_id} was cancelled") @timeout(60 * 20) async def _chat(self, system, history, gen_conf): @@ -95,10 +101,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: n_clusters = np.arange(1, max_clusters) bics = [] for n in n_clusters: - if task_id: - if has_canceled(task_id): - logging.info(f"Task {task_id} cancelled during get optimal clusters.") - raise TaskCanceledException(f"Task {task_id} was cancelled") + self._check_task_canceled(task_id, "get optimal clusters") gm = GaussianMixture(n_components=n, random_state=random_state) gm.fit(embeddings) @@ -117,19 +120,14 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: async def summarize(ck_idx: list[int]): nonlocal chunks - if task_id: - if has_canceled(task_id): - logging.info(f"Task {task_id} cancelled during RAPTOR summarization.") - raise TaskCanceledException(f"Task {task_id} was cancelled") + self._check_task_canceled(task_id, "summarization") texts = [chunks[i][0] for i in ck_idx] len_per_chunk = int((self._llm_model.max_length - self._max_token) / len(texts)) cluster_content = "\n".join([truncate(t, max(1, len_per_chunk)) for t in texts]) try: async with chat_limiter: - if task_id and has_canceled(task_id): - logging.info(f"Task {task_id} cancelled before RAPTOR LLM call.") - raise TaskCanceledException(f"Task {task_id} was cancelled") + self._check_task_canceled(task_id, "before LLM call") cnt = await self._chat( "You're a helpful assistant.", @@ -148,9 +146,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: ) logging.debug(f"SUM: {cnt}") - if task_id and has_canceled(task_id): - logging.info(f"Task {task_id} cancelled before RAPTOR embedding.") - raise TaskCanceledException(f"Task {task_id} was cancelled") + self._check_task_canceled(task_id, "before embedding") embds = await self._embedding_encode(cnt) chunks.append((cnt, embds)) @@ -167,10 +163,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: labels = [] while end - start > 1: - if task_id: - if has_canceled(task_id): - logging.info(f"Task {task_id} cancelled during RAPTOR layer processing.") - raise TaskCanceledException(f"Task {task_id} was cancelled") + self._check_task_canceled(task_id, "layer processing") embeddings = [embd for _, embd in chunks[start:end]] if len(embeddings) == 2: @@ -203,9 +196,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: for c in range(n_clusters): ck_idx = [i + start for i in range(len(lbls)) if lbls[i] == c] assert len(ck_idx) > 0 - if task_id and has_canceled(task_id): - logging.info(f"Task {task_id} cancelled before RAPTOR cluster processing.") - raise TaskCanceledException(f"Task {task_id} was cancelled") + self._check_task_canceled(task_id, "before cluster processing") tasks.append(asyncio.create_task(summarize(ck_idx))) try: await asyncio.gather(*tasks, return_exceptions=False)