Feat: add extractor component. (#10271)

### What problem does this PR solve?


### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Kevin Hu
2025-09-25 11:34:47 +08:00
committed by GitHub
parent 840b2b5809
commit 1b19d302c5
16 changed files with 379 additions and 127 deletions

View File

@ -113,6 +113,15 @@ class LLM(ComponentBase):
def add2system_prompt(self, txt): def add2system_prompt(self, txt):
self._param.sys_prompt += txt self._param.sys_prompt += txt
def _sys_prompt_and_msg(self, msg, args):
for p in self._param.prompts:
if msg and msg[-1]["role"] == p["role"]:
continue
p = deepcopy(p)
p["content"] = self.string_format(p["content"], args)
msg.append(p)
return msg, self.string_format(self._param.sys_prompt, args)
def _prepare_prompt_variables(self): def _prepare_prompt_variables(self):
if self._param.visual_files_var: if self._param.visual_files_var:
self.imgs = self._canvas.get_variable_value(self._param.visual_files_var) self.imgs = self._canvas.get_variable_value(self._param.visual_files_var)
@ -128,7 +137,6 @@ class LLM(ComponentBase):
args = {} args = {}
vars = self.get_input_elements() if not self._param.debug_inputs else self._param.debug_inputs vars = self.get_input_elements() if not self._param.debug_inputs else self._param.debug_inputs
sys_prompt = self._param.sys_prompt
for k, o in vars.items(): for k, o in vars.items():
args[k] = o["value"] args[k] = o["value"]
if not isinstance(args[k], str): if not isinstance(args[k], str):
@ -138,16 +146,8 @@ class LLM(ComponentBase):
args[k] = str(args[k]) args[k] = str(args[k])
self.set_input_value(k, args[k]) self.set_input_value(k, args[k])
msg = self._canvas.get_history(self._param.message_history_window_size)[:-1] msg, sys_prompt = self._sys_prompt_and_msg(self._canvas.get_history(self._param.message_history_window_size)[:-1], args)
for p in self._param.prompts:
if msg and msg[-1]["role"] == p["role"]:
continue
msg.append(deepcopy(p))
sys_prompt = self.string_format(sys_prompt, args)
user_defined_prompt, sys_prompt = self._extract_prompts(sys_prompt) user_defined_prompt, sys_prompt = self._extract_prompts(sys_prompt)
for m in msg:
m["content"] = self.string_format(m["content"], args)
if self._param.cite and self._canvas.get_reference()["chunks"]: if self._param.cite and self._canvas.get_reference()["chunks"]:
sys_prompt += citation_prompt(user_defined_prompt) sys_prompt += citation_prompt(user_defined_prompt)

View File

@ -28,6 +28,7 @@ from api.db import CanvasCategory, FileType
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService, API4ConversationService from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService, API4ConversationService
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService
from api.db.services.file_service import FileService from api.db.services.file_service import FileService
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
from api.db.services.task_service import queue_dataflow from api.db.services.task_service import queue_dataflow
from api.db.services.user_service import TenantService from api.db.services.user_service import TenantService
from api.db.services.user_canvas_version import UserCanvasVersionService from api.db.services.user_canvas_version import UserCanvasVersionService
@ -174,6 +175,25 @@ def run():
return resp return resp
@manager.route('/rerun', methods=['POST']) # noqa: F821
@validate_request("id", "dsl", "component_id")
@login_required
def rerun():
req = request.json
doc = PipelineOperationLogService.get_documents_info(req["id"])
if not doc:
return get_data_error_result(message="Document not found.")
doc = doc[0]
if 0 < doc["progress"] < 1:
return get_data_error_result(message=f"`{doc['name']}` is processing...")
dsl = req["dsl"]
dsl["path"] = [req["component_id"]]
PipelineOperationLogService.update_by_id(req["id"], {"dsl": dsl})
queue_dataflow(tenant_id=current_user.id, flow_id=req["id"], task_id=get_uuid(), doc_id=doc["id"], priority=0, rerun=True)
return get_json_result(data=True)
@manager.route('/cancel/<task_id>', methods=['PUT']) # noqa: F821 @manager.route('/cancel/<task_id>', methods=['PUT']) # noqa: F821
@login_required @login_required
def cancel(task_id): def cancel(task_id):

View File

@ -121,12 +121,20 @@ class DocumentService(CommonService):
orderby, desc, keywords, run_status, types, suffix): orderby, desc, keywords, run_status, types, suffix):
fields = cls.get_cls_model_fields() fields = cls.get_cls_model_fields()
if keywords: if keywords:
docs = cls.model.select(*fields).join(File2Document, on=(File2Document.document_id == cls.model.id)).join(File, on=(File.id == File2Document.file_id)).where( docs = cls.model.select(*[*fields, UserCanvas.title])\
(cls.model.kb_id == kb_id), .join(File2Document, on=(File2Document.document_id == cls.model.id))\
(fn.LOWER(cls.model.name).contains(keywords.lower())) .join(File, on=(File.id == File2Document.file_id))\
) .join(UserCanvas, on=(cls.model.pipeline_id == UserCanvas.id), join_type=JOIN.LEFT_OUTER)\
.where(
(cls.model.kb_id == kb_id),
(fn.LOWER(cls.model.name).contains(keywords.lower()))
)
else: else:
docs = cls.model.select(*fields).join(File2Document, on=(File2Document.document_id == cls.model.id)).join(File, on=(File.id == File2Document.file_id)).where(cls.model.kb_id == kb_id) docs = cls.model.select(*[*fields, UserCanvas.title])\
.join(File2Document, on=(File2Document.document_id == cls.model.id))\
.join(UserCanvas, on=(cls.model.pipeline_id == UserCanvas.id), join_type=JOIN.LEFT_OUTER)\
.join(File, on=(File.id == File2Document.file_id))\
.where(cls.model.kb_id == kb_id)
if run_status: if run_status:
docs = docs.where(cls.model.run.in_(run_status)) docs = docs.where(cls.model.run.in_(run_status))

