diff --git a/api/db/services/pipeline_operation_log_service.py b/api/db/services/pipeline_operation_log_service.py index df95d2517..74bb83554 100644 --- a/api/db/services/pipeline_operation_log_service.py +++ b/api/db/services/pipeline_operation_log_service.py @@ -15,6 +15,7 @@ # import json import logging +import os from datetime import datetime, timedelta from peewee import fn @@ -81,9 +82,17 @@ class PipelineOperationLogService(CommonService): cls.model.update_date, ] + @classmethod + def save(cls, **kwargs): + """ + wrap this function in a transaction + """ + sample_obj = cls.model(**kwargs).save(force_insert=True) + return sample_obj + @classmethod @DB.connection_context() - def create(cls, document_id, pipeline_id, task_type, fake_document_ids=[], dsl:str="{}"): + def create(cls, document_id, pipeline_id, task_type, fake_document_ids=[], dsl: str = "{}"): referred_document_id = document_id if referred_document_id == GRAPH_RAPTOR_FAKE_DOC_ID and fake_document_ids: @@ -163,7 +172,19 @@ class PipelineOperationLogService(CommonService): log["create_date"] = datetime_format(datetime.now()) log["update_time"] = current_timestamp() log["update_date"] = datetime_format(datetime.now()) - obj = cls.save(**log) + + with DB.atomic(): + obj = cls.save(**log) + + limit = int(os.getenv("PIPELINE_OPERATION_LOG_LIMIT", 1)) + total = cls.model.select().where(cls.model.kb_id == document.kb_id).count() + + if total > limit: + keep_ids = [m.id for m in cls.model.select(cls.model.id).where(cls.model.kb_id == document.kb_id).order_by(cls.model.create_time.desc()).limit(limit)] + + deleted = cls.model.delete().where(cls.model.kb_id == document.kb_id, cls.model.id.not_in(keep_ids)).execute() + logging.info(f"[PipelineOperationLogService] Cleaned {deleted} old logs, kept latest {limit} for {document.kb_id}") + return obj @classmethod