diff --git a/agent/component/llm.py b/agent/component/llm.py index b13c4a87e..9db894305 100644 --- a/agent/component/llm.py +++ b/agent/component/llm.py @@ -113,6 +113,15 @@ class LLM(ComponentBase): def add2system_prompt(self, txt): self._param.sys_prompt += txt + def _sys_prompt_and_msg(self, msg, args): + for p in self._param.prompts: + if msg and msg[-1]["role"] == p["role"]: + continue + p = deepcopy(p) + p["content"] = self.string_format(p["content"], args) + msg.append(p) + return msg, self.string_format(self._param.sys_prompt, args) + def _prepare_prompt_variables(self): if self._param.visual_files_var: self.imgs = self._canvas.get_variable_value(self._param.visual_files_var) @@ -128,7 +137,6 @@ class LLM(ComponentBase): args = {} vars = self.get_input_elements() if not self._param.debug_inputs else self._param.debug_inputs - sys_prompt = self._param.sys_prompt for k, o in vars.items(): args[k] = o["value"] if not isinstance(args[k], str): @@ -138,16 +146,8 @@ class LLM(ComponentBase): args[k] = str(args[k]) self.set_input_value(k, args[k]) - msg = self._canvas.get_history(self._param.message_history_window_size)[:-1] - for p in self._param.prompts: - if msg and msg[-1]["role"] == p["role"]: - continue - msg.append(deepcopy(p)) - - sys_prompt = self.string_format(sys_prompt, args) + msg, sys_prompt = self._sys_prompt_and_msg(self._canvas.get_history(self._param.message_history_window_size)[:-1], args) user_defined_prompt, sys_prompt = self._extract_prompts(sys_prompt) - for m in msg: - m["content"] = self.string_format(m["content"], args) if self._param.cite and self._canvas.get_reference()["chunks"]: sys_prompt += citation_prompt(user_defined_prompt) diff --git a/api/apps/canvas_app.py b/api/apps/canvas_app.py index 3d7b13d8d..a205affa2 100644 --- a/api/apps/canvas_app.py +++ b/api/apps/canvas_app.py @@ -28,6 +28,7 @@ from api.db import CanvasCategory, FileType from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService, API4ConversationService from api.db.services.document_service import DocumentService from api.db.services.file_service import FileService +from api.db.services.pipeline_operation_log_service import PipelineOperationLogService from api.db.services.task_service import queue_dataflow from api.db.services.user_service import TenantService from api.db.services.user_canvas_version import UserCanvasVersionService @@ -174,6 +175,25 @@ def run(): return resp +@manager.route('/rerun', methods=['POST']) # noqa: F821 +@validate_request("id", "dsl", "component_id") +@login_required +def rerun(): + req = request.json + doc = PipelineOperationLogService.get_documents_info(req["id"]) + if not doc: + return get_data_error_result(message="Document not found.") + doc = doc[0] + if 0 < doc["progress"] < 1: + return get_data_error_result(message=f"`{doc['name']}` is processing...") + + dsl = req["dsl"] + dsl["path"] = [req["component_id"]] + PipelineOperationLogService.update_by_id(req["id"], {"dsl": dsl}) + queue_dataflow(tenant_id=current_user.id, flow_id=req["id"], task_id=get_uuid(), doc_id=doc["id"], priority=0, rerun=True) + return get_json_result(data=True) + + @manager.route('/cancel/', methods=['PUT']) # noqa: F821 @login_required def cancel(task_id): diff --git a/api/db/services/document_service.py b/api/db/services/document_service.py index 56d08b230..a8aadd6ac 100644 --- a/api/db/services/document_service.py +++ b/api/db/services/document_service.py @@ -121,12 +121,20 @@ class DocumentService(CommonService): orderby, desc, keywords, run_status, types, suffix): fields = cls.get_cls_model_fields() if keywords: - docs = cls.model.select(*fields).join(File2Document, on=(File2Document.document_id == cls.model.id)).join(File, on=(File.id == File2Document.file_id)).where( - (cls.model.kb_id == kb_id), - (fn.LOWER(cls.model.name).contains(keywords.lower())) - ) + docs = cls.model.select(*[*fields, UserCanvas.title])\ + .join(File2Document, on=(File2Document.document_id == cls.model.id))\ + .join(File, on=(File.id == File2Document.file_id))\ + .join(UserCanvas, on=(cls.model.pipeline_id == UserCanvas.id), join_type=JOIN.LEFT_OUTER)\ + .where( + (cls.model.kb_id == kb_id), + (fn.LOWER(cls.model.name).contains(keywords.lower())) + ) else: - docs = cls.model.select(*fields).join(File2Document, on=(File2Document.document_id == cls.model.id)).join(File, on=(File.id == File2Document.file_id)).where(cls.model.kb_id == kb_id) + docs = cls.model.select(*[*fields, UserCanvas.title])\ + .join(File2Document, on=(File2Document.document_id == cls.model.id))\ + .join(UserCanvas, on=(cls.model.pipeline_id == UserCanvas.id), join_type=JOIN.LEFT_OUTER)\ + .join(File, on=(File.id == File2Document.file_id))\ + .where(cls.model.kb_id == kb_id) if run_status: docs = docs.where(cls.model.run.in_(run_status)) diff --git a/api/db/services/knowledgebase_service.py b/api/db/services/knowledgebase_service.py index 454bdbdc7..bb4bbdde3 100644 --- a/api/db/services/knowledgebase_service.py +++ b/api/db/services/knowledgebase_service.py @@ -225,6 +225,7 @@ class KnowledgebaseService(CommonService): cls.model.token_num, cls.model.chunk_num, cls.model.parser_id, + cls.model.pipeline_id, cls.model.parser_config, cls.model.pagerank, cls.model.create_time, diff --git a/api/db/services/pipeline_operation_log_service.py b/api/db/services/pipeline_operation_log_service.py index d316cb46b..1d71e41cd 100644 --- a/api/db/services/pipeline_operation_log_service.py +++ b/api/db/services/pipeline_operation_log_service.py @@ -14,12 +14,13 @@ # limitations under the License. # import json +import logging from datetime import datetime from peewee import fn from api.db import VALID_PIPELINE_TASK_TYPES -from api.db.db_models import DB, PipelineOperationLog +from api.db.db_models import DB, PipelineOperationLog, Document from api.db.services.canvas_service import UserCanvasService from api.db.services.common_service import CommonService from api.db.services.document_service import DocumentService @@ -84,22 +85,20 @@ class PipelineOperationLogService(CommonService): def create(cls, document_id, pipeline_id, task_type, fake_document_ids=[]): from rag.flow.pipeline import Pipeline - tenant_id = "" - title = "" - avatar = "" dsl = "" - operation_status = "" referred_document_id = document_id if referred_document_id == "x" and fake_document_ids: referred_document_id = fake_document_ids[0] ok, document = DocumentService.get_by_id(referred_document_id) if not ok: - raise RuntimeError(f"Document for referred_document_id {referred_document_id} not found") + logging.warning(f"Document for referred_document_id {referred_document_id} not found") + return DocumentService.update_progress_immediately([document.to_dict()]) ok, document = DocumentService.get_by_id(referred_document_id) if not ok: - raise RuntimeError(f"Document for referred_document_id {referred_document_id} not found") + logging.warning(f"Document for referred_document_id {referred_document_id} not found") + return if document.progress not in [1, -1]: return operation_status = document.run @@ -189,6 +188,20 @@ class PipelineOperationLogService(CommonService): return list(logs.dicts()), count + @classmethod + @DB.connection_context() + def get_documents_info(cls, id): + fields = [ + Document.id, + Document.name, + Document.progress + ] + return cls.model.select(*fields).join(Document, on=(cls.model.document_id == Document.id)).where( + cls.model.id == id, + Document.progress > 0, + Document.progress < 1 + ).dicts() + @classmethod @DB.connection_context() def get_dataset_logs_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, operation_status): @@ -208,3 +221,4 @@ class PipelineOperationLogService(CommonService): logs = logs.paginate(page_number, items_per_page) return list(logs.dicts()), count + diff --git a/api/db/services/task_service.py b/api/db/services/task_service.py index 215e5c724..324835b11 100644 --- a/api/db/services/task_service.py +++ b/api/db/services/task_service.py @@ -35,6 +35,7 @@ from rag.utils.redis_conn import REDIS_CONN from api import settings from rag.nlp import search +CANVAS_DEBUG_DOC_ID = "dataflow_x" def trim_header_by_lines(text: str, max_length) -> str: # Trim header text to maximum length while preserving line breaks @@ -85,7 +86,7 @@ class TaskService(CommonService): Returns None if task is not found or has exceeded retry limit. """ doc_id = cls.model.doc_id - if doc_id == "x" and doc_ids: + if doc_id == CANVAS_DEBUG_DOC_ID and doc_ids: doc_id = doc_ids[0] fields = [ @@ -476,14 +477,14 @@ def has_canceled(task_id): return False -def queue_dataflow(tenant_id:str, flow_id:str, task_id:str, doc_id:str="x", file:dict=None, priority: int=0) -> tuple[bool, str]: +def queue_dataflow(tenant_id:str, flow_id:str, task_id:str, doc_id:str=CANVAS_DEBUG_DOC_ID, file:dict=None, priority: int=0, rerun:bool=False) -> tuple[bool, str]: task = dict( id=task_id, doc_id=doc_id, from_page=0, to_page=100000000, - task_type="dataflow", + task_type="dataflow" if not rerun else "dataflow_rerun", priority=priority, ) diff --git a/api/utils/base64_image.py b/api/utils/base64_image.py index aa24ac63a..25afcf332 100644 --- a/api/utils/base64_image.py +++ b/api/utils/base64_image.py @@ -1,4 +1,5 @@ import base64 +import logging from functools import partial from io import BytesIO @@ -8,7 +9,7 @@ test_image_base64 = "iVBORw0KGgoAAAANSUhEUgAAAGQAAABkCAIAAAD/gAIDAAAA6ElEQVR4nO3 test_image = base64.b64decode(test_image_base64) -async def image2id(d: dict, storage_put_func: partial, objname:str, bucket:str="IMAGETEMPS"): +async def image2id(d: dict, storage_put_func: partial, objname:str, bucket:str="imagetemps"): import logging from io import BytesIO import trio @@ -46,7 +47,10 @@ def id2image(image_id:str|None, storage_get_func: partial): if len(arr) != 2: return bkt, nm = image_id.split("-") - blob = storage_get_func(bucket=bkt, filename=nm) - if not blob: - return - return Image.open(BytesIO(blob)) \ No newline at end of file + try: + blob = storage_get_func(bucket=bkt, filename=nm) + if not blob: + return + return Image.open(BytesIO(blob)) + except Exception as e: + logging.exception(e) diff --git a/rag/flow/base.py b/rag/flow/base.py index 0809062d7..fae5f1ed1 100644 --- a/rag/flow/base.py +++ b/rag/flow/base.py @@ -13,13 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import logging import os import time from functools import partial from typing import Any - import trio - from agent.component.base import ComponentBase, ComponentParamBase from api.utils.api_utils import timeout @@ -43,17 +42,17 @@ class ProcessBase(ComponentBase): self.set_output("_created_time", time.perf_counter()) for k, v in kwargs.items(): self.set_output(k, v) - #try: - with trio.fail_after(self._param.timeout): - await self._invoke(**kwargs) - self.callback(1, "Done") - #except Exception as e: - # if self.get_exception_default_value(): - # self.set_exception_default_value() - # else: - # self.set_output("_ERROR", str(e)) - # logging.exception(e) - # self.callback(-1, str(e)) + try: + with trio.fail_after(self._param.timeout): + await self._invoke(**kwargs) + self.callback(1, "Done") + except Exception as e: + if self.get_exception_default_value(): + self.set_exception_default_value() + else: + self.set_output("_ERROR", str(e)) + logging.exception(e) + self.callback(-1, str(e)) self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time")) return self.output() diff --git a/rag/flow/extractor/__init__.py b/rag/flow/extractor/__init__.py new file mode 100644 index 000000000..b4663378e --- /dev/null +++ b/rag/flow/extractor/__init__.py @@ -0,0 +1,15 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/rag/flow/extractor/extractor.py b/rag/flow/extractor/extractor.py new file mode 100644 index 000000000..242da56a8 --- /dev/null +++ b/rag/flow/extractor/extractor.py @@ -0,0 +1,59 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import random +from agent.component.llm import LLMParam, LLM + + +class ExtractorParam(LLMParam): + def __init__(self): + super().__init__() + self.field_name = "" + + def check(self): + super().check() + self.check_empty(self.field_name, "Result Destination") + + +class Extractor(LLM): + component_name = "Extractor" + + async def _invoke(self, **kwargs): + self.callback(random.randint(1, 5) / 100.0, "Start to generate.") + inputs = self.get_input_elements() + chunks = [] + chunks_key = "" + args = {} + for k, v in inputs.items(): + args[k] = v["value"] + if isinstance(args[k], list): + chunks = args[k] + chunks_key = k + + if chunks: + prog = 0 + for i, ck in enumerate(chunks): + args[chunks_key] = ck["text"] + msg, sys_prompt = self._sys_prompt_and_msg([], args) + msg.insert(0, {"role": "system", "content": sys_prompt}) + ck[self._param.field_name] = self._generate(msg) + prog += 1./len(chunks) + self.callback(prog, f"{i+1} / {len(chunks)}") + self.set_output("chunks", chunks) + else: + msg, sys_prompt = self._sys_prompt_and_msg([], args) + msg.insert(0, {"role": "system", "content": sys_prompt}) + self.set_output("chunks", [{self._param.field_name: self._generate(msg)}]) + + diff --git a/rag/flow/extractor/schema.py b/rag/flow/extractor/schema.py new file mode 100644 index 000000000..214e500e2 --- /dev/null +++ b/rag/flow/extractor/schema.py @@ -0,0 +1,38 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Literal + +from pydantic import BaseModel, ConfigDict, Field + + +class ExtractorFromUpstream(BaseModel): + created_time: float | None = Field(default=None, alias="_created_time") + elapsed_time: float | None = Field(default=None, alias="_elapsed_time") + + name: str + file: dict | None = Field(default=None) + chunks: list[dict[str, Any]] | None = Field(default=None) + + output_format: Literal["json", "markdown", "text", "html"] | None = Field(default=None) + + json_result: list[dict[str, Any]] | None = Field(default=None, alias="json") + markdown_result: str | None = Field(default=None, alias="markdown") + text_result: str | None = Field(default=None, alias="text") + html_result: list[str] | None = Field(default=None, alias="html") + + model_config = ConfigDict(populate_by_name=True, extra="forbid") + + # def to_dict(self, *, exclude_none: bool = True) -> dict: + # return self.model_dump(by_alias=True, exclude_none=exclude_none) diff --git a/rag/flow/pipeline.py b/rag/flow/pipeline.py index 4f2211df0..8a30dce64 100644 --- a/rag/flow/pipeline.py +++ b/rag/flow/pipeline.py @@ -17,15 +17,11 @@ import datetime import json import logging import random -import time from timeit import default_timer as timer import trio - from agent.canvas import Graph -from api.db import PipelineTaskType from api.db.services.document_service import DocumentService -from api.db.services.task_service import has_canceled -from api.db.services.pipeline_operation_log_service import PipelineOperationLogService +from api.db.services.task_service import has_canceled, TaskService, CANVAS_DEBUG_DOC_ID from rag.utils.redis_conn import REDIS_CONN @@ -34,9 +30,9 @@ class Pipeline(Graph): if isinstance(dsl, dict): dsl = json.dumps(dsl, ensure_ascii=False) super().__init__(dsl, tenant_id, task_id) + if doc_id == CANVAS_DEBUG_DOC_ID: + doc_id = None self._doc_id = doc_id - if self._doc_id == "x": - self._doc_id = None self._flow_id = flow_id self._kb_id = None if self._doc_id: @@ -80,7 +76,7 @@ class Pipeline(Graph): } ] REDIS_CONN.set_obj(log_key, obj, 60 * 30) - if self._doc_id: + if self._doc_id and self.task_id: percentage = 1.0 / len(self.components.items()) msg = "" finished = 0.0 @@ -96,7 +92,7 @@ class Pipeline(Graph): if finished < 0: break finished += o["trace"][-1]["progress"] * percentage - DocumentService.update_by_id(self._doc_id, {"progress": finished, "progress_msg": msg}) + TaskService.update_progress(self.task_id, {"progress": finished, "progress_msg": msg}) except Exception as e: logging.exception(e) @@ -113,34 +109,32 @@ class Pipeline(Graph): logging.exception(e) return [] - def reset(self): - super().reset() + + async def run(self, **kwargs): log_key = f"{self._flow_id}-{self.task_id}-logs" try: REDIS_CONN.set_obj(log_key, [], 60 * 10) except Exception as e: logging.exception(e) - - async def run(self, **kwargs): - st = time.perf_counter() + self.error = "" if not self.path: self.path.append("File") - - if self._doc_id: - DocumentService.update_by_id( - self._doc_id, {"progress": random.randint(0, 5) / 100.0, "progress_msg": "Start the pipeline...", "process_begin_at": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")} - ) - - self.error = "" - idx = len(self.path) - 1 - if idx == 0: cpn_obj = self.get_component_obj(self.path[0]) await cpn_obj.invoke(**kwargs) if cpn_obj.error(): self.error = "[ERROR]" + cpn_obj.error() - else: - idx += 1 - self.path.extend(cpn_obj.get_downstream()) + self.callback(cpn_obj.component_name, -1, self.error) + + if self._doc_id: + TaskService.update_progress(self.task_id, { + "progress": random.randint(0, 5) / 100.0, + "progress_msg": "Start the pipeline...", + "begin_at": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}) + + idx = len(self.path) - 1 + cpn_obj = self.get_component_obj(self.path[idx]) + idx += 1 + self.path.extend(cpn_obj.get_downstream()) while idx < len(self.path) and not self.error: last_cpn = self.get_component_obj(self.path[idx - 1]) @@ -152,23 +146,21 @@ class Pipeline(Graph): async with trio.open_nursery() as nursery: nursery.start_soon(invoke) + if cpn_obj.error(): self.error = "[ERROR]" + cpn_obj.error() - self.callback(cpn_obj.component_name, -1, self.error) + self.callback(cpn_obj._id, -1, self.error) break idx += 1 self.path.extend(cpn_obj.get_downstream()) - self.callback("END", 1, json.dumps(self.get_component_obj(self.path[-1]).output(), ensure_ascii=False)) + self.callback("END", 1 if not self.error else -1, json.dumps(self.get_component_obj(self.path[-1]).output(), ensure_ascii=False)) - if self._doc_id: - DocumentService.update_by_id( - self._doc_id, - { - "progress": 1 if not self.error else -1, - "progress_msg": "Pipeline finished...\n" + self.error, - "process_duration": time.perf_counter() - st, - }, - ) + if not self.error: + return self.get_component_obj(self.path[-1]).output() - PipelineOperationLogService.create(document_id=self._doc_id, pipeline_id=self._flow_id, task_type=PipelineTaskType.PARSE) + TaskService.update_progress(self.task_id, { + "progress": -1, + "progress_msg": f"[ERROR]: {self.error}"}) + + return {} diff --git a/rag/flow/splitter/splitter.py b/rag/flow/splitter/splitter.py index 10976bb53..43584bbfc 100644 --- a/rag/flow/splitter/splitter.py +++ b/rag/flow/splitter/splitter.py @@ -99,7 +99,7 @@ class Splitter(ProcessBase): { "text": RAGFlowPdfParser.remove_tag(c), "image": img, - "positions": RAGFlowPdfParser.extract_positions(c), + "positions": [[pos[0][-1]+1, *pos[1:]] for pos in RAGFlowPdfParser.extract_positions(c)], } for c, img in zip(chunks, images) ] diff --git a/rag/flow/tokenizer/tokenizer.py b/rag/flow/tokenizer/tokenizer.py index bdc4b9adc..a97e43ed2 100644 --- a/rag/flow/tokenizer/tokenizer.py +++ b/rag/flow/tokenizer/tokenizer.py @@ -120,8 +120,12 @@ class Tokenizer(ProcessBase): ck["question_tks"] = rag_tokenizer.tokenize("\n".join(ck["questions"])) if ck.get("keywords"): ck["important_tks"] = rag_tokenizer.tokenize("\n".join(ck["keywords"])) - ck["content_ltks"] = rag_tokenizer.tokenize(ck["text"]) - ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"]) + if ck.get("summary"): + ck["content_ltks"] = rag_tokenizer.tokenize(ck["summary"]) + ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"]) + else: + ck["content_ltks"] = rag_tokenizer.tokenize(ck["text"]) + ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"]) if i % 100 == 99: self.callback(i * 1.0 / len(chunks) / parts) diff --git a/rag/nlp/__init__.py b/rag/nlp/__init__.py index 4acf5da44..37e59205d 100644 --- a/rag/nlp/__init__.py +++ b/rag/nlp/__init__.py @@ -285,6 +285,7 @@ def tokenize_chunks(chunks, doc, eng, pdf_parser=None): res.append(d) return res + def tokenize_chunks_with_images(chunks, doc, eng, images): res = [] # wrap up as es documents @@ -299,6 +300,7 @@ def tokenize_chunks_with_images(chunks, doc, eng, images): res.append(d) return res + def tokenize_table(tbls, doc, eng, batch_size=10): res = [] # add tables diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index d60c7e155..3495cd4c9 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -21,6 +21,7 @@ import sys import threading import time from api.db.services.canvas_service import UserCanvasService +from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.pipeline_operation_log_service import PipelineOperationLogService from api.utils.api_utils import timeout from api.utils.base64_image import image2id @@ -49,7 +50,7 @@ from peewee import DoesNotExist from api.db import LLMType, ParserType, PipelineTaskType from api.db.services.document_service import DocumentService from api.db.services.llm_service import LLMBundle -from api.db.services.task_service import TaskService, has_canceled +from api.db.services.task_service import TaskService, has_canceled, CANVAS_DEBUG_DOC_ID from api.db.services.file2document_service import File2DocumentService from api import settings from api.versions import get_ragflow_version @@ -146,6 +147,7 @@ def start_tracemalloc_and_snapshot(signum, frame): max_rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss logging.info(f"taken snapshot {snapshot_file}. max RSS={max_rss / 1000:.2f} MB, current memory usage: {current / 10**6:.2f} MB, Peak memory usage: {peak / 10**6:.2f} MB") + # SIGUSR2 handler: stop tracemalloc def stop_tracemalloc(signum, frame): if tracemalloc.is_tracing(): @@ -154,6 +156,7 @@ def stop_tracemalloc(signum, frame): else: logging.info("tracemalloc not running") + class TaskCanceledException(Exception): def __init__(self, msg): self.msg = msg @@ -471,11 +474,97 @@ async def embedding(docs, mdl, parser_config=None, callback=None): async def run_dataflow(task: dict): + task_start_ts = timer() dataflow_id = task["dataflow_id"] - e, cvs = UserCanvasService.get_by_id(dataflow_id) - pipeline = Pipeline(cvs.dsl, tenant_id=task["tenant_id"], doc_id=task["doc_id"], task_id=task["id"], flow_id=dataflow_id) - pipeline.reset() - await pipeline.run(file=task.get("file")) + doc_id = task["doc_id"] + task_id = task["id"] + if task["task_type"] == "dataflow": + e, cvs = UserCanvasService.get_by_id(dataflow_id) + assert e, "User pipeline not found." + dsl = cvs.dsl + else: + e, pipeline_log = PipelineOperationLogService.get_by_id(dataflow_id) + assert e, "Pipeline log not found." + dsl = pipeline_log.dsl + pipeline = Pipeline(dsl, tenant_id=task["tenant_id"], doc_id=doc_id, task_id=task_id, flow_id=dataflow_id) + chunks = await pipeline.run(file=task["file"]) if task.get("file") else pipeline.run() + if doc_id == CANVAS_DEBUG_DOC_ID: + return + + if not chunks: + PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE) + return + + embedding_token_consumption = chunks.get("embedding_token_consumption", 0) + if chunks.get("chunks"): + chunks = chunks["chunks"] + elif chunks.get("json"): + chunks = chunks["json"] + elif chunks.get("markdown"): + chunks = [{"text": [chunks["markdown"]]}] + elif chunks.get("text"): + chunks = [{"text": [chunks["text"]]}] + elif chunks.get("html"): + chunks = [{"text": [chunks["html"]]}] + + keys = [k for o in chunks for k in list(o.keys())] + if not any([re.match(r"q_[0-9]+_vec", k) for k in keys]): + set_progress(task_id, prog=0.82, msg="Start to embedding...") + e, kb = KnowledgebaseService.get_by_id(task["kb_id"]) + embedding_id = kb.embd_id + embedding_model = LLMBundle(task["tenant_id"], LLMType.EMBEDDING, llm_name=embedding_id) + @timeout(60) + def batch_encode(txts): + nonlocal embedding_model + return embedding_model.encode([truncate(c, embedding_model.max_length - 10) for c in txts]) + vects = np.array([]) + texts = [o.get("questions", o.get("summary", o["text"])) for o in chunks] + delta = 0.20/(len(texts)//EMBEDDING_BATCH_SIZE) + prog = 0.8 + for i in range(0, len(texts), EMBEDDING_BATCH_SIZE): + async with embed_limiter: + vts, c = await trio.to_thread.run_sync(lambda: batch_encode(texts[i : i + EMBEDDING_BATCH_SIZE])) + if len(vects) == 0: + vects = vts + else: + vects = np.concatenate((vects, vts), axis=0) + embedding_token_consumption += c + prog += delta + set_progress(task_id, prog=prog, msg=f"{i+1} / {len(texts)//EMBEDDING_BATCH_SIZE}") + + assert len(vects) == len(chunks) + for i, ck in enumerate(chunks): + v = vects[i].tolist() + ck["q_%d_vec" % len(v)] = v + + for ck in chunks: + ck["doc_id"] = task["doc_id"] + ck["kb_id"] = [str(task["kb_id"])] + ck["docnm_kwd"] = task["name"] + ck["create_time"] = str(datetime.now()).replace("T", " ")[:19] + ck["create_timestamp_flt"] = datetime.now().timestamp() + ck["id"] = xxhash.xxh64((ck["text"] + str(ck["doc_id"])).encode("utf-8")).hexdigest() + if "questions" in ck: + del ck["questions"] + if "keywords" in ck: + del ck["keywords"] + if "summary" in ck: + del ck["summary"] + del ck["text"] + + start_ts = timer() + set_progress(task_id, prog=0.82, msg="Start to index...") + e = await insert_es(task_id, task["tenant_id"], task["kb_id"], chunks, partial(set_progress, task_id, 0, 100000000)) + if not e: + PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE) + return + + time_cost = timer() - start_ts + task_time_cost = timer() - task_start_ts + set_progress(task_id, prog=1., msg="Indexing done ({:.2f}s). Task done ({:.2f}s)".format(time_cost, task_time_cost)) + logging.info( + "[Done], chunks({}), token({}), elapsed:{:.2f}".format(len(chunks), embedding_token_consumption, task_time_cost)) + PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE) @timeout(3600) @@ -520,11 +609,48 @@ async def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None): return res, tk_count +async def delete_image(kb_id, chunk_id): + try: + async with minio_limiter: + STORAGE_IMPL.delete(kb_id, chunk_id) + except Exception: + logging.exception(f"Deleting image of chunk {chunk_id} got exception") + raise + + +async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_callback): + for b in range(0, len(chunks), DOC_BULK_SIZE): + doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + DOC_BULK_SIZE], search.index_name(task_tenant_id), task_dataset_id)) + task_canceled = has_canceled(task_id) + if task_canceled: + progress_callback(-1, msg="Task has been canceled.") + return + if b % 128 == 0: + progress_callback(prog=0.8 + 0.1 * (b + 1) / len(chunks), msg="") + if doc_store_result: + error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!" + progress_callback(-1, msg=error_message) + raise Exception(error_message) + chunk_ids = [chunk["id"] for chunk in chunks[:b + DOC_BULK_SIZE]] + chunk_ids_str = " ".join(chunk_ids) + try: + TaskService.update_chunk_ids(task_id, chunk_ids_str) + except DoesNotExist: + logging.warning(f"do_handle_task update_chunk_ids failed since task {task_id} is unknown.") + doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id), task_dataset_id)) + async with trio.open_nursery() as nursery: + for chunk_id in chunk_ids: + nursery.start_soon(delete_image, task_dataset_id, chunk_id) + progress_callback(-1, msg=f"Chunk updates failed since task {task_id} is unknown.") + return + return True + + @timeout(60*60*2, 1) async def do_handle_task(task): task_type = task.get("task_type", "") - if task_type == "dataflow" and task.get("doc_id", "") == "x": + if task_type == "dataflow" and task.get("doc_id", "") == CANVAS_DEBUG_DOC_ID: await run_dataflow(task) return @@ -569,7 +695,7 @@ async def do_handle_task(task): init_kb(task, vector_size) - if task_type == "dataflow": + if task_type[:len("dataflow")] == "dataflow": await run_dataflow(task) return @@ -631,41 +757,9 @@ async def do_handle_task(task): chunk_count = len(set([chunk["id"] for chunk in chunks])) start_ts = timer() - doc_store_result = "" - - async def delete_image(kb_id, chunk_id): - try: - async with minio_limiter: - STORAGE_IMPL.delete(kb_id, chunk_id) - except Exception: - logging.exception( - "Deleting image of chunk {}/{}/{} got exception".format(task["location"], task["name"], chunk_id)) - raise - - for b in range(0, len(chunks), DOC_BULK_SIZE): - doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + DOC_BULK_SIZE], search.index_name(task_tenant_id), task_dataset_id)) - task_canceled = has_canceled(task_id) - if task_canceled: - progress_callback(-1, msg="Task has been canceled.") - return - if b % 128 == 0: - progress_callback(prog=0.8 + 0.1 * (b + 1) / len(chunks), msg="") - if doc_store_result: - error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!" - progress_callback(-1, msg=error_message) - raise Exception(error_message) - chunk_ids = [chunk["id"] for chunk in chunks[:b + DOC_BULK_SIZE]] - chunk_ids_str = " ".join(chunk_ids) - try: - TaskService.update_chunk_ids(task["id"], chunk_ids_str) - except DoesNotExist: - logging.warning(f"do_handle_task update_chunk_ids failed since task {task['id']} is unknown.") - doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id), task_dataset_id)) - async with trio.open_nursery() as nursery: - for chunk_id in chunk_ids: - nursery.start_soon(delete_image, task_dataset_id, chunk_id) - progress_callback(-1, msg=f"Chunk updates failed since task {task['id']} is unknown.") - return + e = await insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_callback) + if not e: + return logging.info("Indexing doc({}), page({}-{}), chunks({}), elapsed: {:.2f}".format(task_document_name, task_from_page, task_to_page, len(chunks), @@ -715,7 +809,8 @@ async def handle_task(): task_document_ids = [] if task_type in ["graphrag"]: task_document_ids = task["doc_ids"] - PipelineOperationLogService.record_pipeline_operation(document_id=task["doc_id"], pipeline_id=task.get("dataflow_id", "") or "", task_type=pipeline_task_type, fake_document_ids=task_document_ids) + if task["doc_id"] != CANVAS_DEBUG_DOC_ID: + PipelineOperationLogService.record_pipeline_operation(document_id=task["doc_id"], pipeline_id=task.get("dataflow_id", "") or "", task_type=pipeline_task_type, fake_document_ids=task_document_ids) redis_msg.ack()