mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-30 00:32:30 +08:00
Feat: add extractor component. (#10271)
### What problem does this PR solve? ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -13,13 +13,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
import trio
|
||||
|
||||
from agent.component.base import ComponentBase, ComponentParamBase
|
||||
from api.utils.api_utils import timeout
|
||||
|
||||
@ -43,17 +42,17 @@ class ProcessBase(ComponentBase):
|
||||
self.set_output("_created_time", time.perf_counter())
|
||||
for k, v in kwargs.items():
|
||||
self.set_output(k, v)
|
||||
#try:
|
||||
with trio.fail_after(self._param.timeout):
|
||||
await self._invoke(**kwargs)
|
||||
self.callback(1, "Done")
|
||||
#except Exception as e:
|
||||
# if self.get_exception_default_value():
|
||||
# self.set_exception_default_value()
|
||||
# else:
|
||||
# self.set_output("_ERROR", str(e))
|
||||
# logging.exception(e)
|
||||
# self.callback(-1, str(e))
|
||||
try:
|
||||
with trio.fail_after(self._param.timeout):
|
||||
await self._invoke(**kwargs)
|
||||
self.callback(1, "Done")
|
||||
except Exception as e:
|
||||
if self.get_exception_default_value():
|
||||
self.set_exception_default_value()
|
||||
else:
|
||||
self.set_output("_ERROR", str(e))
|
||||
logging.exception(e)
|
||||
self.callback(-1, str(e))
|
||||
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
|
||||
return self.output()
|
||||
|
||||
|
||||
15
rag/flow/extractor/__init__.py
Normal file
15
rag/flow/extractor/__init__.py
Normal file
@ -0,0 +1,15 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
59
rag/flow/extractor/extractor.py
Normal file
59
rag/flow/extractor/extractor.py
Normal file
@ -0,0 +1,59 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import random
|
||||
from agent.component.llm import LLMParam, LLM
|
||||
|
||||
|
||||
class ExtractorParam(LLMParam):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.field_name = ""
|
||||
|
||||
def check(self):
|
||||
super().check()
|
||||
self.check_empty(self.field_name, "Result Destination")
|
||||
|
||||
|
||||
class Extractor(LLM):
|
||||
component_name = "Extractor"
|
||||
|
||||
async def _invoke(self, **kwargs):
|
||||
self.callback(random.randint(1, 5) / 100.0, "Start to generate.")
|
||||
inputs = self.get_input_elements()
|
||||
chunks = []
|
||||
chunks_key = ""
|
||||
args = {}
|
||||
for k, v in inputs.items():
|
||||
args[k] = v["value"]
|
||||
if isinstance(args[k], list):
|
||||
chunks = args[k]
|
||||
chunks_key = k
|
||||
|
||||
if chunks:
|
||||
prog = 0
|
||||
for i, ck in enumerate(chunks):
|
||||
args[chunks_key] = ck["text"]
|
||||
msg, sys_prompt = self._sys_prompt_and_msg([], args)
|
||||
msg.insert(0, {"role": "system", "content": sys_prompt})
|
||||
ck[self._param.field_name] = self._generate(msg)
|
||||
prog += 1./len(chunks)
|
||||
self.callback(prog, f"{i+1} / {len(chunks)}")
|
||||
self.set_output("chunks", chunks)
|
||||
else:
|
||||
msg, sys_prompt = self._sys_prompt_and_msg([], args)
|
||||
msg.insert(0, {"role": "system", "content": sys_prompt})
|
||||
self.set_output("chunks", [{self._param.field_name: self._generate(msg)}])
|
||||
|
||||
|
||||
38
rag/flow/extractor/schema.py
Normal file
38
rag/flow/extractor/schema.py
Normal file
@ -0,0 +1,38 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class ExtractorFromUpstream(BaseModel):
|
||||
created_time: float | None = Field(default=None, alias="_created_time")
|
||||
elapsed_time: float | None = Field(default=None, alias="_elapsed_time")
|
||||
|
||||
name: str
|
||||
file: dict | None = Field(default=None)
|
||||
chunks: list[dict[str, Any]] | None = Field(default=None)
|
||||
|
||||
output_format: Literal["json", "markdown", "text", "html"] | None = Field(default=None)
|
||||
|
||||
json_result: list[dict[str, Any]] | None = Field(default=None, alias="json")
|
||||
markdown_result: str | None = Field(default=None, alias="markdown")
|
||||
text_result: str | None = Field(default=None, alias="text")
|
||||
html_result: list[str] | None = Field(default=None, alias="html")
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
|
||||
# def to_dict(self, *, exclude_none: bool = True) -> dict:
|
||||
# return self.model_dump(by_alias=True, exclude_none=exclude_none)
|
||||
@ -17,15 +17,11 @@ import datetime
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
import time
|
||||
from timeit import default_timer as timer
|
||||
import trio
|
||||
|
||||
from agent.canvas import Graph
|
||||
from api.db import PipelineTaskType
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.task_service import has_canceled
|
||||
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
|
||||
from api.db.services.task_service import has_canceled, TaskService, CANVAS_DEBUG_DOC_ID
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
|
||||
|
||||
@ -34,9 +30,9 @@ class Pipeline(Graph):
|
||||
if isinstance(dsl, dict):
|
||||
dsl = json.dumps(dsl, ensure_ascii=False)
|
||||
super().__init__(dsl, tenant_id, task_id)
|
||||
if doc_id == CANVAS_DEBUG_DOC_ID:
|
||||
doc_id = None
|
||||
self._doc_id = doc_id
|
||||
if self._doc_id == "x":
|
||||
self._doc_id = None
|
||||
self._flow_id = flow_id
|
||||
self._kb_id = None
|
||||
if self._doc_id:
|
||||
@ -80,7 +76,7 @@ class Pipeline(Graph):
|
||||
}
|
||||
]
|
||||
REDIS_CONN.set_obj(log_key, obj, 60 * 30)
|
||||
if self._doc_id:
|
||||
if self._doc_id and self.task_id:
|
||||
percentage = 1.0 / len(self.components.items())
|
||||
msg = ""
|
||||
finished = 0.0
|
||||
@ -96,7 +92,7 @@ class Pipeline(Graph):
|
||||
if finished < 0:
|
||||
break
|
||||
finished += o["trace"][-1]["progress"] * percentage
|
||||
DocumentService.update_by_id(self._doc_id, {"progress": finished, "progress_msg": msg})
|
||||
TaskService.update_progress(self.task_id, {"progress": finished, "progress_msg": msg})
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
|
||||
@ -113,34 +109,32 @@ class Pipeline(Graph):
|
||||
logging.exception(e)
|
||||
return []
|
||||
|
||||
def reset(self):
|
||||
super().reset()
|
||||
|
||||
async def run(self, **kwargs):
|
||||
log_key = f"{self._flow_id}-{self.task_id}-logs"
|
||||
try:
|
||||
REDIS_CONN.set_obj(log_key, [], 60 * 10)
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
|
||||
async def run(self, **kwargs):
|
||||
st = time.perf_counter()
|
||||
self.error = ""
|
||||
if not self.path:
|
||||
self.path.append("File")
|
||||
|
||||
if self._doc_id:
|
||||
DocumentService.update_by_id(
|
||||
self._doc_id, {"progress": random.randint(0, 5) / 100.0, "progress_msg": "Start the pipeline...", "process_begin_at": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
|
||||
)
|
||||
|
||||
self.error = ""
|
||||
idx = len(self.path) - 1
|
||||
if idx == 0:
|
||||
cpn_obj = self.get_component_obj(self.path[0])
|
||||
await cpn_obj.invoke(**kwargs)
|
||||
if cpn_obj.error():
|
||||
self.error = "[ERROR]" + cpn_obj.error()
|
||||
else:
|
||||
idx += 1
|
||||
self.path.extend(cpn_obj.get_downstream())
|
||||
self.callback(cpn_obj.component_name, -1, self.error)
|
||||
|
||||
if self._doc_id:
|
||||
TaskService.update_progress(self.task_id, {
|
||||
"progress": random.randint(0, 5) / 100.0,
|
||||
"progress_msg": "Start the pipeline...",
|
||||
"begin_at": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")})
|
||||
|
||||
idx = len(self.path) - 1
|
||||
cpn_obj = self.get_component_obj(self.path[idx])
|
||||
idx += 1
|
||||
self.path.extend(cpn_obj.get_downstream())
|
||||
|
||||
while idx < len(self.path) and not self.error:
|
||||
last_cpn = self.get_component_obj(self.path[idx - 1])
|
||||
@ -152,23 +146,21 @@ class Pipeline(Graph):
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(invoke)
|
||||
|
||||
if cpn_obj.error():
|
||||
self.error = "[ERROR]" + cpn_obj.error()
|
||||
self.callback(cpn_obj.component_name, -1, self.error)
|
||||
self.callback(cpn_obj._id, -1, self.error)
|
||||
break
|
||||
idx += 1
|
||||
self.path.extend(cpn_obj.get_downstream())
|
||||
|
||||
self.callback("END", 1, json.dumps(self.get_component_obj(self.path[-1]).output(), ensure_ascii=False))
|
||||
self.callback("END", 1 if not self.error else -1, json.dumps(self.get_component_obj(self.path[-1]).output(), ensure_ascii=False))
|
||||
|
||||
if self._doc_id:
|
||||
DocumentService.update_by_id(
|
||||
self._doc_id,
|
||||
{
|
||||
"progress": 1 if not self.error else -1,
|
||||
"progress_msg": "Pipeline finished...\n" + self.error,
|
||||
"process_duration": time.perf_counter() - st,
|
||||
},
|
||||
)
|
||||
if not self.error:
|
||||
return self.get_component_obj(self.path[-1]).output()
|
||||
|
||||
PipelineOperationLogService.create(document_id=self._doc_id, pipeline_id=self._flow_id, task_type=PipelineTaskType.PARSE)
|
||||
TaskService.update_progress(self.task_id, {
|
||||
"progress": -1,
|
||||
"progress_msg": f"[ERROR]: {self.error}"})
|
||||
|
||||
return {}
|
||||
|
||||
@ -99,7 +99,7 @@ class Splitter(ProcessBase):
|
||||
{
|
||||
"text": RAGFlowPdfParser.remove_tag(c),
|
||||
"image": img,
|
||||
"positions": RAGFlowPdfParser.extract_positions(c),
|
||||
"positions": [[pos[0][-1]+1, *pos[1:]] for pos in RAGFlowPdfParser.extract_positions(c)],
|
||||
}
|
||||
for c, img in zip(chunks, images)
|
||||
]
|
||||
|
||||
@ -120,8 +120,12 @@ class Tokenizer(ProcessBase):
|
||||
ck["question_tks"] = rag_tokenizer.tokenize("\n".join(ck["questions"]))
|
||||
if ck.get("keywords"):
|
||||
ck["important_tks"] = rag_tokenizer.tokenize("\n".join(ck["keywords"]))
|
||||
ck["content_ltks"] = rag_tokenizer.tokenize(ck["text"])
|
||||
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
|
||||
if ck.get("summary"):
|
||||
ck["content_ltks"] = rag_tokenizer.tokenize(ck["summary"])
|
||||
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
|
||||
else:
|
||||
ck["content_ltks"] = rag_tokenizer.tokenize(ck["text"])
|
||||
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
|
||||
if i % 100 == 99:
|
||||
self.callback(i * 1.0 / len(chunks) / parts)
|
||||
|
||||
|
||||
@ -285,6 +285,7 @@ def tokenize_chunks(chunks, doc, eng, pdf_parser=None):
|
||||
res.append(d)
|
||||
return res
|
||||
|
||||
|
||||
def tokenize_chunks_with_images(chunks, doc, eng, images):
|
||||
res = []
|
||||
# wrap up as es documents
|
||||
@ -299,6 +300,7 @@ def tokenize_chunks_with_images(chunks, doc, eng, images):
|
||||
res.append(d)
|
||||
return res
|
||||
|
||||
|
||||
def tokenize_table(tbls, doc, eng, batch_size=10):
|
||||
res = []
|
||||
# add tables
|
||||
|
||||
@ -21,6 +21,7 @@ import sys
|
||||
import threading
|
||||
import time
|
||||
from api.db.services.canvas_service import UserCanvasService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
|
||||
from api.utils.api_utils import timeout
|
||||
from api.utils.base64_image import image2id
|
||||
@ -49,7 +50,7 @@ from peewee import DoesNotExist
|
||||
from api.db import LLMType, ParserType, PipelineTaskType
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.task_service import TaskService, has_canceled
|
||||
from api.db.services.task_service import TaskService, has_canceled, CANVAS_DEBUG_DOC_ID
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api import settings
|
||||
from api.versions import get_ragflow_version
|
||||
@ -146,6 +147,7 @@ def start_tracemalloc_and_snapshot(signum, frame):
|
||||
max_rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
|
||||
logging.info(f"taken snapshot {snapshot_file}. max RSS={max_rss / 1000:.2f} MB, current memory usage: {current / 10**6:.2f} MB, Peak memory usage: {peak / 10**6:.2f} MB")
|
||||
|
||||
|
||||
# SIGUSR2 handler: stop tracemalloc
|
||||
def stop_tracemalloc(signum, frame):
|
||||
if tracemalloc.is_tracing():
|
||||
@ -154,6 +156,7 @@ def stop_tracemalloc(signum, frame):
|
||||
else:
|
||||
logging.info("tracemalloc not running")
|
||||
|
||||
|
||||
class TaskCanceledException(Exception):
|
||||
def __init__(self, msg):
|
||||
self.msg = msg
|
||||
@ -471,11 +474,97 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
|
||||
|
||||
|
||||
async def run_dataflow(task: dict):
|
||||
task_start_ts = timer()
|
||||
dataflow_id = task["dataflow_id"]
|
||||
e, cvs = UserCanvasService.get_by_id(dataflow_id)
|
||||
pipeline = Pipeline(cvs.dsl, tenant_id=task["tenant_id"], doc_id=task["doc_id"], task_id=task["id"], flow_id=dataflow_id)
|
||||
pipeline.reset()
|
||||
await pipeline.run(file=task.get("file"))
|
||||
doc_id = task["doc_id"]
|
||||
task_id = task["id"]
|
||||
if task["task_type"] == "dataflow":
|
||||
e, cvs = UserCanvasService.get_by_id(dataflow_id)
|
||||
assert e, "User pipeline not found."
|
||||
dsl = cvs.dsl
|
||||
else:
|
||||
e, pipeline_log = PipelineOperationLogService.get_by_id(dataflow_id)
|
||||
assert e, "Pipeline log not found."
|
||||
dsl = pipeline_log.dsl
|
||||
pipeline = Pipeline(dsl, tenant_id=task["tenant_id"], doc_id=doc_id, task_id=task_id, flow_id=dataflow_id)
|
||||
chunks = await pipeline.run(file=task["file"]) if task.get("file") else pipeline.run()
|
||||
if doc_id == CANVAS_DEBUG_DOC_ID:
|
||||
return
|
||||
|
||||
if not chunks:
|
||||
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE)
|
||||
return
|
||||
|
||||
embedding_token_consumption = chunks.get("embedding_token_consumption", 0)
|
||||
if chunks.get("chunks"):
|
||||
chunks = chunks["chunks"]
|
||||
elif chunks.get("json"):
|
||||
chunks = chunks["json"]
|
||||
elif chunks.get("markdown"):
|
||||
chunks = [{"text": [chunks["markdown"]]}]
|
||||
elif chunks.get("text"):
|
||||
chunks = [{"text": [chunks["text"]]}]
|
||||
elif chunks.get("html"):
|
||||
chunks = [{"text": [chunks["html"]]}]
|
||||
|
||||
keys = [k for o in chunks for k in list(o.keys())]
|
||||
if not any([re.match(r"q_[0-9]+_vec", k) for k in keys]):
|
||||
set_progress(task_id, prog=0.82, msg="Start to embedding...")
|
||||
e, kb = KnowledgebaseService.get_by_id(task["kb_id"])
|
||||
embedding_id = kb.embd_id
|
||||
embedding_model = LLMBundle(task["tenant_id"], LLMType.EMBEDDING, llm_name=embedding_id)
|
||||
@timeout(60)
|
||||
def batch_encode(txts):
|
||||
nonlocal embedding_model
|
||||
return embedding_model.encode([truncate(c, embedding_model.max_length - 10) for c in txts])
|
||||
vects = np.array([])
|
||||
texts = [o.get("questions", o.get("summary", o["text"])) for o in chunks]
|
||||
delta = 0.20/(len(texts)//EMBEDDING_BATCH_SIZE)
|
||||
prog = 0.8
|
||||
for i in range(0, len(texts), EMBEDDING_BATCH_SIZE):
|
||||
async with embed_limiter:
|
||||
vts, c = await trio.to_thread.run_sync(lambda: batch_encode(texts[i : i + EMBEDDING_BATCH_SIZE]))
|
||||
if len(vects) == 0:
|
||||
vects = vts
|
||||
else:
|
||||
vects = np.concatenate((vects, vts), axis=0)
|
||||
embedding_token_consumption += c
|
||||
prog += delta
|
||||
set_progress(task_id, prog=prog, msg=f"{i+1} / {len(texts)//EMBEDDING_BATCH_SIZE}")
|
||||
|
||||
assert len(vects) == len(chunks)
|
||||
for i, ck in enumerate(chunks):
|
||||
v = vects[i].tolist()
|
||||
ck["q_%d_vec" % len(v)] = v
|
||||
|
||||
for ck in chunks:
|
||||
ck["doc_id"] = task["doc_id"]
|
||||
ck["kb_id"] = [str(task["kb_id"])]
|
||||
ck["docnm_kwd"] = task["name"]
|
||||
ck["create_time"] = str(datetime.now()).replace("T", " ")[:19]
|
||||
ck["create_timestamp_flt"] = datetime.now().timestamp()
|
||||
ck["id"] = xxhash.xxh64((ck["text"] + str(ck["doc_id"])).encode("utf-8")).hexdigest()
|
||||
if "questions" in ck:
|
||||
del ck["questions"]
|
||||
if "keywords" in ck:
|
||||
del ck["keywords"]
|
||||
if "summary" in ck:
|
||||
del ck["summary"]
|
||||
del ck["text"]
|
||||
|
||||
start_ts = timer()
|
||||
set_progress(task_id, prog=0.82, msg="Start to index...")
|
||||
e = await insert_es(task_id, task["tenant_id"], task["kb_id"], chunks, partial(set_progress, task_id, 0, 100000000))
|
||||
if not e:
|
||||
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE)
|
||||
return
|
||||
|
||||
time_cost = timer() - start_ts
|
||||
task_time_cost = timer() - task_start_ts
|
||||
set_progress(task_id, prog=1., msg="Indexing done ({:.2f}s). Task done ({:.2f}s)".format(time_cost, task_time_cost))
|
||||
logging.info(
|
||||
"[Done], chunks({}), token({}), elapsed:{:.2f}".format(len(chunks), embedding_token_consumption, task_time_cost))
|
||||
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE)
|
||||
|
||||
|
||||
@timeout(3600)
|
||||
@ -520,11 +609,48 @@ async def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
|
||||
return res, tk_count
|
||||
|
||||
|
||||
async def delete_image(kb_id, chunk_id):
|
||||
try:
|
||||
async with minio_limiter:
|
||||
STORAGE_IMPL.delete(kb_id, chunk_id)
|
||||
except Exception:
|
||||
logging.exception(f"Deleting image of chunk {chunk_id} got exception")
|
||||
raise
|
||||
|
||||
|
||||
async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_callback):
|
||||
for b in range(0, len(chunks), DOC_BULK_SIZE):
|
||||
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + DOC_BULK_SIZE], search.index_name(task_tenant_id), task_dataset_id))
|
||||
task_canceled = has_canceled(task_id)
|
||||
if task_canceled:
|
||||
progress_callback(-1, msg="Task has been canceled.")
|
||||
return
|
||||
if b % 128 == 0:
|
||||
progress_callback(prog=0.8 + 0.1 * (b + 1) / len(chunks), msg="")
|
||||
if doc_store_result:
|
||||
error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
|
||||
progress_callback(-1, msg=error_message)
|
||||
raise Exception(error_message)
|
||||
chunk_ids = [chunk["id"] for chunk in chunks[:b + DOC_BULK_SIZE]]
|
||||
chunk_ids_str = " ".join(chunk_ids)
|
||||
try:
|
||||
TaskService.update_chunk_ids(task_id, chunk_ids_str)
|
||||
except DoesNotExist:
|
||||
logging.warning(f"do_handle_task update_chunk_ids failed since task {task_id} is unknown.")
|
||||
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id), task_dataset_id))
|
||||
async with trio.open_nursery() as nursery:
|
||||
for chunk_id in chunk_ids:
|
||||
nursery.start_soon(delete_image, task_dataset_id, chunk_id)
|
||||
progress_callback(-1, msg=f"Chunk updates failed since task {task_id} is unknown.")
|
||||
return
|
||||
return True
|
||||
|
||||
|
||||
@timeout(60*60*2, 1)
|
||||
async def do_handle_task(task):
|
||||
task_type = task.get("task_type", "")
|
||||
|
||||
if task_type == "dataflow" and task.get("doc_id", "") == "x":
|
||||
if task_type == "dataflow" and task.get("doc_id", "") == CANVAS_DEBUG_DOC_ID:
|
||||
await run_dataflow(task)
|
||||
return
|
||||
|
||||
@ -569,7 +695,7 @@ async def do_handle_task(task):
|
||||
|
||||
init_kb(task, vector_size)
|
||||
|
||||
if task_type == "dataflow":
|
||||
if task_type[:len("dataflow")] == "dataflow":
|
||||
await run_dataflow(task)
|
||||
return
|
||||
|
||||
@ -631,41 +757,9 @@ async def do_handle_task(task):
|
||||
|
||||
chunk_count = len(set([chunk["id"] for chunk in chunks]))
|
||||
start_ts = timer()
|
||||
doc_store_result = ""
|
||||
|
||||
async def delete_image(kb_id, chunk_id):
|
||||
try:
|
||||
async with minio_limiter:
|
||||
STORAGE_IMPL.delete(kb_id, chunk_id)
|
||||
except Exception:
|
||||
logging.exception(
|
||||
"Deleting image of chunk {}/{}/{} got exception".format(task["location"], task["name"], chunk_id))
|
||||
raise
|
||||
|
||||
for b in range(0, len(chunks), DOC_BULK_SIZE):
|
||||
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + DOC_BULK_SIZE], search.index_name(task_tenant_id), task_dataset_id))
|
||||
task_canceled = has_canceled(task_id)
|
||||
if task_canceled:
|
||||
progress_callback(-1, msg="Task has been canceled.")
|
||||
return
|
||||
if b % 128 == 0:
|
||||
progress_callback(prog=0.8 + 0.1 * (b + 1) / len(chunks), msg="")
|
||||
if doc_store_result:
|
||||
error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
|
||||
progress_callback(-1, msg=error_message)
|
||||
raise Exception(error_message)
|
||||
chunk_ids = [chunk["id"] for chunk in chunks[:b + DOC_BULK_SIZE]]
|
||||
chunk_ids_str = " ".join(chunk_ids)
|
||||
try:
|
||||
TaskService.update_chunk_ids(task["id"], chunk_ids_str)
|
||||
except DoesNotExist:
|
||||
logging.warning(f"do_handle_task update_chunk_ids failed since task {task['id']} is unknown.")
|
||||
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id), task_dataset_id))
|
||||
async with trio.open_nursery() as nursery:
|
||||
for chunk_id in chunk_ids:
|
||||
nursery.start_soon(delete_image, task_dataset_id, chunk_id)
|
||||
progress_callback(-1, msg=f"Chunk updates failed since task {task['id']} is unknown.")
|
||||
return
|
||||
e = await insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_callback)
|
||||
if not e:
|
||||
return
|
||||
|
||||
logging.info("Indexing doc({}), page({}-{}), chunks({}), elapsed: {:.2f}".format(task_document_name, task_from_page,
|
||||
task_to_page, len(chunks),
|
||||
@ -715,7 +809,8 @@ async def handle_task():
|
||||
task_document_ids = []
|
||||
if task_type in ["graphrag"]:
|
||||
task_document_ids = task["doc_ids"]
|
||||
PipelineOperationLogService.record_pipeline_operation(document_id=task["doc_id"], pipeline_id=task.get("dataflow_id", "") or "", task_type=pipeline_task_type, fake_document_ids=task_document_ids)
|
||||
if task["doc_id"] != CANVAS_DEBUG_DOC_ID:
|
||||
PipelineOperationLogService.record_pipeline_operation(document_id=task["doc_id"], pipeline_id=task.get("dataflow_id", "") or "", task_type=pipeline_task_type, fake_document_ids=task_document_ids)
|
||||
|
||||
redis_msg.ack()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user