mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Feat: add extractor component. (#10271)
### What problem does this PR solve? ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -113,6 +113,15 @@ class LLM(ComponentBase):
|
||||
def add2system_prompt(self, txt):
|
||||
self._param.sys_prompt += txt
|
||||
|
||||
def _sys_prompt_and_msg(self, msg, args):
|
||||
for p in self._param.prompts:
|
||||
if msg and msg[-1]["role"] == p["role"]:
|
||||
continue
|
||||
p = deepcopy(p)
|
||||
p["content"] = self.string_format(p["content"], args)
|
||||
msg.append(p)
|
||||
return msg, self.string_format(self._param.sys_prompt, args)
|
||||
|
||||
def _prepare_prompt_variables(self):
|
||||
if self._param.visual_files_var:
|
||||
self.imgs = self._canvas.get_variable_value(self._param.visual_files_var)
|
||||
@ -128,7 +137,6 @@ class LLM(ComponentBase):
|
||||
|
||||
args = {}
|
||||
vars = self.get_input_elements() if not self._param.debug_inputs else self._param.debug_inputs
|
||||
sys_prompt = self._param.sys_prompt
|
||||
for k, o in vars.items():
|
||||
args[k] = o["value"]
|
||||
if not isinstance(args[k], str):
|
||||
@ -138,16 +146,8 @@ class LLM(ComponentBase):
|
||||
args[k] = str(args[k])
|
||||
self.set_input_value(k, args[k])
|
||||
|
||||
msg = self._canvas.get_history(self._param.message_history_window_size)[:-1]
|
||||
for p in self._param.prompts:
|
||||
if msg and msg[-1]["role"] == p["role"]:
|
||||
continue
|
||||
msg.append(deepcopy(p))
|
||||
|
||||
sys_prompt = self.string_format(sys_prompt, args)
|
||||
msg, sys_prompt = self._sys_prompt_and_msg(self._canvas.get_history(self._param.message_history_window_size)[:-1], args)
|
||||
user_defined_prompt, sys_prompt = self._extract_prompts(sys_prompt)
|
||||
for m in msg:
|
||||
m["content"] = self.string_format(m["content"], args)
|
||||
if self._param.cite and self._canvas.get_reference()["chunks"]:
|
||||
sys_prompt += citation_prompt(user_defined_prompt)
|
||||
|
||||
|
||||
@ -28,6 +28,7 @@ from api.db import CanvasCategory, FileType
|
||||
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService, API4ConversationService
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
|
||||
from api.db.services.task_service import queue_dataflow
|
||||
from api.db.services.user_service import TenantService
|
||||
from api.db.services.user_canvas_version import UserCanvasVersionService
|
||||
@ -174,6 +175,25 @@ def run():
|
||||
return resp
|
||||
|
||||
|
||||
@manager.route('/rerun', methods=['POST']) # noqa: F821
|
||||
@validate_request("id", "dsl", "component_id")
|
||||
@login_required
|
||||
def rerun():
|
||||
req = request.json
|
||||
doc = PipelineOperationLogService.get_documents_info(req["id"])
|
||||
if not doc:
|
||||
return get_data_error_result(message="Document not found.")
|
||||
doc = doc[0]
|
||||
if 0 < doc["progress"] < 1:
|
||||
return get_data_error_result(message=f"`{doc['name']}` is processing...")
|
||||
|
||||
dsl = req["dsl"]
|
||||
dsl["path"] = [req["component_id"]]
|
||||
PipelineOperationLogService.update_by_id(req["id"], {"dsl": dsl})
|
||||
queue_dataflow(tenant_id=current_user.id, flow_id=req["id"], task_id=get_uuid(), doc_id=doc["id"], priority=0, rerun=True)
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
@manager.route('/cancel/<task_id>', methods=['PUT']) # noqa: F821
|
||||
@login_required
|
||||
def cancel(task_id):
|
||||
|
||||
@ -121,12 +121,20 @@ class DocumentService(CommonService):
|
||||
orderby, desc, keywords, run_status, types, suffix):
|
||||
fields = cls.get_cls_model_fields()
|
||||
if keywords:
|
||||
docs = cls.model.select(*fields).join(File2Document, on=(File2Document.document_id == cls.model.id)).join(File, on=(File.id == File2Document.file_id)).where(
|
||||
(cls.model.kb_id == kb_id),
|
||||
(fn.LOWER(cls.model.name).contains(keywords.lower()))
|
||||
)
|
||||
docs = cls.model.select(*[*fields, UserCanvas.title])\
|
||||
.join(File2Document, on=(File2Document.document_id == cls.model.id))\
|
||||
.join(File, on=(File.id == File2Document.file_id))\
|
||||
.join(UserCanvas, on=(cls.model.pipeline_id == UserCanvas.id), join_type=JOIN.LEFT_OUTER)\
|
||||
.where(
|
||||
(cls.model.kb_id == kb_id),
|
||||
(fn.LOWER(cls.model.name).contains(keywords.lower()))
|
||||
)
|
||||
else:
|
||||
docs = cls.model.select(*fields).join(File2Document, on=(File2Document.document_id == cls.model.id)).join(File, on=(File.id == File2Document.file_id)).where(cls.model.kb_id == kb_id)
|
||||
docs = cls.model.select(*[*fields, UserCanvas.title])\
|
||||
.join(File2Document, on=(File2Document.document_id == cls.model.id))\
|
||||
.join(UserCanvas, on=(cls.model.pipeline_id == UserCanvas.id), join_type=JOIN.LEFT_OUTER)\
|
||||
.join(File, on=(File.id == File2Document.file_id))\
|
||||
.where(cls.model.kb_id == kb_id)
|
||||
|
||||
if run_status:
|
||||
docs = docs.where(cls.model.run.in_(run_status))
|
||||
|
||||
@ -225,6 +225,7 @@ class KnowledgebaseService(CommonService):
|
||||
cls.model.token_num,
|
||||
cls.model.chunk_num,
|
||||
cls.model.parser_id,
|
||||
cls.model.pipeline_id,
|
||||
cls.model.parser_config,
|
||||
cls.model.pagerank,
|
||||
cls.model.create_time,
|
||||
|
||||
@ -14,12 +14,13 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from peewee import fn
|
||||
|
||||
from api.db import VALID_PIPELINE_TASK_TYPES
|
||||
from api.db.db_models import DB, PipelineOperationLog
|
||||
from api.db.db_models import DB, PipelineOperationLog, Document
|
||||
from api.db.services.canvas_service import UserCanvasService
|
||||
from api.db.services.common_service import CommonService
|
||||
from api.db.services.document_service import DocumentService
|
||||
@ -84,22 +85,20 @@ class PipelineOperationLogService(CommonService):
|
||||
def create(cls, document_id, pipeline_id, task_type, fake_document_ids=[]):
|
||||
from rag.flow.pipeline import Pipeline
|
||||
|
||||
tenant_id = ""
|
||||
title = ""
|
||||
avatar = ""
|
||||
dsl = ""
|
||||
operation_status = ""
|
||||
referred_document_id = document_id
|
||||
|
||||
if referred_document_id == "x" and fake_document_ids:
|
||||
referred_document_id = fake_document_ids[0]
|
||||
ok, document = DocumentService.get_by_id(referred_document_id)
|
||||
if not ok:
|
||||
raise RuntimeError(f"Document for referred_document_id {referred_document_id} not found")
|
||||
logging.warning(f"Document for referred_document_id {referred_document_id} not found")
|
||||
return
|
||||
DocumentService.update_progress_immediately([document.to_dict()])
|
||||
ok, document = DocumentService.get_by_id(referred_document_id)
|
||||
if not ok:
|
||||
raise RuntimeError(f"Document for referred_document_id {referred_document_id} not found")
|
||||
logging.warning(f"Document for referred_document_id {referred_document_id} not found")
|
||||
return
|
||||
if document.progress not in [1, -1]:
|
||||
return
|
||||
operation_status = document.run
|
||||
@ -189,6 +188,20 @@ class PipelineOperationLogService(CommonService):
|
||||
|
||||
return list(logs.dicts()), count
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_documents_info(cls, id):
|
||||
fields = [
|
||||
Document.id,
|
||||
Document.name,
|
||||
Document.progress
|
||||
]
|
||||
return cls.model.select(*fields).join(Document, on=(cls.model.document_id == Document.id)).where(
|
||||
cls.model.id == id,
|
||||
Document.progress > 0,
|
||||
Document.progress < 1
|
||||
).dicts()
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_dataset_logs_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, operation_status):
|
||||
@ -208,3 +221,4 @@ class PipelineOperationLogService(CommonService):
|
||||
logs = logs.paginate(page_number, items_per_page)
|
||||
|
||||
return list(logs.dicts()), count
|
||||
|
||||
|
||||
@ -35,6 +35,7 @@ from rag.utils.redis_conn import REDIS_CONN
|
||||
from api import settings
|
||||
from rag.nlp import search
|
||||
|
||||
CANVAS_DEBUG_DOC_ID = "dataflow_x"
|
||||
|
||||
def trim_header_by_lines(text: str, max_length) -> str:
|
||||
# Trim header text to maximum length while preserving line breaks
|
||||
@ -85,7 +86,7 @@ class TaskService(CommonService):
|
||||
Returns None if task is not found or has exceeded retry limit.
|
||||
"""
|
||||
doc_id = cls.model.doc_id
|
||||
if doc_id == "x" and doc_ids:
|
||||
if doc_id == CANVAS_DEBUG_DOC_ID and doc_ids:
|
||||
doc_id = doc_ids[0]
|
||||
|
||||
fields = [
|
||||
@ -476,14 +477,14 @@ def has_canceled(task_id):
|
||||
return False
|
||||
|
||||
|
||||
def queue_dataflow(tenant_id:str, flow_id:str, task_id:str, doc_id:str="x", file:dict=None, priority: int=0) -> tuple[bool, str]:
|
||||
def queue_dataflow(tenant_id:str, flow_id:str, task_id:str, doc_id:str=CANVAS_DEBUG_DOC_ID, file:dict=None, priority: int=0, rerun:bool=False) -> tuple[bool, str]:
|
||||
|
||||
task = dict(
|
||||
id=task_id,
|
||||
doc_id=doc_id,
|
||||
from_page=0,
|
||||
to_page=100000000,
|
||||
task_type="dataflow",
|
||||
task_type="dataflow" if not rerun else "dataflow_rerun",
|
||||
priority=priority,
|
||||
)
|
||||
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import base64
|
||||
import logging
|
||||
from functools import partial
|
||||
from io import BytesIO
|
||||
|
||||
@ -8,7 +9,7 @@ test_image_base64 = "iVBORw0KGgoAAAANSUhEUgAAAGQAAABkCAIAAAD/gAIDAAAA6ElEQVR4nO3
|
||||
test_image = base64.b64decode(test_image_base64)
|
||||
|
||||
|
||||
async def image2id(d: dict, storage_put_func: partial, objname:str, bucket:str="IMAGETEMPS"):
|
||||
async def image2id(d: dict, storage_put_func: partial, objname:str, bucket:str="imagetemps"):
|
||||
import logging
|
||||
from io import BytesIO
|
||||
import trio
|
||||
@ -46,7 +47,10 @@ def id2image(image_id:str|None, storage_get_func: partial):
|
||||
if len(arr) != 2:
|
||||
return
|
||||
bkt, nm = image_id.split("-")
|
||||
blob = storage_get_func(bucket=bkt, filename=nm)
|
||||
if not blob:
|
||||
return
|
||||
return Image.open(BytesIO(blob))
|
||||
try:
|
||||
blob = storage_get_func(bucket=bkt, filename=nm)
|
||||
if not blob:
|
||||
return
|
||||
return Image.open(BytesIO(blob))
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
|
||||
@ -13,13 +13,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
import trio
|
||||
|
||||
from agent.component.base import ComponentBase, ComponentParamBase
|
||||
from api.utils.api_utils import timeout
|
||||
|
||||
@ -43,17 +42,17 @@ class ProcessBase(ComponentBase):
|
||||
self.set_output("_created_time", time.perf_counter())
|
||||
for k, v in kwargs.items():
|
||||
self.set_output(k, v)
|
||||
#try:
|
||||
with trio.fail_after(self._param.timeout):
|
||||
await self._invoke(**kwargs)
|
||||
self.callback(1, "Done")
|
||||
#except Exception as e:
|
||||
# if self.get_exception_default_value():
|
||||
# self.set_exception_default_value()
|
||||
# else:
|
||||
# self.set_output("_ERROR", str(e))
|
||||
# logging.exception(e)
|
||||
# self.callback(-1, str(e))
|
||||
try:
|
||||
with trio.fail_after(self._param.timeout):
|
||||
await self._invoke(**kwargs)
|
||||
self.callback(1, "Done")
|
||||
except Exception as e:
|
||||
if self.get_exception_default_value():
|
||||
self.set_exception_default_value()
|
||||
else:
|
||||
self.set_output("_ERROR", str(e))
|
||||
logging.exception(e)
|
||||
self.callback(-1, str(e))
|
||||
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
|
||||
return self.output()
|
||||
|
||||
|
||||
15
rag/flow/extractor/__init__.py
Normal file
15
rag/flow/extractor/__init__.py
Normal file
@ -0,0 +1,15 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
59
rag/flow/extractor/extractor.py
Normal file
59
rag/flow/extractor/extractor.py
Normal file
@ -0,0 +1,59 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import random
|
||||
from agent.component.llm import LLMParam, LLM
|
||||
|
||||
|
||||
class ExtractorParam(LLMParam):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.field_name = ""
|
||||
|
||||
def check(self):
|
||||
super().check()
|
||||
self.check_empty(self.field_name, "Result Destination")
|
||||
|
||||
|
||||
class Extractor(LLM):
|
||||
component_name = "Extractor"
|
||||
|
||||
async def _invoke(self, **kwargs):
|
||||
self.callback(random.randint(1, 5) / 100.0, "Start to generate.")
|
||||
inputs = self.get_input_elements()
|
||||
chunks = []
|
||||
chunks_key = ""
|
||||
args = {}
|
||||
for k, v in inputs.items():
|
||||
args[k] = v["value"]
|
||||
if isinstance(args[k], list):
|
||||
chunks = args[k]
|
||||
chunks_key = k
|
||||
|
||||
if chunks:
|
||||
prog = 0
|
||||
for i, ck in enumerate(chunks):
|
||||
args[chunks_key] = ck["text"]
|
||||
msg, sys_prompt = self._sys_prompt_and_msg([], args)
|
||||
msg.insert(0, {"role": "system", "content": sys_prompt})
|
||||
ck[self._param.field_name] = self._generate(msg)
|
||||
prog += 1./len(chunks)
|
||||
self.callback(prog, f"{i+1} / {len(chunks)}")
|
||||
self.set_output("chunks", chunks)
|
||||
else:
|
||||
msg, sys_prompt = self._sys_prompt_and_msg([], args)
|
||||
msg.insert(0, {"role": "system", "content": sys_prompt})
|
||||
self.set_output("chunks", [{self._param.field_name: self._generate(msg)}])
|
||||
|
||||
|
||||
38
rag/flow/extractor/schema.py
Normal file
38
rag/flow/extractor/schema.py
Normal file
@ -0,0 +1,38 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class ExtractorFromUpstream(BaseModel):
|
||||
created_time: float | None = Field(default=None, alias="_created_time")
|
||||
elapsed_time: float | None = Field(default=None, alias="_elapsed_time")
|
||||
|
||||
name: str
|
||||
file: dict | None = Field(default=None)
|
||||
chunks: list[dict[str, Any]] | None = Field(default=None)
|
||||
|
||||
output_format: Literal["json", "markdown", "text", "html"] | None = Field(default=None)
|
||||
|
||||
json_result: list[dict[str, Any]] | None = Field(default=None, alias="json")
|
||||
markdown_result: str | None = Field(default=None, alias="markdown")
|
||||
text_result: str | None = Field(default=None, alias="text")
|
||||
html_result: list[str] | None = Field(default=None, alias="html")
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
|
||||
# def to_dict(self, *, exclude_none: bool = True) -> dict:
|
||||
# return self.model_dump(by_alias=True, exclude_none=exclude_none)
|
||||
@ -17,15 +17,11 @@ import datetime
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
import time
|
||||
from timeit import default_timer as timer
|
||||
import trio
|
||||
|
||||
from agent.canvas import Graph
|
||||
from api.db import PipelineTaskType
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.task_service import has_canceled
|
||||
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
|
||||
from api.db.services.task_service import has_canceled, TaskService, CANVAS_DEBUG_DOC_ID
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
|
||||
|
||||
@ -34,9 +30,9 @@ class Pipeline(Graph):
|
||||
if isinstance(dsl, dict):
|
||||
dsl = json.dumps(dsl, ensure_ascii=False)
|
||||
super().__init__(dsl, tenant_id, task_id)
|
||||
if doc_id == CANVAS_DEBUG_DOC_ID:
|
||||
doc_id = None
|
||||
self._doc_id = doc_id
|
||||
if self._doc_id == "x":
|
||||
self._doc_id = None
|
||||
self._flow_id = flow_id
|
||||
self._kb_id = None
|
||||
if self._doc_id:
|
||||
@ -80,7 +76,7 @@ class Pipeline(Graph):
|
||||
}
|
||||
]
|
||||
REDIS_CONN.set_obj(log_key, obj, 60 * 30)
|
||||
if self._doc_id:
|
||||
if self._doc_id and self.task_id:
|
||||
percentage = 1.0 / len(self.components.items())
|
||||
msg = ""
|
||||
finished = 0.0
|
||||
@ -96,7 +92,7 @@ class Pipeline(Graph):
|
||||
if finished < 0:
|
||||
break
|
||||
finished += o["trace"][-1]["progress"] * percentage
|
||||
DocumentService.update_by_id(self._doc_id, {"progress": finished, "progress_msg": msg})
|
||||
TaskService.update_progress(self.task_id, {"progress": finished, "progress_msg": msg})
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
|
||||
@ -113,34 +109,32 @@ class Pipeline(Graph):
|
||||
logging.exception(e)
|
||||
return []
|
||||
|
||||
def reset(self):
|
||||
super().reset()
|
||||
|
||||
async def run(self, **kwargs):
|
||||
log_key = f"{self._flow_id}-{self.task_id}-logs"
|
||||
try:
|
||||
REDIS_CONN.set_obj(log_key, [], 60 * 10)
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
|
||||
async def run(self, **kwargs):
|
||||
st = time.perf_counter()
|
||||
self.error = ""
|
||||
if not self.path:
|
||||
self.path.append("File")
|
||||
|
||||
if self._doc_id:
|
||||
DocumentService.update_by_id(
|
||||
self._doc_id, {"progress": random.randint(0, 5) / 100.0, "progress_msg": "Start the pipeline...", "process_begin_at": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
|
||||
)
|
||||
|
||||
self.error = ""
|
||||
idx = len(self.path) - 1
|
||||
if idx == 0:
|
||||
cpn_obj = self.get_component_obj(self.path[0])
|
||||
await cpn_obj.invoke(**kwargs)
|
||||
if cpn_obj.error():
|
||||
self.error = "[ERROR]" + cpn_obj.error()
|
||||
else:
|
||||
idx += 1
|
||||
self.path.extend(cpn_obj.get_downstream())
|
||||
self.callback(cpn_obj.component_name, -1, self.error)
|
||||
|
||||
if self._doc_id:
|
||||
TaskService.update_progress(self.task_id, {
|
||||
"progress": random.randint(0, 5) / 100.0,
|
||||
"progress_msg": "Start the pipeline...",
|
||||
"begin_at": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")})
|
||||
|
||||
idx = len(self.path) - 1
|
||||
cpn_obj = self.get_component_obj(self.path[idx])
|
||||
idx += 1
|
||||
self.path.extend(cpn_obj.get_downstream())
|
||||
|
||||
while idx < len(self.path) and not self.error:
|
||||
last_cpn = self.get_component_obj(self.path[idx - 1])
|
||||
@ -152,23 +146,21 @@ class Pipeline(Graph):
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(invoke)
|
||||
|
||||
if cpn_obj.error():
|
||||
self.error = "[ERROR]" + cpn_obj.error()
|
||||
self.callback(cpn_obj.component_name, -1, self.error)
|
||||
self.callback(cpn_obj._id, -1, self.error)
|
||||
break
|
||||
idx += 1
|
||||
self.path.extend(cpn_obj.get_downstream())
|
||||
|
||||
self.callback("END", 1, json.dumps(self.get_component_obj(self.path[-1]).output(), ensure_ascii=False))
|
||||
self.callback("END", 1 if not self.error else -1, json.dumps(self.get_component_obj(self.path[-1]).output(), ensure_ascii=False))
|
||||
|
||||
if self._doc_id:
|
||||
DocumentService.update_by_id(
|
||||
self._doc_id,
|
||||
{
|
||||
"progress": 1 if not self.error else -1,
|
||||
"progress_msg": "Pipeline finished...\n" + self.error,
|
||||
"process_duration": time.perf_counter() - st,
|
||||
},
|
||||
)
|
||||
if not self.error:
|
||||
return self.get_component_obj(self.path[-1]).output()
|
||||
|
||||
PipelineOperationLogService.create(document_id=self._doc_id, pipeline_id=self._flow_id, task_type=PipelineTaskType.PARSE)
|
||||
TaskService.update_progress(self.task_id, {
|
||||
"progress": -1,
|
||||
"progress_msg": f"[ERROR]: {self.error}"})
|
||||
|
||||
return {}
|
||||
|
||||
@ -99,7 +99,7 @@ class Splitter(ProcessBase):
|
||||
{
|
||||
"text": RAGFlowPdfParser.remove_tag(c),
|
||||
"image": img,
|
||||
"positions": RAGFlowPdfParser.extract_positions(c),
|
||||
"positions": [[pos[0][-1]+1, *pos[1:]] for pos in RAGFlowPdfParser.extract_positions(c)],
|
||||
}
|
||||
for c, img in zip(chunks, images)
|
||||
]
|
||||
|
||||
@ -120,8 +120,12 @@ class Tokenizer(ProcessBase):
|
||||
ck["question_tks"] = rag_tokenizer.tokenize("\n".join(ck["questions"]))
|
||||
if ck.get("keywords"):
|
||||
ck["important_tks"] = rag_tokenizer.tokenize("\n".join(ck["keywords"]))
|
||||
ck["content_ltks"] = rag_tokenizer.tokenize(ck["text"])
|
||||
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
|
||||
if ck.get("summary"):
|
||||
ck["content_ltks"] = rag_tokenizer.tokenize(ck["summary"])
|
||||
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
|
||||
else:
|
||||
ck["content_ltks"] = rag_tokenizer.tokenize(ck["text"])
|
||||
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
|
||||
if i % 100 == 99:
|
||||
self.callback(i * 1.0 / len(chunks) / parts)
|
||||
|
||||
|
||||
@ -285,6 +285,7 @@ def tokenize_chunks(chunks, doc, eng, pdf_parser=None):
|
||||
res.append(d)
|
||||
return res
|
||||
|
||||
|
||||
def tokenize_chunks_with_images(chunks, doc, eng, images):
|
||||
res = []
|
||||
# wrap up as es documents
|
||||
@ -299,6 +300,7 @@ def tokenize_chunks_with_images(chunks, doc, eng, images):
|
||||
res.append(d)
|
||||
return res
|
||||
|
||||
|
||||
def tokenize_table(tbls, doc, eng, batch_size=10):
|
||||
res = []
|
||||
# add tables
|
||||
|
||||
@ -21,6 +21,7 @@ import sys
|
||||
import threading
|
||||
import time
|
||||
from api.db.services.canvas_service import UserCanvasService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
|
||||
from api.utils.api_utils import timeout
|
||||
from api.utils.base64_image import image2id
|
||||
@ -49,7 +50,7 @@ from peewee import DoesNotExist
|
||||
from api.db import LLMType, ParserType, PipelineTaskType
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.task_service import TaskService, has_canceled
|
||||
from api.db.services.task_service import TaskService, has_canceled, CANVAS_DEBUG_DOC_ID
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api import settings
|
||||
from api.versions import get_ragflow_version
|
||||
@ -146,6 +147,7 @@ def start_tracemalloc_and_snapshot(signum, frame):
|
||||
max_rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
|
||||
logging.info(f"taken snapshot {snapshot_file}. max RSS={max_rss / 1000:.2f} MB, current memory usage: {current / 10**6:.2f} MB, Peak memory usage: {peak / 10**6:.2f} MB")
|
||||
|
||||
|
||||
# SIGUSR2 handler: stop tracemalloc
|
||||
def stop_tracemalloc(signum, frame):
|
||||
if tracemalloc.is_tracing():
|
||||
@ -154,6 +156,7 @@ def stop_tracemalloc(signum, frame):
|
||||
else:
|
||||
logging.info("tracemalloc not running")
|
||||
|
||||
|
||||
class TaskCanceledException(Exception):
|
||||
def __init__(self, msg):
|
||||
self.msg = msg
|
||||
@ -471,11 +474,97 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
|
||||
|
||||
|
||||
async def run_dataflow(task: dict):
|
||||
task_start_ts = timer()
|
||||
dataflow_id = task["dataflow_id"]
|
||||
e, cvs = UserCanvasService.get_by_id(dataflow_id)
|
||||
pipeline = Pipeline(cvs.dsl, tenant_id=task["tenant_id"], doc_id=task["doc_id"], task_id=task["id"], flow_id=dataflow_id)
|
||||
pipeline.reset()
|
||||
await pipeline.run(file=task.get("file"))
|
||||
doc_id = task["doc_id"]
|
||||
task_id = task["id"]
|
||||
if task["task_type"] == "dataflow":
|
||||
e, cvs = UserCanvasService.get_by_id(dataflow_id)
|
||||
assert e, "User pipeline not found."
|
||||
dsl = cvs.dsl
|
||||
else:
|
||||
e, pipeline_log = PipelineOperationLogService.get_by_id(dataflow_id)
|
||||
assert e, "Pipeline log not found."
|
||||
dsl = pipeline_log.dsl
|
||||
pipeline = Pipeline(dsl, tenant_id=task["tenant_id"], doc_id=doc_id, task_id=task_id, flow_id=dataflow_id)
|
||||
chunks = await pipeline.run(file=task["file"]) if task.get("file") else pipeline.run()
|
||||
if doc_id == CANVAS_DEBUG_DOC_ID:
|
||||
return
|
||||
|
||||
if not chunks:
|
||||
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE)
|
||||
return
|
||||
|
||||
embedding_token_consumption = chunks.get("embedding_token_consumption", 0)
|
||||
if chunks.get("chunks"):
|
||||
chunks = chunks["chunks"]
|
||||
elif chunks.get("json"):
|
||||
chunks = chunks["json"]
|
||||
elif chunks.get("markdown"):
|
||||
chunks = [{"text": [chunks["markdown"]]}]
|
||||
elif chunks.get("text"):
|
||||
chunks = [{"text": [chunks["text"]]}]
|
||||
elif chunks.get("html"):
|
||||
chunks = [{"text": [chunks["html"]]}]
|
||||
|
||||
keys = [k for o in chunks for k in list(o.keys())]
|
||||
if not any([re.match(r"q_[0-9]+_vec", k) for k in keys]):
|
||||
set_progress(task_id, prog=0.82, msg="Start to embedding...")
|
||||
e, kb = KnowledgebaseService.get_by_id(task["kb_id"])
|
||||
embedding_id = kb.embd_id
|
||||
embedding_model = LLMBundle(task["tenant_id"], LLMType.EMBEDDING, llm_name=embedding_id)
|
||||
@timeout(60)
|
||||
def batch_encode(txts):
|
||||
nonlocal embedding_model
|
||||
return embedding_model.encode([truncate(c, embedding_model.max_length - 10) for c in txts])
|
||||
vects = np.array([])
|
||||
texts = [o.get("questions", o.get("summary", o["text"])) for o in chunks]
|
||||
delta = 0.20/(len(texts)//EMBEDDING_BATCH_SIZE)
|
||||
prog = 0.8
|
||||
for i in range(0, len(texts), EMBEDDING_BATCH_SIZE):
|
||||
async with embed_limiter:
|
||||
vts, c = await trio.to_thread.run_sync(lambda: batch_encode(texts[i : i + EMBEDDING_BATCH_SIZE]))
|
||||
if len(vects) == 0:
|
||||
vects = vts
|
||||
else:
|
||||
vects = np.concatenate((vects, vts), axis=0)
|
||||
embedding_token_consumption += c
|
||||
prog += delta
|
||||
set_progress(task_id, prog=prog, msg=f"{i+1} / {len(texts)//EMBEDDING_BATCH_SIZE}")
|
||||
|
||||
assert len(vects) == len(chunks)
|
||||
for i, ck in enumerate(chunks):
|
||||
v = vects[i].tolist()
|
||||
ck["q_%d_vec" % len(v)] = v
|
||||
|
||||
for ck in chunks:
|
||||
ck["doc_id"] = task["doc_id"]
|
||||
ck["kb_id"] = [str(task["kb_id"])]
|
||||
ck["docnm_kwd"] = task["name"]
|
||||
ck["create_time"] = str(datetime.now()).replace("T", " ")[:19]
|
||||
ck["create_timestamp_flt"] = datetime.now().timestamp()
|
||||
ck["id"] = xxhash.xxh64((ck["text"] + str(ck["doc_id"])).encode("utf-8")).hexdigest()
|
||||
if "questions" in ck:
|
||||
del ck["questions"]
|
||||
if "keywords" in ck:
|
||||
del ck["keywords"]
|
||||
if "summary" in ck:
|
||||
del ck["summary"]
|
||||
del ck["text"]
|
||||
|
||||
start_ts = timer()
|
||||
set_progress(task_id, prog=0.82, msg="Start to index...")
|
||||
e = await insert_es(task_id, task["tenant_id"], task["kb_id"], chunks, partial(set_progress, task_id, 0, 100000000))
|
||||
if not e:
|
||||
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE)
|
||||
return
|
||||
|
||||
time_cost = timer() - start_ts
|
||||
task_time_cost = timer() - task_start_ts
|
||||
set_progress(task_id, prog=1., msg="Indexing done ({:.2f}s). Task done ({:.2f}s)".format(time_cost, task_time_cost))
|
||||
logging.info(
|
||||
"[Done], chunks({}), token({}), elapsed:{:.2f}".format(len(chunks), embedding_token_consumption, task_time_cost))
|
||||
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE)
|
||||
|
||||
|
||||
@timeout(3600)
|
||||
@ -520,11 +609,48 @@ async def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
|
||||
return res, tk_count
|
||||
|
||||
|
||||
async def delete_image(kb_id, chunk_id):
|
||||
try:
|
||||
async with minio_limiter:
|
||||
STORAGE_IMPL.delete(kb_id, chunk_id)
|
||||
except Exception:
|
||||
logging.exception(f"Deleting image of chunk {chunk_id} got exception")
|
||||
raise
|
||||
|
||||
|
||||
async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_callback):
|
||||
for b in range(0, len(chunks), DOC_BULK_SIZE):
|
||||
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + DOC_BULK_SIZE], search.index_name(task_tenant_id), task_dataset_id))
|
||||
task_canceled = has_canceled(task_id)
|
||||
if task_canceled:
|
||||
progress_callback(-1, msg="Task has been canceled.")
|
||||
return
|
||||
if b % 128 == 0:
|
||||
progress_callback(prog=0.8 + 0.1 * (b + 1) / len(chunks), msg="")
|
||||
if doc_store_result:
|
||||
error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
|
||||
progress_callback(-1, msg=error_message)
|
||||
raise Exception(error_message)
|
||||
chunk_ids = [chunk["id"] for chunk in chunks[:b + DOC_BULK_SIZE]]
|
||||
chunk_ids_str = " ".join(chunk_ids)
|
||||
try:
|
||||
TaskService.update_chunk_ids(task_id, chunk_ids_str)
|
||||
except DoesNotExist:
|
||||
logging.warning(f"do_handle_task update_chunk_ids failed since task {task_id} is unknown.")
|
||||
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id), task_dataset_id))
|
||||
async with trio.open_nursery() as nursery:
|
||||
for chunk_id in chunk_ids:
|
||||
nursery.start_soon(delete_image, task_dataset_id, chunk_id)
|
||||
progress_callback(-1, msg=f"Chunk updates failed since task {task_id} is unknown.")
|
||||
return
|
||||
return True
|
||||
|
||||
|
||||
@timeout(60*60*2, 1)
|
||||
async def do_handle_task(task):
|
||||
task_type = task.get("task_type", "")
|
||||
|
||||
if task_type == "dataflow" and task.get("doc_id", "") == "x":
|
||||
if task_type == "dataflow" and task.get("doc_id", "") == CANVAS_DEBUG_DOC_ID:
|
||||
await run_dataflow(task)
|
||||
return
|
||||
|
||||
@ -569,7 +695,7 @@ async def do_handle_task(task):
|
||||
|
||||
init_kb(task, vector_size)
|
||||
|
||||
if task_type == "dataflow":
|
||||
if task_type[:len("dataflow")] == "dataflow":
|
||||
await run_dataflow(task)
|
||||
return
|
||||
|
||||
@ -631,41 +757,9 @@ async def do_handle_task(task):
|
||||
|
||||
chunk_count = len(set([chunk["id"] for chunk in chunks]))
|
||||
start_ts = timer()
|
||||
doc_store_result = ""
|
||||
|
||||
async def delete_image(kb_id, chunk_id):
|
||||
try:
|
||||
async with minio_limiter:
|
||||
STORAGE_IMPL.delete(kb_id, chunk_id)
|
||||
except Exception:
|
||||
logging.exception(
|
||||
"Deleting image of chunk {}/{}/{} got exception".format(task["location"], task["name"], chunk_id))
|
||||
raise
|
||||
|
||||
for b in range(0, len(chunks), DOC_BULK_SIZE):
|
||||
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + DOC_BULK_SIZE], search.index_name(task_tenant_id), task_dataset_id))
|
||||
task_canceled = has_canceled(task_id)
|
||||
if task_canceled:
|
||||
progress_callback(-1, msg="Task has been canceled.")
|
||||
return
|
||||
if b % 128 == 0:
|
||||
progress_callback(prog=0.8 + 0.1 * (b + 1) / len(chunks), msg="")
|
||||
if doc_store_result:
|
||||
error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
|
||||
progress_callback(-1, msg=error_message)
|
||||
raise Exception(error_message)
|
||||
chunk_ids = [chunk["id"] for chunk in chunks[:b + DOC_BULK_SIZE]]
|
||||
chunk_ids_str = " ".join(chunk_ids)
|
||||
try:
|
||||
TaskService.update_chunk_ids(task["id"], chunk_ids_str)
|
||||
except DoesNotExist:
|
||||
logging.warning(f"do_handle_task update_chunk_ids failed since task {task['id']} is unknown.")
|
||||
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id), task_dataset_id))
|
||||
async with trio.open_nursery() as nursery:
|
||||
for chunk_id in chunk_ids:
|
||||
nursery.start_soon(delete_image, task_dataset_id, chunk_id)
|
||||
progress_callback(-1, msg=f"Chunk updates failed since task {task['id']} is unknown.")
|
||||
return
|
||||
e = await insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_callback)
|
||||
if not e:
|
||||
return
|
||||
|
||||
logging.info("Indexing doc({}), page({}-{}), chunks({}), elapsed: {:.2f}".format(task_document_name, task_from_page,
|
||||
task_to_page, len(chunks),
|
||||
@ -715,7 +809,8 @@ async def handle_task():
|
||||
task_document_ids = []
|
||||
if task_type in ["graphrag"]:
|
||||
task_document_ids = task["doc_ids"]
|
||||
PipelineOperationLogService.record_pipeline_operation(document_id=task["doc_id"], pipeline_id=task.get("dataflow_id", "") or "", task_type=pipeline_task_type, fake_document_ids=task_document_ids)
|
||||
if task["doc_id"] != CANVAS_DEBUG_DOC_ID:
|
||||
PipelineOperationLogService.record_pipeline_operation(document_id=task["doc_id"], pipeline_id=task.get("dataflow_id", "") or "", task_type=pipeline_task_type, fake_document_ids=task_document_ids)
|
||||
|
||||
redis_msg.ack()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user