From 24c41d2a6113a70a9dc25941b46ca2a4d9a050ed Mon Sep 17 00:00:00 2001 From: Kevin Hu Date: Tue, 15 Jul 2025 14:35:00 +0800 Subject: [PATCH] Perf: make `do_cancel` quicker. (#8846) ### What problem does this PR solve? ### Type of change - [x] Performance Improvement --- api/utils/api_utils.py | 56 ++++++++++++++++++++++------------------ rag/raptor.py | 1 + rag/svr/task_executor.py | 8 +++--- 3 files changed, 36 insertions(+), 29 deletions(-) diff --git a/api/utils/api_utils.py b/api/utils/api_utils.py index 546cb2dfb..f574a8f00 100644 --- a/api/utils/api_utils.py +++ b/api/utils/api_utils.py @@ -606,6 +606,7 @@ TimeoutException = Union[Type[BaseException], BaseException] OnTimeoutCallback = Union[Callable[..., Any], Coroutine[Any, Any, Any]] def timeout( seconds: float |int = None, + attempts: int = 2, *, exception: Optional[TimeoutException] = None, on_timeout: Optional[OnTimeoutCallback] = None @@ -625,41 +626,46 @@ def timeout( thread.daemon = True thread.start() - try: - result = result_queue.get(timeout=seconds) - if isinstance(result, Exception): - raise result - return result - except queue.Empty: - raise TimeoutError(f"Function '{func.__name__}' timed out after {seconds} seconds") + for a in range(attempts): + try: + result = result_queue.get(timeout=seconds) + if isinstance(result, Exception): + raise result + return result + except queue.Empty: + pass + raise TimeoutError(f"Function '{func.__name__}' timed out after {seconds} seconds and {attempts} attempts.") @wraps(func) async def async_wrapper(*args, **kwargs) -> Any: if seconds is None: return await func(*args, **kwargs) - try: - with trio.fail_after(seconds): - return await func(*args, **kwargs) - except trio.TooSlowError: - if on_timeout is not None: - if callable(on_timeout): - result = on_timeout() - if isinstance(result, Coroutine): - return await result - return result - return on_timeout + for a in range(attempts): + try: + with trio.fail_after(seconds): + return await func(*args, **kwargs) + except trio.TooSlowError: + if a < attempts -1: + continue + if on_timeout is not None: + if callable(on_timeout): + result = on_timeout() + if isinstance(result, Coroutine): + return await result + return result + return on_timeout - if exception is None: - raise TimeoutError(f"Operation timed out after {seconds} seconds") + if exception is None: + raise TimeoutError(f"Operation timed out after {seconds} seconds and {attempts} attempts.") - if isinstance(exception, BaseException): - raise exception + if isinstance(exception, BaseException): + raise exception - if isinstance(exception, type) and issubclass(exception, BaseException): - raise exception(f"Operation timed out after {seconds} seconds") + if isinstance(exception, type) and issubclass(exception, BaseException): + raise exception(f"Operation timed out after {seconds} seconds and {attempts} attempts.") - raise RuntimeError("Invalid exception type provided") + raise RuntimeError("Invalid exception type provided") if asyncio.iscoroutinefunction(func): return async_wrapper diff --git a/rag/raptor.py b/rag/raptor.py index 20a4334d0..80bf6ca88 100644 --- a/rag/raptor.py +++ b/rag/raptor.py @@ -42,6 +42,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: self._prompt = prompt self._max_token = max_token + @timeout(60) async def _chat(self, system, history, gen_conf): response = get_llm_cache(self._llm_model.llm_name, system, history, gen_conf) if response: diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index 7fdc570c2..094bf6299 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -214,7 +214,7 @@ async def collect(): canceled = False task = TaskService.get_task(msg["id"]) if task: - canceled = TaskService.do_cancel(task["id"]) + canceled = DocumentService.do_cancel(task["doc_id"]) if not task or canceled: state = "is unknown" if not task else "has been cancelled" FAILED_TASKS += 1 @@ -382,7 +382,7 @@ async def build_chunks(task, progress_callback): docs_to_tag = [] for d in docs: - task_canceled = TaskService.do_cancel(task["id"]) + task_canceled = DocumentService.do_cancel(task["doc_id"]) if task_canceled: progress_callback(-1, msg="Task has been canceled.") return @@ -531,7 +531,7 @@ async def do_handle_task(task): progress_callback(-1, msg=error_message) raise Exception(error_message) - task_canceled = TaskService.do_cancel(task_id) + task_canceled = DocumentService.do_cancel(task_doc_id) if task_canceled: progress_callback(-1, msg="Task has been canceled.") return @@ -609,7 +609,7 @@ async def do_handle_task(task): for b in range(0, len(chunks), DOC_BULK_SIZE): doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + DOC_BULK_SIZE], search.index_name(task_tenant_id), task_dataset_id)) - task_canceled = TaskService.do_cancel(task_id) + task_canceled = DocumentService.do_cancel(task_doc_id) if task_canceled: progress_callback(-1, msg="Task has been canceled.") return