Feat: add extractor component. (#10271)

### What problem does this PR solve?


### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Kevin Hu
2025-09-25 11:34:47 +08:00
committed by GitHub
parent 840b2b5809
commit 1b19d302c5
16 changed files with 379 additions and 127 deletions

View File

@ -17,15 +17,11 @@ import datetime
import json
import logging
import random
import time
from timeit import default_timer as timer
import trio
from agent.canvas import Graph
from api.db import PipelineTaskType
from api.db.services.document_service import DocumentService
from api.db.services.task_service import has_canceled
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
from api.db.services.task_service import has_canceled, TaskService, CANVAS_DEBUG_DOC_ID
from rag.utils.redis_conn import REDIS_CONN
@ -34,9 +30,9 @@ class Pipeline(Graph):
if isinstance(dsl, dict):
dsl = json.dumps(dsl, ensure_ascii=False)
super().__init__(dsl, tenant_id, task_id)
if doc_id == CANVAS_DEBUG_DOC_ID:
doc_id = None
self._doc_id = doc_id
if self._doc_id == "x":
self._doc_id = None
self._flow_id = flow_id
self._kb_id = None
if self._doc_id:
@ -80,7 +76,7 @@ class Pipeline(Graph):
}
]
REDIS_CONN.set_obj(log_key, obj, 60 * 30)
if self._doc_id:
if self._doc_id and self.task_id:
percentage = 1.0 / len(self.components.items())
msg = ""
finished = 0.0
@ -96,7 +92,7 @@ class Pipeline(Graph):
if finished < 0:
break
finished += o["trace"][-1]["progress"] * percentage
DocumentService.update_by_id(self._doc_id, {"progress": finished, "progress_msg": msg})
TaskService.update_progress(self.task_id, {"progress": finished, "progress_msg": msg})
except Exception as e:
logging.exception(e)
@ -113,34 +109,32 @@ class Pipeline(Graph):
logging.exception(e)
return []
def reset(self):
super().reset()
async def run(self, **kwargs):
log_key = f"{self._flow_id}-{self.task_id}-logs"
try:
REDIS_CONN.set_obj(log_key, [], 60 * 10)
except Exception as e:
logging.exception(e)
async def run(self, **kwargs):
st = time.perf_counter()
self.error = ""
if not self.path:
self.path.append("File")
if self._doc_id:
DocumentService.update_by_id(
self._doc_id, {"progress": random.randint(0, 5) / 100.0, "progress_msg": "Start the pipeline...", "process_begin_at": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
)
self.error = ""
idx = len(self.path) - 1
if idx == 0:
cpn_obj = self.get_component_obj(self.path[0])
await cpn_obj.invoke(**kwargs)
if cpn_obj.error():
self.error = "[ERROR]" + cpn_obj.error()
else:
idx += 1
self.path.extend(cpn_obj.get_downstream())
self.callback(cpn_obj.component_name, -1, self.error)
if self._doc_id:
TaskService.update_progress(self.task_id, {
"progress": random.randint(0, 5) / 100.0,
"progress_msg": "Start the pipeline...",
"begin_at": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")})
idx = len(self.path) - 1
cpn_obj = self.get_component_obj(self.path[idx])
idx += 1
self.path.extend(cpn_obj.get_downstream())
while idx < len(self.path) and not self.error:
last_cpn = self.get_component_obj(self.path[idx - 1])
@ -152,23 +146,21 @@ class Pipeline(Graph):
async with trio.open_nursery() as nursery:
nursery.start_soon(invoke)
if cpn_obj.error():
self.error = "[ERROR]" + cpn_obj.error()
self.callback(cpn_obj.component_name, -1, self.error)
self.callback(cpn_obj._id, -1, self.error)
break
idx += 1
self.path.extend(cpn_obj.get_downstream())
self.callback("END", 1, json.dumps(self.get_component_obj(self.path[-1]).output(), ensure_ascii=False))
self.callback("END", 1 if not self.error else -1, json.dumps(self.get_component_obj(self.path[-1]).output(), ensure_ascii=False))
if self._doc_id:
DocumentService.update_by_id(
self._doc_id,
{
"progress": 1 if not self.error else -1,
"progress_msg": "Pipeline finished...\n" + self.error,
"process_duration": time.perf_counter() - st,
},
)
if not self.error:
return self.get_component_obj(self.path[-1]).output()
PipelineOperationLogService.create(document_id=self._doc_id, pipeline_id=self._flow_id, task_type=PipelineTaskType.PARSE)
TaskService.update_progress(self.task_id, {
"progress": -1,
"progress_msg": f"[ERROR]: {self.error}"})
return {}