View File

@ -225,6 +225,7 @@ class KnowledgebaseService(CommonService):
cls.model.token_num, cls.model.token_num,
cls.model.chunk_num, cls.model.chunk_num,
cls.model.parser_id, cls.model.parser_id,
cls.model.pipeline_id,
cls.model.parser_config, cls.model.parser_config,
cls.model.pagerank, cls.model.pagerank,
cls.model.create_time, cls.model.create_time,

View File

@ -14,12 +14,13 @@
# limitations under the License. # limitations under the License.
# #
import json import json
import logging
from datetime import datetime from datetime import datetime
from peewee import fn from peewee import fn
from api.db import VALID_PIPELINE_TASK_TYPES from api.db import VALID_PIPELINE_TASK_TYPES
from api.db.db_models import DB, PipelineOperationLog from api.db.db_models import DB, PipelineOperationLog, Document
from api.db.services.canvas_service import UserCanvasService from api.db.services.canvas_service import UserCanvasService
from api.db.services.common_service import CommonService from api.db.services.common_service import CommonService
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService
@ -84,22 +85,20 @@ class PipelineOperationLogService(CommonService):
def create(cls, document_id, pipeline_id, task_type, fake_document_ids=[]): def create(cls, document_id, pipeline_id, task_type, fake_document_ids=[]):
from rag.flow.pipeline import Pipeline from rag.flow.pipeline import Pipeline
tenant_id = ""
title = ""
avatar = ""
dsl = "" dsl = ""
operation_status = ""
referred_document_id = document_id referred_document_id = document_id
if referred_document_id == "x" and fake_document_ids: if referred_document_id == "x" and fake_document_ids:
referred_document_id = fake_document_ids[0] referred_document_id = fake_document_ids[0]
ok, document = DocumentService.get_by_id(referred_document_id) ok, document = DocumentService.get_by_id(referred_document_id)
if not ok: if not ok:
raise RuntimeError(f"Document for referred_document_id {referred_document_id} not found") logging.warning(f"Document for referred_document_id {referred_document_id} not found")
return
DocumentService.update_progress_immediately([document.to_dict()]) DocumentService.update_progress_immediately([document.to_dict()])
ok, document = DocumentService.get_by_id(referred_document_id) ok, document = DocumentService.get_by_id(referred_document_id)
if not ok: if not ok:
raise RuntimeError(f"Document for referred_document_id {referred_document_id} not found") logging.warning(f"Document for referred_document_id {referred_document_id} not found")
return
if document.progress not in [1, -1]: if document.progress not in [1, -1]:
return return
operation_status = document.run operation_status = document.run
@ -189,6 +188,20 @@ class PipelineOperationLogService(CommonService):
return list(logs.dicts()), count return list(logs.dicts()), count
@classmethod
@DB.connection_context()
def get_documents_info(cls, id):
fields = [
Document.id,
Document.name,
Document.progress
]
return cls.model.select(*fields).join(Document, on=(cls.model.document_id == Document.id)).where(
cls.model.id == id,
Document.progress > 0,
Document.progress < 1
).dicts()
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def get_dataset_logs_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, operation_status): def get_dataset_logs_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, operation_status):
@ -208,3 +221,4 @@ class PipelineOperationLogService(CommonService):
logs = logs.paginate(page_number, items_per_page) logs = logs.paginate(page_number, items_per_page)
return list(logs.dicts()), count return list(logs.dicts()), count

