diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index 0377c72c3..b24e7c141 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -46,7 +46,7 @@ import faulthandler import numpy as np from peewee import DoesNotExist -from api.db import LLMType, ParserType, TaskStatus +from api.db import LLMType, ParserType from api.db.services.document_service import DocumentService from api.db.services.llm_service import LLMBundle from api.db.services.task_service import TaskService @@ -213,8 +213,7 @@ async def collect(): canceled = False task = TaskService.get_task(msg["id"]) if task: - _, doc = DocumentService.get_by_id(task["doc_id"]) - canceled = doc.run == TaskStatus.CANCEL.value or doc.progress < 0 + canceled = TaskService.do_cancel(task["id"]) if not task or canceled: state = "is unknown" if not task else "has been cancelled" FAILED_TASKS += 1