From d907e798931301b9aaad046152dda1cf259afd8e Mon Sep 17 00:00:00 2001 From: Kevin Hu Date: Thu, 25 Sep 2025 13:52:50 +0800 Subject: [PATCH] Refa: fake doc ID. (#10276) ### What problem does this PR solve? #10273 ### Type of change - [x] Refactoring --- api/apps/kb_app.py | 4 +- api/db/services/document_service.py | 4 +- .../pipeline_operation_log_service.py | 9 +++-- api/db/services/task_service.py | 1 + api/utils/api_utils.py | 4 +- rag/svr/task_executor.py | 37 ++++++++++++++++--- 6 files changed, 45 insertions(+), 14 deletions(-) diff --git a/api/apps/kb_app.py b/api/apps/kb_app.py index f644c9001..128eb9cfe 100644 --- a/api/apps/kb_app.py +++ b/api/apps/kb_app.py @@ -24,7 +24,7 @@ from api.db.services.document_service import DocumentService, queue_raptor_o_gra from api.db.services.file2document_service import File2DocumentService from api.db.services.file_service import FileService from api.db.services.pipeline_operation_log_service import PipelineOperationLogService -from api.db.services.task_service import TaskService +from api.db.services.task_service import TaskService, GRAPH_RAPTOR_FAKE_DOC_ID from api.db.services.user_service import TenantService, UserTenantService from api.utils.api_utils import get_error_data_result, server_error_response, get_data_error_result, validate_request, not_allowed_parameters from api.utils import get_uuid @@ -558,7 +558,7 @@ def run_graphrag(): if not sample_document: sample_document = document.to_dict() - task_id = queue_raptor_o_graphrag_tasks(doc=sample_document, ty="graphrag", priority=0, fake_doc_id="x", doc_ids=list(document_ids)) + task_id = queue_raptor_o_graphrag_tasks(doc=sample_document, ty="graphrag", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids)) if not KnowledgebaseService.update_by_id(kb.id, {"graphrag_task_id": task_id}): logging.warning(f"Cannot save graphrag_task_id for kb {kb_id}") diff --git a/api/db/services/document_service.py b/api/db/services/document_service.py index a8aadd6ac..48828cb6d 100644 --- a/api/db/services/document_service.py +++ b/api/db/services/document_service.py @@ -121,7 +121,7 @@ class DocumentService(CommonService): orderby, desc, keywords, run_status, types, suffix): fields = cls.get_cls_model_fields() if keywords: - docs = cls.model.select(*[*fields, UserCanvas.title])\ + docs = cls.model.select(*[*fields, UserCanvas.title.alias("pipeline_name")])\ .join(File2Document, on=(File2Document.document_id == cls.model.id))\ .join(File, on=(File.id == File2Document.file_id))\ .join(UserCanvas, on=(cls.model.pipeline_id == UserCanvas.id), join_type=JOIN.LEFT_OUTER)\ @@ -130,7 +130,7 @@ class DocumentService(CommonService): (fn.LOWER(cls.model.name).contains(keywords.lower())) ) else: - docs = cls.model.select(*[*fields, UserCanvas.title])\ + docs = cls.model.select(*[*fields, UserCanvas.title.alias("pipeline_name")])\ .join(File2Document, on=(File2Document.document_id == cls.model.id))\ .join(UserCanvas, on=(cls.model.pipeline_id == UserCanvas.id), join_type=JOIN.LEFT_OUTER)\ .join(File, on=(File.id == File2Document.file_id))\ diff --git a/api/db/services/pipeline_operation_log_service.py b/api/db/services/pipeline_operation_log_service.py index 1d71e41cd..349538c0c 100644 --- a/api/db/services/pipeline_operation_log_service.py +++ b/api/db/services/pipeline_operation_log_service.py @@ -25,6 +25,7 @@ from api.db.services.canvas_service import UserCanvasService from api.db.services.common_service import CommonService from api.db.services.document_service import DocumentService from api.db.services.knowledgebase_service import KnowledgebaseService +from api.db.services.task_service import GRAPH_RAPTOR_FAKE_DOC_ID from api.utils import current_timestamp, datetime_format, get_uuid @@ -88,7 +89,7 @@ class PipelineOperationLogService(CommonService): dsl = "" referred_document_id = document_id - if referred_document_id == "x" and fake_document_ids: + if referred_document_id == GRAPH_RAPTOR_FAKE_DOC_ID and fake_document_ids: referred_document_id = fake_document_ids[0] ok, document = DocumentService.get_by_id(referred_document_id) if not ok: @@ -128,7 +129,7 @@ class PipelineOperationLogService(CommonService): log = dict( id=get_uuid(), - document_id=document_id, # "x" or real document_id + document_id=document_id, # GRAPH_RAPTOR_FAKE_DOC_ID or real document_id tenant_id=tenant_id, kb_id=document.kb_id, pipeline_id=pipeline_id, @@ -168,7 +169,7 @@ class PipelineOperationLogService(CommonService): else: logs = cls.model.select(*fields).where(cls.model.kb_id == kb_id) - logs = logs.where(cls.model.document_id != "x") + logs = logs.where(cls.model.document_id != GRAPH_RAPTOR_FAKE_DOC_ID) if operation_status: logs = logs.where(cls.model.operation_status.in_(operation_status)) @@ -206,7 +207,7 @@ class PipelineOperationLogService(CommonService): @DB.connection_context() def get_dataset_logs_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, operation_status): fields = cls.get_dataset_logs_fields() - logs = cls.model.select(*fields).where((cls.model.kb_id == kb_id), (cls.model.document_id == "x")) + logs = cls.model.select(*fields).where((cls.model.kb_id == kb_id), (cls.model.document_id == GRAPH_RAPTOR_FAKE_DOC_ID)) if operation_status: logs = logs.where(cls.model.operation_status.in_(operation_status)) diff --git a/api/db/services/task_service.py b/api/db/services/task_service.py index 324835b11..38fdeec68 100644 --- a/api/db/services/task_service.py +++ b/api/db/services/task_service.py @@ -36,6 +36,7 @@ from api import settings from rag.nlp import search CANVAS_DEBUG_DOC_ID = "dataflow_x" +GRAPH_RAPTOR_FAKE_DOC_ID = "graph_raptor_x" def trim_header_by_lines(text: str, max_length) -> str: # Trim header text to maximum length while preserving line breaks diff --git a/api/utils/api_utils.py b/api/utils/api_utils.py index c66347d3e..35f9d3eca 100644 --- a/api/utils/api_utils.py +++ b/api/utils/api_utils.py @@ -679,7 +679,9 @@ TimeoutException = Union[Type[BaseException], BaseException] OnTimeoutCallback = Union[Callable[..., Any], Coroutine[Any, Any, Any]] -def timeout(seconds: float | int = None, attempts: int = 2, *, exception: Optional[TimeoutException] = None, on_timeout: Optional[OnTimeoutCallback] = None): +def timeout(seconds: float | int | str = None, attempts: int = 2, *, exception: Optional[TimeoutException] = None, on_timeout: Optional[OnTimeoutCallback] = None): + if isinstance(seconds, str): + seconds = float(seconds) def decorator(func): @wraps(func) def wrapper(*args, **kwargs): diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index 3495cd4c9..bd26cf8f0 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -50,7 +50,7 @@ from peewee import DoesNotExist from api.db import LLMType, ParserType, PipelineTaskType from api.db.services.document_service import DocumentService from api.db.services.llm_service import LLMBundle -from api.db.services.task_service import TaskService, has_canceled, CANVAS_DEBUG_DOC_ID +from api.db.services.task_service import TaskService, has_canceled, CANVAS_DEBUG_DOC_ID, GRAPH_RAPTOR_FAKE_DOC_ID from api.db.services.file2document_service import File2DocumentService from api import settings from api.versions import get_ragflow_version @@ -222,7 +222,7 @@ async def collect(): return None, None canceled = False - if msg.get("doc_id", "") == "x": + if msg.get("doc_id", "") == GRAPH_RAPTOR_FAKE_DOC_ID: task = msg if task["task_type"] == "graphrag" and msg.get("doc_ids", []): print(f"hack {msg['doc_ids']=}=",flush=True) @@ -537,8 +537,25 @@ async def run_dataflow(task: dict): v = vects[i].tolist() ck["q_%d_vec" % len(v)] = v + metadata = {} + def dict_update(meta): + nonlocal metadata + if not meta or not isinstance(meta, dict): + return + for k,v in meta.items(): + if k not in metadata: + metadata[k] = v + continue + if isinstance(metadata[k], list): + if isinstance(v, list): + metadata[k].extend(v) + else: + metadata[k].append(v) + else: + metadata[k] = v + for ck in chunks: - ck["doc_id"] = task["doc_id"] + ck["doc_id"] = doc_id ck["kb_id"] = [str(task["kb_id"])] ck["docnm_kwd"] = task["name"] ck["create_time"] = str(datetime.now()).replace("T", " ")[:19] @@ -550,8 +567,19 @@ async def run_dataflow(task: dict): del ck["keywords"] if "summary" in ck: del ck["summary"] + if "metadata" in ck: + dict_update(ck["metadata"]) + del ck["metadata"] del ck["text"] + if metadata: + e, doc = DocumentService.get_by_id(doc_id) + if e: + if isinstance(doc.meta_fields, str): + doc.meta_fields = json.loads(doc.meta_fields) + dict_update(doc.meta_fields) + DocumentService.update_by_id(doc_id, {"meta_fields": metadata}) + start_ts = timer() set_progress(task_id, prog=0.82, msg="Start to index...") e = await insert_es(task_id, task["tenant_id"], task["kb_id"], chunks, partial(set_progress, task_id, 0, 100000000)) @@ -562,8 +590,7 @@ async def run_dataflow(task: dict): time_cost = timer() - start_ts task_time_cost = timer() - task_start_ts set_progress(task_id, prog=1., msg="Indexing done ({:.2f}s). Task done ({:.2f}s)".format(time_cost, task_time_cost)) - logging.info( - "[Done], chunks({}), token({}), elapsed:{:.2f}".format(len(chunks), embedding_token_consumption, task_time_cost)) + logging.info("[Done], chunks({}), token({}), elapsed:{:.2f}".format(len(chunks), embedding_token_consumption, task_time_cost)) PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE)