diff --git a/api/db/__init__.py b/api/db/__init__.py index c13c85de6..0ebd9f56f 100644 --- a/api/db/__init__.py +++ b/api/db/__init__.py @@ -70,4 +70,7 @@ class PipelineTaskType(StrEnum): VALID_PIPELINE_TASK_TYPES = {PipelineTaskType.PARSE, PipelineTaskType.DOWNLOAD, PipelineTaskType.RAPTOR, PipelineTaskType.GRAPH_RAG, PipelineTaskType.MINDMAP} +PIPELINE_SPECIAL_PROGRESS_FREEZE_TASK_TYPES = {PipelineTaskType.RAPTOR.lower(), PipelineTaskType.GRAPH_RAG.lower(), PipelineTaskType.MINDMAP.lower()} + + KNOWLEDGEBASE_FOLDER_NAME=".knowledgebase" diff --git a/api/db/services/document_service.py b/api/db/services/document_service.py index 198028210..a64ae16de 100644 --- a/api/db/services/document_service.py +++ b/api/db/services/document_service.py @@ -27,7 +27,7 @@ import xxhash from peewee import fn, Case, JOIN from api.constants import IMG_BASE64_PREFIX, FILE_NAME_LEN_LIMIT -from api.db import FileType, UserTenantRole, CanvasCategory +from api.db import PIPELINE_SPECIAL_PROGRESS_FREEZE_TASK_TYPES, FileType, UserTenantRole, CanvasCategory from api.db.db_models import DB, Document, Knowledgebase, Task, Tenant, UserTenant, File2Document, File, UserCanvas, \ User from api.db.db_utils import bulk_insert_into_db @@ -372,12 +372,16 @@ class DocumentService(CommonService): def get_unfinished_docs(cls): fields = [cls.model.id, cls.model.process_begin_at, cls.model.parser_config, cls.model.progress_msg, cls.model.run, cls.model.parser_id] + unfinished_task_query = Task.select(Task.doc_id).where( + (Task.progress >= 0) & (Task.progress < 1) + ) + docs = cls.model.select(*fields) \ .where( cls.model.status == StatusEnum.VALID.value, ~(cls.model.type == FileType.VIRTUAL.value), - cls.model.progress < 1, - cls.model.progress > 0) + (((cls.model.progress < 1) & (cls.model.progress > 0)) | + (cls.model.id.in_(unfinished_task_query)))) # including unfinished tasks like GraphRAG, RAPTOR and Mindmap return list(docs.dicts()) @classmethod @@ -619,13 +623,17 @@ class DocumentService(CommonService): @classmethod @DB.connection_context() - def begin2parse(cls, docid): - cls.update_by_id( - docid, {"progress": random.random() * 1 / 100., - "progress_msg": "Task is queued...", - "process_begin_at": get_format_time(), - "run": TaskStatus.RUNNING.value - }) + def begin2parse(cls, doc_id, keep_progress=False): + info = { + "progress_msg": "Task is queued...", + "process_begin_at": get_format_time(), + } + if not keep_progress: + info["progress"] = random.random() * 1 / 100. + info["run"] = TaskStatus.RUNNING.value + # keep the doc in DONE state when keep_progress=True for GraphRAG, RAPTOR and Mindmap tasks + + cls.update_by_id(doc_id, info) @classmethod @DB.connection_context() @@ -684,8 +692,13 @@ class DocumentService(CommonService): bad = 0 e, doc = DocumentService.get_by_id(d["id"]) status = doc.run # TaskStatus.RUNNING.value + doc_progress = doc.progress if doc and doc.progress else 0.0 + special_task_running = False priority = 0 for t in tsks: + task_type = (t.task_type or "").lower() + if task_type in PIPELINE_SPECIAL_PROGRESS_FREEZE_TASK_TYPES: + special_task_running = True if 0 <= t.progress < 1: finished = False if t.progress == -1: @@ -702,13 +715,15 @@ class DocumentService(CommonService): prg = 1 status = TaskStatus.DONE.value + # only for special task and parsed docs and unfinised + freeze_progress = special_task_running and doc_progress >= 1 and not finished msg = "\n".join(sorted(msg)) info = { "process_duration": datetime.timestamp( datetime.now()) - d["process_begin_at"].timestamp(), "run": status} - if prg != 0: + if prg != 0 and not freeze_progress: info["progress"] = prg if msg: info["progress_msg"] = msg @@ -858,7 +873,7 @@ def queue_raptor_o_graphrag_tasks(sample_doc_id, ty, priority, fake_doc_id="", d task["doc_id"] = fake_doc_id task["doc_ids"] = doc_ids - DocumentService.begin2parse(sample_doc_id["id"]) + DocumentService.begin2parse(sample_doc_id["id"], keep_progress=True) assert REDIS_CONN.queue_product(settings.get_svr_queue_name(priority), message=task), "Can't access Redis. Please check the Redis' status." return task["id"] @@ -1012,4 +1027,3 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id): doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0) return [d["id"] for d, _ in files] - diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index 9b309ee68..af8dfc186 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -24,6 +24,7 @@ import time import json_repair +from api.db import PIPELINE_SPECIAL_PROGRESS_FREEZE_TASK_TYPES from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.pipeline_operation_log_service import PipelineOperationLogService from common.connection_utils import timeout @@ -192,7 +193,7 @@ async def collect(): canceled = False if msg.get("doc_id", "") in [GRAPH_RAPTOR_FAKE_DOC_ID, CANVAS_DEBUG_DOC_ID]: task = msg - if task["task_type"] in ["graphrag", "raptor", "mindmap"]: + if task["task_type"] in PIPELINE_SPECIAL_PROGRESS_FREEZE_TASK_TYPES: task = TaskService.get_task(msg["id"], msg["doc_ids"]) if task: task["doc_id"] = msg["doc_id"]