diff --git a/api/db/services/task_service.py b/api/db/services/task_service.py index 4693a6684..5fd0eefc3 100644 --- a/api/db/services/task_service.py +++ b/api/db/services/task_service.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import logging import os import random import xxhash @@ -256,36 +257,55 @@ class TaskService(CommonService): @DB.connection_context() def update_progress(cls, id, info): """Update the progress information for a task. - + This method updates both the progress message and completion percentage of a task. It handles platform-specific behavior (macOS vs others) and uses database locking when necessary to ensure thread safety. - + + Update Rules: + - progress_msg: Always appends the new message to the existing one, and trims the result to max 3000 lines. + - progress: Only updates if the current progress is not -1 AND + (the new progress is -1 OR greater than the existing progress), + to avoid overwriting valid progress with invalid or regressive values. + Args: id (str): The unique identifier of the task to update. info (dict): Dictionary containing progress information with keys: - progress_msg (str, optional): Progress message to append - progress (float, optional): Progress percentage (0.0 to 1.0) """ + task = cls.model.get_by_id(id) + if not task: + logging.warning("Update_progress error: task not found") + return + if os.environ.get("MACOS"): if info["progress_msg"]: - task = cls.model.get_by_id(id) progress_msg = trim_header_by_lines(task.progress_msg + "\n" + info["progress_msg"], 3000) cls.model.update(progress_msg=progress_msg).where(cls.model.id == id).execute() if "progress" in info: - cls.model.update(progress=info["progress"]).where( - cls.model.id == id + prog = info["progress"] + cls.model.update(progress=prog).where( + (cls.model.id == id) & + ( + (cls.model.progress != -1) & + ((prog == -1) | (prog > cls.model.progress)) + ) ).execute() return with DB.lock("update_progress", -1): if info["progress_msg"]: - task = cls.model.get_by_id(id) progress_msg = trim_header_by_lines(task.progress_msg + "\n" + info["progress_msg"], 3000) cls.model.update(progress_msg=progress_msg).where(cls.model.id == id).execute() if "progress" in info: - cls.model.update(progress=info["progress"]).where( - cls.model.id == id + prog = info["progress"] + cls.model.update(progress=prog).where( + (cls.model.id == id) & + ( + (cls.model.progress != -1) & + ((prog == -1) | (prog > cls.model.progress)) + ) ).execute()