View File

@ -35,6 +35,7 @@ from rag.utils.redis_conn import REDIS_CONN
from api import settings from api import settings
from rag.nlp import search from rag.nlp import search
CANVAS_DEBUG_DOC_ID = "dataflow_x"
def trim_header_by_lines(text: str, max_length) -> str: def trim_header_by_lines(text: str, max_length) -> str:
# Trim header text to maximum length while preserving line breaks # Trim header text to maximum length while preserving line breaks
@ -85,7 +86,7 @@ class TaskService(CommonService):
Returns None if task is not found or has exceeded retry limit. Returns None if task is not found or has exceeded retry limit.
""" """
doc_id = cls.model.doc_id doc_id = cls.model.doc_id
if doc_id == "x" and doc_ids: if doc_id == CANVAS_DEBUG_DOC_ID and doc_ids:
doc_id = doc_ids[0] doc_id = doc_ids[0]
fields = [ fields = [
@ -476,14 +477,14 @@ def has_canceled(task_id):
return False return False
def queue_dataflow(tenant_id:str, flow_id:str, task_id:str, doc_id:str="x", file:dict=None, priority: int=0) -> tuple[bool, str]: def queue_dataflow(tenant_id:str, flow_id:str, task_id:str, doc_id:str=CANVAS_DEBUG_DOC_ID, file:dict=None, priority: int=0, rerun:bool=False) -> tuple[bool, str]:
task = dict( task = dict(
id=task_id, id=task_id,
doc_id=doc_id, doc_id=doc_id,
from_page=0, from_page=0,
to_page=100000000, to_page=100000000,
task_type="dataflow", task_type="dataflow" if not rerun else "dataflow_rerun",
priority=priority, priority=priority,
) )

View File

@ -1,4 +1,5 @@
import base64 import base64
import logging
from functools import partial from functools import partial
from io import BytesIO from io import BytesIO
@ -8,7 +9,7 @@ test_image_base64 = "iVBORw0KGgoAAAANSUhEUgAAAGQAAABkCAIAAAD/gAIDAAAA6ElEQVR4nO3
test_image = base64.b64decode(test_image_base64) test_image = base64.b64decode(test_image_base64)
async def image2id(d: dict, storage_put_func: partial, objname:str, bucket:str="IMAGETEMPS"): async def image2id(d: dict, storage_put_func: partial, objname:str, bucket:str="imagetemps"):
import logging import logging
from io import BytesIO from io import BytesIO
import trio import trio
@ -46,7 +47,10 @@ def id2image(image_id:str|None, storage_get_func: partial):
if len(arr) != 2: if len(arr) != 2:
return return
bkt, nm = image_id.split("-") bkt, nm = image_id.split("-")
blob = storage_get_func(bucket=bkt, filename=nm) try:
if not blob: blob = storage_get_func(bucket=bkt, filename=nm)
return if not blob:
return Image.open(BytesIO(blob)) return
return Image.open(BytesIO(blob))
except Exception as e:
logging.exception(e)

View File

@ -13,13 +13,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import logging
import os import os
import time import time
from functools import partial from functools import partial
from typing import Any from typing import Any
import trio import trio
from agent.component.base import ComponentBase, ComponentParamBase from agent.component.base import ComponentBase, ComponentParamBase
from api.utils.api_utils import timeout from api.utils.api_utils import timeout
@ -43,17 +42,17 @@ class ProcessBase(ComponentBase):
self.set_output("_created_time", time.perf_counter()) self.set_output("_created_time", time.perf_counter())
for k, v in kwargs.items(): for k, v in kwargs.items():
self.set_output(k, v) self.set_output(k, v)
#try: try:
with trio.fail_after(self._param.timeout): with trio.fail_after(self._param.timeout):
await self._invoke(**kwargs) await self._invoke(**kwargs)
self.callback(1, "Done") self.callback(1, "Done")
#except Exception as e: except Exception as e:
# if self.get_exception_default_value(): if self.get_exception_default_value():
# self.set_exception_default_value() self.set_exception_default_value()
# else: else:
# self.set_output("_ERROR", str(e)) self.set_output("_ERROR", str(e))
# logging.exception(e) logging.exception(e)
# self.callback(-1, str(e)) self.callback(-1, str(e))
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time")) self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
return self.output() return self.output()

View File

@ -0,0 +1,15 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,59 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
from agent.component.llm import LLMParam, LLM
class ExtractorParam(LLMParam):
def __init__(self):
super().__init__()
self.field_name = ""
def check(self):
super().check()
self.check_empty(self.field_name, "Result Destination")
class Extractor(LLM):
component_name = "Extractor"
async def _invoke(self, **kwargs):
self.callback(random.randint(1, 5) / 100.0, "Start to generate.")
inputs = self.get_input_elements()
chunks = []
chunks_key = ""
args = {}
for k, v in inputs.items():
args[k] = v["value"]
if isinstance(args[k], list):
chunks = args[k]
chunks_key = k
if chunks:
prog = 0
for i, ck in enumerate(chunks):
args[chunks_key] = ck["text"]
msg, sys_prompt = self._sys_prompt_and_msg([], args)
msg.insert(0, {"role": "system", "content": sys_prompt})
ck[self._param.field_name] = self._generate(msg)
prog += 1./len(chunks)
self.callback(prog, f"{i+1} / {len(chunks)}")
self.set_output("chunks", chunks)
else:
msg, sys_prompt = self._sys_prompt_and_msg([], args)
msg.insert(0, {"role": "system", "content": sys_prompt})
self.set_output("chunks", [{self._param.field_name: self._generate(msg)}])

View File

@ -0,0 +1,38 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Literal
from pydantic import BaseModel, ConfigDict, Field
class ExtractorFromUpstream(BaseModel):
created_time: float | None = Field(default=None, alias="_created_time")
elapsed_time: float | None = Field(default=None, alias="_elapsed_time")
name: str
file: dict | None = Field(default=None)
chunks: list[dict[str, Any]] | None = Field(default=None)
output_format: Literal["json", "markdown", "text", "html"] | None = Field(default=None)
json_result: list[dict[str, Any]] | None = Field(default=None, alias="json")
markdown_result: str | None = Field(default=None, alias="markdown")
text_result: str | None = Field(default=None, alias="text")
html_result: list[str] | None = Field(default=None, alias="html")
model_config = ConfigDict(populate_by_name=True, extra="forbid")
# def to_dict(self, *, exclude_none: bool = True) -> dict:
# return self.model_dump(by_alias=True, exclude_none=exclude_none)

View File

@ -17,15 +17,11 @@ import datetime
import json import json
import logging import logging
import random import random
import time
from timeit import default_timer as timer from timeit import default_timer as timer
import trio import trio
from agent.canvas import Graph from agent.canvas import Graph
from api.db import PipelineTaskType
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService
from api.db.services.task_service import has_canceled from api.db.services.task_service import has_canceled, TaskService, CANVAS_DEBUG_DOC_ID
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
from rag.utils.redis_conn import REDIS_CONN from rag.utils.redis_conn import REDIS_CONN
@ -34,9 +30,9 @@ class Pipeline(Graph):
if isinstance(dsl, dict): if isinstance(dsl, dict):
dsl = json.dumps(dsl, ensure_ascii=False) dsl = json.dumps(dsl, ensure_ascii=False)
super().__init__(dsl, tenant_id, task_id) super().__init__(dsl, tenant_id, task_id)
if doc_id == CANVAS_DEBUG_DOC_ID:
doc_id = None
self._doc_id = doc_id self._doc_id = doc_id
if self._doc_id == "x":
self._doc_id = None
self._flow_id = flow_id self._flow_id = flow_id
self._kb_id = None self._kb_id = None
if self._doc_id: if self._doc_id:
@ -80,7 +76,7 @@ class Pipeline(Graph):
} }
] ]
REDIS_CONN.set_obj(log_key, obj, 60 * 30) REDIS_CONN.set_obj(log_key, obj, 60 * 30)
if self._doc_id: if self._doc_id and self.task_id:
percentage = 1.0 / len(self.components.items()) percentage = 1.0 / len(self.components.items())
msg = "" msg = ""
finished = 0.0 finished = 0.0
@ -96,7 +92,7 @@ class Pipeline(Graph):
if finished < 0: if finished < 0:
break break
finished += o["trace"][-1]["progress"] * percentage finished += o["trace"][-1]["progress"] * percentage
DocumentService.update_by_id(self._doc_id, {"progress": finished, "progress_msg": msg}) TaskService.update_progress(self.task_id, {"progress": finished, "progress_msg": msg})
except Exception as e: except Exception as e:
logging.exception(e) logging.exception(e)
@ -113,34 +109,32 @@ class Pipeline(Graph):
logging.exception(e) logging.exception(e)
return [] return []
def reset(self):
super().reset() async def run(self, **kwargs):
log_key = f"{self._flow_id}-{self.task_id}-logs" log_key = f"{self._flow_id}-{self.task_id}-logs"
try: try:
REDIS_CONN.set_obj(log_key, [], 60 * 10) REDIS_CONN.set_obj(log_key, [], 60 * 10)
except Exception as e: except Exception as e:
logging.exception(e) logging.exception(e)
self.error = ""
async def run(self, **kwargs):
st = time.perf_counter()
if not self.path: if not self.path:
self.path.append("File") self.path.append("File")
if self._doc_id:
DocumentService.update_by_id(
self._doc_id, {"progress": random.randint(0, 5) / 100.0, "progress_msg": "Start the pipeline...", "process_begin_at": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
)
self.error = ""
idx = len(self.path) - 1
if idx == 0:
cpn_obj = self.get_component_obj(self.path[0]) cpn_obj = self.get_component_obj(self.path[0])
await cpn_obj.invoke(**kwargs) await cpn_obj.invoke(**kwargs)
if cpn_obj.error(): if cpn_obj.error():
self.error = "[ERROR]" + cpn_obj.error() self.error = "[ERROR]" + cpn_obj.error()
else: self.callback(cpn_obj.component_name, -1, self.error)
idx += 1
self.path.extend(cpn_obj.get_downstream()) if self._doc_id:
TaskService.update_progress(self.task_id, {
"progress": random.randint(0, 5) / 100.0,
"progress_msg": "Start the pipeline...",
"begin_at": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")})
idx = len(self.path) - 1
cpn_obj = self.get_component_obj(self.path[idx])
idx += 1
self.path.extend(cpn_obj.get_downstream())
while idx < len(self.path) and not self.error: while idx < len(self.path) and not self.error:
last_cpn = self.get_component_obj(self.path[idx - 1]) last_cpn = self.get_component_obj(self.path[idx - 1])
@ -152,23 +146,21 @@ class Pipeline(Graph):
async with trio.open_nursery() as nursery: async with trio.open_nursery() as nursery:
nursery.start_soon(invoke) nursery.start_soon(invoke)
if cpn_obj.error(): if cpn_obj.error():
self.error = "[ERROR]" + cpn_obj.error() self.error = "[ERROR]" + cpn_obj.error()
self.callback(cpn_obj.component_name, -1, self.error) self.callback(cpn_obj._id, -1, self.error)
break break
idx += 1 idx += 1
self.path.extend(cpn_obj.get_downstream()) self.path.extend(cpn_obj.get_downstream())
self.callback("END", 1, json.dumps(self.get_component_obj(self.path[-1]).output(), ensure_ascii=False)) self.callback("END", 1 if not self.error else -1, json.dumps(self.get_component_obj(self.path[-1]).output(), ensure_ascii=False))
if self._doc_id: if not self.error:
DocumentService.update_by_id( return self.get_component_obj(self.path[-1]).output()
self._doc_id,
{
"progress": 1 if not self.error else -1,
"progress_msg": "Pipeline finished...\n" + self.error,
"process_duration": time.perf_counter() - st,
},
)
PipelineOperationLogService.create(document_id=self._doc_id, pipeline_id=self._flow_id, task_type=PipelineTaskType.PARSE) TaskService.update_progress(self.task_id, {
"progress": -1,
"progress_msg": f"[ERROR]: {self.error}"})
return {}

View File

@ -99,7 +99,7 @@ class Splitter(ProcessBase):
{ {
"text": RAGFlowPdfParser.remove_tag(c), "text": RAGFlowPdfParser.remove_tag(c),
"image": img, "image": img,
"positions": RAGFlowPdfParser.extract_positions(c), "positions": [[pos[0][-1]+1, *pos[1:]] for pos in RAGFlowPdfParser.extract_positions(c)],
} }
for c, img in zip(chunks, images) for c, img in zip(chunks, images)
] ]

View File

@ -120,8 +120,12 @@ class Tokenizer(ProcessBase):
ck["question_tks"] = rag_tokenizer.tokenize("\n".join(ck["questions"])) ck["question_tks"] = rag_tokenizer.tokenize("\n".join(ck["questions"]))
if ck.get("keywords"): if ck.get("keywords"):
ck["important_tks"] = rag_tokenizer.tokenize("\n".join(ck["keywords"])) ck["important_tks"] = rag_tokenizer.tokenize("\n".join(ck["keywords"]))
ck["content_ltks"] = rag_tokenizer.tokenize(ck["text"]) if ck.get("summary"):
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"]) ck["content_ltks"] = rag_tokenizer.tokenize(ck["summary"])
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
else:
ck["content_ltks"] = rag_tokenizer.tokenize(ck["text"])
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
if i % 100 == 99: if i % 100 == 99:
self.callback(i * 1.0 / len(chunks) / parts) self.callback(i * 1.0 / len(chunks) / parts)

View File

@ -285,6 +285,7 @@ def tokenize_chunks(chunks, doc, eng, pdf_parser=None):
res.append(d) res.append(d)
return res return res
def tokenize_chunks_with_images(chunks, doc, eng, images): def tokenize_chunks_with_images(chunks, doc, eng, images):
res = [] res = []
# wrap up as es documents # wrap up as es documents
@ -299,6 +300,7 @@ def tokenize_chunks_with_images(chunks, doc, eng, images):
res.append(d) res.append(d)
return res return res
def tokenize_table(tbls, doc, eng, batch_size=10): def tokenize_table(tbls, doc, eng, batch_size=10):
res = [] res = []
# add tables # add tables

View File

@ -21,6 +21,7 @@ import sys
import threading import threading
import time import time
from api.db.services.canvas_service import UserCanvasService from api.db.services.canvas_service import UserCanvasService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
from api.utils.api_utils import timeout from api.utils.api_utils import timeout
from api.utils.base64_image import image2id from api.utils.base64_image import image2id
@ -49,7 +50,7 @@ from peewee import DoesNotExist
from api.db import LLMType, ParserType, PipelineTaskType from api.db import LLMType, ParserType, PipelineTaskType
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService
from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
from api.db.services.task_service import TaskService, has_canceled from api.db.services.task_service import TaskService, has_canceled, CANVAS_DEBUG_DOC_ID
from api.db.services.file2document_service import File2DocumentService from api.db.services.file2document_service import File2DocumentService
from api import settings from api import settings
from api.versions import get_ragflow_version from api.versions import get_ragflow_version
@ -146,6 +147,7 @@ def start_tracemalloc_and_snapshot(signum, frame):
max_rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss max_rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
logging.info(f"taken snapshot {snapshot_file}. max RSS={max_rss / 1000:.2f} MB, current memory usage: {current / 10**6:.2f} MB, Peak memory usage: {peak / 10**6:.2f} MB") logging.info(f"taken snapshot {snapshot_file}. max RSS={max_rss / 1000:.2f} MB, current memory usage: {current / 10**6:.2f} MB, Peak memory usage: {peak / 10**6:.2f} MB")
# SIGUSR2 handler: stop tracemalloc # SIGUSR2 handler: stop tracemalloc
def stop_tracemalloc(signum, frame): def stop_tracemalloc(signum, frame):
if tracemalloc.is_tracing(): if tracemalloc.is_tracing():
@ -154,6 +156,7 @@ def stop_tracemalloc(signum, frame):
else: else:
logging.info("tracemalloc not running") logging.info("tracemalloc not running")
class TaskCanceledException(Exception): class TaskCanceledException(Exception):
def __init__(self, msg): def __init__(self, msg):
self.msg = msg self.msg = msg
@ -471,11 +474,97 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
async def run_dataflow(task: dict): async def run_dataflow(task: dict):
task_start_ts = timer()
dataflow_id = task["dataflow_id"] dataflow_id = task["dataflow_id"]
e, cvs = UserCanvasService.get_by_id(dataflow_id) doc_id = task["doc_id"]
pipeline = Pipeline(cvs.dsl, tenant_id=task["tenant_id"], doc_id=task["doc_id"], task_id=task["id"], flow_id=dataflow_id) task_id = task["id"]
pipeline.reset() if task["task_type"] == "dataflow":
await pipeline.run(file=task.get("file")) e, cvs = UserCanvasService.get_by_id(dataflow_id)
assert e, "User pipeline not found."
dsl = cvs.dsl
else:
e, pipeline_log = PipelineOperationLogService.get_by_id(dataflow_id)
assert e, "Pipeline log not found."
dsl = pipeline_log.dsl
pipeline = Pipeline(dsl, tenant_id=task["tenant_id"], doc_id=doc_id, task_id=task_id, flow_id=dataflow_id)
chunks = await pipeline.run(file=task["file"]) if task.get("file") else pipeline.run()
if doc_id == CANVAS_DEBUG_DOC_ID:
return
if not chunks:
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE)
return
embedding_token_consumption = chunks.get("embedding_token_consumption", 0)
if chunks.get("chunks"):
chunks = chunks["chunks"]
elif chunks.get("json"):
chunks = chunks["json"]
elif chunks.get("markdown"):
chunks = [{"text": [chunks["markdown"]]}]
elif chunks.get("text"):
chunks = [{"text": [chunks["text"]]}]
elif chunks.get("html"):
chunks = [{"text": [chunks["html"]]}]
keys = [k for o in chunks for k in list(o.keys())]
if not any([re.match(r"q_[0-9]+_vec", k) for k in keys]):
set_progress(task_id, prog=0.82, msg="Start to embedding...")
e, kb = KnowledgebaseService.get_by_id(task["kb_id"])
embedding_id = kb.embd_id
embedding_model = LLMBundle(task["tenant_id"], LLMType.EMBEDDING, llm_name=embedding_id)
@timeout(60)
def batch_encode(txts):
nonlocal embedding_model
return embedding_model.encode([truncate(c, embedding_model.max_length - 10) for c in txts])
vects = np.array([])
texts = [o.get("questions", o.get("summary", o["text"])) for o in chunks]
delta = 0.20/(len(texts)//EMBEDDING_BATCH_SIZE)
prog = 0.8
for i in range(0, len(texts), EMBEDDING_BATCH_SIZE):
async with embed_limiter:
vts, c = await trio.to_thread.run_sync(lambda: batch_encode(texts[i : i + EMBEDDING_BATCH_SIZE]))
if len(vects) == 0:
vects = vts
else:
vects = np.concatenate((vects, vts), axis=0)
embedding_token_consumption += c
prog += delta
set_progress(task_id, prog=prog, msg=f"{i+1} / {len(texts)//EMBEDDING_BATCH_SIZE}")
assert len(vects) == len(chunks)
for i, ck in enumerate(chunks):
v = vects[i].tolist()
ck["q_%d_vec" % len(v)] = v
for ck in chunks:
ck["doc_id"] = task["doc_id"]
ck["kb_id"] = [str(task["kb_id"])]
ck["docnm_kwd"] = task["name"]
ck["create_time"] = str(datetime.now()).replace("T", " ")[:19]
ck["create_timestamp_flt"] = datetime.now().timestamp()
ck["id"] = xxhash.xxh64((ck["text"] + str(ck["doc_id"])).encode("utf-8")).hexdigest()
if "questions" in ck:
del ck["questions"]
if "keywords" in ck:
del ck["keywords"]
if "summary" in ck:
del ck["summary"]
del ck["text"]
start_ts = timer()
set_progress(task_id, prog=0.82, msg="Start to index...")
e = await insert_es(task_id, task["tenant_id"], task["kb_id"], chunks, partial(set_progress, task_id, 0, 100000000))
if not e:
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE)
return
time_cost = timer() - start_ts
task_time_cost = timer() - task_start_ts
set_progress(task_id, prog=1., msg="Indexing done ({:.2f}s). Task done ({:.2f}s)".format(time_cost, task_time_cost))
logging.info(
"[Done], chunks({}), token({}), elapsed:{:.2f}".format(len(chunks), embedding_token_consumption, task_time_cost))
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE)
@timeout(3600) @timeout(3600)
@ -520,11 +609,48 @@ async def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
return res, tk_count return res, tk_count
async def delete_image(kb_id, chunk_id):
try:
async with minio_limiter:
STORAGE_IMPL.delete(kb_id, chunk_id)
except Exception:
logging.exception(f"Deleting image of chunk {chunk_id} got exception")
raise
async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_callback):
for b in range(0, len(chunks), DOC_BULK_SIZE):
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + DOC_BULK_SIZE], search.index_name(task_tenant_id), task_dataset_id))
task_canceled = has_canceled(task_id)
if task_canceled:
progress_callback(-1, msg="Task has been canceled.")
return
if b % 128 == 0:
progress_callback(prog=0.8 + 0.1 * (b + 1) / len(chunks), msg="")
if doc_store_result:
error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
progress_callback(-1, msg=error_message)
raise Exception(error_message)
chunk_ids = [chunk["id"] for chunk in chunks[:b + DOC_BULK_SIZE]]
chunk_ids_str = " ".join(chunk_ids)
try:
TaskService.update_chunk_ids(task_id, chunk_ids_str)
except DoesNotExist:
logging.warning(f"do_handle_task update_chunk_ids failed since task {task_id} is unknown.")
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id), task_dataset_id))
async with trio.open_nursery() as nursery:
for chunk_id in chunk_ids:
nursery.start_soon(delete_image, task_dataset_id, chunk_id)
progress_callback(-1, msg=f"Chunk updates failed since task {task_id} is unknown.")
return
return True
@timeout(60*60*2, 1) @timeout(60*60*2, 1)
async def do_handle_task(task): async def do_handle_task(task):
task_type = task.get("task_type", "") task_type = task.get("task_type", "")
if task_type == "dataflow" and task.get("doc_id", "") == "x": if task_type == "dataflow" and task.get("doc_id", "") == CANVAS_DEBUG_DOC_ID:
await run_dataflow(task) await run_dataflow(task)
return return
@ -569,7 +695,7 @@ async def do_handle_task(task):
init_kb(task, vector_size) init_kb(task, vector_size)
if task_type == "dataflow": if task_type[:len("dataflow")] == "dataflow":
await run_dataflow(task) await run_dataflow(task)
return return
@ -631,41 +757,9 @@ async def do_handle_task(task):
chunk_count = len(set([chunk["id"] for chunk in chunks])) chunk_count = len(set([chunk["id"] for chunk in chunks]))
start_ts = timer() start_ts = timer()
doc_store_result = "" e = await insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_callback)
if not e:
async def delete_image(kb_id, chunk_id): return
try:
async with minio_limiter:
STORAGE_IMPL.delete(kb_id, chunk_id)
except Exception:
logging.exception(
"Deleting image of chunk {}/{}/{} got exception".format(task["location"], task["name"], chunk_id))
raise
for b in range(0, len(chunks), DOC_BULK_SIZE):
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + DOC_BULK_SIZE], search.index_name(task_tenant_id), task_dataset_id))
task_canceled = has_canceled(task_id)
if task_canceled:
progress_callback(-1, msg="Task has been canceled.")
return
if b % 128 == 0:
progress_callback(prog=0.8 + 0.1 * (b + 1) / len(chunks), msg="")
if doc_store_result:
error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
progress_callback(-1, msg=error_message)
raise Exception(error_message)
chunk_ids = [chunk["id"] for chunk in chunks[:b + DOC_BULK_SIZE]]
chunk_ids_str = " ".join(chunk_ids)
try:
TaskService.update_chunk_ids(task["id"], chunk_ids_str)
except DoesNotExist:
logging.warning(f"do_handle_task update_chunk_ids failed since task {task['id']} is unknown.")
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id), task_dataset_id))
async with trio.open_nursery() as nursery:
for chunk_id in chunk_ids:
nursery.start_soon(delete_image, task_dataset_id, chunk_id)
progress_callback(-1, msg=f"Chunk updates failed since task {task['id']} is unknown.")
return
logging.info("Indexing doc({}), page({}-{}), chunks({}), elapsed: {:.2f}".format(task_document_name, task_from_page, logging.info("Indexing doc({}), page({}-{}), chunks({}), elapsed: {:.2f}".format(task_document_name, task_from_page,
task_to_page, len(chunks), task_to_page, len(chunks),
@ -715,7 +809,8 @@ async def handle_task():
task_document_ids = [] task_document_ids = []
if task_type in ["graphrag"]: if task_type in ["graphrag"]:
task_document_ids = task["doc_ids"] task_document_ids = task["doc_ids"]
PipelineOperationLogService.record_pipeline_operation(document_id=task["doc_id"], pipeline_id=task.get("dataflow_id", "") or "", task_type=pipeline_task_type, fake_document_ids=task_document_ids) if task["doc_id"] != CANVAS_DEBUG_DOC_ID:
PipelineOperationLogService.record_pipeline_operation(document_id=task["doc_id"], pipeline_id=task.get("dataflow_id", "") or "", task_type=pipeline_task_type, fake_document_ids=task_document_ids)
redis_msg.ack() redis_msg.ack()