mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Feat: add extractor component. (#10271)
### What problem does this PR solve? ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -113,6 +113,15 @@ class LLM(ComponentBase):
|
|||||||
def add2system_prompt(self, txt):
|
def add2system_prompt(self, txt):
|
||||||
self._param.sys_prompt += txt
|
self._param.sys_prompt += txt
|
||||||
|
|
||||||
|
def _sys_prompt_and_msg(self, msg, args):
|
||||||
|
for p in self._param.prompts:
|
||||||
|
if msg and msg[-1]["role"] == p["role"]:
|
||||||
|
continue
|
||||||
|
p = deepcopy(p)
|
||||||
|
p["content"] = self.string_format(p["content"], args)
|
||||||
|
msg.append(p)
|
||||||
|
return msg, self.string_format(self._param.sys_prompt, args)
|
||||||
|
|
||||||
def _prepare_prompt_variables(self):
|
def _prepare_prompt_variables(self):
|
||||||
if self._param.visual_files_var:
|
if self._param.visual_files_var:
|
||||||
self.imgs = self._canvas.get_variable_value(self._param.visual_files_var)
|
self.imgs = self._canvas.get_variable_value(self._param.visual_files_var)
|
||||||
@ -128,7 +137,6 @@ class LLM(ComponentBase):
|
|||||||
|
|
||||||
args = {}
|
args = {}
|
||||||
vars = self.get_input_elements() if not self._param.debug_inputs else self._param.debug_inputs
|
vars = self.get_input_elements() if not self._param.debug_inputs else self._param.debug_inputs
|
||||||
sys_prompt = self._param.sys_prompt
|
|
||||||
for k, o in vars.items():
|
for k, o in vars.items():
|
||||||
args[k] = o["value"]
|
args[k] = o["value"]
|
||||||
if not isinstance(args[k], str):
|
if not isinstance(args[k], str):
|
||||||
@ -138,16 +146,8 @@ class LLM(ComponentBase):
|
|||||||
args[k] = str(args[k])
|
args[k] = str(args[k])
|
||||||
self.set_input_value(k, args[k])
|
self.set_input_value(k, args[k])
|
||||||
|
|
||||||
msg = self._canvas.get_history(self._param.message_history_window_size)[:-1]
|
msg, sys_prompt = self._sys_prompt_and_msg(self._canvas.get_history(self._param.message_history_window_size)[:-1], args)
|
||||||
for p in self._param.prompts:
|
|
||||||
if msg and msg[-1]["role"] == p["role"]:
|
|
||||||
continue
|
|
||||||
msg.append(deepcopy(p))
|
|
||||||
|
|
||||||
sys_prompt = self.string_format(sys_prompt, args)
|
|
||||||
user_defined_prompt, sys_prompt = self._extract_prompts(sys_prompt)
|
user_defined_prompt, sys_prompt = self._extract_prompts(sys_prompt)
|
||||||
for m in msg:
|
|
||||||
m["content"] = self.string_format(m["content"], args)
|
|
||||||
if self._param.cite and self._canvas.get_reference()["chunks"]:
|
if self._param.cite and self._canvas.get_reference()["chunks"]:
|
||||||
sys_prompt += citation_prompt(user_defined_prompt)
|
sys_prompt += citation_prompt(user_defined_prompt)
|
||||||
|
|
||||||
|
|||||||
@ -28,6 +28,7 @@ from api.db import CanvasCategory, FileType
|
|||||||
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService, API4ConversationService
|
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService, API4ConversationService
|
||||||
from api.db.services.document_service import DocumentService
|
from api.db.services.document_service import DocumentService
|
||||||
from api.db.services.file_service import FileService
|
from api.db.services.file_service import FileService
|
||||||
|
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
|
||||||
from api.db.services.task_service import queue_dataflow
|
from api.db.services.task_service import queue_dataflow
|
||||||
from api.db.services.user_service import TenantService
|
from api.db.services.user_service import TenantService
|
||||||
from api.db.services.user_canvas_version import UserCanvasVersionService
|
from api.db.services.user_canvas_version import UserCanvasVersionService
|
||||||
@ -174,6 +175,25 @@ def run():
|
|||||||
return resp
|
return resp
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route('/rerun', methods=['POST']) # noqa: F821
|
||||||
|
@validate_request("id", "dsl", "component_id")
|
||||||
|
@login_required
|
||||||
|
def rerun():
|
||||||
|
req = request.json
|
||||||
|
doc = PipelineOperationLogService.get_documents_info(req["id"])
|
||||||
|
if not doc:
|
||||||
|
return get_data_error_result(message="Document not found.")
|
||||||
|
doc = doc[0]
|
||||||
|
if 0 < doc["progress"] < 1:
|
||||||
|
return get_data_error_result(message=f"`{doc['name']}` is processing...")
|
||||||
|
|
||||||
|
dsl = req["dsl"]
|
||||||
|
dsl["path"] = [req["component_id"]]
|
||||||
|
PipelineOperationLogService.update_by_id(req["id"], {"dsl": dsl})
|
||||||
|
queue_dataflow(tenant_id=current_user.id, flow_id=req["id"], task_id=get_uuid(), doc_id=doc["id"], priority=0, rerun=True)
|
||||||
|
return get_json_result(data=True)
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/cancel/<task_id>', methods=['PUT']) # noqa: F821
|
@manager.route('/cancel/<task_id>', methods=['PUT']) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
def cancel(task_id):
|
def cancel(task_id):
|
||||||
|
|||||||
@ -121,12 +121,20 @@ class DocumentService(CommonService):
|
|||||||
orderby, desc, keywords, run_status, types, suffix):
|
orderby, desc, keywords, run_status, types, suffix):
|
||||||
fields = cls.get_cls_model_fields()
|
fields = cls.get_cls_model_fields()
|
||||||
if keywords:
|
if keywords:
|
||||||
docs = cls.model.select(*fields).join(File2Document, on=(File2Document.document_id == cls.model.id)).join(File, on=(File.id == File2Document.file_id)).where(
|
docs = cls.model.select(*[*fields, UserCanvas.title])\
|
||||||
(cls.model.kb_id == kb_id),
|
.join(File2Document, on=(File2Document.document_id == cls.model.id))\
|
||||||
(fn.LOWER(cls.model.name).contains(keywords.lower()))
|
.join(File, on=(File.id == File2Document.file_id))\
|
||||||
)
|
.join(UserCanvas, on=(cls.model.pipeline_id == UserCanvas.id), join_type=JOIN.LEFT_OUTER)\
|
||||||
|
.where(
|
||||||
|
(cls.model.kb_id == kb_id),
|
||||||
|
(fn.LOWER(cls.model.name).contains(keywords.lower()))
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
docs = cls.model.select(*fields).join(File2Document, on=(File2Document.document_id == cls.model.id)).join(File, on=(File.id == File2Document.file_id)).where(cls.model.kb_id == kb_id)
|
docs = cls.model.select(*[*fields, UserCanvas.title])\
|
||||||
|
.join(File2Document, on=(File2Document.document_id == cls.model.id))\
|
||||||
|
.join(UserCanvas, on=(cls.model.pipeline_id == UserCanvas.id), join_type=JOIN.LEFT_OUTER)\
|
||||||
|
.join(File, on=(File.id == File2Document.file_id))\
|
||||||
|
.where(cls.model.kb_id == kb_id)
|
||||||
|
|
||||||
if run_status:
|
if run_status:
|
||||||
docs = docs.where(cls.model.run.in_(run_status))
|
docs = docs.where(cls.model.run.in_(run_status))
|
||||||
|
|||||||
@ -225,6 +225,7 @@ class KnowledgebaseService(CommonService):
|
|||||||
cls.model.token_num,
|
cls.model.token_num,
|
||||||
cls.model.chunk_num,
|
cls.model.chunk_num,
|
||||||
cls.model.parser_id,
|
cls.model.parser_id,
|
||||||
|
cls.model.pipeline_id,
|
||||||
cls.model.parser_config,
|
cls.model.parser_config,
|
||||||
cls.model.pagerank,
|
cls.model.pagerank,
|
||||||
cls.model.create_time,
|
cls.model.create_time,
|
||||||
|
|||||||
@ -14,12 +14,13 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from peewee import fn
|
from peewee import fn
|
||||||
|
|
||||||
from api.db import VALID_PIPELINE_TASK_TYPES
|
from api.db import VALID_PIPELINE_TASK_TYPES
|
||||||
from api.db.db_models import DB, PipelineOperationLog
|
from api.db.db_models import DB, PipelineOperationLog, Document
|
||||||
from api.db.services.canvas_service import UserCanvasService
|
from api.db.services.canvas_service import UserCanvasService
|
||||||
from api.db.services.common_service import CommonService
|
from api.db.services.common_service import CommonService
|
||||||
from api.db.services.document_service import DocumentService
|
from api.db.services.document_service import DocumentService
|
||||||
@ -84,22 +85,20 @@ class PipelineOperationLogService(CommonService):
|
|||||||
def create(cls, document_id, pipeline_id, task_type, fake_document_ids=[]):
|
def create(cls, document_id, pipeline_id, task_type, fake_document_ids=[]):
|
||||||
from rag.flow.pipeline import Pipeline
|
from rag.flow.pipeline import Pipeline
|
||||||
|
|
||||||
tenant_id = ""
|
|
||||||
title = ""
|
|
||||||
avatar = ""
|
|
||||||
dsl = ""
|
dsl = ""
|
||||||
operation_status = ""
|
|
||||||
referred_document_id = document_id
|
referred_document_id = document_id
|
||||||
|
|
||||||
if referred_document_id == "x" and fake_document_ids:
|
if referred_document_id == "x" and fake_document_ids:
|
||||||
referred_document_id = fake_document_ids[0]
|
referred_document_id = fake_document_ids[0]
|
||||||
ok, document = DocumentService.get_by_id(referred_document_id)
|
ok, document = DocumentService.get_by_id(referred_document_id)
|
||||||
if not ok:
|
if not ok:
|
||||||
raise RuntimeError(f"Document for referred_document_id {referred_document_id} not found")
|
logging.warning(f"Document for referred_document_id {referred_document_id} not found")
|
||||||
|
return
|
||||||
DocumentService.update_progress_immediately([document.to_dict()])
|
DocumentService.update_progress_immediately([document.to_dict()])
|
||||||
ok, document = DocumentService.get_by_id(referred_document_id)
|
ok, document = DocumentService.get_by_id(referred_document_id)
|
||||||
if not ok:
|
if not ok:
|
||||||
raise RuntimeError(f"Document for referred_document_id {referred_document_id} not found")
|
logging.warning(f"Document for referred_document_id {referred_document_id} not found")
|
||||||
|
return
|
||||||
if document.progress not in [1, -1]:
|
if document.progress not in [1, -1]:
|
||||||
return
|
return
|
||||||
operation_status = document.run
|
operation_status = document.run
|
||||||
@ -189,6 +188,20 @@ class PipelineOperationLogService(CommonService):
|
|||||||
|
|
||||||
return list(logs.dicts()), count
|
return list(logs.dicts()), count
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def get_documents_info(cls, id):
|
||||||
|
fields = [
|
||||||
|
Document.id,
|
||||||
|
Document.name,
|
||||||
|
Document.progress
|
||||||
|
]
|
||||||
|
return cls.model.select(*fields).join(Document, on=(cls.model.document_id == Document.id)).where(
|
||||||
|
cls.model.id == id,
|
||||||
|
Document.progress > 0,
|
||||||
|
Document.progress < 1
|
||||||
|
).dicts()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_dataset_logs_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, operation_status):
|
def get_dataset_logs_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, operation_status):
|
||||||
@ -208,3 +221,4 @@ class PipelineOperationLogService(CommonService):
|
|||||||
logs = logs.paginate(page_number, items_per_page)
|
logs = logs.paginate(page_number, items_per_page)
|
||||||
|
|
||||||
return list(logs.dicts()), count
|
return list(logs.dicts()), count
|
||||||
|
|
||||||
|
|||||||
@ -35,6 +35,7 @@ from rag.utils.redis_conn import REDIS_CONN
|
|||||||
from api import settings
|
from api import settings
|
||||||
from rag.nlp import search
|
from rag.nlp import search
|
||||||
|
|
||||||
|
CANVAS_DEBUG_DOC_ID = "dataflow_x"
|
||||||
|
|
||||||
def trim_header_by_lines(text: str, max_length) -> str:
|
def trim_header_by_lines(text: str, max_length) -> str:
|
||||||
# Trim header text to maximum length while preserving line breaks
|
# Trim header text to maximum length while preserving line breaks
|
||||||
@ -85,7 +86,7 @@ class TaskService(CommonService):
|
|||||||
Returns None if task is not found or has exceeded retry limit.
|
Returns None if task is not found or has exceeded retry limit.
|
||||||
"""
|
"""
|
||||||
doc_id = cls.model.doc_id
|
doc_id = cls.model.doc_id
|
||||||
if doc_id == "x" and doc_ids:
|
if doc_id == CANVAS_DEBUG_DOC_ID and doc_ids:
|
||||||
doc_id = doc_ids[0]
|
doc_id = doc_ids[0]
|
||||||
|
|
||||||
fields = [
|
fields = [
|
||||||
@ -476,14 +477,14 @@ def has_canceled(task_id):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def queue_dataflow(tenant_id:str, flow_id:str, task_id:str, doc_id:str="x", file:dict=None, priority: int=0) -> tuple[bool, str]:
|
def queue_dataflow(tenant_id:str, flow_id:str, task_id:str, doc_id:str=CANVAS_DEBUG_DOC_ID, file:dict=None, priority: int=0, rerun:bool=False) -> tuple[bool, str]:
|
||||||
|
|
||||||
task = dict(
|
task = dict(
|
||||||
id=task_id,
|
id=task_id,
|
||||||
doc_id=doc_id,
|
doc_id=doc_id,
|
||||||
from_page=0,
|
from_page=0,
|
||||||
to_page=100000000,
|
to_page=100000000,
|
||||||
task_type="dataflow",
|
task_type="dataflow" if not rerun else "dataflow_rerun",
|
||||||
priority=priority,
|
priority=priority,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
import base64
|
import base64
|
||||||
|
import logging
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
@ -8,7 +9,7 @@ test_image_base64 = "iVBORw0KGgoAAAANSUhEUgAAAGQAAABkCAIAAAD/gAIDAAAA6ElEQVR4nO3
|
|||||||
test_image = base64.b64decode(test_image_base64)
|
test_image = base64.b64decode(test_image_base64)
|
||||||
|
|
||||||
|
|
||||||
async def image2id(d: dict, storage_put_func: partial, objname:str, bucket:str="IMAGETEMPS"):
|
async def image2id(d: dict, storage_put_func: partial, objname:str, bucket:str="imagetemps"):
|
||||||
import logging
|
import logging
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
import trio
|
import trio
|
||||||
@ -46,7 +47,10 @@ def id2image(image_id:str|None, storage_get_func: partial):
|
|||||||
if len(arr) != 2:
|
if len(arr) != 2:
|
||||||
return
|
return
|
||||||
bkt, nm = image_id.split("-")
|
bkt, nm = image_id.split("-")
|
||||||
blob = storage_get_func(bucket=bkt, filename=nm)
|
try:
|
||||||
if not blob:
|
blob = storage_get_func(bucket=bkt, filename=nm)
|
||||||
return
|
if not blob:
|
||||||
return Image.open(BytesIO(blob))
|
return
|
||||||
|
return Image.open(BytesIO(blob))
|
||||||
|
except Exception as e:
|
||||||
|
logging.exception(e)
|
||||||
|
|||||||
@ -13,13 +13,12 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import trio
|
import trio
|
||||||
|
|
||||||
from agent.component.base import ComponentBase, ComponentParamBase
|
from agent.component.base import ComponentBase, ComponentParamBase
|
||||||
from api.utils.api_utils import timeout
|
from api.utils.api_utils import timeout
|
||||||
|
|
||||||
@ -43,17 +42,17 @@ class ProcessBase(ComponentBase):
|
|||||||
self.set_output("_created_time", time.perf_counter())
|
self.set_output("_created_time", time.perf_counter())
|
||||||
for k, v in kwargs.items():
|
for k, v in kwargs.items():
|
||||||
self.set_output(k, v)
|
self.set_output(k, v)
|
||||||
#try:
|
try:
|
||||||
with trio.fail_after(self._param.timeout):
|
with trio.fail_after(self._param.timeout):
|
||||||
await self._invoke(**kwargs)
|
await self._invoke(**kwargs)
|
||||||
self.callback(1, "Done")
|
self.callback(1, "Done")
|
||||||
#except Exception as e:
|
except Exception as e:
|
||||||
# if self.get_exception_default_value():
|
if self.get_exception_default_value():
|
||||||
# self.set_exception_default_value()
|
self.set_exception_default_value()
|
||||||
# else:
|
else:
|
||||||
# self.set_output("_ERROR", str(e))
|
self.set_output("_ERROR", str(e))
|
||||||
# logging.exception(e)
|
logging.exception(e)
|
||||||
# self.callback(-1, str(e))
|
self.callback(-1, str(e))
|
||||||
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
|
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
|
||||||
return self.output()
|
return self.output()
|
||||||
|
|
||||||
|
|||||||
15
rag/flow/extractor/__init__.py
Normal file
15
rag/flow/extractor/__init__.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
59
rag/flow/extractor/extractor.py
Normal file
59
rag/flow/extractor/extractor.py
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import random
|
||||||
|
from agent.component.llm import LLMParam, LLM
|
||||||
|
|
||||||
|
|
||||||
|
class ExtractorParam(LLMParam):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.field_name = ""
|
||||||
|
|
||||||
|
def check(self):
|
||||||
|
super().check()
|
||||||
|
self.check_empty(self.field_name, "Result Destination")
|
||||||
|
|
||||||
|
|
||||||
|
class Extractor(LLM):
|
||||||
|
component_name = "Extractor"
|
||||||
|
|
||||||
|
async def _invoke(self, **kwargs):
|
||||||
|
self.callback(random.randint(1, 5) / 100.0, "Start to generate.")
|
||||||
|
inputs = self.get_input_elements()
|
||||||
|
chunks = []
|
||||||
|
chunks_key = ""
|
||||||
|
args = {}
|
||||||
|
for k, v in inputs.items():
|
||||||
|
args[k] = v["value"]
|
||||||
|
if isinstance(args[k], list):
|
||||||
|
chunks = args[k]
|
||||||
|
chunks_key = k
|
||||||
|
|
||||||
|
if chunks:
|
||||||
|
prog = 0
|
||||||
|
for i, ck in enumerate(chunks):
|
||||||
|
args[chunks_key] = ck["text"]
|
||||||
|
msg, sys_prompt = self._sys_prompt_and_msg([], args)
|
||||||
|
msg.insert(0, {"role": "system", "content": sys_prompt})
|
||||||
|
ck[self._param.field_name] = self._generate(msg)
|
||||||
|
prog += 1./len(chunks)
|
||||||
|
self.callback(prog, f"{i+1} / {len(chunks)}")
|
||||||
|
self.set_output("chunks", chunks)
|
||||||
|
else:
|
||||||
|
msg, sys_prompt = self._sys_prompt_and_msg([], args)
|
||||||
|
msg.insert(0, {"role": "system", "content": sys_prompt})
|
||||||
|
self.set_output("chunks", [{self._param.field_name: self._generate(msg)}])
|
||||||
|
|
||||||
|
|
||||||
38
rag/flow/extractor/schema.py
Normal file
38
rag/flow/extractor/schema.py
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
|
|
||||||
|
class ExtractorFromUpstream(BaseModel):
|
||||||
|
created_time: float | None = Field(default=None, alias="_created_time")
|
||||||
|
elapsed_time: float | None = Field(default=None, alias="_elapsed_time")
|
||||||
|
|
||||||
|
name: str
|
||||||
|
file: dict | None = Field(default=None)
|
||||||
|
chunks: list[dict[str, Any]] | None = Field(default=None)
|
||||||
|
|
||||||
|
output_format: Literal["json", "markdown", "text", "html"] | None = Field(default=None)
|
||||||
|
|
||||||
|
json_result: list[dict[str, Any]] | None = Field(default=None, alias="json")
|
||||||
|
markdown_result: str | None = Field(default=None, alias="markdown")
|
||||||
|
text_result: str | None = Field(default=None, alias="text")
|
||||||
|
html_result: list[str] | None = Field(default=None, alias="html")
|
||||||
|
|
||||||
|
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||||
|
|
||||||
|
# def to_dict(self, *, exclude_none: bool = True) -> dict:
|
||||||
|
# return self.model_dump(by_alias=True, exclude_none=exclude_none)
|
||||||
@ -17,15 +17,11 @@ import datetime
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
import time
|
|
||||||
from timeit import default_timer as timer
|
from timeit import default_timer as timer
|
||||||
import trio
|
import trio
|
||||||
|
|
||||||
from agent.canvas import Graph
|
from agent.canvas import Graph
|
||||||
from api.db import PipelineTaskType
|
|
||||||
from api.db.services.document_service import DocumentService
|
from api.db.services.document_service import DocumentService
|
||||||
from api.db.services.task_service import has_canceled
|
from api.db.services.task_service import has_canceled, TaskService, CANVAS_DEBUG_DOC_ID
|
||||||
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
|
|
||||||
from rag.utils.redis_conn import REDIS_CONN
|
from rag.utils.redis_conn import REDIS_CONN
|
||||||
|
|
||||||
|
|
||||||
@ -34,9 +30,9 @@ class Pipeline(Graph):
|
|||||||
if isinstance(dsl, dict):
|
if isinstance(dsl, dict):
|
||||||
dsl = json.dumps(dsl, ensure_ascii=False)
|
dsl = json.dumps(dsl, ensure_ascii=False)
|
||||||
super().__init__(dsl, tenant_id, task_id)
|
super().__init__(dsl, tenant_id, task_id)
|
||||||
|
if doc_id == CANVAS_DEBUG_DOC_ID:
|
||||||
|
doc_id = None
|
||||||
self._doc_id = doc_id
|
self._doc_id = doc_id
|
||||||
if self._doc_id == "x":
|
|
||||||
self._doc_id = None
|
|
||||||
self._flow_id = flow_id
|
self._flow_id = flow_id
|
||||||
self._kb_id = None
|
self._kb_id = None
|
||||||
if self._doc_id:
|
if self._doc_id:
|
||||||
@ -80,7 +76,7 @@ class Pipeline(Graph):
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
REDIS_CONN.set_obj(log_key, obj, 60 * 30)
|
REDIS_CONN.set_obj(log_key, obj, 60 * 30)
|
||||||
if self._doc_id:
|
if self._doc_id and self.task_id:
|
||||||
percentage = 1.0 / len(self.components.items())
|
percentage = 1.0 / len(self.components.items())
|
||||||
msg = ""
|
msg = ""
|
||||||
finished = 0.0
|
finished = 0.0
|
||||||
@ -96,7 +92,7 @@ class Pipeline(Graph):
|
|||||||
if finished < 0:
|
if finished < 0:
|
||||||
break
|
break
|
||||||
finished += o["trace"][-1]["progress"] * percentage
|
finished += o["trace"][-1]["progress"] * percentage
|
||||||
DocumentService.update_by_id(self._doc_id, {"progress": finished, "progress_msg": msg})
|
TaskService.update_progress(self.task_id, {"progress": finished, "progress_msg": msg})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception(e)
|
logging.exception(e)
|
||||||
|
|
||||||
@ -113,34 +109,32 @@ class Pipeline(Graph):
|
|||||||
logging.exception(e)
|
logging.exception(e)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
super().reset()
|
async def run(self, **kwargs):
|
||||||
log_key = f"{self._flow_id}-{self.task_id}-logs"
|
log_key = f"{self._flow_id}-{self.task_id}-logs"
|
||||||
try:
|
try:
|
||||||
REDIS_CONN.set_obj(log_key, [], 60 * 10)
|
REDIS_CONN.set_obj(log_key, [], 60 * 10)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception(e)
|
logging.exception(e)
|
||||||
|
self.error = ""
|
||||||
async def run(self, **kwargs):
|
|
||||||
st = time.perf_counter()
|
|
||||||
if not self.path:
|
if not self.path:
|
||||||
self.path.append("File")
|
self.path.append("File")
|
||||||
|
|
||||||
if self._doc_id:
|
|
||||||
DocumentService.update_by_id(
|
|
||||||
self._doc_id, {"progress": random.randint(0, 5) / 100.0, "progress_msg": "Start the pipeline...", "process_begin_at": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
|
|
||||||
)
|
|
||||||
|
|
||||||
self.error = ""
|
|
||||||
idx = len(self.path) - 1
|
|
||||||
if idx == 0:
|
|
||||||
cpn_obj = self.get_component_obj(self.path[0])
|
cpn_obj = self.get_component_obj(self.path[0])
|
||||||
await cpn_obj.invoke(**kwargs)
|
await cpn_obj.invoke(**kwargs)
|
||||||
if cpn_obj.error():
|
if cpn_obj.error():
|
||||||
self.error = "[ERROR]" + cpn_obj.error()
|
self.error = "[ERROR]" + cpn_obj.error()
|
||||||
else:
|
self.callback(cpn_obj.component_name, -1, self.error)
|
||||||
idx += 1
|
|
||||||
self.path.extend(cpn_obj.get_downstream())
|
if self._doc_id:
|
||||||
|
TaskService.update_progress(self.task_id, {
|
||||||
|
"progress": random.randint(0, 5) / 100.0,
|
||||||
|
"progress_msg": "Start the pipeline...",
|
||||||
|
"begin_at": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")})
|
||||||
|
|
||||||
|
idx = len(self.path) - 1
|
||||||
|
cpn_obj = self.get_component_obj(self.path[idx])
|
||||||
|
idx += 1
|
||||||
|
self.path.extend(cpn_obj.get_downstream())
|
||||||
|
|
||||||
while idx < len(self.path) and not self.error:
|
while idx < len(self.path) and not self.error:
|
||||||
last_cpn = self.get_component_obj(self.path[idx - 1])
|
last_cpn = self.get_component_obj(self.path[idx - 1])
|
||||||
@ -152,23 +146,21 @@ class Pipeline(Graph):
|
|||||||
|
|
||||||
async with trio.open_nursery() as nursery:
|
async with trio.open_nursery() as nursery:
|
||||||
nursery.start_soon(invoke)
|
nursery.start_soon(invoke)
|
||||||
|
|
||||||
if cpn_obj.error():
|
if cpn_obj.error():
|
||||||
self.error = "[ERROR]" + cpn_obj.error()
|
self.error = "[ERROR]" + cpn_obj.error()
|
||||||
self.callback(cpn_obj.component_name, -1, self.error)
|
self.callback(cpn_obj._id, -1, self.error)
|
||||||
break
|
break
|
||||||
idx += 1
|
idx += 1
|
||||||
self.path.extend(cpn_obj.get_downstream())
|
self.path.extend(cpn_obj.get_downstream())
|
||||||
|
|
||||||
self.callback("END", 1, json.dumps(self.get_component_obj(self.path[-1]).output(), ensure_ascii=False))
|
self.callback("END", 1 if not self.error else -1, json.dumps(self.get_component_obj(self.path[-1]).output(), ensure_ascii=False))
|
||||||
|
|
||||||
if self._doc_id:
|
if not self.error:
|
||||||
DocumentService.update_by_id(
|
return self.get_component_obj(self.path[-1]).output()
|
||||||
self._doc_id,
|
|
||||||
{
|
|
||||||
"progress": 1 if not self.error else -1,
|
|
||||||
"progress_msg": "Pipeline finished...\n" + self.error,
|
|
||||||
"process_duration": time.perf_counter() - st,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
PipelineOperationLogService.create(document_id=self._doc_id, pipeline_id=self._flow_id, task_type=PipelineTaskType.PARSE)
|
TaskService.update_progress(self.task_id, {
|
||||||
|
"progress": -1,
|
||||||
|
"progress_msg": f"[ERROR]: {self.error}"})
|
||||||
|
|
||||||
|
return {}
|
||||||
|
|||||||
@ -99,7 +99,7 @@ class Splitter(ProcessBase):
|
|||||||
{
|
{
|
||||||
"text": RAGFlowPdfParser.remove_tag(c),
|
"text": RAGFlowPdfParser.remove_tag(c),
|
||||||
"image": img,
|
"image": img,
|
||||||
"positions": RAGFlowPdfParser.extract_positions(c),
|
"positions": [[pos[0][-1]+1, *pos[1:]] for pos in RAGFlowPdfParser.extract_positions(c)],
|
||||||
}
|
}
|
||||||
for c, img in zip(chunks, images)
|
for c, img in zip(chunks, images)
|
||||||
]
|
]
|
||||||
|
|||||||
@ -120,8 +120,12 @@ class Tokenizer(ProcessBase):
|
|||||||
ck["question_tks"] = rag_tokenizer.tokenize("\n".join(ck["questions"]))
|
ck["question_tks"] = rag_tokenizer.tokenize("\n".join(ck["questions"]))
|
||||||
if ck.get("keywords"):
|
if ck.get("keywords"):
|
||||||
ck["important_tks"] = rag_tokenizer.tokenize("\n".join(ck["keywords"]))
|
ck["important_tks"] = rag_tokenizer.tokenize("\n".join(ck["keywords"]))
|
||||||
ck["content_ltks"] = rag_tokenizer.tokenize(ck["text"])
|
if ck.get("summary"):
|
||||||
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
|
ck["content_ltks"] = rag_tokenizer.tokenize(ck["summary"])
|
||||||
|
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
|
||||||
|
else:
|
||||||
|
ck["content_ltks"] = rag_tokenizer.tokenize(ck["text"])
|
||||||
|
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
|
||||||
if i % 100 == 99:
|
if i % 100 == 99:
|
||||||
self.callback(i * 1.0 / len(chunks) / parts)
|
self.callback(i * 1.0 / len(chunks) / parts)
|
||||||
|
|
||||||
|
|||||||
@ -285,6 +285,7 @@ def tokenize_chunks(chunks, doc, eng, pdf_parser=None):
|
|||||||
res.append(d)
|
res.append(d)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
def tokenize_chunks_with_images(chunks, doc, eng, images):
|
def tokenize_chunks_with_images(chunks, doc, eng, images):
|
||||||
res = []
|
res = []
|
||||||
# wrap up as es documents
|
# wrap up as es documents
|
||||||
@ -299,6 +300,7 @@ def tokenize_chunks_with_images(chunks, doc, eng, images):
|
|||||||
res.append(d)
|
res.append(d)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
def tokenize_table(tbls, doc, eng, batch_size=10):
|
def tokenize_table(tbls, doc, eng, batch_size=10):
|
||||||
res = []
|
res = []
|
||||||
# add tables
|
# add tables
|
||||||
|
|||||||
@ -21,6 +21,7 @@ import sys
|
|||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from api.db.services.canvas_service import UserCanvasService
|
from api.db.services.canvas_service import UserCanvasService
|
||||||
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
|
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
|
||||||
from api.utils.api_utils import timeout
|
from api.utils.api_utils import timeout
|
||||||
from api.utils.base64_image import image2id
|
from api.utils.base64_image import image2id
|
||||||
@ -49,7 +50,7 @@ from peewee import DoesNotExist
|
|||||||
from api.db import LLMType, ParserType, PipelineTaskType
|
from api.db import LLMType, ParserType, PipelineTaskType
|
||||||
from api.db.services.document_service import DocumentService
|
from api.db.services.document_service import DocumentService
|
||||||
from api.db.services.llm_service import LLMBundle
|
from api.db.services.llm_service import LLMBundle
|
||||||
from api.db.services.task_service import TaskService, has_canceled
|
from api.db.services.task_service import TaskService, has_canceled, CANVAS_DEBUG_DOC_ID
|
||||||
from api.db.services.file2document_service import File2DocumentService
|
from api.db.services.file2document_service import File2DocumentService
|
||||||
from api import settings
|
from api import settings
|
||||||
from api.versions import get_ragflow_version
|
from api.versions import get_ragflow_version
|
||||||
@ -146,6 +147,7 @@ def start_tracemalloc_and_snapshot(signum, frame):
|
|||||||
max_rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
|
max_rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
|
||||||
logging.info(f"taken snapshot {snapshot_file}. max RSS={max_rss / 1000:.2f} MB, current memory usage: {current / 10**6:.2f} MB, Peak memory usage: {peak / 10**6:.2f} MB")
|
logging.info(f"taken snapshot {snapshot_file}. max RSS={max_rss / 1000:.2f} MB, current memory usage: {current / 10**6:.2f} MB, Peak memory usage: {peak / 10**6:.2f} MB")
|
||||||
|
|
||||||
|
|
||||||
# SIGUSR2 handler: stop tracemalloc
|
# SIGUSR2 handler: stop tracemalloc
|
||||||
def stop_tracemalloc(signum, frame):
|
def stop_tracemalloc(signum, frame):
|
||||||
if tracemalloc.is_tracing():
|
if tracemalloc.is_tracing():
|
||||||
@ -154,6 +156,7 @@ def stop_tracemalloc(signum, frame):
|
|||||||
else:
|
else:
|
||||||
logging.info("tracemalloc not running")
|
logging.info("tracemalloc not running")
|
||||||
|
|
||||||
|
|
||||||
class TaskCanceledException(Exception):
|
class TaskCanceledException(Exception):
|
||||||
def __init__(self, msg):
|
def __init__(self, msg):
|
||||||
self.msg = msg
|
self.msg = msg
|
||||||
@ -471,11 +474,97 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
|
|||||||
|
|
||||||
|
|
||||||
async def run_dataflow(task: dict):
|
async def run_dataflow(task: dict):
|
||||||
|
task_start_ts = timer()
|
||||||
dataflow_id = task["dataflow_id"]
|
dataflow_id = task["dataflow_id"]
|
||||||
e, cvs = UserCanvasService.get_by_id(dataflow_id)
|
doc_id = task["doc_id"]
|
||||||
pipeline = Pipeline(cvs.dsl, tenant_id=task["tenant_id"], doc_id=task["doc_id"], task_id=task["id"], flow_id=dataflow_id)
|
task_id = task["id"]
|
||||||
pipeline.reset()
|
if task["task_type"] == "dataflow":
|
||||||
await pipeline.run(file=task.get("file"))
|
e, cvs = UserCanvasService.get_by_id(dataflow_id)
|
||||||
|
assert e, "User pipeline not found."
|
||||||
|
dsl = cvs.dsl
|
||||||
|
else:
|
||||||
|
e, pipeline_log = PipelineOperationLogService.get_by_id(dataflow_id)
|
||||||
|
assert e, "Pipeline log not found."
|
||||||
|
dsl = pipeline_log.dsl
|
||||||
|
pipeline = Pipeline(dsl, tenant_id=task["tenant_id"], doc_id=doc_id, task_id=task_id, flow_id=dataflow_id)
|
||||||
|
chunks = await pipeline.run(file=task["file"]) if task.get("file") else pipeline.run()
|
||||||
|
if doc_id == CANVAS_DEBUG_DOC_ID:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not chunks:
|
||||||
|
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE)
|
||||||
|
return
|
||||||
|
|
||||||
|
embedding_token_consumption = chunks.get("embedding_token_consumption", 0)
|
||||||
|
if chunks.get("chunks"):
|
||||||
|
chunks = chunks["chunks"]
|
||||||
|
elif chunks.get("json"):
|
||||||
|
chunks = chunks["json"]
|
||||||
|
elif chunks.get("markdown"):
|
||||||
|
chunks = [{"text": [chunks["markdown"]]}]
|
||||||
|
elif chunks.get("text"):
|
||||||
|
chunks = [{"text": [chunks["text"]]}]
|
||||||
|
elif chunks.get("html"):
|
||||||
|
chunks = [{"text": [chunks["html"]]}]
|
||||||
|
|
||||||
|
keys = [k for o in chunks for k in list(o.keys())]
|
||||||
|
if not any([re.match(r"q_[0-9]+_vec", k) for k in keys]):
|
||||||
|
set_progress(task_id, prog=0.82, msg="Start to embedding...")
|
||||||
|
e, kb = KnowledgebaseService.get_by_id(task["kb_id"])
|
||||||
|
embedding_id = kb.embd_id
|
||||||
|
embedding_model = LLMBundle(task["tenant_id"], LLMType.EMBEDDING, llm_name=embedding_id)
|
||||||
|
@timeout(60)
|
||||||
|
def batch_encode(txts):
|
||||||
|
nonlocal embedding_model
|
||||||
|
return embedding_model.encode([truncate(c, embedding_model.max_length - 10) for c in txts])
|
||||||
|
vects = np.array([])
|
||||||
|
texts = [o.get("questions", o.get("summary", o["text"])) for o in chunks]
|
||||||
|
delta = 0.20/(len(texts)//EMBEDDING_BATCH_SIZE)
|
||||||
|
prog = 0.8
|
||||||
|
for i in range(0, len(texts), EMBEDDING_BATCH_SIZE):
|
||||||
|
async with embed_limiter:
|
||||||
|
vts, c = await trio.to_thread.run_sync(lambda: batch_encode(texts[i : i + EMBEDDING_BATCH_SIZE]))
|
||||||
|
if len(vects) == 0:
|
||||||
|
vects = vts
|
||||||
|
else:
|
||||||
|
vects = np.concatenate((vects, vts), axis=0)
|
||||||
|
embedding_token_consumption += c
|
||||||
|
prog += delta
|
||||||
|
set_progress(task_id, prog=prog, msg=f"{i+1} / {len(texts)//EMBEDDING_BATCH_SIZE}")
|
||||||
|
|
||||||
|
assert len(vects) == len(chunks)
|
||||||
|
for i, ck in enumerate(chunks):
|
||||||
|
v = vects[i].tolist()
|
||||||
|
ck["q_%d_vec" % len(v)] = v
|
||||||
|
|
||||||
|
for ck in chunks:
|
||||||
|
ck["doc_id"] = task["doc_id"]
|
||||||
|
ck["kb_id"] = [str(task["kb_id"])]
|
||||||
|
ck["docnm_kwd"] = task["name"]
|
||||||
|
ck["create_time"] = str(datetime.now()).replace("T", " ")[:19]
|
||||||
|
ck["create_timestamp_flt"] = datetime.now().timestamp()
|
||||||
|
ck["id"] = xxhash.xxh64((ck["text"] + str(ck["doc_id"])).encode("utf-8")).hexdigest()
|
||||||
|
if "questions" in ck:
|
||||||
|
del ck["questions"]
|
||||||
|
if "keywords" in ck:
|
||||||
|
del ck["keywords"]
|
||||||
|
if "summary" in ck:
|
||||||
|
del ck["summary"]
|
||||||
|
del ck["text"]
|
||||||
|
|
||||||
|
start_ts = timer()
|
||||||
|
set_progress(task_id, prog=0.82, msg="Start to index...")
|
||||||
|
e = await insert_es(task_id, task["tenant_id"], task["kb_id"], chunks, partial(set_progress, task_id, 0, 100000000))
|
||||||
|
if not e:
|
||||||
|
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE)
|
||||||
|
return
|
||||||
|
|
||||||
|
time_cost = timer() - start_ts
|
||||||
|
task_time_cost = timer() - task_start_ts
|
||||||
|
set_progress(task_id, prog=1., msg="Indexing done ({:.2f}s). Task done ({:.2f}s)".format(time_cost, task_time_cost))
|
||||||
|
logging.info(
|
||||||
|
"[Done], chunks({}), token({}), elapsed:{:.2f}".format(len(chunks), embedding_token_consumption, task_time_cost))
|
||||||
|
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE)
|
||||||
|
|
||||||
|
|
||||||
@timeout(3600)
|
@timeout(3600)
|
||||||
@ -520,11 +609,48 @@ async def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
|
|||||||
return res, tk_count
|
return res, tk_count
|
||||||
|
|
||||||
|
|
||||||
|
async def delete_image(kb_id, chunk_id):
|
||||||
|
try:
|
||||||
|
async with minio_limiter:
|
||||||
|
STORAGE_IMPL.delete(kb_id, chunk_id)
|
||||||
|
except Exception:
|
||||||
|
logging.exception(f"Deleting image of chunk {chunk_id} got exception")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_callback):
|
||||||
|
for b in range(0, len(chunks), DOC_BULK_SIZE):
|
||||||
|
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + DOC_BULK_SIZE], search.index_name(task_tenant_id), task_dataset_id))
|
||||||
|
task_canceled = has_canceled(task_id)
|
||||||
|
if task_canceled:
|
||||||
|
progress_callback(-1, msg="Task has been canceled.")
|
||||||
|
return
|
||||||
|
if b % 128 == 0:
|
||||||
|
progress_callback(prog=0.8 + 0.1 * (b + 1) / len(chunks), msg="")
|
||||||
|
if doc_store_result:
|
||||||
|
error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
|
||||||
|
progress_callback(-1, msg=error_message)
|
||||||
|
raise Exception(error_message)
|
||||||
|
chunk_ids = [chunk["id"] for chunk in chunks[:b + DOC_BULK_SIZE]]
|
||||||
|
chunk_ids_str = " ".join(chunk_ids)
|
||||||
|
try:
|
||||||
|
TaskService.update_chunk_ids(task_id, chunk_ids_str)
|
||||||
|
except DoesNotExist:
|
||||||
|
logging.warning(f"do_handle_task update_chunk_ids failed since task {task_id} is unknown.")
|
||||||
|
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id), task_dataset_id))
|
||||||
|
async with trio.open_nursery() as nursery:
|
||||||
|
for chunk_id in chunk_ids:
|
||||||
|
nursery.start_soon(delete_image, task_dataset_id, chunk_id)
|
||||||
|
progress_callback(-1, msg=f"Chunk updates failed since task {task_id} is unknown.")
|
||||||
|
return
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
@timeout(60*60*2, 1)
|
@timeout(60*60*2, 1)
|
||||||
async def do_handle_task(task):
|
async def do_handle_task(task):
|
||||||
task_type = task.get("task_type", "")
|
task_type = task.get("task_type", "")
|
||||||
|
|
||||||
if task_type == "dataflow" and task.get("doc_id", "") == "x":
|
if task_type == "dataflow" and task.get("doc_id", "") == CANVAS_DEBUG_DOC_ID:
|
||||||
await run_dataflow(task)
|
await run_dataflow(task)
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -569,7 +695,7 @@ async def do_handle_task(task):
|
|||||||
|
|
||||||
init_kb(task, vector_size)
|
init_kb(task, vector_size)
|
||||||
|
|
||||||
if task_type == "dataflow":
|
if task_type[:len("dataflow")] == "dataflow":
|
||||||
await run_dataflow(task)
|
await run_dataflow(task)
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -631,41 +757,9 @@ async def do_handle_task(task):
|
|||||||
|
|
||||||
chunk_count = len(set([chunk["id"] for chunk in chunks]))
|
chunk_count = len(set([chunk["id"] for chunk in chunks]))
|
||||||
start_ts = timer()
|
start_ts = timer()
|
||||||
doc_store_result = ""
|
e = await insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_callback)
|
||||||
|
if not e:
|
||||||
async def delete_image(kb_id, chunk_id):
|
return
|
||||||
try:
|
|
||||||
async with minio_limiter:
|
|
||||||
STORAGE_IMPL.delete(kb_id, chunk_id)
|
|
||||||
except Exception:
|
|
||||||
logging.exception(
|
|
||||||
"Deleting image of chunk {}/{}/{} got exception".format(task["location"], task["name"], chunk_id))
|
|
||||||
raise
|
|
||||||
|
|
||||||
for b in range(0, len(chunks), DOC_BULK_SIZE):
|
|
||||||
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + DOC_BULK_SIZE], search.index_name(task_tenant_id), task_dataset_id))
|
|
||||||
task_canceled = has_canceled(task_id)
|
|
||||||
if task_canceled:
|
|
||||||
progress_callback(-1, msg="Task has been canceled.")
|
|
||||||
return
|
|
||||||
if b % 128 == 0:
|
|
||||||
progress_callback(prog=0.8 + 0.1 * (b + 1) / len(chunks), msg="")
|
|
||||||
if doc_store_result:
|
|
||||||
error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
|
|
||||||
progress_callback(-1, msg=error_message)
|
|
||||||
raise Exception(error_message)
|
|
||||||
chunk_ids = [chunk["id"] for chunk in chunks[:b + DOC_BULK_SIZE]]
|
|
||||||
chunk_ids_str = " ".join(chunk_ids)
|
|
||||||
try:
|
|
||||||
TaskService.update_chunk_ids(task["id"], chunk_ids_str)
|
|
||||||
except DoesNotExist:
|
|
||||||
logging.warning(f"do_handle_task update_chunk_ids failed since task {task['id']} is unknown.")
|
|
||||||
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id), task_dataset_id))
|
|
||||||
async with trio.open_nursery() as nursery:
|
|
||||||
for chunk_id in chunk_ids:
|
|
||||||
nursery.start_soon(delete_image, task_dataset_id, chunk_id)
|
|
||||||
progress_callback(-1, msg=f"Chunk updates failed since task {task['id']} is unknown.")
|
|
||||||
return
|
|
||||||
|
|
||||||
logging.info("Indexing doc({}), page({}-{}), chunks({}), elapsed: {:.2f}".format(task_document_name, task_from_page,
|
logging.info("Indexing doc({}), page({}-{}), chunks({}), elapsed: {:.2f}".format(task_document_name, task_from_page,
|
||||||
task_to_page, len(chunks),
|
task_to_page, len(chunks),
|
||||||
@ -715,7 +809,8 @@ async def handle_task():
|
|||||||
task_document_ids = []
|
task_document_ids = []
|
||||||
if task_type in ["graphrag"]:
|
if task_type in ["graphrag"]:
|
||||||
task_document_ids = task["doc_ids"]
|
task_document_ids = task["doc_ids"]
|
||||||
PipelineOperationLogService.record_pipeline_operation(document_id=task["doc_id"], pipeline_id=task.get("dataflow_id", "") or "", task_type=pipeline_task_type, fake_document_ids=task_document_ids)
|
if task["doc_id"] != CANVAS_DEBUG_DOC_ID:
|
||||||
|
PipelineOperationLogService.record_pipeline_operation(document_id=task["doc_id"], pipeline_id=task.get("dataflow_id", "") or "", task_type=pipeline_task_type, fake_document_ids=task_document_ids)
|
||||||
|
|
||||||
redis_msg.ack()
|
redis_msg.ack()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user