Perf: make do_cancel quicker. (#8846)

### What problem does this PR solve?

### Type of change

- [x] Performance Improvement
This commit is contained in:
Kevin Hu
2025-07-15 14:35:00 +08:00
committed by GitHub
parent 5fa6f2f151
commit 24c41d2a61
3 changed files with 36 additions and 29 deletions

View File

@ -606,6 +606,7 @@ TimeoutException = Union[Type[BaseException], BaseException]
OnTimeoutCallback = Union[Callable[..., Any], Coroutine[Any, Any, Any]] OnTimeoutCallback = Union[Callable[..., Any], Coroutine[Any, Any, Any]]
def timeout( def timeout(
seconds: float |int = None, seconds: float |int = None,
attempts: int = 2,
*, *,
exception: Optional[TimeoutException] = None, exception: Optional[TimeoutException] = None,
on_timeout: Optional[OnTimeoutCallback] = None on_timeout: Optional[OnTimeoutCallback] = None
@ -625,41 +626,46 @@ def timeout(
thread.daemon = True thread.daemon = True
thread.start() thread.start()
try: for a in range(attempts):
result = result_queue.get(timeout=seconds) try:
if isinstance(result, Exception): result = result_queue.get(timeout=seconds)
raise result if isinstance(result, Exception):
return result raise result
except queue.Empty: return result
raise TimeoutError(f"Function '{func.__name__}' timed out after {seconds} seconds") except queue.Empty:
pass
raise TimeoutError(f"Function '{func.__name__}' timed out after {seconds} seconds and {attempts} attempts.")
@wraps(func) @wraps(func)
async def async_wrapper(*args, **kwargs) -> Any: async def async_wrapper(*args, **kwargs) -> Any:
if seconds is None: if seconds is None:
return await func(*args, **kwargs) return await func(*args, **kwargs)
try: for a in range(attempts):
with trio.fail_after(seconds): try:
return await func(*args, **kwargs) with trio.fail_after(seconds):
except trio.TooSlowError: return await func(*args, **kwargs)
if on_timeout is not None: except trio.TooSlowError:
if callable(on_timeout): if a < attempts -1:
result = on_timeout() continue
if isinstance(result, Coroutine): if on_timeout is not None:
return await result if callable(on_timeout):
return result result = on_timeout()
return on_timeout if isinstance(result, Coroutine):
return await result
return result
return on_timeout
if exception is None: if exception is None:
raise TimeoutError(f"Operation timed out after {seconds} seconds") raise TimeoutError(f"Operation timed out after {seconds} seconds and {attempts} attempts.")
if isinstance(exception, BaseException): if isinstance(exception, BaseException):
raise exception raise exception
if isinstance(exception, type) and issubclass(exception, BaseException): if isinstance(exception, type) and issubclass(exception, BaseException):
raise exception(f"Operation timed out after {seconds} seconds") raise exception(f"Operation timed out after {seconds} seconds and {attempts} attempts.")
raise RuntimeError("Invalid exception type provided") raise RuntimeError("Invalid exception type provided")
if asyncio.iscoroutinefunction(func): if asyncio.iscoroutinefunction(func):
return async_wrapper return async_wrapper

View File

@ -42,6 +42,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
self._prompt = prompt self._prompt = prompt
self._max_token = max_token self._max_token = max_token
@timeout(60)
async def _chat(self, system, history, gen_conf): async def _chat(self, system, history, gen_conf):
response = get_llm_cache(self._llm_model.llm_name, system, history, gen_conf) response = get_llm_cache(self._llm_model.llm_name, system, history, gen_conf)
if response: if response:

View File

@ -214,7 +214,7 @@ async def collect():
canceled = False canceled = False
task = TaskService.get_task(msg["id"]) task = TaskService.get_task(msg["id"])
if task: if task:
canceled = TaskService.do_cancel(task["id"]) canceled = DocumentService.do_cancel(task["doc_id"])
if not task or canceled: if not task or canceled:
state = "is unknown" if not task else "has been cancelled" state = "is unknown" if not task else "has been cancelled"
FAILED_TASKS += 1 FAILED_TASKS += 1
@ -382,7 +382,7 @@ async def build_chunks(task, progress_callback):
docs_to_tag = [] docs_to_tag = []
for d in docs: for d in docs:
task_canceled = TaskService.do_cancel(task["id"]) task_canceled = DocumentService.do_cancel(task["doc_id"])
if task_canceled: if task_canceled:
progress_callback(-1, msg="Task has been canceled.") progress_callback(-1, msg="Task has been canceled.")
return return
@ -531,7 +531,7 @@ async def do_handle_task(task):
progress_callback(-1, msg=error_message) progress_callback(-1, msg=error_message)
raise Exception(error_message) raise Exception(error_message)
task_canceled = TaskService.do_cancel(task_id) task_canceled = DocumentService.do_cancel(task_doc_id)
if task_canceled: if task_canceled:
progress_callback(-1, msg="Task has been canceled.") progress_callback(-1, msg="Task has been canceled.")
return return
@ -609,7 +609,7 @@ async def do_handle_task(task):
for b in range(0, len(chunks), DOC_BULK_SIZE): for b in range(0, len(chunks), DOC_BULK_SIZE):
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + DOC_BULK_SIZE], search.index_name(task_tenant_id), task_dataset_id)) doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + DOC_BULK_SIZE], search.index_name(task_tenant_id), task_dataset_id))
task_canceled = TaskService.do_cancel(task_id) task_canceled = DocumentService.do_cancel(task_doc_id)
if task_canceled: if task_canceled:
progress_callback(-1, msg="Task has been canceled.") progress_callback(-1, msg="Task has been canceled.")
return return