Compare commits

...

2 Commits

Author SHA1 Message Date
1b19d302c5 Feat: add extractor component. (#10271)
### What problem does this PR solve?


### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-09-25 11:34:47 +08:00
840b2b5809 Feat: add foundational support for GraphRAG dataset pipeline logs (#10264)
### What problem does this PR solve?

Add foundational support for GraphRAG dataset pipeline logs

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-09-25 09:35:50 +08:00
21 changed files with 843 additions and 158 deletions

View File

@ -113,6 +113,15 @@ class LLM(ComponentBase):
def add2system_prompt(self, txt): def add2system_prompt(self, txt):
self._param.sys_prompt += txt self._param.sys_prompt += txt
def _sys_prompt_and_msg(self, msg, args):
for p in self._param.prompts:
if msg and msg[-1]["role"] == p["role"]:
continue
p = deepcopy(p)
p["content"] = self.string_format(p["content"], args)
msg.append(p)
return msg, self.string_format(self._param.sys_prompt, args)
def _prepare_prompt_variables(self): def _prepare_prompt_variables(self):
if self._param.visual_files_var: if self._param.visual_files_var:
self.imgs = self._canvas.get_variable_value(self._param.visual_files_var) self.imgs = self._canvas.get_variable_value(self._param.visual_files_var)
@ -128,7 +137,6 @@ class LLM(ComponentBase):
args = {} args = {}
vars = self.get_input_elements() if not self._param.debug_inputs else self._param.debug_inputs vars = self.get_input_elements() if not self._param.debug_inputs else self._param.debug_inputs
sys_prompt = self._param.sys_prompt
for k, o in vars.items(): for k, o in vars.items():
args[k] = o["value"] args[k] = o["value"]
if not isinstance(args[k], str): if not isinstance(args[k], str):
@ -138,16 +146,8 @@ class LLM(ComponentBase):
args[k] = str(args[k]) args[k] = str(args[k])
self.set_input_value(k, args[k]) self.set_input_value(k, args[k])
msg = self._canvas.get_history(self._param.message_history_window_size)[:-1] msg, sys_prompt = self._sys_prompt_and_msg(self._canvas.get_history(self._param.message_history_window_size)[:-1], args)
for p in self._param.prompts:
if msg and msg[-1]["role"] == p["role"]:
continue
msg.append(deepcopy(p))
sys_prompt = self.string_format(sys_prompt, args)
user_defined_prompt, sys_prompt = self._extract_prompts(sys_prompt) user_defined_prompt, sys_prompt = self._extract_prompts(sys_prompt)
for m in msg:
m["content"] = self.string_format(m["content"], args)
if self._param.cite and self._canvas.get_reference()["chunks"]: if self._param.cite and self._canvas.get_reference()["chunks"]:
sys_prompt += citation_prompt(user_defined_prompt) sys_prompt += citation_prompt(user_defined_prompt)

View File

@ -28,6 +28,7 @@ from api.db import CanvasCategory, FileType
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService, API4ConversationService from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService, API4ConversationService
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService
from api.db.services.file_service import FileService from api.db.services.file_service import FileService
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
from api.db.services.task_service import queue_dataflow from api.db.services.task_service import queue_dataflow
from api.db.services.user_service import TenantService from api.db.services.user_service import TenantService
from api.db.services.user_canvas_version import UserCanvasVersionService from api.db.services.user_canvas_version import UserCanvasVersionService
@ -174,6 +175,25 @@ def run():
return resp return resp
@manager.route('/rerun', methods=['POST']) # noqa: F821
@validate_request("id", "dsl", "component_id")
@login_required
def rerun():
req = request.json
doc = PipelineOperationLogService.get_documents_info(req["id"])
if not doc:
return get_data_error_result(message="Document not found.")
doc = doc[0]
if 0 < doc["progress"] < 1:
return get_data_error_result(message=f"`{doc['name']}` is processing...")
dsl = req["dsl"]
dsl["path"] = [req["component_id"]]
PipelineOperationLogService.update_by_id(req["id"], {"dsl": dsl})
queue_dataflow(tenant_id=current_user.id, flow_id=req["id"], task_id=get_uuid(), doc_id=doc["id"], priority=0, rerun=True)
return get_json_result(data=True)
@manager.route('/cancel/<task_id>', methods=['PUT']) # noqa: F821 @manager.route('/cancel/<task_id>', methods=['PUT']) # noqa: F821
@login_required @login_required
def cancel(task_id): def cancel(task_id):

View File

@ -14,17 +14,19 @@
# limitations under the License. # limitations under the License.
# #
import json import json
import logging
from flask import request from flask import request
from flask_login import login_required, current_user from flask_login import login_required, current_user
from api.db.services import duplicate_name from api.db.services import duplicate_name
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService, queue_raptor_o_graphrag_tasks
from api.db.services.file2document_service import File2DocumentService from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService from api.db.services.file_service import FileService
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
from api.db.services.task_service import TaskService
from api.db.services.user_service import TenantService, UserTenantService from api.db.services.user_service import TenantService, UserTenantService
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request, not_allowed_parameters from api.utils.api_utils import get_error_data_result, server_error_response, get_data_error_result, validate_request, not_allowed_parameters
from api.utils import get_uuid from api.utils import get_uuid
from api.db import StatusEnum, FileSource, VALID_FILE_TYPES from api.db import StatusEnum, FileSource, VALID_FILE_TYPES
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
@ -435,18 +437,60 @@ def list_pipeline_logs():
suffix = req.get("suffix", []) suffix = req.get("suffix", [])
try: try:
docs, tol = PipelineOperationLogService.get_by_kb_id(kb_id, page_number, items_per_page, orderby, desc, keywords, operation_status, types, suffix) logs, tol = PipelineOperationLogService.get_file_logs_by_kb_id(kb_id, page_number, items_per_page, orderby, desc, keywords, operation_status, types, suffix)
if create_time_from or create_time_to: if create_time_from or create_time_to:
filtered_docs = [] filtered_docs = []
for doc in docs: for doc in logs:
doc_create_time = doc.get("create_time", 0) doc_create_time = doc.get("create_time", 0)
if (create_time_from == 0 or doc_create_time >= create_time_from) and (create_time_to == 0 or doc_create_time <= create_time_to): if (create_time_from == 0 or doc_create_time >= create_time_from) and (create_time_to == 0 or doc_create_time <= create_time_to):
filtered_docs.append(doc) filtered_docs.append(doc)
docs = filtered_docs logs = filtered_docs
return get_json_result(data={"total": tol, "docs": docs}) return get_json_result(data={"total": tol, "logs": logs})
except Exception as e:
return server_error_response(e)
@manager.route("/list_pipeline_dataset_logs", methods=["POST"]) # noqa: F821
@login_required
def list_pipeline_dataset_logs():
kb_id = request.args.get("kb_id")
if not kb_id:
return get_json_result(data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
page_number = int(request.args.get("page", 0))
items_per_page = int(request.args.get("page_size", 0))
orderby = request.args.get("orderby", "create_time")
if request.args.get("desc", "true").lower() == "false":
desc = False
else:
desc = True
create_time_from = int(request.args.get("create_time_from", 0))
create_time_to = int(request.args.get("create_time_to", 0))
req = request.get_json()
operation_status = req.get("operation_status", [])
if operation_status:
invalid_status = {s for s in operation_status if s not in ["success", "failed", "running", "pending"]}
if invalid_status:
return get_data_error_result(message=f"Invalid filter operation_status status conditions: {', '.join(invalid_status)}")
try:
logs, tol = PipelineOperationLogService.get_dataset_logs_by_kb_id(kb_id, page_number, items_per_page, orderby, desc, operation_status)
if create_time_from or create_time_to:
filtered_docs = []
for doc in logs:
doc_create_time = doc.get("create_time", 0)
if (create_time_from == 0 or doc_create_time >= create_time_from) and (create_time_to == 0 or doc_create_time <= create_time_to):
filtered_docs.append(doc)
logs = filtered_docs
return get_json_result(data={"total": tol, "logs": logs})
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@ -478,3 +522,68 @@ def pipeline_log_detail():
return get_data_error_result(message="Invalid pipeline log ID") return get_data_error_result(message="Invalid pipeline log ID")
return get_json_result(data=log.to_dict()) return get_json_result(data=log.to_dict())
@manager.route("/run_graphrag", methods=["POST"]) # noqa: F821
@login_required
def run_graphrag():
req = request.json
kb_id = req.get("kb_id", "")
if not kb_id:
return get_error_data_result(message='Lack of "KB ID"')
doc_ids = req.get("doc_ids", [])
if not doc_ids:
return get_error_data_result(message="Need to specify document IDs to run Graph RAG")
ok, kb = KnowledgebaseService.get_by_id(kb_id)
if not ok:
return get_error_data_result(message="Invalid Knowledgebase ID")
task_id = kb.graphrag_task_id
ok, task = TaskService.get_by_id(task_id)
if not ok:
logging.warning(f"A valid GraphRAG task id is expected for kb {kb_id}")
if task and task.progress not in [-1, 1]:
return get_error_data_result(message=f"Task in progress with status {task.progress}. A Graph Task is already running.")
document_ids = set()
sample_document = {}
for doc_id in doc_ids:
ok, document = DocumentService.get_by_id(doc_id)
if ok:
document_ids.add(document.id)
if not sample_document:
sample_document = document.to_dict()
task_id = queue_raptor_o_graphrag_tasks(doc=sample_document, ty="graphrag", priority=0, fake_doc_id="x", doc_ids=list(document_ids))
if not KnowledgebaseService.update_by_id(kb.id, {"graphrag_task_id": task_id}):
logging.warning(f"Cannot save graphrag_task_id for kb {kb_id}")
return get_json_result(data={"graphrag_task_id": task_id})
@manager.route("/trace_graphrag", methods=["GET"]) # noqa: F821
@login_required
def trace_graphrag():
kb_id = request.args.get("kb_id", "")
if not kb_id:
return get_error_data_result(message='Lack of "KB ID"')
ok, kb = KnowledgebaseService.get_by_id(kb_id)
if not ok:
return get_error_data_result(message="Invalid Knowledgebase ID")
task_id = kb.graphrag_task_id
if not task_id:
return get_error_data_result(message="GraphRAG Task ID Not Found")
ok, task = TaskService.get_by_id(task_id)
if not ok:
return get_json_result(data=False, message="GraphRAG Task Not Found or Error Occurred", code=settings.RetCode.ARGUMENT_ERROR)
return get_json_result(data=task.to_dict())

View File

@ -124,10 +124,12 @@ VALID_MCP_SERVER_TYPES = {MCPServerType.SSE, MCPServerType.STREAMABLE_HTTP}
class PipelineTaskType(StrEnum): class PipelineTaskType(StrEnum):
PARSE = "Parse" PARSE = "Parse"
DOWNLOAD = "DOWNLOAD" DOWNLOAD = "Download"
RAPTOR = "RAPTOR"
GRAPH_RAG = "GraphRAG"
VALID_PIPELINE_TASK_TYPES = {PipelineTaskType.PARSE, PipelineTaskType.DOWNLOAD} VALID_PIPELINE_TASK_TYPES = {PipelineTaskType.PARSE, PipelineTaskType.DOWNLOAD, PipelineTaskType.RAPTOR, PipelineTaskType.GRAPH_RAG}
KNOWLEDGEBASE_FOLDER_NAME=".knowledgebase" KNOWLEDGEBASE_FOLDER_NAME=".knowledgebase"

View File

@ -649,6 +649,9 @@ class Knowledgebase(DataBaseModel):
pipeline_id = CharField(max_length=32, null=True, help_text="Pipeline ID", index=True) pipeline_id = CharField(max_length=32, null=True, help_text="Pipeline ID", index=True)
parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]}) parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]})
pagerank = IntegerField(default=0, index=False) pagerank = IntegerField(default=0, index=False)
graphrag_task_id = CharField(max_length=32, null=True, help_text="Graph RAG task ID", index=True)
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True) status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True)
def __str__(self): def __str__(self):
@ -1065,11 +1068,15 @@ def migrate_db():
except Exception: except Exception:
pass pass
try: try:
migrate(migrator.add_column("knowledgebase", "pipeline_id", CharField(max_length=32, null=True, help_text="default parser ID", index=True))) migrate(migrator.add_column("knowledgebase", "pipeline_id", CharField(max_length=32, null=True, help_text="Pipeline ID", index=True)))
except Exception: except Exception:
pass pass
try: try:
migrate(migrator.add_column("document", "pipeline_id", CharField(max_length=32, null=True, help_text="default parser ID", index=True))) migrate(migrator.add_column("document", "pipeline_id", CharField(max_length=32, null=True, help_text="Pipeline ID", index=True)))
except Exception:
pass
try:
migrate(migrator.add_column("knowledgebase", "graphrag_task_id", CharField(max_length=32, null=True, help_text="Gragh RAG task ID", index=True)))
except Exception: except Exception:
pass pass
logging.disable(logging.NOTSET) logging.disable(logging.NOTSET)

View File

@ -121,12 +121,20 @@ class DocumentService(CommonService):
orderby, desc, keywords, run_status, types, suffix): orderby, desc, keywords, run_status, types, suffix):
fields = cls.get_cls_model_fields() fields = cls.get_cls_model_fields()
if keywords: if keywords:
docs = cls.model.select(*fields).join(File2Document, on=(File2Document.document_id == cls.model.id)).join(File, on=(File.id == File2Document.file_id)).where( docs = cls.model.select(*[*fields, UserCanvas.title])\
(cls.model.kb_id == kb_id), .join(File2Document, on=(File2Document.document_id == cls.model.id))\
(fn.LOWER(cls.model.name).contains(keywords.lower())) .join(File, on=(File.id == File2Document.file_id))\
) .join(UserCanvas, on=(cls.model.pipeline_id == UserCanvas.id), join_type=JOIN.LEFT_OUTER)\
.where(
(cls.model.kb_id == kb_id),
(fn.LOWER(cls.model.name).contains(keywords.lower()))
)
else: else:
docs = cls.model.select(*fields).join(File2Document, on=(File2Document.document_id == cls.model.id)).join(File, on=(File.id == File2Document.file_id)).where(cls.model.kb_id == kb_id) docs = cls.model.select(*[*fields, UserCanvas.title])\
.join(File2Document, on=(File2Document.document_id == cls.model.id))\
.join(UserCanvas, on=(cls.model.pipeline_id == UserCanvas.id), join_type=JOIN.LEFT_OUTER)\
.join(File, on=(File.id == File2Document.file_id))\
.where(cls.model.kb_id == kb_id)
if run_status: if run_status:
docs = docs.where(cls.model.run.in_(run_status)) docs = docs.where(cls.model.run.in_(run_status))
@ -507,6 +515,9 @@ class DocumentService(CommonService):
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def get_doc_id_by_doc_name(cls, doc_name): def get_doc_id_by_doc_name(cls, doc_name):
"""
highly rely on the strict deduplication guarantee from Document
"""
fields = [cls.model.id] fields = [cls.model.id]
doc_id = cls.model.select(*fields) \ doc_id = cls.model.select(*fields) \
.where(cls.model.name == doc_name) .where(cls.model.name == doc_name)
@ -656,6 +667,7 @@ class DocumentService(CommonService):
queue_raptor_o_graphrag_tasks(d, "graphrag", priority) queue_raptor_o_graphrag_tasks(d, "graphrag", priority)
prg = 0.98 * len(tsks) / (len(tsks) + 1) prg = 0.98 * len(tsks) / (len(tsks) + 1)
else: else:
prg = 1
status = TaskStatus.DONE.value status = TaskStatus.DONE.value
msg = "\n".join(sorted(msg)) msg = "\n".join(sorted(msg))
@ -741,7 +753,11 @@ class DocumentService(CommonService):
"cancelled": int(cancelled), "cancelled": int(cancelled),
} }
def queue_raptor_o_graphrag_tasks(doc, ty, priority): def queue_raptor_o_graphrag_tasks(doc, ty, priority, fake_doc_id="", doc_ids=[]):
"""
You can provide a fake_doc_id to bypass the restriction of tasks at the knowledgebase level.
Optionally, specify a list of doc_ids to determine which documents participate in the task.
"""
chunking_config = DocumentService.get_chunking_config(doc["id"]) chunking_config = DocumentService.get_chunking_config(doc["id"])
hasher = xxhash.xxh64() hasher = xxhash.xxh64()
for field in sorted(chunking_config.keys()): for field in sorted(chunking_config.keys()):
@ -751,7 +767,7 @@ def queue_raptor_o_graphrag_tasks(doc, ty, priority):
nonlocal doc nonlocal doc
return { return {
"id": get_uuid(), "id": get_uuid(),
"doc_id": doc["id"], "doc_id": fake_doc_id if fake_doc_id else doc["id"],
"from_page": 100000000, "from_page": 100000000,
"to_page": 100000000, "to_page": 100000000,
"task_type": ty, "task_type": ty,
@ -764,7 +780,11 @@ def queue_raptor_o_graphrag_tasks(doc, ty, priority):
hasher.update(ty.encode("utf-8")) hasher.update(ty.encode("utf-8"))
task["digest"] = hasher.hexdigest() task["digest"] = hasher.hexdigest()
bulk_insert_into_db(Task, [task], True) bulk_insert_into_db(Task, [task], True)
if ty == "graphrag":
task["doc_ids"] = doc_ids
assert REDIS_CONN.queue_product(get_svr_queue_name(priority), message=task), "Can't access Redis. Please check the Redis' status." assert REDIS_CONN.queue_product(get_svr_queue_name(priority), message=task), "Can't access Redis. Please check the Redis' status."
return task["id"]
def get_queue_length(priority): def get_queue_length(priority):

View File

@ -225,6 +225,7 @@ class KnowledgebaseService(CommonService):
cls.model.token_num, cls.model.token_num,
cls.model.chunk_num, cls.model.chunk_num,
cls.model.parser_id, cls.model.parser_id,
cls.model.pipeline_id,
cls.model.parser_config, cls.model.parser_config,
cls.model.pagerank, cls.model.pagerank,
cls.model.create_time, cls.model.create_time,

View File

@ -14,12 +14,13 @@
# limitations under the License. # limitations under the License.
# #
import json import json
import logging
from datetime import datetime from datetime import datetime
from peewee import fn from peewee import fn
from api.db import VALID_PIPELINE_TASK_TYPES from api.db import VALID_PIPELINE_TASK_TYPES
from api.db.db_models import DB, PipelineOperationLog from api.db.db_models import DB, PipelineOperationLog, Document
from api.db.services.canvas_service import UserCanvasService from api.db.services.canvas_service import UserCanvasService
from api.db.services.common_service import CommonService from api.db.services.common_service import CommonService
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService
@ -31,7 +32,7 @@ class PipelineOperationLogService(CommonService):
model = PipelineOperationLog model = PipelineOperationLog
@classmethod @classmethod
def get_cls_model_fields(cls): def get_file_logs_fields(cls):
return [ return [
cls.model.id, cls.model.id,
cls.model.document_id, cls.model.document_id,
@ -59,24 +60,47 @@ class PipelineOperationLogService(CommonService):
cls.model.update_date, cls.model.update_date,
] ]
@classmethod
def get_dataset_logs_fields(cls):
return [
cls.model.id,
cls.model.tenant_id,
cls.model.kb_id,
cls.model.progress,
cls.model.progress_msg,
cls.model.process_begin_at,
cls.model.process_duration,
cls.model.task_type,
cls.model.operation_status,
cls.model.avatar,
cls.model.status,
cls.model.create_time,
cls.model.create_date,
cls.model.update_time,
cls.model.update_date,
]
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def create(cls, document_id, pipeline_id, task_type): def create(cls, document_id, pipeline_id, task_type, fake_document_ids=[]):
from rag.flow.pipeline import Pipeline from rag.flow.pipeline import Pipeline
tenant_id = ""
title = ""
avatar = ""
dsl = "" dsl = ""
operation_status = "" referred_document_id = document_id
ok, document = DocumentService.get_by_id(document_id) if referred_document_id == "x" and fake_document_ids:
referred_document_id = fake_document_ids[0]
ok, document = DocumentService.get_by_id(referred_document_id)
if not ok: if not ok:
raise RuntimeError(f"Document {document_id} not found") logging.warning(f"Document for referred_document_id {referred_document_id} not found")
return
DocumentService.update_progress_immediately([document.to_dict()]) DocumentService.update_progress_immediately([document.to_dict()])
ok, document = DocumentService.get_by_id(document_id) ok, document = DocumentService.get_by_id(referred_document_id)
if not ok: if not ok:
raise RuntimeError(f"Document {document_id} not found") logging.warning(f"Document for referred_document_id {referred_document_id} not found")
return
if document.progress not in [1, -1]:
return
operation_status = document.run operation_status = document.run
if pipeline_id: if pipeline_id:
@ -84,7 +108,7 @@ class PipelineOperationLogService(CommonService):
if not ok: if not ok:
raise RuntimeError(f"Pipeline {pipeline_id} not found") raise RuntimeError(f"Pipeline {pipeline_id} not found")
pipeline = Pipeline(dsl=json.dumps(user_pipeline.dsl), tenant_id=user_pipeline.user_id, doc_id=document_id, task_id="", flow_id=pipeline_id) pipeline = Pipeline(dsl=json.dumps(user_pipeline.dsl), tenant_id=user_pipeline.user_id, doc_id=referred_document_id, task_id="", flow_id=pipeline_id)
tenant_id = user_pipeline.user_id tenant_id = user_pipeline.user_id
title = user_pipeline.title title = user_pipeline.title
@ -93,7 +117,7 @@ class PipelineOperationLogService(CommonService):
else: else:
ok, kb_info = KnowledgebaseService.get_by_id(document.kb_id) ok, kb_info = KnowledgebaseService.get_by_id(document.kb_id)
if not ok: if not ok:
raise RuntimeError(f"Cannot find knowledge base {document.kb_id} for document {document_id}") raise RuntimeError(f"Cannot find knowledge base {document.kb_id} for referred_document {referred_document_id}")
tenant_id = kb_info.tenant_id tenant_id = kb_info.tenant_id
title = document.name title = document.name
@ -104,7 +128,7 @@ class PipelineOperationLogService(CommonService):
log = dict( log = dict(
id=get_uuid(), id=get_uuid(),
document_id=document_id, document_id=document_id, # "x" or real document_id
tenant_id=tenant_id, tenant_id=tenant_id,
kb_id=document.kb_id, kb_id=document.kb_id,
pipeline_id=pipeline_id, pipeline_id=pipeline_id,
@ -132,18 +156,20 @@ class PipelineOperationLogService(CommonService):
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def record_pipeline_operation(cls, document_id, pipeline_id, task_type): def record_pipeline_operation(cls, document_id, pipeline_id, task_type, fake_document_ids=[]):
return cls.create(document_id=document_id, pipeline_id=pipeline_id, task_type=task_type) return cls.create(document_id=document_id, pipeline_id=pipeline_id, task_type=task_type, fake_document_ids=fake_document_ids)
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def get_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, keywords, operation_status, types, suffix): def get_file_logs_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, keywords, operation_status, types, suffix):
fields = cls.get_cls_model_fields() fields = cls.get_file_logs_fields()
if keywords: if keywords:
logs = cls.model.select(*fields).where((cls.model.kb_id == kb_id), (fn.LOWER(cls.model.document_name).contains(keywords.lower()))) logs = cls.model.select(*fields).where((cls.model.kb_id == kb_id), (fn.LOWER(cls.model.document_name).contains(keywords.lower())))
else: else:
logs = cls.model.select(*fields).where(cls.model.kb_id == kb_id) logs = cls.model.select(*fields).where(cls.model.kb_id == kb_id)
logs = logs.where(cls.model.document_id != "x")
if operation_status: if operation_status:
logs = logs.where(cls.model.operation_status.in_(operation_status)) logs = logs.where(cls.model.operation_status.in_(operation_status))
if types: if types:
@ -161,3 +187,38 @@ class PipelineOperationLogService(CommonService):
logs = logs.paginate(page_number, items_per_page) logs = logs.paginate(page_number, items_per_page)
return list(logs.dicts()), count return list(logs.dicts()), count
@classmethod
@DB.connection_context()
def get_documents_info(cls, id):
fields = [
Document.id,
Document.name,
Document.progress
]
return cls.model.select(*fields).join(Document, on=(cls.model.document_id == Document.id)).where(
cls.model.id == id,
Document.progress > 0,
Document.progress < 1
).dicts()
@classmethod
@DB.connection_context()
def get_dataset_logs_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, operation_status):
fields = cls.get_dataset_logs_fields()
logs = cls.model.select(*fields).where((cls.model.kb_id == kb_id), (cls.model.document_id == "x"))
if operation_status:
logs = logs.where(cls.model.operation_status.in_(operation_status))
count = logs.count()
if desc:
logs = logs.order_by(cls.model.getter_by(orderby).desc())
else:
logs = logs.order_by(cls.model.getter_by(orderby).asc())
if page_number and items_per_page:
logs = logs.paginate(page_number, items_per_page)
return list(logs.dicts()), count

View File

@ -35,6 +35,7 @@ from rag.utils.redis_conn import REDIS_CONN
from api import settings from api import settings
from rag.nlp import search from rag.nlp import search
CANVAS_DEBUG_DOC_ID = "dataflow_x"
def trim_header_by_lines(text: str, max_length) -> str: def trim_header_by_lines(text: str, max_length) -> str:
# Trim header text to maximum length while preserving line breaks # Trim header text to maximum length while preserving line breaks
@ -70,7 +71,7 @@ class TaskService(CommonService):
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def get_task(cls, task_id): def get_task(cls, task_id, doc_ids=[]):
"""Retrieve detailed task information by task ID. """Retrieve detailed task information by task ID.
This method fetches comprehensive task details including associated document, This method fetches comprehensive task details including associated document,
@ -84,6 +85,10 @@ class TaskService(CommonService):
dict: Task details dictionary containing all task information and related metadata. dict: Task details dictionary containing all task information and related metadata.
Returns None if task is not found or has exceeded retry limit. Returns None if task is not found or has exceeded retry limit.
""" """
doc_id = cls.model.doc_id
if doc_id == CANVAS_DEBUG_DOC_ID and doc_ids:
doc_id = doc_ids[0]
fields = [ fields = [
cls.model.id, cls.model.id,
cls.model.doc_id, cls.model.doc_id,
@ -109,7 +114,7 @@ class TaskService(CommonService):
] ]
docs = ( docs = (
cls.model.select(*fields) cls.model.select(*fields)
.join(Document, on=(cls.model.doc_id == Document.id)) .join(Document, on=(doc_id == Document.id))
.join(Knowledgebase, on=(Document.kb_id == Knowledgebase.id)) .join(Knowledgebase, on=(Document.kb_id == Knowledgebase.id))
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id)) .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))
.where(cls.model.id == task_id) .where(cls.model.id == task_id)
@ -472,14 +477,14 @@ def has_canceled(task_id):
return False return False
def queue_dataflow(tenant_id:str, flow_id:str, task_id:str, doc_id:str="x", file:dict=None, priority: int=0) -> tuple[bool, str]: def queue_dataflow(tenant_id:str, flow_id:str, task_id:str, doc_id:str=CANVAS_DEBUG_DOC_ID, file:dict=None, priority: int=0, rerun:bool=False) -> tuple[bool, str]:
task = dict( task = dict(
id=task_id, id=task_id,
doc_id=doc_id, doc_id=doc_id,
from_page=0, from_page=0,
to_page=100000000, to_page=100000000,
task_type="dataflow", task_type="dataflow" if not rerun else "dataflow_rerun",
priority=priority, priority=priority,
) )

View File

@ -1,4 +1,5 @@
import base64 import base64
import logging
from functools import partial from functools import partial
from io import BytesIO from io import BytesIO
@ -8,7 +9,7 @@ test_image_base64 = "iVBORw0KGgoAAAANSUhEUgAAAGQAAABkCAIAAAD/gAIDAAAA6ElEQVR4nO3
test_image = base64.b64decode(test_image_base64) test_image = base64.b64decode(test_image_base64)
async def image2id(d: dict, storage_put_func: partial, objname:str, bucket:str="IMAGETEMPS"): async def image2id(d: dict, storage_put_func: partial, objname:str, bucket:str="imagetemps"):
import logging import logging
from io import BytesIO from io import BytesIO
import trio import trio
@ -46,7 +47,10 @@ def id2image(image_id:str|None, storage_get_func: partial):
if len(arr) != 2: if len(arr) != 2:
return return
bkt, nm = image_id.split("-") bkt, nm = image_id.split("-")
blob = storage_get_func(bucket=bkt, filename=nm) try:
if not blob: blob = storage_get_func(bucket=bkt, filename=nm)
return if not blob:
return Image.open(BytesIO(blob)) return
return Image.open(BytesIO(blob))
except Exception as e:
logging.exception(e)

View File

@ -21,6 +21,7 @@ import networkx as nx
import trio import trio
from api import settings from api import settings
from api.db.services.document_service import DocumentService
from api.utils import get_uuid from api.utils import get_uuid
from api.utils.api_utils import timeout from api.utils.api_utils import timeout
from graphrag.entity_resolution import EntityResolution from graphrag.entity_resolution import EntityResolution
@ -54,7 +55,7 @@ async def run_graphrag(
start = trio.current_time() start = trio.current_time()
tenant_id, kb_id, doc_id = row["tenant_id"], str(row["kb_id"]), row["doc_id"] tenant_id, kb_id, doc_id = row["tenant_id"], str(row["kb_id"]), row["doc_id"]
chunks = [] chunks = []
for d in settings.retrievaler.chunk_list(doc_id, tenant_id, [kb_id], fields=["content_with_weight", "doc_id"]): for d in settings.retrievaler.chunk_list(doc_id, tenant_id, [kb_id], fields=["content_with_weight", "doc_id"], sort_by_position=True):
chunks.append(d["content_with_weight"]) chunks.append(d["content_with_weight"])
with trio.fail_after(max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000): with trio.fail_after(max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000):
@ -125,6 +126,212 @@ async def run_graphrag(
return return
async def run_graphrag_for_kb(
row: dict,
doc_ids: list[str],
language: str,
kb_parser_config: dict,
chat_model,
embedding_model,
callback,
*,
with_resolution: bool = True,
with_community: bool = True,
max_parallel_docs: int = 4,
) -> dict:
tenant_id, kb_id = row["tenant_id"], row["kb_id"]
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
start = trio.current_time()
fields_for_chunks = ["content_with_weight", "doc_id"]
if not doc_ids:
logging.info(f"Fetching all docs for {kb_id}")
docs, _ = DocumentService.get_by_kb_id(
kb_id=kb_id,
page_number=0,
items_per_page=0,
orderby="create_time",
desc=False,
keywords="",
run_status=[],
types=[],
suffix=[],
)
doc_ids = [doc["id"] for doc in docs]
doc_ids = list(dict.fromkeys(doc_ids))
if not doc_ids:
callback(msg=f"[GraphRAG] kb:{kb_id} has no processable doc_id.")
return {"ok_docs": [], "failed_docs": [], "total_docs": 0, "total_chunks": 0, "seconds": 0.0}
def load_doc_chunks(doc_id: str) -> list[str]:
from rag.utils import num_tokens_from_string
chunks = []
current_chunk = ""
for d in settings.retrievaler.chunk_list(
doc_id,
tenant_id,
[kb_id],
fields=fields_for_chunks,
sort_by_position=True,
):
content = d["content_with_weight"]
if num_tokens_from_string(current_chunk + content) < 1024:
current_chunk += content
else:
if current_chunk:
chunks.append(current_chunk)
current_chunk = content
if current_chunk:
chunks.append(current_chunk)
return chunks
all_doc_chunks: dict[str, list[str]] = {}
total_chunks = 0
for doc_id in doc_ids:
chunks = load_doc_chunks(doc_id)
all_doc_chunks[doc_id] = chunks
total_chunks += len(chunks)
if total_chunks == 0:
callback(msg=f"[GraphRAG] kb:{kb_id} has no available chunks in all documents, skip.")
return {"ok_docs": [], "failed_docs": doc_ids, "total_docs": len(doc_ids), "total_chunks": 0, "seconds": 0.0}
semaphore = trio.Semaphore(max_parallel_docs)
subgraphs: dict[str, object] = {}
failed_docs: list[tuple[str, str]] = [] # (doc_id, error)
async def build_one(doc_id: str):
chunks = all_doc_chunks.get(doc_id, [])
if not chunks:
callback(msg=f"[GraphRAG] doc:{doc_id} has no available chunks, skip generation.")
return
kg_extractor = LightKGExt if ("method" not in kb_parser_config.get("graphrag", {}) or kb_parser_config["graphrag"]["method"] != "general") else GeneralKGExt
deadline = max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000
async with semaphore:
try:
msg = f"[GraphRAG] build_subgraph doc:{doc_id}"
callback(msg=f"{msg} start (chunks={len(chunks)}, timeout={deadline}s)")
with trio.fail_after(deadline):
sg = await generate_subgraph(
kg_extractor,
tenant_id,
kb_id,
doc_id,
chunks,
language,
kb_parser_config.get("graphrag", {}).get("entity_types", []),
chat_model,
embedding_model,
callback,
)
if sg:
subgraphs[doc_id] = sg
callback(msg=f"{msg} done")
else:
failed_docs.append((doc_id, "subgraph is empty"))
callback(msg=f"{msg} empty")
except Exception as e:
failed_docs.append((doc_id, repr(e)))
callback(msg=f"[GraphRAG] build_subgraph doc:{doc_id} FAILED: {e!r}")
async with trio.open_nursery() as nursery:
for doc_id in doc_ids:
nursery.start_soon(build_one, doc_id)
ok_docs = [d for d in doc_ids if d in subgraphs]
if not ok_docs:
callback(msg=f"[GraphRAG] kb:{kb_id} no subgraphs generated successfully, end.")
now = trio.current_time()
return {"ok_docs": [], "failed_docs": failed_docs, "total_docs": len(doc_ids), "total_chunks": total_chunks, "seconds": now - start}
kb_lock = RedisDistributedLock(f"graphrag_task_{kb_id}", lock_value="batch_merge", timeout=1200)
await kb_lock.spin_acquire()
callback(msg=f"[GraphRAG] kb:{kb_id} merge lock acquired")
try:
union_nodes: set = set()
final_graph = None
for doc_id in ok_docs:
sg = subgraphs[doc_id]
union_nodes.update(set(sg.nodes()))
new_graph = await merge_subgraph(
tenant_id,
kb_id,
doc_id,
sg,
embedding_model,
callback,
)
if new_graph is not None:
final_graph = new_graph
if final_graph is None:
callback(msg=f"[GraphRAG] kb:{kb_id} merge finished (no in-memory graph returned).")
else:
callback(msg=f"[GraphRAG] kb:{kb_id} merge finished, graph ready.")
finally:
kb_lock.release()
if not with_resolution and not with_community:
now = trio.current_time()
callback(msg=f"[GraphRAG] KB merge only done in {now - start:.2f}s. ok={len(ok_docs)} / total={len(doc_ids)}")
return {"ok_docs": ok_docs, "failed_docs": failed_docs, "total_docs": len(doc_ids), "total_chunks": total_chunks, "seconds": now - start}
await kb_lock.spin_acquire()
callback(msg=f"[GraphRAG] kb:{kb_id} post-merge lock acquired for resolution/community")
try:
subgraph_nodes = set()
for sg in subgraphs.values():
subgraph_nodes.update(set(sg.nodes()))
if with_resolution:
await resolve_entities(
final_graph,
subgraph_nodes,
tenant_id,
kb_id,
None,
chat_model,
embedding_model,
callback,
)
if with_community:
await extract_community(
final_graph,
tenant_id,
kb_id,
None,
chat_model,
embedding_model,
callback,
)
finally:
kb_lock.release()
now = trio.current_time()
callback(msg=f"[GraphRAG] GraphRAG for KB {kb_id} done in {now - start:.2f} seconds. ok={len(ok_docs)} failed={len(failed_docs)} total_docs={len(doc_ids)} total_chunks={total_chunks}")
return {
"ok_docs": ok_docs,
"failed_docs": failed_docs, # [(doc_id, error), ...]
"total_docs": len(doc_ids),
"total_chunks": total_chunks,
"seconds": now - start,
}
async def generate_subgraph( async def generate_subgraph(
extractor: Extractor, extractor: Extractor,
tenant_id: str, tenant_id: str,

View File

@ -13,13 +13,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import logging
import os import os
import time import time
from functools import partial from functools import partial
from typing import Any from typing import Any
import trio import trio
from agent.component.base import ComponentBase, ComponentParamBase from agent.component.base import ComponentBase, ComponentParamBase
from api.utils.api_utils import timeout from api.utils.api_utils import timeout
@ -43,17 +42,17 @@ class ProcessBase(ComponentBase):
self.set_output("_created_time", time.perf_counter()) self.set_output("_created_time", time.perf_counter())
for k, v in kwargs.items(): for k, v in kwargs.items():
self.set_output(k, v) self.set_output(k, v)
#try: try:
with trio.fail_after(self._param.timeout): with trio.fail_after(self._param.timeout):
await self._invoke(**kwargs) await self._invoke(**kwargs)
self.callback(1, "Done") self.callback(1, "Done")
#except Exception as e: except Exception as e:
# if self.get_exception_default_value(): if self.get_exception_default_value():
# self.set_exception_default_value() self.set_exception_default_value()
# else: else:
# self.set_output("_ERROR", str(e)) self.set_output("_ERROR", str(e))
# logging.exception(e) logging.exception(e)
# self.callback(-1, str(e)) self.callback(-1, str(e))
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time")) self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
return self.output() return self.output()

View File

@ -0,0 +1,15 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,59 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
from agent.component.llm import LLMParam, LLM
class ExtractorParam(LLMParam):
def __init__(self):
super().__init__()
self.field_name = ""
def check(self):
super().check()
self.check_empty(self.field_name, "Result Destination")
class Extractor(LLM):
component_name = "Extractor"
async def _invoke(self, **kwargs):
self.callback(random.randint(1, 5) / 100.0, "Start to generate.")
inputs = self.get_input_elements()
chunks = []
chunks_key = ""
args = {}
for k, v in inputs.items():
args[k] = v["value"]
if isinstance(args[k], list):
chunks = args[k]
chunks_key = k
if chunks:
prog = 0
for i, ck in enumerate(chunks):
args[chunks_key] = ck["text"]
msg, sys_prompt = self._sys_prompt_and_msg([], args)
msg.insert(0, {"role": "system", "content": sys_prompt})
ck[self._param.field_name] = self._generate(msg)
prog += 1./len(chunks)
self.callback(prog, f"{i+1} / {len(chunks)}")
self.set_output("chunks", chunks)
else:
msg, sys_prompt = self._sys_prompt_and_msg([], args)
msg.insert(0, {"role": "system", "content": sys_prompt})
self.set_output("chunks", [{self._param.field_name: self._generate(msg)}])

View File

@ -0,0 +1,38 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Literal
from pydantic import BaseModel, ConfigDict, Field
class ExtractorFromUpstream(BaseModel):
created_time: float | None = Field(default=None, alias="_created_time")
elapsed_time: float | None = Field(default=None, alias="_elapsed_time")
name: str
file: dict | None = Field(default=None)
chunks: list[dict[str, Any]] | None = Field(default=None)
output_format: Literal["json", "markdown", "text", "html"] | None = Field(default=None)
json_result: list[dict[str, Any]] | None = Field(default=None, alias="json")
markdown_result: str | None = Field(default=None, alias="markdown")
text_result: str | None = Field(default=None, alias="text")
html_result: list[str] | None = Field(default=None, alias="html")
model_config = ConfigDict(populate_by_name=True, extra="forbid")
# def to_dict(self, *, exclude_none: bool = True) -> dict:
# return self.model_dump(by_alias=True, exclude_none=exclude_none)

View File

@ -17,15 +17,11 @@ import datetime
import json import json
import logging import logging
import random import random
import time
from timeit import default_timer as timer from timeit import default_timer as timer
import trio import trio
from agent.canvas import Graph from agent.canvas import Graph
from api.db import PipelineTaskType
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService
from api.db.services.task_service import has_canceled from api.db.services.task_service import has_canceled, TaskService, CANVAS_DEBUG_DOC_ID
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
from rag.utils.redis_conn import REDIS_CONN from rag.utils.redis_conn import REDIS_CONN
@ -34,8 +30,8 @@ class Pipeline(Graph):
if isinstance(dsl, dict): if isinstance(dsl, dict):
dsl = json.dumps(dsl, ensure_ascii=False) dsl = json.dumps(dsl, ensure_ascii=False)
super().__init__(dsl, tenant_id, task_id) super().__init__(dsl, tenant_id, task_id)
if self._doc_id == "x": if doc_id == CANVAS_DEBUG_DOC_ID:
self._doc_id = None doc_id = None
self._doc_id = doc_id self._doc_id = doc_id
self._flow_id = flow_id self._flow_id = flow_id
self._kb_id = None self._kb_id = None
@ -80,7 +76,7 @@ class Pipeline(Graph):
} }
] ]
REDIS_CONN.set_obj(log_key, obj, 60 * 30) REDIS_CONN.set_obj(log_key, obj, 60 * 30)
if self._doc_id: if self._doc_id and self.task_id:
percentage = 1.0 / len(self.components.items()) percentage = 1.0 / len(self.components.items())
msg = "" msg = ""
finished = 0.0 finished = 0.0
@ -96,7 +92,7 @@ class Pipeline(Graph):
if finished < 0: if finished < 0:
break break
finished += o["trace"][-1]["progress"] * percentage finished += o["trace"][-1]["progress"] * percentage
DocumentService.update_by_id(self._doc_id, {"progress": finished, "progress_msg": msg}) TaskService.update_progress(self.task_id, {"progress": finished, "progress_msg": msg})
except Exception as e: except Exception as e:
logging.exception(e) logging.exception(e)
@ -113,34 +109,32 @@ class Pipeline(Graph):
logging.exception(e) logging.exception(e)
return [] return []
def reset(self):
super().reset() async def run(self, **kwargs):
log_key = f"{self._flow_id}-{self.task_id}-logs" log_key = f"{self._flow_id}-{self.task_id}-logs"
try: try:
REDIS_CONN.set_obj(log_key, [], 60 * 10) REDIS_CONN.set_obj(log_key, [], 60 * 10)
except Exception as e: except Exception as e:
logging.exception(e) logging.exception(e)
self.error = ""
async def run(self, **kwargs):
st = time.perf_counter()
if not self.path: if not self.path:
self.path.append("File") self.path.append("File")
if self._doc_id:
DocumentService.update_by_id(
self._doc_id, {"progress": random.randint(0, 5) / 100.0, "progress_msg": "Start the pipeline...", "process_begin_at": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
)
self.error = ""
idx = len(self.path) - 1
if idx == 0:
cpn_obj = self.get_component_obj(self.path[0]) cpn_obj = self.get_component_obj(self.path[0])
await cpn_obj.invoke(**kwargs) await cpn_obj.invoke(**kwargs)
if cpn_obj.error(): if cpn_obj.error():
self.error = "[ERROR]" + cpn_obj.error() self.error = "[ERROR]" + cpn_obj.error()
else: self.callback(cpn_obj.component_name, -1, self.error)
idx += 1
self.path.extend(cpn_obj.get_downstream()) if self._doc_id:
TaskService.update_progress(self.task_id, {
"progress": random.randint(0, 5) / 100.0,
"progress_msg": "Start the pipeline...",
"begin_at": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")})
idx = len(self.path) - 1
cpn_obj = self.get_component_obj(self.path[idx])
idx += 1
self.path.extend(cpn_obj.get_downstream())
while idx < len(self.path) and not self.error: while idx < len(self.path) and not self.error:
last_cpn = self.get_component_obj(self.path[idx - 1]) last_cpn = self.get_component_obj(self.path[idx - 1])
@ -152,23 +146,21 @@ class Pipeline(Graph):
async with trio.open_nursery() as nursery: async with trio.open_nursery() as nursery:
nursery.start_soon(invoke) nursery.start_soon(invoke)
if cpn_obj.error(): if cpn_obj.error():
self.error = "[ERROR]" + cpn_obj.error() self.error = "[ERROR]" + cpn_obj.error()
self.callback(cpn_obj.component_name, -1, self.error) self.callback(cpn_obj._id, -1, self.error)
break break
idx += 1 idx += 1
self.path.extend(cpn_obj.get_downstream()) self.path.extend(cpn_obj.get_downstream())
self.callback("END", 1, json.dumps(self.get_component_obj(self.path[-1]).output(), ensure_ascii=False)) self.callback("END", 1 if not self.error else -1, json.dumps(self.get_component_obj(self.path[-1]).output(), ensure_ascii=False))
if self._doc_id: if not self.error:
DocumentService.update_by_id( return self.get_component_obj(self.path[-1]).output()
self._doc_id,
{
"progress": 1 if not self.error else -1,
"progress_msg": "Pipeline finished...\n" + self.error,
"process_duration": time.perf_counter() - st,
},
)
PipelineOperationLogService.create(document_id=self._doc_id, pipeline_id=self._flow_id, task_type=PipelineTaskType.PARSE) TaskService.update_progress(self.task_id, {
"progress": -1,
"progress_msg": f"[ERROR]: {self.error}"})
return {}

View File

@ -99,7 +99,7 @@ class Splitter(ProcessBase):
{ {
"text": RAGFlowPdfParser.remove_tag(c), "text": RAGFlowPdfParser.remove_tag(c),
"image": img, "image": img,
"positions": RAGFlowPdfParser.extract_positions(c), "positions": [[pos[0][-1]+1, *pos[1:]] for pos in RAGFlowPdfParser.extract_positions(c)],
} }
for c, img in zip(chunks, images) for c, img in zip(chunks, images)
] ]

View File

@ -120,8 +120,12 @@ class Tokenizer(ProcessBase):
ck["question_tks"] = rag_tokenizer.tokenize("\n".join(ck["questions"])) ck["question_tks"] = rag_tokenizer.tokenize("\n".join(ck["questions"]))
if ck.get("keywords"): if ck.get("keywords"):
ck["important_tks"] = rag_tokenizer.tokenize("\n".join(ck["keywords"])) ck["important_tks"] = rag_tokenizer.tokenize("\n".join(ck["keywords"]))
ck["content_ltks"] = rag_tokenizer.tokenize(ck["text"]) if ck.get("summary"):
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"]) ck["content_ltks"] = rag_tokenizer.tokenize(ck["summary"])
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
else:
ck["content_ltks"] = rag_tokenizer.tokenize(ck["text"])
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
if i % 100 == 99: if i % 100 == 99:
self.callback(i * 1.0 / len(chunks) / parts) self.callback(i * 1.0 / len(chunks) / parts)

View File

@ -285,6 +285,7 @@ def tokenize_chunks(chunks, doc, eng, pdf_parser=None):
res.append(d) res.append(d)
return res return res
def tokenize_chunks_with_images(chunks, doc, eng, images): def tokenize_chunks_with_images(chunks, doc, eng, images):
res = [] res = []
# wrap up as es documents # wrap up as es documents
@ -299,6 +300,7 @@ def tokenize_chunks_with_images(chunks, doc, eng, images):
res.append(d) res.append(d)
return res return res
def tokenize_table(tbls, doc, eng, batch_size=10): def tokenize_table(tbls, doc, eng, batch_size=10):
res = [] res = []
# add tables # add tables

View File

@ -383,7 +383,7 @@ class Dealer:
vector_column = f"q_{dim}_vec" vector_column = f"q_{dim}_vec"
zero_vector = [0.0] * dim zero_vector = [0.0] * dim
sim_np = np.array(sim) sim_np = np.array(sim)
filtered_count = (sim_np >= similarity_threshold).sum() filtered_count = (sim_np >= similarity_threshold).sum()
ranks["total"] = int(filtered_count) # Convert from np.int64 to Python int otherwise JSON serializable error ranks["total"] = int(filtered_count) # Convert from np.int64 to Python int otherwise JSON serializable error
for i in idx: for i in idx:
if sim[i] < similarity_threshold: if sim[i] < similarity_threshold:
@ -444,12 +444,27 @@ class Dealer:
def chunk_list(self, doc_id: str, tenant_id: str, def chunk_list(self, doc_id: str, tenant_id: str,
kb_ids: list[str], max_count=1024, kb_ids: list[str], max_count=1024,
offset=0, offset=0,
fields=["docnm_kwd", "content_with_weight", "img_id"]): fields=["docnm_kwd", "content_with_weight", "img_id"],
sort_by_position: bool = False):
condition = {"doc_id": doc_id} condition = {"doc_id": doc_id}
fields_set = set(fields or [])
if sort_by_position:
for need in ("page_num_int", "position_int", "top_int"):
if need not in fields_set:
fields_set.add(need)
fields = list(fields_set)
orderBy = OrderByExpr()
if sort_by_position:
orderBy.asc("page_num_int")
orderBy.asc("position_int")
orderBy.asc("top_int")
res = [] res = []
bs = 128 bs = 128
for p in range(offset, max_count, bs): for p in range(offset, max_count, bs):
es_res = self.dataStore.search(fields, [], condition, [], OrderByExpr(), p, bs, index_name(tenant_id), es_res = self.dataStore.search(fields, [], condition, [], orderBy, p, bs, index_name(tenant_id),
kb_ids) kb_ids)
dict_chunks = self.dataStore.getFields(es_res, fields) dict_chunks = self.dataStore.getFields(es_res, fields)
for id, doc in dict_chunks.items(): for id, doc in dict_chunks.items():

View File

@ -21,11 +21,12 @@ import sys
import threading import threading
import time import time
from api.db.services.canvas_service import UserCanvasService from api.db.services.canvas_service import UserCanvasService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
from api.utils.api_utils import timeout from api.utils.api_utils import timeout
from api.utils.base64_image import image2id from api.utils.base64_image import image2id
from api.utils.log_utils import init_root_logger, get_project_base_directory from api.utils.log_utils import init_root_logger, get_project_base_directory
from graphrag.general.index import run_graphrag from graphrag.general.index import run_graphrag_for_kb
from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache
from rag.flow.pipeline import Pipeline from rag.flow.pipeline import Pipeline
from rag.prompts import keyword_extraction, question_proposal, content_tagging from rag.prompts import keyword_extraction, question_proposal, content_tagging
@ -49,7 +50,7 @@ from peewee import DoesNotExist
from api.db import LLMType, ParserType, PipelineTaskType from api.db import LLMType, ParserType, PipelineTaskType
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService
from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
from api.db.services.task_service import TaskService, has_canceled from api.db.services.task_service import TaskService, has_canceled, CANVAS_DEBUG_DOC_ID
from api.db.services.file2document_service import File2DocumentService from api.db.services.file2document_service import File2DocumentService
from api import settings from api import settings
from api.versions import get_ragflow_version from api.versions import get_ragflow_version
@ -85,6 +86,12 @@ FACTORY = {
ParserType.TAG.value: tag ParserType.TAG.value: tag
} }
TASK_TYPE_TO_PIPELINE_TASK_TYPE = {
"dataflow" : PipelineTaskType.PARSE,
"raptor": PipelineTaskType.RAPTOR,
"graphrag": PipelineTaskType.GRAPH_RAG,
}
UNACKED_ITERATOR = None UNACKED_ITERATOR = None
CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1] CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
@ -140,6 +147,7 @@ def start_tracemalloc_and_snapshot(signum, frame):
max_rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss max_rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
logging.info(f"taken snapshot {snapshot_file}. max RSS={max_rss / 1000:.2f} MB, current memory usage: {current / 10**6:.2f} MB, Peak memory usage: {peak / 10**6:.2f} MB") logging.info(f"taken snapshot {snapshot_file}. max RSS={max_rss / 1000:.2f} MB, current memory usage: {current / 10**6:.2f} MB, Peak memory usage: {peak / 10**6:.2f} MB")
# SIGUSR2 handler: stop tracemalloc # SIGUSR2 handler: stop tracemalloc
def stop_tracemalloc(signum, frame): def stop_tracemalloc(signum, frame):
if tracemalloc.is_tracing(): if tracemalloc.is_tracing():
@ -148,6 +156,7 @@ def stop_tracemalloc(signum, frame):
else: else:
logging.info("tracemalloc not running") logging.info("tracemalloc not running")
class TaskCanceledException(Exception): class TaskCanceledException(Exception):
def __init__(self, msg): def __init__(self, msg):
self.msg = msg self.msg = msg
@ -215,6 +224,10 @@ async def collect():
canceled = False canceled = False
if msg.get("doc_id", "") == "x": if msg.get("doc_id", "") == "x":
task = msg task = msg
if task["task_type"] == "graphrag" and msg.get("doc_ids", []):
print(f"hack {msg['doc_ids']=}=",flush=True)
task = TaskService.get_task(msg["id"], msg["doc_ids"])
task["doc_ids"] = msg["doc_ids"]
else: else:
task = TaskService.get_task(msg["id"]) task = TaskService.get_task(msg["id"])
@ -461,11 +474,97 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
async def run_dataflow(task: dict): async def run_dataflow(task: dict):
task_start_ts = timer()
dataflow_id = task["dataflow_id"] dataflow_id = task["dataflow_id"]
e, cvs = UserCanvasService.get_by_id(dataflow_id) doc_id = task["doc_id"]
pipeline = Pipeline(cvs.dsl, tenant_id=task["tenant_id"], doc_id=task["doc_id"], task_id=task["id"], flow_id=dataflow_id) task_id = task["id"]
pipeline.reset() if task["task_type"] == "dataflow":
await pipeline.run(file=task.get("file")) e, cvs = UserCanvasService.get_by_id(dataflow_id)
assert e, "User pipeline not found."
dsl = cvs.dsl
else:
e, pipeline_log = PipelineOperationLogService.get_by_id(dataflow_id)
assert e, "Pipeline log not found."
dsl = pipeline_log.dsl
pipeline = Pipeline(dsl, tenant_id=task["tenant_id"], doc_id=doc_id, task_id=task_id, flow_id=dataflow_id)
chunks = await pipeline.run(file=task["file"]) if task.get("file") else pipeline.run()
if doc_id == CANVAS_DEBUG_DOC_ID:
return
if not chunks:
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE)
return
embedding_token_consumption = chunks.get("embedding_token_consumption", 0)
if chunks.get("chunks"):
chunks = chunks["chunks"]
elif chunks.get("json"):
chunks = chunks["json"]
elif chunks.get("markdown"):
chunks = [{"text": [chunks["markdown"]]}]
elif chunks.get("text"):
chunks = [{"text": [chunks["text"]]}]
elif chunks.get("html"):
chunks = [{"text": [chunks["html"]]}]
keys = [k for o in chunks for k in list(o.keys())]
if not any([re.match(r"q_[0-9]+_vec", k) for k in keys]):
set_progress(task_id, prog=0.82, msg="Start to embedding...")
e, kb = KnowledgebaseService.get_by_id(task["kb_id"])
embedding_id = kb.embd_id
embedding_model = LLMBundle(task["tenant_id"], LLMType.EMBEDDING, llm_name=embedding_id)
@timeout(60)
def batch_encode(txts):
nonlocal embedding_model
return embedding_model.encode([truncate(c, embedding_model.max_length - 10) for c in txts])
vects = np.array([])
texts = [o.get("questions", o.get("summary", o["text"])) for o in chunks]
delta = 0.20/(len(texts)//EMBEDDING_BATCH_SIZE)
prog = 0.8
for i in range(0, len(texts), EMBEDDING_BATCH_SIZE):
async with embed_limiter:
vts, c = await trio.to_thread.run_sync(lambda: batch_encode(texts[i : i + EMBEDDING_BATCH_SIZE]))
if len(vects) == 0:
vects = vts
else:
vects = np.concatenate((vects, vts), axis=0)
embedding_token_consumption += c
prog += delta
set_progress(task_id, prog=prog, msg=f"{i+1} / {len(texts)//EMBEDDING_BATCH_SIZE}")
assert len(vects) == len(chunks)
for i, ck in enumerate(chunks):
v = vects[i].tolist()
ck["q_%d_vec" % len(v)] = v
for ck in chunks:
ck["doc_id"] = task["doc_id"]
ck["kb_id"] = [str(task["kb_id"])]
ck["docnm_kwd"] = task["name"]
ck["create_time"] = str(datetime.now()).replace("T", " ")[:19]
ck["create_timestamp_flt"] = datetime.now().timestamp()
ck["id"] = xxhash.xxh64((ck["text"] + str(ck["doc_id"])).encode("utf-8")).hexdigest()
if "questions" in ck:
del ck["questions"]
if "keywords" in ck:
del ck["keywords"]
if "summary" in ck:
del ck["summary"]
del ck["text"]
start_ts = timer()
set_progress(task_id, prog=0.82, msg="Start to index...")
e = await insert_es(task_id, task["tenant_id"], task["kb_id"], chunks, partial(set_progress, task_id, 0, 100000000))
if not e:
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE)
return
time_cost = timer() - start_ts
task_time_cost = timer() - task_start_ts
set_progress(task_id, prog=1., msg="Indexing done ({:.2f}s). Task done ({:.2f}s)".format(time_cost, task_time_cost))
logging.info(
"[Done], chunks({}), token({}), elapsed:{:.2f}".format(len(chunks), embedding_token_consumption, task_time_cost))
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE)
@timeout(3600) @timeout(3600)
@ -510,11 +609,48 @@ async def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
return res, tk_count return res, tk_count
async def delete_image(kb_id, chunk_id):
try:
async with minio_limiter:
STORAGE_IMPL.delete(kb_id, chunk_id)
except Exception:
logging.exception(f"Deleting image of chunk {chunk_id} got exception")
raise
async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_callback):
for b in range(0, len(chunks), DOC_BULK_SIZE):
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + DOC_BULK_SIZE], search.index_name(task_tenant_id), task_dataset_id))
task_canceled = has_canceled(task_id)
if task_canceled:
progress_callback(-1, msg="Task has been canceled.")
return
if b % 128 == 0:
progress_callback(prog=0.8 + 0.1 * (b + 1) / len(chunks), msg="")
if doc_store_result:
error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
progress_callback(-1, msg=error_message)
raise Exception(error_message)
chunk_ids = [chunk["id"] for chunk in chunks[:b + DOC_BULK_SIZE]]
chunk_ids_str = " ".join(chunk_ids)
try:
TaskService.update_chunk_ids(task_id, chunk_ids_str)
except DoesNotExist:
logging.warning(f"do_handle_task update_chunk_ids failed since task {task_id} is unknown.")
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id), task_dataset_id))
async with trio.open_nursery() as nursery:
for chunk_id in chunk_ids:
nursery.start_soon(delete_image, task_dataset_id, chunk_id)
progress_callback(-1, msg=f"Chunk updates failed since task {task_id} is unknown.")
return
return True
@timeout(60*60*2, 1) @timeout(60*60*2, 1)
async def do_handle_task(task): async def do_handle_task(task):
task_type = task.get("task_type", "") task_type = task.get("task_type", "")
if task_type == "dataflow" and task.get("doc_id", "") == "x": if task_type == "dataflow" and task.get("doc_id", "") == CANVAS_DEBUG_DOC_ID:
await run_dataflow(task) await run_dataflow(task)
return return
@ -559,7 +695,7 @@ async def do_handle_task(task):
init_kb(task, vector_size) init_kb(task, vector_size)
if task_type == "dataflow": if task_type[:len("dataflow")] == "dataflow":
await run_dataflow(task) await run_dataflow(task)
return return
@ -580,7 +716,19 @@ async def do_handle_task(task):
with_resolution = graphrag_conf.get("resolution", False) with_resolution = graphrag_conf.get("resolution", False)
with_community = graphrag_conf.get("community", False) with_community = graphrag_conf.get("community", False)
async with kg_limiter: async with kg_limiter:
await run_graphrag(task, task_language, with_resolution, with_community, chat_model, embedding_model, progress_callback) # await run_graphrag(task, task_language, with_resolution, with_community, chat_model, embedding_model, progress_callback)
result = await run_graphrag_for_kb(
row=task,
doc_ids=task.get("doc_ids", []),
language=task_language,
kb_parser_config=task_parser_config,
chat_model=chat_model,
embedding_model=embedding_model,
callback=progress_callback,
with_resolution=with_resolution,
with_community=with_community,
)
logging.info(f"GraphRAG task result for task {task}:\n{result}")
progress_callback(prog=1.0, msg="Knowledge Graph done ({:.2f}s)".format(timer() - start_ts)) progress_callback(prog=1.0, msg="Knowledge Graph done ({:.2f}s)".format(timer() - start_ts))
return return
else: else:
@ -609,48 +757,15 @@ async def do_handle_task(task):
chunk_count = len(set([chunk["id"] for chunk in chunks])) chunk_count = len(set([chunk["id"] for chunk in chunks]))
start_ts = timer() start_ts = timer()
doc_store_result = "" e = await insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_callback)
if not e:
async def delete_image(kb_id, chunk_id): return
try:
async with minio_limiter:
STORAGE_IMPL.delete(kb_id, chunk_id)
except Exception:
logging.exception(
"Deleting image of chunk {}/{}/{} got exception".format(task["location"], task["name"], chunk_id))
raise
for b in range(0, len(chunks), DOC_BULK_SIZE):
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + DOC_BULK_SIZE], search.index_name(task_tenant_id), task_dataset_id))
task_canceled = has_canceled(task_id)
if task_canceled:
progress_callback(-1, msg="Task has been canceled.")
return
if b % 128 == 0:
progress_callback(prog=0.8 + 0.1 * (b + 1) / len(chunks), msg="")
if doc_store_result:
error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
progress_callback(-1, msg=error_message)
raise Exception(error_message)
chunk_ids = [chunk["id"] for chunk in chunks[:b + DOC_BULK_SIZE]]
chunk_ids_str = " ".join(chunk_ids)
try:
TaskService.update_chunk_ids(task["id"], chunk_ids_str)
except DoesNotExist:
logging.warning(f"do_handle_task update_chunk_ids failed since task {task['id']} is unknown.")
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id), task_dataset_id))
async with trio.open_nursery() as nursery:
for chunk_id in chunk_ids:
nursery.start_soon(delete_image, task_dataset_id, chunk_id)
progress_callback(-1, msg=f"Chunk updates failed since task {task['id']} is unknown.")
return
logging.info("Indexing doc({}), page({}-{}), chunks({}), elapsed: {:.2f}".format(task_document_name, task_from_page, logging.info("Indexing doc({}), page({}-{}), chunks({}), elapsed: {:.2f}".format(task_document_name, task_from_page,
task_to_page, len(chunks), task_to_page, len(chunks),
timer() - start_ts)) timer() - start_ts))
DocumentService.increment_chunk_num(task_doc_id, task_dataset_id, token_count, chunk_count, 0) DocumentService.increment_chunk_num(task_doc_id, task_dataset_id, token_count, chunk_count, 0)
PipelineOperationLogService.record_pipeline_operation(document_id=task_doc_id, pipeline_id="", task_type=PipelineTaskType.PARSE)
time_cost = timer() - start_ts time_cost = timer() - start_ts
task_time_cost = timer() - task_start_ts task_time_cost = timer() - task_start_ts
@ -667,6 +782,10 @@ async def handle_task():
if not task: if not task:
await trio.sleep(5) await trio.sleep(5)
return return
task_type = task["task_type"]
pipeline_task_type = TASK_TYPE_TO_PIPELINE_TASK_TYPE.get(task_type, PipelineTaskType.PARSE) or PipelineTaskType.PARSE
try: try:
logging.info(f"handle_task begin for task {json.dumps(task)}") logging.info(f"handle_task begin for task {json.dumps(task)}")
CURRENT_TASKS[task["id"]] = copy.deepcopy(task) CURRENT_TASKS[task["id"]] = copy.deepcopy(task)
@ -686,7 +805,13 @@ async def handle_task():
except Exception: except Exception:
pass pass
logging.exception(f"handle_task got exception for task {json.dumps(task)}") logging.exception(f"handle_task got exception for task {json.dumps(task)}")
PipelineOperationLogService.record_pipeline_operation(document_id=task["doc_id"], pipeline_id=task.get("dataflow_id", "") or "", task_type=PipelineTaskType.PARSE) finally:
task_document_ids = []
if task_type in ["graphrag"]:
task_document_ids = task["doc_ids"]
if task["doc_id"] != CANVAS_DEBUG_DOC_ID:
PipelineOperationLogService.record_pipeline_operation(document_id=task["doc_id"], pipeline_id=task.get("dataflow_id", "") or "", task_type=pipeline_task_type, fake_document_ids=task_document_ids)
redis_msg.ack() redis_msg.ack()