Feat: add extractor component. (#10271)

### What problem does this PR solve?


### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Kevin Hu
2025-09-25 11:34:47 +08:00
committed by GitHub
parent 840b2b5809
commit 1b19d302c5
16 changed files with 379 additions and 127 deletions

View File

@ -28,6 +28,7 @@ from api.db import CanvasCategory, FileType
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService, API4ConversationService
from api.db.services.document_service import DocumentService
from api.db.services.file_service import FileService
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
from api.db.services.task_service import queue_dataflow
from api.db.services.user_service import TenantService
from api.db.services.user_canvas_version import UserCanvasVersionService
@ -174,6 +175,25 @@ def run():
return resp
@manager.route('/rerun', methods=['POST']) # noqa: F821
@validate_request("id", "dsl", "component_id")
@login_required
def rerun():
req = request.json
doc = PipelineOperationLogService.get_documents_info(req["id"])
if not doc:
return get_data_error_result(message="Document not found.")
doc = doc[0]
if 0 < doc["progress"] < 1:
return get_data_error_result(message=f"`{doc['name']}` is processing...")
dsl = req["dsl"]
dsl["path"] = [req["component_id"]]
PipelineOperationLogService.update_by_id(req["id"], {"dsl": dsl})
queue_dataflow(tenant_id=current_user.id, flow_id=req["id"], task_id=get_uuid(), doc_id=doc["id"], priority=0, rerun=True)
return get_json_result(data=True)
@manager.route('/cancel/<task_id>', methods=['PUT']) # noqa: F821
@login_required
def cancel(task_id):

View File

@ -121,12 +121,20 @@ class DocumentService(CommonService):
orderby, desc, keywords, run_status, types, suffix):
fields = cls.get_cls_model_fields()
if keywords:
docs = cls.model.select(*fields).join(File2Document, on=(File2Document.document_id == cls.model.id)).join(File, on=(File.id == File2Document.file_id)).where(
(cls.model.kb_id == kb_id),
(fn.LOWER(cls.model.name).contains(keywords.lower()))
)
docs = cls.model.select(*[*fields, UserCanvas.title])\
.join(File2Document, on=(File2Document.document_id == cls.model.id))\
.join(File, on=(File.id == File2Document.file_id))\
.join(UserCanvas, on=(cls.model.pipeline_id == UserCanvas.id), join_type=JOIN.LEFT_OUTER)\
.where(
(cls.model.kb_id == kb_id),
(fn.LOWER(cls.model.name).contains(keywords.lower()))
)
else:
docs = cls.model.select(*fields).join(File2Document, on=(File2Document.document_id == cls.model.id)).join(File, on=(File.id == File2Document.file_id)).where(cls.model.kb_id == kb_id)
docs = cls.model.select(*[*fields, UserCanvas.title])\
.join(File2Document, on=(File2Document.document_id == cls.model.id))\
.join(UserCanvas, on=(cls.model.pipeline_id == UserCanvas.id), join_type=JOIN.LEFT_OUTER)\
.join(File, on=(File.id == File2Document.file_id))\
.where(cls.model.kb_id == kb_id)
if run_status:
docs = docs.where(cls.model.run.in_(run_status))

View File

@ -225,6 +225,7 @@ class KnowledgebaseService(CommonService):
cls.model.token_num,
cls.model.chunk_num,
cls.model.parser_id,
cls.model.pipeline_id,
cls.model.parser_config,
cls.model.pagerank,
cls.model.create_time,

View File

