Refactor: improve task cancellation checks in RAPTOR (#12813)

### What problem does this PR solve?
Introduced a helper method _check_task_canceled to centralize and
simplify task cancellation checks throughout
RecursiveAbstractiveProcessing4TreeOrganizedRetrieval. This reduces code
duplication and improves maintainability.

### Type of change

- [x] Refactoring
This commit is contained in:
Stephen Hu
2026-01-26 11:34:54 +08:00
committed by GitHub
parent 4236a62855
commit 0782a7d3c6

View File

@ -54,6 +54,12 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
self._max_token = max_token
self._max_errors = max(1, max_errors)
self._error_count = 0
def _check_task_canceled(self, task_id: str, message: str = ""):
if task_id and has_canceled(task_id):
log_msg = f"Task {task_id} cancelled during RAPTOR {message}."
logging.info(log_msg)
raise TaskCanceledException(f"Task {task_id} was cancelled")
@timeout(60 * 20)
async def _chat(self, system, history, gen_conf):
@ -95,10 +101,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
n_clusters = np.arange(1, max_clusters)
bics = []
for n in n_clusters:
if task_id:
if has_canceled(task_id):
logging.info(f"Task {task_id} cancelled during get optimal clusters.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
self._check_task_canceled(task_id, "get optimal clusters")
gm = GaussianMixture(n_components=n, random_state=random_state)
gm.fit(embeddings)
@ -117,19 +120,14 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
async def summarize(ck_idx: list[int]):
nonlocal chunks
if task_id:
if has_canceled(task_id):
logging.info(f"Task {task_id} cancelled during RAPTOR summarization.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
self._check_task_canceled(task_id, "summarization")
texts = [chunks[i][0] for i in ck_idx]
len_per_chunk = int((self._llm_model.max_length - self._max_token) / len(texts))
cluster_content = "\n".join([truncate(t, max(1, len_per_chunk)) for t in texts])
try:
async with chat_limiter:
if task_id and has_canceled(task_id):
logging.info(f"Task {task_id} cancelled before RAPTOR LLM call.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
self._check_task_canceled(task_id, "before LLM call")
cnt = await self._chat(
"You're a helpful assistant.",
@ -148,9 +146,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
)
logging.debug(f"SUM: {cnt}")
if task_id and has_canceled(task_id):
logging.info(f"Task {task_id} cancelled before RAPTOR embedding.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
self._check_task_canceled(task_id, "before embedding")
embds = await self._embedding_encode(cnt)
chunks.append((cnt, embds))
@ -167,10 +163,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
labels = []
while end - start > 1:
if task_id:
if has_canceled(task_id):
logging.info(f"Task {task_id} cancelled during RAPTOR layer processing.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
self._check_task_canceled(task_id, "layer processing")
embeddings = [embd for _, embd in chunks[start:end]]
if len(embeddings) == 2:
@ -203,9 +196,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
for c in range(n_clusters):
ck_idx = [i + start for i in range(len(lbls)) if lbls[i] == c]
assert len(ck_idx) > 0
if task_id and has_canceled(task_id):
logging.info(f"Task {task_id} cancelled before RAPTOR cluster processing.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
self._check_task_canceled(task_id, "before cluster processing")
tasks.append(asyncio.create_task(summarize(ck_idx)))
try:
await asyncio.gather(*tasks, return_exceptions=False)