mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-02-02 00:25:06 +08:00
Refactor: improve task cancellation checks in RAPTOR (#12813)
### What problem does this PR solve? Introduced a helper method _check_task_canceled to centralize and simplify task cancellation checks throughout RecursiveAbstractiveProcessing4TreeOrganizedRetrieval. This reduces code duplication and improves maintainability. ### Type of change - [x] Refactoring
This commit is contained in:
@ -55,6 +55,12 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
|||||||
self._max_errors = max(1, max_errors)
|
self._max_errors = max(1, max_errors)
|
||||||
self._error_count = 0
|
self._error_count = 0
|
||||||
|
|
||||||
|
def _check_task_canceled(self, task_id: str, message: str = ""):
|
||||||
|
if task_id and has_canceled(task_id):
|
||||||
|
log_msg = f"Task {task_id} cancelled during RAPTOR {message}."
|
||||||
|
logging.info(log_msg)
|
||||||
|
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||||
|
|
||||||
@timeout(60 * 20)
|
@timeout(60 * 20)
|
||||||
async def _chat(self, system, history, gen_conf):
|
async def _chat(self, system, history, gen_conf):
|
||||||
cached = await thread_pool_exec(get_llm_cache, self._llm_model.llm_name, system, history, gen_conf)
|
cached = await thread_pool_exec(get_llm_cache, self._llm_model.llm_name, system, history, gen_conf)
|
||||||
@ -95,10 +101,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
|||||||
n_clusters = np.arange(1, max_clusters)
|
n_clusters = np.arange(1, max_clusters)
|
||||||
bics = []
|
bics = []
|
||||||
for n in n_clusters:
|
for n in n_clusters:
|
||||||
if task_id:
|
self._check_task_canceled(task_id, "get optimal clusters")
|
||||||
if has_canceled(task_id):
|
|
||||||
logging.info(f"Task {task_id} cancelled during get optimal clusters.")
|
|
||||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
|
||||||
|
|
||||||
gm = GaussianMixture(n_components=n, random_state=random_state)
|
gm = GaussianMixture(n_components=n, random_state=random_state)
|
||||||
gm.fit(embeddings)
|
gm.fit(embeddings)
|
||||||
@ -117,19 +120,14 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
|||||||
async def summarize(ck_idx: list[int]):
|
async def summarize(ck_idx: list[int]):
|
||||||
nonlocal chunks
|
nonlocal chunks
|
||||||
|
|
||||||
if task_id:
|
self._check_task_canceled(task_id, "summarization")
|
||||||
if has_canceled(task_id):
|
|
||||||
logging.info(f"Task {task_id} cancelled during RAPTOR summarization.")
|
|
||||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
|
||||||
|
|
||||||
texts = [chunks[i][0] for i in ck_idx]
|
texts = [chunks[i][0] for i in ck_idx]
|
||||||
len_per_chunk = int((self._llm_model.max_length - self._max_token) / len(texts))
|
len_per_chunk = int((self._llm_model.max_length - self._max_token) / len(texts))
|
||||||
cluster_content = "\n".join([truncate(t, max(1, len_per_chunk)) for t in texts])
|
cluster_content = "\n".join([truncate(t, max(1, len_per_chunk)) for t in texts])
|
||||||
try:
|
try:
|
||||||
async with chat_limiter:
|
async with chat_limiter:
|
||||||
if task_id and has_canceled(task_id):
|
self._check_task_canceled(task_id, "before LLM call")
|
||||||
logging.info(f"Task {task_id} cancelled before RAPTOR LLM call.")
|
|
||||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
|
||||||
|
|
||||||
cnt = await self._chat(
|
cnt = await self._chat(
|
||||||
"You're a helpful assistant.",
|
"You're a helpful assistant.",
|
||||||
@ -148,9 +146,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
|||||||
)
|
)
|
||||||
logging.debug(f"SUM: {cnt}")
|
logging.debug(f"SUM: {cnt}")
|
||||||
|
|
||||||
if task_id and has_canceled(task_id):
|
self._check_task_canceled(task_id, "before embedding")
|
||||||
logging.info(f"Task {task_id} cancelled before RAPTOR embedding.")
|
|
||||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
|
||||||
|
|
||||||
embds = await self._embedding_encode(cnt)
|
embds = await self._embedding_encode(cnt)
|
||||||
chunks.append((cnt, embds))
|
chunks.append((cnt, embds))
|
||||||
@ -167,10 +163,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
|||||||
|
|
||||||
labels = []
|
labels = []
|
||||||
while end - start > 1:
|
while end - start > 1:
|
||||||
if task_id:
|
self._check_task_canceled(task_id, "layer processing")
|
||||||
if has_canceled(task_id):
|
|
||||||
logging.info(f"Task {task_id} cancelled during RAPTOR layer processing.")
|
|
||||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
|
||||||
|
|
||||||
embeddings = [embd for _, embd in chunks[start:end]]
|
embeddings = [embd for _, embd in chunks[start:end]]
|
||||||
if len(embeddings) == 2:
|
if len(embeddings) == 2:
|
||||||
@ -203,9 +196,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
|||||||
for c in range(n_clusters):
|
for c in range(n_clusters):
|
||||||
ck_idx = [i + start for i in range(len(lbls)) if lbls[i] == c]
|
ck_idx = [i + start for i in range(len(lbls)) if lbls[i] == c]
|
||||||
assert len(ck_idx) > 0
|
assert len(ck_idx) > 0
|
||||||
if task_id and has_canceled(task_id):
|
self._check_task_canceled(task_id, "before cluster processing")
|
||||||
logging.info(f"Task {task_id} cancelled before RAPTOR cluster processing.")
|
|
||||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
|
||||||
tasks.append(asyncio.create_task(summarize(ck_idx)))
|
tasks.append(asyncio.create_task(summarize(ck_idx)))
|
||||||
try:
|
try:
|
||||||
await asyncio.gather(*tasks, return_exceptions=False)
|
await asyncio.gather(*tasks, return_exceptions=False)
|
||||||
|
|||||||
Reference in New Issue
Block a user