mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-01-27 05:36:33 +08:00
Refactor: improve task cancellation checks in RAPTOR (#12813)
### What problem does this PR solve? Introduced a helper method _check_task_canceled to centralize and simplify task cancellation checks throughout RecursiveAbstractiveProcessing4TreeOrganizedRetrieval. This reduces code duplication and improves maintainability. ### Type of change - [x] Refactoring
This commit is contained in:
@ -54,6 +54,12 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
||||
self._max_token = max_token
|
||||
self._max_errors = max(1, max_errors)
|
||||
self._error_count = 0
|
||||
|
||||
def _check_task_canceled(self, task_id: str, message: str = ""):
|
||||
if task_id and has_canceled(task_id):
|
||||
log_msg = f"Task {task_id} cancelled during RAPTOR {message}."
|
||||
logging.info(log_msg)
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||
|
||||
@timeout(60 * 20)
|
||||
async def _chat(self, system, history, gen_conf):
|
||||
@ -95,10 +101,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
||||
n_clusters = np.arange(1, max_clusters)
|
||||
bics = []
|
||||
for n in n_clusters:
|
||||
if task_id:
|
||||
if has_canceled(task_id):
|
||||
logging.info(f"Task {task_id} cancelled during get optimal clusters.")
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||
self._check_task_canceled(task_id, "get optimal clusters")
|
||||
|
||||
gm = GaussianMixture(n_components=n, random_state=random_state)
|
||||
gm.fit(embeddings)
|
||||
@ -117,19 +120,14 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
||||
async def summarize(ck_idx: list[int]):
|
||||
nonlocal chunks
|
||||
|
||||
if task_id:
|
||||
if has_canceled(task_id):
|
||||
logging.info(f"Task {task_id} cancelled during RAPTOR summarization.")
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||
self._check_task_canceled(task_id, "summarization")
|
||||
|
||||
texts = [chunks[i][0] for i in ck_idx]
|
||||
len_per_chunk = int((self._llm_model.max_length - self._max_token) / len(texts))
|
||||
cluster_content = "\n".join([truncate(t, max(1, len_per_chunk)) for t in texts])
|
||||
try:
|
||||
async with chat_limiter:
|
||||
if task_id and has_canceled(task_id):
|
||||
logging.info(f"Task {task_id} cancelled before RAPTOR LLM call.")
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||
self._check_task_canceled(task_id, "before LLM call")
|
||||
|
||||
cnt = await self._chat(
|
||||
"You're a helpful assistant.",
|
||||
@ -148,9 +146,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
||||
)
|
||||
logging.debug(f"SUM: {cnt}")
|
||||
|
||||
if task_id and has_canceled(task_id):
|
||||
logging.info(f"Task {task_id} cancelled before RAPTOR embedding.")
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||
self._check_task_canceled(task_id, "before embedding")
|
||||
|
||||
embds = await self._embedding_encode(cnt)
|
||||
chunks.append((cnt, embds))
|
||||
@ -167,10 +163,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
||||
|
||||
labels = []
|
||||
while end - start > 1:
|
||||
if task_id:
|
||||
if has_canceled(task_id):
|
||||
logging.info(f"Task {task_id} cancelled during RAPTOR layer processing.")
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||
self._check_task_canceled(task_id, "layer processing")
|
||||
|
||||
embeddings = [embd for _, embd in chunks[start:end]]
|
||||
if len(embeddings) == 2:
|
||||
@ -203,9 +196,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
||||
for c in range(n_clusters):
|
||||
ck_idx = [i + start for i in range(len(lbls)) if lbls[i] == c]
|
||||
assert len(ck_idx) > 0
|
||||
if task_id and has_canceled(task_id):
|
||||
logging.info(f"Task {task_id} cancelled before RAPTOR cluster processing.")
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||
self._check_task_canceled(task_id, "before cluster processing")
|
||||
tasks.append(asyncio.create_task(summarize(ck_idx)))
|
||||
try:
|
||||
await asyncio.gather(*tasks, return_exceptions=False)
|
||||
|
||||
Reference in New Issue
Block a user