@ -14,12 +14,13 @@
# limitations under the License.
#
import json
import logging
from datetime import datetime
from peewee import fn
from api.db import VALID_PIPELINE_TASK_TYPES
from api.db.db_models import DB, PipelineOperationLog
from api.db.db_models import DB, PipelineOperationLog, Document
from api.db.services.canvas_service import UserCanvasService
from api.db.services.common_service import CommonService
from api.db.services.document_service import DocumentService
@ -84,22 +85,20 @@ class PipelineOperationLogService(CommonService):
def create(cls, document_id, pipeline_id, task_type, fake_document_ids=[]):
from rag.flow.pipeline import Pipeline
tenant_id = ""
title = ""
avatar = ""
dsl = ""
operation_status = ""
referred_document_id = document_id
if referred_document_id == "x" and fake_document_ids:
referred_document_id = fake_document_ids[0]
ok, document = DocumentService.get_by_id(referred_document_id)
if not ok:
raise RuntimeError(f"Document for referred_document_id {referred_document_id} not found")
logging.warning(f"Document for referred_document_id {referred_document_id} not found")
return
DocumentService.update_progress_immediately([document.to_dict()])
ok, document = DocumentService.get_by_id(referred_document_id)
if not ok:
raise RuntimeError(f"Document for referred_document_id {referred_document_id} not found")
logging.warning(f"Document for referred_document_id {referred_document_id} not found")
return
if document.progress not in [1, -1]:
return
operation_status = document.run
@ -189,6 +188,20 @@ class PipelineOperationLogService(CommonService):
return list(logs.dicts()), count
@classmethod
@DB.connection_context()
def get_documents_info(cls, id):
fields = [
Document.id,
Document.name,
Document.progress
]
return cls.model.select(*fields).join(Document, on=(cls.model.document_id == Document.id)).where(
cls.model.id == id,
Document.progress > 0,
Document.progress < 1
).dicts()
@classmethod
@DB.connection_context()
def get_dataset_logs_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, operation_status):
@ -208,3 +221,4 @@ class PipelineOperationLogService(CommonService):
logs = logs.paginate(page_number, items_per_page)
return list(logs.dicts()), count

View File

@ -35,6 +35,7 @@ from rag.utils.redis_conn import REDIS_CONN
from api import settings
from rag.nlp import search
CANVAS_DEBUG_DOC_ID = "dataflow_x"
def trim_header_by_lines(text: str, max_length) -> str:
# Trim header text to maximum length while preserving line breaks
@ -85,7 +86,7 @@ class TaskService(CommonService):
Returns None if task is not found or has exceeded retry limit.
"""
doc_id = cls.model.doc_id
if doc_id == "x" and doc_ids:
if doc_id == CANVAS_DEBUG_DOC_ID and doc_ids:
doc_id = doc_ids[0]
fields = [
@ -476,14 +477,14 @@ def has_canceled(task_id):
return False
def queue_dataflow(tenant_id:str, flow_id:str, task_id:str, doc_id:str="x", file:dict=None, priority: int=0) -> tuple[bool, str]:
def queue_dataflow(tenant_id:str, flow_id:str, task_id:str, doc_id:str=CANVAS_DEBUG_DOC_ID, file:dict=None, priority: int=0, rerun:bool=False) -> tuple[bool, str]:
task = dict(
id=task_id,
doc_id=doc_id,
from_page=0,
to_page=100000000,
task_type="dataflow",
task_type="dataflow" if not rerun else "dataflow_rerun",
priority=priority,
)

View File

@ -1,4 +1,5 @@
import base64
import logging
from functools import partial
from io import BytesIO
@ -8,7 +9,7 @@ test_image_base64 = "iVBORw0KGgoAAAANSUhEUgAAAGQAAABkCAIAAAD/gAIDAAAA6ElEQVR4nO3
test_image = base64.b64decode(test_image_base64)
async def image2id(d: dict, storage_put_func: partial, objname:str, bucket:str="IMAGETEMPS"):
async def image2id(d: dict, storage_put_func: partial, objname:str, bucket:str="imagetemps"):
import logging
from io import BytesIO
import trio
@ -46,7 +47,10 @@ def id2image(image_id:str|None, storage_get_func: partial):
if len(arr) != 2:
return
bkt, nm = image_id.split("-")
blob = storage_get_func(bucket=bkt, filename=nm)
if not blob:
return
return Image.open(BytesIO(blob))
try:
blob = storage_get_func(bucket=bkt, filename=nm)
if not blob:
return
return Image.open(BytesIO(blob))
except Exception as e:
logging.exception(e)