Feat: add extractor component. (#10271)

### What problem does this PR solve?


### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Kevin Hu
2025-09-25 11:34:47 +08:00
committed by GitHub
parent 840b2b5809
commit 1b19d302c5
16 changed files with 379 additions and 127 deletions

View File

@ -13,13 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import os
import time
from functools import partial
from typing import Any
import trio
from agent.component.base import ComponentBase, ComponentParamBase
from api.utils.api_utils import timeout
@ -43,17 +42,17 @@ class ProcessBase(ComponentBase):
self.set_output("_created_time", time.perf_counter())
for k, v in kwargs.items():
self.set_output(k, v)
#try:
with trio.fail_after(self._param.timeout):
await self._invoke(**kwargs)
self.callback(1, "Done")
#except Exception as e:
# if self.get_exception_default_value():
# self.set_exception_default_value()
# else:
# self.set_output("_ERROR", str(e))
# logging.exception(e)
# self.callback(-1, str(e))
try:
with trio.fail_after(self._param.timeout):
await self._invoke(**kwargs)
self.callback(1, "Done")
except Exception as e:
if self.get_exception_default_value():
self.set_exception_default_value()
else:
self.set_output("_ERROR", str(e))
logging.exception(e)
self.callback(-1, str(e))
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
return self.output()

View File

@ -0,0 +1,15 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,59 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
from agent.component.llm import LLMParam, LLM
class ExtractorParam(LLMParam):
def __init__(self):
super().__init__()
self.field_name = ""
def check(self):
super().check()
self.check_empty(self.field_name, "Result Destination")
class Extractor(LLM):
component_name = "Extractor"
async def _invoke(self, **kwargs):
self.callback(random.randint(1, 5) / 100.0, "Start to generate.")
inputs = self.get_input_elements()
chunks = []
chunks_key = ""
args = {}
for k, v in inputs.items():
args[k] = v["value"]
if isinstance(args[k], list):
chunks = args[k]
chunks_key = k
if chunks:
prog = 0
for i, ck in enumerate(chunks):
args[chunks_key] = ck["text"]
msg, sys_prompt = self._sys_prompt_and_msg([], args)
msg.insert(0, {"role": "system", "content": sys_prompt})
ck[self._param.field_name] = self._generate(msg)
prog += 1./len(chunks)
self.callback(prog, f"{i+1} / {len(chunks)}")
self.set_output("chunks", chunks)
else:
msg, sys_prompt = self._sys_prompt_and_msg([], args)
msg.insert(0, {"role": "system", "content": sys_prompt})
self.set_output("chunks", [{self._param.field_name: self._generate(msg)}])

View File

@ -0,0 +1,38 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Literal
from pydantic import BaseModel, ConfigDict, Field
class ExtractorFromUpstream(BaseModel):
created_time: float | None = Field(default=None, alias="_created_time")
elapsed_time: float | None = Field(default=None, alias="_elapsed_time")
name: str
file: dict | None = Field(default=None)
chunks: list[dict[str, Any]] | None = Field(default=None)
output_format: Literal["json", "markdown", "text", "html"] | None = Field(default=None)
json_result: list[dict[str, Any]] | None = Field(default=None, alias="json")
markdown_result: str | None = Field(default=None, alias="markdown")
text_result: str | None = Field(default=None, alias="text")
html_result: list[str] | None = Field(default=None, alias="html")
model_config = ConfigDict(populate_by_name=True, extra="forbid")
# def to_dict(self, *, exclude_none: bool = True) -> dict:
# return self.model_dump(by_alias=True, exclude_none=exclude_none)

View File

@ -17,15 +17,11 @@ import datetime
import json
import logging
import random
import time
from timeit import default_timer as timer
import trio
from agent.canvas import Graph
from api.db import PipelineTaskType
from api.db.services.document_service import DocumentService
from api.db.services.task_service import has_canceled
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
from api.db.services.task_service import has_canceled, TaskService, CANVAS_DEBUG_DOC_ID
from rag.utils.redis_conn import REDIS_CONN
@ -34,9 +30,9 @@ class Pipeline(Graph):
if isinstance(dsl, dict):
dsl = json.dumps(dsl, ensure_ascii=False)
super().__init__(dsl, tenant_id, task_id)
if doc_id == CANVAS_DEBUG_DOC_ID:
doc_id = None
self._doc_id = doc_id
if self._doc_id == "x":
self._doc_id = None
self._flow_id = flow_id
self._kb_id = None
if self._doc_id:
@ -80,7 +76,7 @@ class Pipeline(Graph):
}
]
REDIS_CONN.set_obj(log_key, obj, 60 * 30)
if self._doc_id:
if self._doc_id and self.task_id:
percentage = 1.0 / len(self.components.items())
msg = ""
finished = 0.0
@ -96,7 +92,7 @@ class Pipeline(Graph):
if finished < 0:
break
finished += o["trace"][-1]["progress"] * percentage
DocumentService.update_by_id(self._doc_id, {"progress": finished, "progress_msg": msg})
TaskService.update_progress(self.task_id, {"progress": finished, "progress_msg": msg})
except Exception as e:
logging.exception(e)
@ -113,34 +109,32 @@ class Pipeline(Graph):
logging.exception(e)
return []
def reset(self):
super().reset()
async def run(self, **kwargs):
log_key = f"{self._flow_id}-{self.task_id}-logs"
try:
REDIS_CONN.set_obj(log_key, [], 60 * 10)
except Exception as e:
logging.exception(e)
async def run(self, **kwargs):
st = time.perf_counter()
self.error = ""
if not self.path:
self.path.append("File")
if self._doc_id:
DocumentService.update_by_id(
self._doc_id, {"progress": random.randint(0, 5) / 100.0, "progress_msg": "Start the pipeline...", "process_begin_at": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
)
self.error = ""
idx = len(self.path) - 1
if idx == 0:
cpn_obj = self.get_component_obj(self.path[0])
await cpn_obj.invoke(**kwargs)
if cpn_obj.error():
self.error = "[ERROR]" + cpn_obj.error()
else:
idx += 1
self.path.extend(cpn_obj.get_downstream())
self.callback(cpn_obj.component_name, -1, self.error)
if self._doc_id:
TaskService.update_progress(self.task_id, {
"progress": random.randint(0, 5) / 100.0,
"progress_msg": "Start the pipeline...",
"begin_at": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")})
idx = len(self.path) - 1
cpn_obj = self.get_component_obj(self.path[idx])
idx += 1
self.path.extend(cpn_obj.get_downstream())
while idx < len(self.path) and not self.error:
last_cpn = self.get_component_obj(self.path[idx - 1])
@ -152,23 +146,21 @@ class Pipeline(Graph):
async with trio.open_nursery() as nursery:
nursery.start_soon(invoke)
if cpn_obj.error():
self.error = "[ERROR]" + cpn_obj.error()
self.callback(cpn_obj.component_name, -1, self.error)
self.callback(cpn_obj._id, -1, self.error)
break
idx += 1
self.path.extend(cpn_obj.get_downstream())
self.callback("END", 1, json.dumps(self.get_component_obj(self.path[-1]).output(), ensure_ascii=False))
self.callback("END", 1 if not self.error else -1, json.dumps(self.get_component_obj(self.path[-1]).output(), ensure_ascii=False))
if self._doc_id:
DocumentService.update_by_id(
self._doc_id,
{
"progress": 1 if not self.error else -1,
"progress_msg": "Pipeline finished...\n" + self.error,
"process_duration": time.perf_counter() - st,
},
)
if not self.error:
return self.get_component_obj(self.path[-1]).output()
PipelineOperationLogService.create(document_id=self._doc_id, pipeline_id=self._flow_id, task_type=PipelineTaskType.PARSE)
TaskService.update_progress(self.task_id, {
"progress": -1,
"progress_msg": f"[ERROR]: {self.error}"})
return {}

View File

@ -99,7 +99,7 @@ class Splitter(ProcessBase):
{
"text": RAGFlowPdfParser.remove_tag(c),
"image": img,
"positions": RAGFlowPdfParser.extract_positions(c),
"positions": [[pos[0][-1]+1, *pos[1:]] for pos in RAGFlowPdfParser.extract_positions(c)],
}
for c, img in zip(chunks, images)
]

View File

@ -120,8 +120,12 @@ class Tokenizer(ProcessBase):
ck["question_tks"] = rag_tokenizer.tokenize("\n".join(ck["questions"]))
if ck.get("keywords"):
ck["important_tks"] = rag_tokenizer.tokenize("\n".join(ck["keywords"]))
ck["content_ltks"] = rag_tokenizer.tokenize(ck["text"])
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
if ck.get("summary"):
ck["content_ltks"] = rag_tokenizer.tokenize(ck["summary"])
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
else:
ck["content_ltks"] = rag_tokenizer.tokenize(ck["text"])
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
if i % 100 == 99:
self.callback(i * 1.0 / len(chunks) / parts)

View File

@ -285,6 +285,7 @@ def tokenize_chunks(chunks, doc, eng, pdf_parser=None):
res.append(d)
return res
def tokenize_chunks_with_images(chunks, doc, eng, images):
res = []
# wrap up as es documents
@ -299,6 +300,7 @@ def tokenize_chunks_with_images(chunks, doc, eng, images):
res.append(d)
return res
def tokenize_table(tbls, doc, eng, batch_size=10):
res = []
# add tables

View File

@ -21,6 +21,7 @@ import sys
import threading
import time
from api.db.services.canvas_service import UserCanvasService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
from api.utils.api_utils import timeout
from api.utils.base64_image import image2id
@ -49,7 +50,7 @@ from peewee import DoesNotExist
from api.db import LLMType, ParserType, PipelineTaskType
from api.db.services.document_service import DocumentService
from api.db.services.llm_service import LLMBundle
from api.db.services.task_service import TaskService, has_canceled
from api.db.services.task_service import TaskService, has_canceled, CANVAS_DEBUG_DOC_ID
from api.db.services.file2document_service import File2DocumentService
from api import settings
from api.versions import get_ragflow_version
@ -146,6 +147,7 @@ def start_tracemalloc_and_snapshot(signum, frame):
max_rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
logging.info(f"taken snapshot {snapshot_file}. max RSS={max_rss / 1000:.2f} MB, current memory usage: {current / 10**6:.2f} MB, Peak memory usage: {peak / 10**6:.2f} MB")
# SIGUSR2 handler: stop tracemalloc
def stop_tracemalloc(signum, frame):
if tracemalloc.is_tracing():
@ -154,6 +156,7 @@ def stop_tracemalloc(signum, frame):
else:
logging.info("tracemalloc not running")
class TaskCanceledException(Exception):
def __init__(self, msg):
self.msg = msg
@ -471,11 +474,97 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
async def run_dataflow(task: dict):
task_start_ts = timer()
dataflow_id = task["dataflow_id"]
e, cvs = UserCanvasService.get_by_id(dataflow_id)
pipeline = Pipeline(cvs.dsl, tenant_id=task["tenant_id"], doc_id=task["doc_id"], task_id=task["id"], flow_id=dataflow_id)
pipeline.reset()
await pipeline.run(file=task.get("file"))
doc_id = task["doc_id"]
task_id = task["id"]
if task["task_type"] == "dataflow":
e, cvs = UserCanvasService.get_by_id(dataflow_id)
assert e, "User pipeline not found."
dsl = cvs.dsl
else:
e, pipeline_log = PipelineOperationLogService.get_by_id(dataflow_id)
assert e, "Pipeline log not found."
dsl = pipeline_log.dsl
pipeline = Pipeline(dsl, tenant_id=task["tenant_id"], doc_id=doc_id, task_id=task_id, flow_id=dataflow_id)
chunks = await pipeline.run(file=task["file"]) if task.get("file") else pipeline.run()
if doc_id == CANVAS_DEBUG_DOC_ID:
return
if not chunks:
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE)
return
embedding_token_consumption = chunks.get("embedding_token_consumption", 0)
if chunks.get("chunks"):
chunks = chunks["chunks"]
elif chunks.get("json"):
chunks = chunks["json"]
elif chunks.get("markdown"):
chunks = [{"text": [chunks["markdown"]]}]
elif chunks.get("text"):
chunks = [{"text": [chunks["text"]]}]
elif chunks.get("html"):
chunks = [{"text": [chunks["html"]]}]
keys = [k for o in chunks for k in list(o.keys())]
if not any([re.match(r"q_[0-9]+_vec", k) for k in keys]):
set_progress(task_id, prog=0.82, msg="Start to embedding...")
e, kb = KnowledgebaseService.get_by_id(task["kb_id"])
embedding_id = kb.embd_id
embedding_model = LLMBundle(task["tenant_id"], LLMType.EMBEDDING, llm_name=embedding_id)
@timeout(60)
def batch_encode(txts):
nonlocal embedding_model
return embedding_model.encode([truncate(c, embedding_model.max_length - 10) for c in txts])
vects = np.array([])
texts = [o.get("questions", o.get("summary", o["text"])) for o in chunks]
delta = 0.20/(len(texts)//EMBEDDING_BATCH_SIZE)
prog = 0.8
for i in range(0, len(texts), EMBEDDING_BATCH_SIZE):
async with embed_limiter:
vts, c = await trio.to_thread.run_sync(lambda: batch_encode(texts[i : i + EMBEDDING_BATCH_SIZE]))
if len(vects) == 0:
vects = vts
else:
vects = np.concatenate((vects, vts), axis=0)
embedding_token_consumption += c
prog += delta
set_progress(task_id, prog=prog, msg=f"{i+1} / {len(texts)//EMBEDDING_BATCH_SIZE}")
assert len(vects) == len(chunks)
for i, ck in enumerate(chunks):
v = vects[i].tolist()
ck["q_%d_vec" % len(v)] = v
for ck in chunks:
ck["doc_id"] = task["doc_id"]
ck["kb_id"] = [str(task["kb_id"])]
ck["docnm_kwd"] = task["name"]
ck["create_time"] = str(datetime.now()).replace("T", " ")[:19]
ck["create_timestamp_flt"] = datetime.now().timestamp()
ck["id"] = xxhash.xxh64((ck["text"] + str(ck["doc_id"])).encode("utf-8")).hexdigest()
if "questions" in ck:
del ck["questions"]
if "keywords" in ck:
del ck["keywords"]
if "summary" in ck:
del ck["summary"]
del ck["text"]
start_ts = timer()
set_progress(task_id, prog=0.82, msg="Start to index...")
e = await insert_es(task_id, task["tenant_id"], task["kb_id"], chunks, partial(set_progress, task_id, 0, 100000000))
if not e:
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE)
return
time_cost = timer() - start_ts
task_time_cost = timer() - task_start_ts
set_progress(task_id, prog=1., msg="Indexing done ({:.2f}s). Task done ({:.2f}s)".format(time_cost, task_time_cost))
logging.info(
"[Done], chunks({}), token({}), elapsed:{:.2f}".format(len(chunks), embedding_token_consumption, task_time_cost))
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE)
@timeout(3600)
@ -520,11 +609,48 @@ async def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
return res, tk_count
async def delete_image(kb_id, chunk_id):
try:
async with minio_limiter:
STORAGE_IMPL.delete(kb_id, chunk_id)
except Exception:
logging.exception(f"Deleting image of chunk {chunk_id} got exception")
raise
async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_callback):
for b in range(0, len(chunks), DOC_BULK_SIZE):
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + DOC_BULK_SIZE], search.index_name(task_tenant_id), task_dataset_id))
task_canceled = has_canceled(task_id)
if task_canceled:
progress_callback(-1, msg="Task has been canceled.")
return
if b % 128 == 0:
progress_callback(prog=0.8 + 0.1 * (b + 1) / len(chunks), msg="")
if doc_store_result:
error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
progress_callback(-1, msg=error_message)
raise Exception(error_message)
chunk_ids = [chunk["id"] for chunk in chunks[:b + DOC_BULK_SIZE]]
chunk_ids_str = " ".join(chunk_ids)
try:
TaskService.update_chunk_ids(task_id, chunk_ids_str)
except DoesNotExist:
logging.warning(f"do_handle_task update_chunk_ids failed since task {task_id} is unknown.")
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id), task_dataset_id))
async with trio.open_nursery() as nursery:
for chunk_id in chunk_ids:
nursery.start_soon(delete_image, task_dataset_id, chunk_id)
progress_callback(-1, msg=f"Chunk updates failed since task {task_id} is unknown.")
return
return True
@timeout(60*60*2, 1)
async def do_handle_task(task):
task_type = task.get("task_type", "")
if task_type == "dataflow" and task.get("doc_id", "") == "x":
if task_type == "dataflow" and task.get("doc_id", "") == CANVAS_DEBUG_DOC_ID:
await run_dataflow(task)
return
@ -569,7 +695,7 @@ async def do_handle_task(task):
init_kb(task, vector_size)
if task_type == "dataflow":
if task_type[:len("dataflow")] == "dataflow":
await run_dataflow(task)
return
@ -631,41 +757,9 @@ async def do_handle_task(task):
chunk_count = len(set([chunk["id"] for chunk in chunks]))
start_ts = timer()
doc_store_result = ""
async def delete_image(kb_id, chunk_id):
try:
async with minio_limiter:
STORAGE_IMPL.delete(kb_id, chunk_id)
except Exception:
logging.exception(
"Deleting image of chunk {}/{}/{} got exception".format(task["location"], task["name"], chunk_id))
raise
for b in range(0, len(chunks), DOC_BULK_SIZE):
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + DOC_BULK_SIZE], search.index_name(task_tenant_id), task_dataset_id))
task_canceled = has_canceled(task_id)
if task_canceled:
progress_callback(-1, msg="Task has been canceled.")
return
if b % 128 == 0:
progress_callback(prog=0.8 + 0.1 * (b + 1) / len(chunks), msg="")
if doc_store_result:
error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
progress_callback(-1, msg=error_message)
raise Exception(error_message)
chunk_ids = [chunk["id"] for chunk in chunks[:b + DOC_BULK_SIZE]]
chunk_ids_str = " ".join(chunk_ids)
try:
TaskService.update_chunk_ids(task["id"], chunk_ids_str)
except DoesNotExist:
logging.warning(f"do_handle_task update_chunk_ids failed since task {task['id']} is unknown.")
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id), task_dataset_id))
async with trio.open_nursery() as nursery:
for chunk_id in chunk_ids:
nursery.start_soon(delete_image, task_dataset_id, chunk_id)
progress_callback(-1, msg=f"Chunk updates failed since task {task['id']} is unknown.")
return
e = await insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_callback)
if not e:
return
logging.info("Indexing doc({}), page({}-{}), chunks({}), elapsed: {:.2f}".format(task_document_name, task_from_page,
task_to_page, len(chunks),
@ -715,7 +809,8 @@ async def handle_task():
task_document_ids = []
if task_type in ["graphrag"]:
task_document_ids = task["doc_ids"]
PipelineOperationLogService.record_pipeline_operation(document_id=task["doc_id"], pipeline_id=task.get("dataflow_id", "") or "", task_type=pipeline_task_type, fake_document_ids=task_document_ids)
if task["doc_id"] != CANVAS_DEBUG_DOC_ID:
PipelineOperationLogService.record_pipeline_operation(document_id=task["doc_id"], pipeline_id=task.get("dataflow_id", "") or "", task_type=pipeline_task_type, fake_document_ids=task_document_ids)
redis_msg.ack()