Feat: add extractor component. (#10271)

### What problem does this PR solve?


### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Kevin Hu
2025-09-25 11:34:47 +08:00
committed by GitHub
parent 840b2b5809
commit 1b19d302c5
16 changed files with 379 additions and 127 deletions

View File

@ -13,13 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import os
import time
from functools import partial
from typing import Any
import trio
from agent.component.base import ComponentBase, ComponentParamBase
from api.utils.api_utils import timeout
@ -43,17 +42,17 @@ class ProcessBase(ComponentBase):
self.set_output("_created_time", time.perf_counter())
for k, v in kwargs.items():
self.set_output(k, v)
#try:
with trio.fail_after(self._param.timeout):
await self._invoke(**kwargs)
self.callback(1, "Done")
#except Exception as e:
# if self.get_exception_default_value():
# self.set_exception_default_value()
# else:
# self.set_output("_ERROR", str(e))
# logging.exception(e)
# self.callback(-1, str(e))
try:
with trio.fail_after(self._param.timeout):
await self._invoke(**kwargs)
self.callback(1, "Done")
except Exception as e:
if self.get_exception_default_value():
self.set_exception_default_value()
else:
self.set_output("_ERROR", str(e))
logging.exception(e)
self.callback(-1, str(e))
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
return self.output()

View File

@ -0,0 +1,15 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,59 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
from agent.component.llm import LLMParam, LLM
class ExtractorParam(LLMParam):
def __init__(self):
super().__init__()
self.field_name = ""
def check(self):
super().check()
self.check_empty(self.field_name, "Result Destination")
class Extractor(LLM):
component_name = "Extractor"
async def _invoke(self, **kwargs):
self.callback(random.randint(1, 5) / 100.0, "Start to generate.")
inputs = self.get_input_elements()
chunks = []
chunks_key = ""
args = {}
for k, v in inputs.items():
args[k] = v["value"]
if isinstance(args[k], list):
chunks = args[k]
chunks_key = k
if chunks:
prog = 0
for i, ck in enumerate(chunks):
args[chunks_key] = ck["text"]
msg, sys_prompt = self._sys_prompt_and_msg([], args)
msg.insert(0, {"role": "system", "content": sys_prompt})
ck[self._param.field_name] = self._generate(msg)
prog += 1./len(chunks)
self.callback(prog, f"{i+1} / {len(chunks)}")
self.set_output("chunks", chunks)
else:
msg, sys_prompt = self._sys_prompt_and_msg([], args)
msg.insert(0, {"role": "system", "content": sys_prompt})
self.set_output("chunks", [{self._param.field_name: self._generate(msg)}])

View File

@ -0,0 +1,38 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Literal
from pydantic import BaseModel, ConfigDict, Field
class ExtractorFromUpstream(BaseModel):
created_time: float | None = Field(default=None, alias="_created_time")
elapsed_time: float | None = Field(default=None, alias="_elapsed_time")
name: str
file: dict | None = Field(default=None)
chunks: list[dict[str, Any]] | None = Field(default=None)
output_format: Literal["json", "markdown", "text", "html"] | None = Field(default=None)
json_result: list[dict[str, Any]] | None = Field(default=None, alias="json")
markdown_result: str | None = Field(default=None, alias="markdown")
text_result: str | None = Field(default=None, alias="text")
html_result: list[str] | None = Field(default=None, alias="html")
model_config = ConfigDict(populate_by_name=True, extra="forbid")
# def to_dict(self, *, exclude_none: bool = True) -> dict:
# return self.model_dump(by_alias=True, exclude_none=exclude_none)

View File

@ -17,15 +17,11 @@ import datetime
import json
import logging
import random
import time
from timeit import default_timer as timer
import trio
from agent.canvas import Graph
from api.db import PipelineTaskType
from api.db.services.document_service import DocumentService
from api.db.services.task_service import has_canceled
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
from api.db.services.task_service import has_canceled, TaskService, CANVAS_DEBUG_DOC_ID
from rag.utils.redis_conn import REDIS_CONN
@ -34,9 +30,9 @@ class Pipeline(Graph):
if isinstance(dsl, dict):
dsl = json.dumps(dsl, ensure_ascii=False)
super().__init__(dsl, tenant_id, task_id)
if doc_id == CANVAS_DEBUG_DOC_ID:
doc_id = None
self._doc_id = doc_id
if self._doc_id == "x":
self._doc_id = None
self._flow_id = flow_id
self._kb_id = None
if self._doc_id:
@ -80,7 +76,7 @@ class Pipeline(Graph):
}
]
REDIS_CONN.set_obj(log_key, obj, 60 * 30)
if self._doc_id:
if self._doc_id and self.task_id:
percentage = 1.0 / len(self.components.items())
msg = ""
finished = 0.0
@ -96,7 +92,7 @@ class Pipeline(Graph):
if finished < 0:
break
finished += o["trace"][-1]["progress"] * percentage
DocumentService.update_by_id(self._doc_id, {"progress": finished, "progress_msg": msg})
TaskService.update_progress(self.task_id, {"progress": finished, "progress_msg": msg})
except Exception as e:
logging.exception(e)
@ -113,34 +109,32 @@ class Pipeline(Graph):
logging.exception(e)
return []
def reset(self):
super().reset()
async def run(self, **kwargs):
log_key = f"{self._flow_id}-{self.task_id}-logs"
try:
REDIS_CONN.set_obj(log_key, [], 60 * 10)
except Exception as e:
logging.exception(e)
async def run(self, **kwargs):
st = time.perf_counter()
self.error = ""
if not self.path:
self.path.append("File")
if self._doc_id:
DocumentService.update_by_id(
self._doc_id, {"progress": random.randint(0, 5) / 100.0, "progress_msg": "Start the pipeline...", "process_begin_at": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
)
self.error = ""
idx = len(self.path) - 1
if idx == 0:
cpn_obj = self.get_component_obj(self.path[0])
await cpn_obj.invoke(**kwargs)
if cpn_obj.error():
self.error = "[ERROR]" + cpn_obj.error()
else:
idx += 1
self.path.extend(cpn_obj.get_downstream())
self.callback(cpn_obj.component_name, -1, self.error)
if self._doc_id:
TaskService.update_progress(self.task_id, {
"progress": random.randint(0, 5) / 100.0,
"progress_msg": "Start the pipeline...",
"begin_at": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")})
idx = len(self.path) - 1
cpn_obj = self.get_component_obj(self.path[idx])
idx += 1
self.path.extend(cpn_obj.get_downstream())
while idx < len(self.path) and not self.error:
last_cpn = self.get_component_obj(self.path[idx - 1])
@ -152,23 +146,21 @@ class Pipeline(Graph):
async with trio.open_nursery() as nursery:
nursery.start_soon(invoke)
if cpn_obj.error():
self.error = "[ERROR]" + cpn_obj.error()
self.callback(cpn_obj.component_name, -1, self.error)
self.callback(cpn_obj._id, -1, self.error)
break
idx += 1
self.path.extend(cpn_obj.get_downstream())
self.callback("END", 1, json.dumps(self.get_component_obj(self.path[-1]).output(), ensure_ascii=False))
self.callback("END", 1 if not self.error else -1, json.dumps(self.get_component_obj(self.path[-1]).output(), ensure_ascii=False))
if self._doc_id:
DocumentService.update_by_id(
self._doc_id,
{
"progress": 1 if not self.error else -1,
"progress_msg": "Pipeline finished...\n" + self.error,
"process_duration": time.perf_counter() - st,
},
)
if not self.error:
return self.get_component_obj(self.path[-1]).output()
PipelineOperationLogService.create(document_id=self._doc_id, pipeline_id=self._flow_id, task_type=PipelineTaskType.PARSE)
TaskService.update_progress(self.task_id, {
"progress": -1,
"progress_msg": f"[ERROR]: {self.error}"})
return {}

View File

@ -99,7 +99,7 @@ class Splitter(ProcessBase):
{
"text": RAGFlowPdfParser.remove_tag(c),
"image": img,
"positions": RAGFlowPdfParser.extract_positions(c),
"positions": [[pos[0][-1]+1, *pos[1:]] for pos in RAGFlowPdfParser.extract_positions(c)],
}
for c, img in zip(chunks, images)
]

View File

@ -120,8 +120,12 @@ class Tokenizer(ProcessBase):
ck["question_tks"] = rag_tokenizer.tokenize("\n".join(ck["questions"]))
if ck.get("keywords"):
ck["important_tks"] = rag_tokenizer.tokenize("\n".join(ck["keywords"]))
ck["content_ltks"] = rag_tokenizer.tokenize(ck["text"])
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
if ck.get("summary"):
ck["content_ltks"] = rag_tokenizer.tokenize(ck["summary"])
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
else:
ck["content_ltks"] = rag_tokenizer.tokenize(ck["text"])
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
if i % 100 == 99:
self.callback(i * 1.0 / len(chunks) / parts)