Refactor: improve task cancellation checks in RAPTOR (#12813)

### What problem does this PR solve?
Introduced a helper method _check_task_canceled to centralize and
simplify task cancellation checks throughout
RecursiveAbstractiveProcessing4TreeOrganizedRetrieval. This reduces code
duplication and improves maintainability.

### Type of change

- [x] Refactoring
This commit is contained in:
Stephen Hu
2026-01-26 11:34:54 +08:00
committed by GitHub
parent 4236a62855
commit 0782a7d3c6

View File

@ -54,6 +54,12 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
self._max_token = max_token self._max_token = max_token
self._max_errors = max(1, max_errors) self._max_errors = max(1, max_errors)
self._error_count = 0 self._error_count = 0
def _check_task_canceled(self, task_id: str, message: str = ""):
if task_id and has_canceled(task_id):
log_msg = f"Task {task_id} cancelled during RAPTOR {message}."
logging.info(log_msg)
raise TaskCanceledException(f"Task {task_id} was cancelled")
@timeout(60 * 20) @timeout(60 * 20)
async def _chat(self, system, history, gen_conf): async def _chat(self, system, history, gen_conf):
@ -95,10 +101,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
n_clusters = np.arange(1, max_clusters) n_clusters = np.arange(1, max_clusters)
bics = [] bics = []
for n in n_clusters: for n in n_clusters:
if task_id: self._check_task_canceled(task_id, "get optimal clusters")
if has_canceled(task_id):
logging.info(f"Task {task_id} cancelled during get optimal clusters.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
gm = GaussianMixture(n_components=n, random_state=random_state) gm = GaussianMixture(n_components=n, random_state=random_state)
gm.fit(embeddings) gm.fit(embeddings)
@ -117,19 +120,14 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
async def summarize(ck_idx: list[int]): async def summarize(ck_idx: list[int]):
nonlocal chunks nonlocal chunks
if task_id: self._check_task_canceled(task_id, "summarization")
if has_canceled(task_id):
logging.info(f"Task {task_id} cancelled during RAPTOR summarization.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
texts = [chunks[i][0] for i in ck_idx] texts = [chunks[i][0] for i in ck_idx]
len_per_chunk = int((self._llm_model.max_length - self._max_token) / len(texts)) len_per_chunk = int((self._llm_model.max_length - self._max_token) / len(texts))
cluster_content = "\n".join([truncate(t, max(1, len_per_chunk)) for t in texts]) cluster_content = "\n".join([truncate(t, max(1, len_per_chunk)) for t in texts])
try: try:
async with chat_limiter: async with chat_limiter:
if task_id and has_canceled(task_id): self._check_task_canceled(task_id, "before LLM call")
logging.info(f"Task {task_id} cancelled before RAPTOR LLM call.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
cnt = await self._chat( cnt = await self._chat(
"You're a helpful assistant.", "You're a helpful assistant.",
@ -148,9 +146,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
) )
logging.debug(f"SUM: {cnt}") logging.debug(f"SUM: {cnt}")
if task_id and has_canceled(task_id): self._check_task_canceled(task_id, "before embedding")
logging.info(f"Task {task_id} cancelled before RAPTOR embedding.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
embds = await self._embedding_encode(cnt) embds = await self._embedding_encode(cnt)
chunks.append((cnt, embds)) chunks.append((cnt, embds))
@ -167,10 +163,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
labels = [] labels = []
while end - start > 1: while end - start > 1:
if task_id: self._check_task_canceled(task_id, "layer processing")
if has_canceled(task_id):
logging.info(f"Task {task_id} cancelled during RAPTOR layer processing.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
embeddings = [embd for _, embd in chunks[start:end]] embeddings = [embd for _, embd in chunks[start:end]]
if len(embeddings) == 2: if len(embeddings) == 2:
@ -203,9 +196,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
for c in range(n_clusters): for c in range(n_clusters):
ck_idx = [i + start for i in range(len(lbls)) if lbls[i] == c] ck_idx = [i + start for i in range(len(lbls)) if lbls[i] == c]
assert len(ck_idx) > 0 assert len(ck_idx) > 0
if task_id and has_canceled(task_id): self._check_task_canceled(task_id, "before cluster processing")
logging.info(f"Task {task_id} cancelled before RAPTOR cluster processing.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
tasks.append(asyncio.create_task(summarize(ck_idx))) tasks.append(asyncio.create_task(summarize(ck_idx)))
try: try:
await asyncio.gather(*tasks, return_exceptions=False) await asyncio.gather(*tasks, return_exceptions=False)