mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-02-03 17:15:08 +08:00
Compare commits
2 Commits
a6039cf563
...
1b19d302c5
| Author | SHA1 | Date | |
|---|---|---|---|
| 1b19d302c5 | |||
| 840b2b5809 |
@ -113,6 +113,15 @@ class LLM(ComponentBase):
|
||||
def add2system_prompt(self, txt):
|
||||
self._param.sys_prompt += txt
|
||||
|
||||
def _sys_prompt_and_msg(self, msg, args):
|
||||
for p in self._param.prompts:
|
||||
if msg and msg[-1]["role"] == p["role"]:
|
||||
continue
|
||||
p = deepcopy(p)
|
||||
p["content"] = self.string_format(p["content"], args)
|
||||
msg.append(p)
|
||||
return msg, self.string_format(self._param.sys_prompt, args)
|
||||
|
||||
def _prepare_prompt_variables(self):
|
||||
if self._param.visual_files_var:
|
||||
self.imgs = self._canvas.get_variable_value(self._param.visual_files_var)
|
||||
@ -128,7 +137,6 @@ class LLM(ComponentBase):
|
||||
|
||||
args = {}
|
||||
vars = self.get_input_elements() if not self._param.debug_inputs else self._param.debug_inputs
|
||||
sys_prompt = self._param.sys_prompt
|
||||
for k, o in vars.items():
|
||||
args[k] = o["value"]
|
||||
if not isinstance(args[k], str):
|
||||
@ -138,16 +146,8 @@ class LLM(ComponentBase):
|
||||
args[k] = str(args[k])
|
||||
self.set_input_value(k, args[k])
|
||||
|
||||
msg = self._canvas.get_history(self._param.message_history_window_size)[:-1]
|
||||
for p in self._param.prompts:
|
||||
if msg and msg[-1]["role"] == p["role"]:
|
||||
continue
|
||||
msg.append(deepcopy(p))
|
||||
|
||||
sys_prompt = self.string_format(sys_prompt, args)
|
||||
msg, sys_prompt = self._sys_prompt_and_msg(self._canvas.get_history(self._param.message_history_window_size)[:-1], args)
|
||||
user_defined_prompt, sys_prompt = self._extract_prompts(sys_prompt)
|
||||
for m in msg:
|
||||
m["content"] = self.string_format(m["content"], args)
|
||||
if self._param.cite and self._canvas.get_reference()["chunks"]:
|
||||
sys_prompt += citation_prompt(user_defined_prompt)
|
||||
|
||||
|
||||
@ -28,6 +28,7 @@ from api.db import CanvasCategory, FileType
|
||||
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService, API4ConversationService
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
|
||||
from api.db.services.task_service import queue_dataflow
|
||||
from api.db.services.user_service import TenantService
|
||||
from api.db.services.user_canvas_version import UserCanvasVersionService
|
||||
@ -174,6 +175,25 @@ def run():
|
||||
return resp
|
||||
|
||||
|
||||
@manager.route('/rerun', methods=['POST']) # noqa: F821
|
||||
@validate_request("id", "dsl", "component_id")
|
||||
@login_required
|
||||
def rerun():
|
||||
req = request.json
|
||||
doc = PipelineOperationLogService.get_documents_info(req["id"])
|
||||
if not doc:
|
||||
return get_data_error_result(message="Document not found.")
|
||||
doc = doc[0]
|
||||
if 0 < doc["progress"] < 1:
|
||||
return get_data_error_result(message=f"`{doc['name']}` is processing...")
|
||||
|
||||
dsl = req["dsl"]
|
||||
dsl["path"] = [req["component_id"]]
|
||||
PipelineOperationLogService.update_by_id(req["id"], {"dsl": dsl})
|
||||
queue_dataflow(tenant_id=current_user.id, flow_id=req["id"], task_id=get_uuid(), doc_id=doc["id"], priority=0, rerun=True)
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
@manager.route('/cancel/<task_id>', methods=['PUT']) # noqa: F821
|
||||
@login_required
|
||||
def cancel(task_id):
|
||||
|
||||
@ -14,17 +14,19 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
import json
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
|
||||
from api.db.services import duplicate_name
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.document_service import DocumentService, queue_raptor_o_graphrag_tasks
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
|
||||
from api.db.services.task_service import TaskService
|
||||
from api.db.services.user_service import TenantService, UserTenantService
|
||||
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request, not_allowed_parameters
|
||||
from api.utils.api_utils import get_error_data_result, server_error_response, get_data_error_result, validate_request, not_allowed_parameters
|
||||
from api.utils import get_uuid
|
||||
from api.db import StatusEnum, FileSource, VALID_FILE_TYPES
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
@ -435,18 +437,60 @@ def list_pipeline_logs():
|
||||
suffix = req.get("suffix", [])
|
||||
|
||||
try:
|
||||
docs, tol = PipelineOperationLogService.get_by_kb_id(kb_id, page_number, items_per_page, orderby, desc, keywords, operation_status, types, suffix)
|
||||
logs, tol = PipelineOperationLogService.get_file_logs_by_kb_id(kb_id, page_number, items_per_page, orderby, desc, keywords, operation_status, types, suffix)
|
||||
|
||||
if create_time_from or create_time_to:
|
||||
filtered_docs = []
|
||||
for doc in docs:
|
||||
for doc in logs:
|
||||
doc_create_time = doc.get("create_time", 0)
|
||||
if (create_time_from == 0 or doc_create_time >= create_time_from) and (create_time_to == 0 or doc_create_time <= create_time_to):
|
||||
filtered_docs.append(doc)
|
||||
docs = filtered_docs
|
||||
logs = filtered_docs
|
||||
|
||||
|
||||
return get_json_result(data={"total": tol, "docs": docs})
|
||||
return get_json_result(data={"total": tol, "logs": logs})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/list_pipeline_dataset_logs", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def list_pipeline_dataset_logs():
|
||||
kb_id = request.args.get("kb_id")
|
||||
if not kb_id:
|
||||
return get_json_result(data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
|
||||
|
||||
page_number = int(request.args.get("page", 0))
|
||||
items_per_page = int(request.args.get("page_size", 0))
|
||||
orderby = request.args.get("orderby", "create_time")
|
||||
if request.args.get("desc", "true").lower() == "false":
|
||||
desc = False
|
||||
else:
|
||||
desc = True
|
||||
create_time_from = int(request.args.get("create_time_from", 0))
|
||||
create_time_to = int(request.args.get("create_time_to", 0))
|
||||
|
||||
req = request.get_json()
|
||||
|
||||
operation_status = req.get("operation_status", [])
|
||||
if operation_status:
|
||||
invalid_status = {s for s in operation_status if s not in ["success", "failed", "running", "pending"]}
|
||||
if invalid_status:
|
||||
return get_data_error_result(message=f"Invalid filter operation_status status conditions: {', '.join(invalid_status)}")
|
||||
|
||||
try:
|
||||
logs, tol = PipelineOperationLogService.get_dataset_logs_by_kb_id(kb_id, page_number, items_per_page, orderby, desc, operation_status)
|
||||
|
||||
if create_time_from or create_time_to:
|
||||
filtered_docs = []
|
||||
for doc in logs:
|
||||
doc_create_time = doc.get("create_time", 0)
|
||||
if (create_time_from == 0 or doc_create_time >= create_time_from) and (create_time_to == 0 or doc_create_time <= create_time_to):
|
||||
filtered_docs.append(doc)
|
||||
logs = filtered_docs
|
||||
|
||||
|
||||
return get_json_result(data={"total": tol, "logs": logs})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@ -478,3 +522,68 @@ def pipeline_log_detail():
|
||||
return get_data_error_result(message="Invalid pipeline log ID")
|
||||
|
||||
return get_json_result(data=log.to_dict())
|
||||
|
||||
|
||||
@manager.route("/run_graphrag", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def run_graphrag():
|
||||
req = request.json
|
||||
|
||||
kb_id = req.get("kb_id", "")
|
||||
if not kb_id:
|
||||
return get_error_data_result(message='Lack of "KB ID"')
|
||||
|
||||
doc_ids = req.get("doc_ids", [])
|
||||
if not doc_ids:
|
||||
return get_error_data_result(message="Need to specify document IDs to run Graph RAG")
|
||||
|
||||
ok, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
if not ok:
|
||||
return get_error_data_result(message="Invalid Knowledgebase ID")
|
||||
|
||||
task_id = kb.graphrag_task_id
|
||||
ok, task = TaskService.get_by_id(task_id)
|
||||
if not ok:
|
||||
logging.warning(f"A valid GraphRAG task id is expected for kb {kb_id}")
|
||||
|
||||
if task and task.progress not in [-1, 1]:
|
||||
return get_error_data_result(message=f"Task in progress with status {task.progress}. A Graph Task is already running.")
|
||||
|
||||
document_ids = set()
|
||||
sample_document = {}
|
||||
for doc_id in doc_ids:
|
||||
ok, document = DocumentService.get_by_id(doc_id)
|
||||
if ok:
|
||||
document_ids.add(document.id)
|
||||
if not sample_document:
|
||||
sample_document = document.to_dict()
|
||||
|
||||
task_id = queue_raptor_o_graphrag_tasks(doc=sample_document, ty="graphrag", priority=0, fake_doc_id="x", doc_ids=list(document_ids))
|
||||
|
||||
if not KnowledgebaseService.update_by_id(kb.id, {"graphrag_task_id": task_id}):
|
||||
logging.warning(f"Cannot save graphrag_task_id for kb {kb_id}")
|
||||
|
||||
return get_json_result(data={"graphrag_task_id": task_id})
|
||||
|
||||
|
||||
@manager.route("/trace_graphrag", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def trace_graphrag():
|
||||
kb_id = request.args.get("kb_id", "")
|
||||
if not kb_id:
|
||||
return get_error_data_result(message='Lack of "KB ID"')
|
||||
|
||||
ok, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
if not ok:
|
||||
return get_error_data_result(message="Invalid Knowledgebase ID")
|
||||
|
||||
|
||||
task_id = kb.graphrag_task_id
|
||||
if not task_id:
|
||||
return get_error_data_result(message="GraphRAG Task ID Not Found")
|
||||
|
||||
ok, task = TaskService.get_by_id(task_id)
|
||||
if not ok:
|
||||
return get_json_result(data=False, message="GraphRAG Task Not Found or Error Occurred", code=settings.RetCode.ARGUMENT_ERROR)
|
||||
|
||||
return get_json_result(data=task.to_dict())
|
||||
|
||||
@ -124,10 +124,12 @@ VALID_MCP_SERVER_TYPES = {MCPServerType.SSE, MCPServerType.STREAMABLE_HTTP}
|
||||
|
||||
class PipelineTaskType(StrEnum):
|
||||
PARSE = "Parse"
|
||||
DOWNLOAD = "DOWNLOAD"
|
||||
DOWNLOAD = "Download"
|
||||
RAPTOR = "RAPTOR"
|
||||
GRAPH_RAG = "GraphRAG"
|
||||
|
||||
|
||||
VALID_PIPELINE_TASK_TYPES = {PipelineTaskType.PARSE, PipelineTaskType.DOWNLOAD}
|
||||
VALID_PIPELINE_TASK_TYPES = {PipelineTaskType.PARSE, PipelineTaskType.DOWNLOAD, PipelineTaskType.RAPTOR, PipelineTaskType.GRAPH_RAG}
|
||||
|
||||
|
||||
KNOWLEDGEBASE_FOLDER_NAME=".knowledgebase"
|
||||
|
||||
@ -649,6 +649,9 @@ class Knowledgebase(DataBaseModel):
|
||||
pipeline_id = CharField(max_length=32, null=True, help_text="Pipeline ID", index=True)
|
||||
parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]})
|
||||
pagerank = IntegerField(default=0, index=False)
|
||||
|
||||
graphrag_task_id = CharField(max_length=32, null=True, help_text="Graph RAG task ID", index=True)
|
||||
|
||||
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True)
|
||||
|
||||
def __str__(self):
|
||||
@ -1065,11 +1068,15 @@ def migrate_db():
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("knowledgebase", "pipeline_id", CharField(max_length=32, null=True, help_text="default parser ID", index=True)))
|
||||
migrate(migrator.add_column("knowledgebase", "pipeline_id", CharField(max_length=32, null=True, help_text="Pipeline ID", index=True)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("document", "pipeline_id", CharField(max_length=32, null=True, help_text="default parser ID", index=True)))
|
||||
migrate(migrator.add_column("document", "pipeline_id", CharField(max_length=32, null=True, help_text="Pipeline ID", index=True)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("knowledgebase", "graphrag_task_id", CharField(max_length=32, null=True, help_text="Gragh RAG task ID", index=True)))
|
||||
except Exception:
|
||||
pass
|
||||
logging.disable(logging.NOTSET)
|
||||
|
||||
@ -121,12 +121,20 @@ class DocumentService(CommonService):
|
||||
orderby, desc, keywords, run_status, types, suffix):
|
||||
fields = cls.get_cls_model_fields()
|
||||
if keywords:
|
||||
docs = cls.model.select(*fields).join(File2Document, on=(File2Document.document_id == cls.model.id)).join(File, on=(File.id == File2Document.file_id)).where(
|
||||
(cls.model.kb_id == kb_id),
|
||||
(fn.LOWER(cls.model.name).contains(keywords.lower()))
|
||||
)
|
||||
docs = cls.model.select(*[*fields, UserCanvas.title])\
|
||||
.join(File2Document, on=(File2Document.document_id == cls.model.id))\
|
||||
.join(File, on=(File.id == File2Document.file_id))\
|
||||
.join(UserCanvas, on=(cls.model.pipeline_id == UserCanvas.id), join_type=JOIN.LEFT_OUTER)\
|
||||
.where(
|
||||
(cls.model.kb_id == kb_id),
|
||||
(fn.LOWER(cls.model.name).contains(keywords.lower()))
|
||||
)
|
||||
else:
|
||||
docs = cls.model.select(*fields).join(File2Document, on=(File2Document.document_id == cls.model.id)).join(File, on=(File.id == File2Document.file_id)).where(cls.model.kb_id == kb_id)
|
||||
docs = cls.model.select(*[*fields, UserCanvas.title])\
|
||||
.join(File2Document, on=(File2Document.document_id == cls.model.id))\
|
||||
.join(UserCanvas, on=(cls.model.pipeline_id == UserCanvas.id), join_type=JOIN.LEFT_OUTER)\
|
||||
.join(File, on=(File.id == File2Document.file_id))\
|
||||
.where(cls.model.kb_id == kb_id)
|
||||
|
||||
if run_status:
|
||||
docs = docs.where(cls.model.run.in_(run_status))
|
||||
@ -507,6 +515,9 @@ class DocumentService(CommonService):
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_doc_id_by_doc_name(cls, doc_name):
|
||||
"""
|
||||
highly rely on the strict deduplication guarantee from Document
|
||||
"""
|
||||
fields = [cls.model.id]
|
||||
doc_id = cls.model.select(*fields) \
|
||||
.where(cls.model.name == doc_name)
|
||||
@ -656,6 +667,7 @@ class DocumentService(CommonService):
|
||||
queue_raptor_o_graphrag_tasks(d, "graphrag", priority)
|
||||
prg = 0.98 * len(tsks) / (len(tsks) + 1)
|
||||
else:
|
||||
prg = 1
|
||||
status = TaskStatus.DONE.value
|
||||
|
||||
msg = "\n".join(sorted(msg))
|
||||
@ -741,7 +753,11 @@ class DocumentService(CommonService):
|
||||
"cancelled": int(cancelled),
|
||||
}
|
||||
|
||||
def queue_raptor_o_graphrag_tasks(doc, ty, priority):
|
||||
def queue_raptor_o_graphrag_tasks(doc, ty, priority, fake_doc_id="", doc_ids=[]):
|
||||
"""
|
||||
You can provide a fake_doc_id to bypass the restriction of tasks at the knowledgebase level.
|
||||
Optionally, specify a list of doc_ids to determine which documents participate in the task.
|
||||
"""
|
||||
chunking_config = DocumentService.get_chunking_config(doc["id"])
|
||||
hasher = xxhash.xxh64()
|
||||
for field in sorted(chunking_config.keys()):
|
||||
@ -751,7 +767,7 @@ def queue_raptor_o_graphrag_tasks(doc, ty, priority):
|
||||
nonlocal doc
|
||||
return {
|
||||
"id": get_uuid(),
|
||||
"doc_id": doc["id"],
|
||||
"doc_id": fake_doc_id if fake_doc_id else doc["id"],
|
||||
"from_page": 100000000,
|
||||
"to_page": 100000000,
|
||||
"task_type": ty,
|
||||
@ -764,7 +780,11 @@ def queue_raptor_o_graphrag_tasks(doc, ty, priority):
|
||||
hasher.update(ty.encode("utf-8"))
|
||||
task["digest"] = hasher.hexdigest()
|
||||
bulk_insert_into_db(Task, [task], True)
|
||||
|
||||
if ty == "graphrag":
|
||||
task["doc_ids"] = doc_ids
|
||||
assert REDIS_CONN.queue_product(get_svr_queue_name(priority), message=task), "Can't access Redis. Please check the Redis' status."
|
||||
return task["id"]
|
||||
|
||||
|
||||
def get_queue_length(priority):
|
||||
|
||||
@ -225,6 +225,7 @@ class KnowledgebaseService(CommonService):
|
||||
cls.model.token_num,
|
||||
cls.model.chunk_num,
|
||||
cls.model.parser_id,
|
||||
cls.model.pipeline_id,
|
||||
cls.model.parser_config,
|
||||
cls.model.pagerank,
|
||||
cls.model.create_time,
|
||||
|
||||
@ -14,12 +14,13 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from peewee import fn
|
||||
|
||||
from api.db import VALID_PIPELINE_TASK_TYPES
|
||||
from api.db.db_models import DB, PipelineOperationLog
|
||||
from api.db.db_models import DB, PipelineOperationLog, Document
|
||||
from api.db.services.canvas_service import UserCanvasService
|
||||
from api.db.services.common_service import CommonService
|
||||
from api.db.services.document_service import DocumentService
|
||||
@ -31,7 +32,7 @@ class PipelineOperationLogService(CommonService):
|
||||
model = PipelineOperationLog
|
||||
|
||||
@classmethod
|
||||
def get_cls_model_fields(cls):
|
||||
def get_file_logs_fields(cls):
|
||||
return [
|
||||
cls.model.id,
|
||||
cls.model.document_id,
|
||||
@ -59,24 +60,47 @@ class PipelineOperationLogService(CommonService):
|
||||
cls.model.update_date,
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def get_dataset_logs_fields(cls):
|
||||
return [
|
||||
cls.model.id,
|
||||
cls.model.tenant_id,
|
||||
cls.model.kb_id,
|
||||
cls.model.progress,
|
||||
cls.model.progress_msg,
|
||||
cls.model.process_begin_at,
|
||||
cls.model.process_duration,
|
||||
cls.model.task_type,
|
||||
cls.model.operation_status,
|
||||
cls.model.avatar,
|
||||
cls.model.status,
|
||||
cls.model.create_time,
|
||||
cls.model.create_date,
|
||||
cls.model.update_time,
|
||||
cls.model.update_date,
|
||||
]
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def create(cls, document_id, pipeline_id, task_type):
|
||||
def create(cls, document_id, pipeline_id, task_type, fake_document_ids=[]):
|
||||
from rag.flow.pipeline import Pipeline
|
||||
|
||||
tenant_id = ""
|
||||
title = ""
|
||||
avatar = ""
|
||||
dsl = ""
|
||||
operation_status = ""
|
||||
referred_document_id = document_id
|
||||
|
||||
ok, document = DocumentService.get_by_id(document_id)
|
||||
if referred_document_id == "x" and fake_document_ids:
|
||||
referred_document_id = fake_document_ids[0]
|
||||
ok, document = DocumentService.get_by_id(referred_document_id)
|
||||
if not ok:
|
||||
raise RuntimeError(f"Document {document_id} not found")
|
||||
logging.warning(f"Document for referred_document_id {referred_document_id} not found")
|
||||
return
|
||||
DocumentService.update_progress_immediately([document.to_dict()])
|
||||
ok, document = DocumentService.get_by_id(document_id)
|
||||
ok, document = DocumentService.get_by_id(referred_document_id)
|
||||
if not ok:
|
||||
raise RuntimeError(f"Document {document_id} not found")
|
||||
logging.warning(f"Document for referred_document_id {referred_document_id} not found")
|
||||
return
|
||||
if document.progress not in [1, -1]:
|
||||
return
|
||||
operation_status = document.run
|
||||
|
||||
if pipeline_id:
|
||||
@ -84,7 +108,7 @@ class PipelineOperationLogService(CommonService):
|
||||
if not ok:
|
||||
raise RuntimeError(f"Pipeline {pipeline_id} not found")
|
||||
|
||||
pipeline = Pipeline(dsl=json.dumps(user_pipeline.dsl), tenant_id=user_pipeline.user_id, doc_id=document_id, task_id="", flow_id=pipeline_id)
|
||||
pipeline = Pipeline(dsl=json.dumps(user_pipeline.dsl), tenant_id=user_pipeline.user_id, doc_id=referred_document_id, task_id="", flow_id=pipeline_id)
|
||||
|
||||
tenant_id = user_pipeline.user_id
|
||||
title = user_pipeline.title
|
||||
@ -93,7 +117,7 @@ class PipelineOperationLogService(CommonService):
|
||||
else:
|
||||
ok, kb_info = KnowledgebaseService.get_by_id(document.kb_id)
|
||||
if not ok:
|
||||
raise RuntimeError(f"Cannot find knowledge base {document.kb_id} for document {document_id}")
|
||||
raise RuntimeError(f"Cannot find knowledge base {document.kb_id} for referred_document {referred_document_id}")
|
||||
|
||||
tenant_id = kb_info.tenant_id
|
||||
title = document.name
|
||||
@ -104,7 +128,7 @@ class PipelineOperationLogService(CommonService):
|
||||
|
||||
log = dict(
|
||||
id=get_uuid(),
|
||||
document_id=document_id,
|
||||
document_id=document_id, # "x" or real document_id
|
||||
tenant_id=tenant_id,
|
||||
kb_id=document.kb_id,
|
||||
pipeline_id=pipeline_id,
|
||||
@ -132,18 +156,20 @@ class PipelineOperationLogService(CommonService):
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def record_pipeline_operation(cls, document_id, pipeline_id, task_type):
|
||||
return cls.create(document_id=document_id, pipeline_id=pipeline_id, task_type=task_type)
|
||||
def record_pipeline_operation(cls, document_id, pipeline_id, task_type, fake_document_ids=[]):
|
||||
return cls.create(document_id=document_id, pipeline_id=pipeline_id, task_type=task_type, fake_document_ids=fake_document_ids)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, keywords, operation_status, types, suffix):
|
||||
fields = cls.get_cls_model_fields()
|
||||
def get_file_logs_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, keywords, operation_status, types, suffix):
|
||||
fields = cls.get_file_logs_fields()
|
||||
if keywords:
|
||||
logs = cls.model.select(*fields).where((cls.model.kb_id == kb_id), (fn.LOWER(cls.model.document_name).contains(keywords.lower())))
|
||||
else:
|
||||
logs = cls.model.select(*fields).where(cls.model.kb_id == kb_id)
|
||||
|
||||
logs = logs.where(cls.model.document_id != "x")
|
||||
|
||||
if operation_status:
|
||||
logs = logs.where(cls.model.operation_status.in_(operation_status))
|
||||
if types:
|
||||
@ -161,3 +187,38 @@ class PipelineOperationLogService(CommonService):
|
||||
logs = logs.paginate(page_number, items_per_page)
|
||||
|
||||
return list(logs.dicts()), count
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_documents_info(cls, id):
|
||||
fields = [
|
||||
Document.id,
|
||||
Document.name,
|
||||
Document.progress
|
||||
]
|
||||
return cls.model.select(*fields).join(Document, on=(cls.model.document_id == Document.id)).where(
|
||||
cls.model.id == id,
|
||||
Document.progress > 0,
|
||||
Document.progress < 1
|
||||
).dicts()
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_dataset_logs_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, operation_status):
|
||||
fields = cls.get_dataset_logs_fields()
|
||||
logs = cls.model.select(*fields).where((cls.model.kb_id == kb_id), (cls.model.document_id == "x"))
|
||||
|
||||
if operation_status:
|
||||
logs = logs.where(cls.model.operation_status.in_(operation_status))
|
||||
|
||||
count = logs.count()
|
||||
if desc:
|
||||
logs = logs.order_by(cls.model.getter_by(orderby).desc())
|
||||
else:
|
||||
logs = logs.order_by(cls.model.getter_by(orderby).asc())
|
||||
|
||||
if page_number and items_per_page:
|
||||
logs = logs.paginate(page_number, items_per_page)
|
||||
|
||||
return list(logs.dicts()), count
|
||||
|
||||
|
||||
@ -35,6 +35,7 @@ from rag.utils.redis_conn import REDIS_CONN
|
||||
from api import settings
|
||||
from rag.nlp import search
|
||||
|
||||
CANVAS_DEBUG_DOC_ID = "dataflow_x"
|
||||
|
||||
def trim_header_by_lines(text: str, max_length) -> str:
|
||||
# Trim header text to maximum length while preserving line breaks
|
||||
@ -70,7 +71,7 @@ class TaskService(CommonService):
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_task(cls, task_id):
|
||||
def get_task(cls, task_id, doc_ids=[]):
|
||||
"""Retrieve detailed task information by task ID.
|
||||
|
||||
This method fetches comprehensive task details including associated document,
|
||||
@ -84,6 +85,10 @@ class TaskService(CommonService):
|
||||
dict: Task details dictionary containing all task information and related metadata.
|
||||
Returns None if task is not found or has exceeded retry limit.
|
||||
"""
|
||||
doc_id = cls.model.doc_id
|
||||
if doc_id == CANVAS_DEBUG_DOC_ID and doc_ids:
|
||||
doc_id = doc_ids[0]
|
||||
|
||||
fields = [
|
||||
cls.model.id,
|
||||
cls.model.doc_id,
|
||||
@ -109,7 +114,7 @@ class TaskService(CommonService):
|
||||
]
|
||||
docs = (
|
||||
cls.model.select(*fields)
|
||||
.join(Document, on=(cls.model.doc_id == Document.id))
|
||||
.join(Document, on=(doc_id == Document.id))
|
||||
.join(Knowledgebase, on=(Document.kb_id == Knowledgebase.id))
|
||||
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))
|
||||
.where(cls.model.id == task_id)
|
||||
@ -472,14 +477,14 @@ def has_canceled(task_id):
|
||||
return False
|
||||
|
||||
|
||||
def queue_dataflow(tenant_id:str, flow_id:str, task_id:str, doc_id:str="x", file:dict=None, priority: int=0) -> tuple[bool, str]:
|
||||
def queue_dataflow(tenant_id:str, flow_id:str, task_id:str, doc_id:str=CANVAS_DEBUG_DOC_ID, file:dict=None, priority: int=0, rerun:bool=False) -> tuple[bool, str]:
|
||||
|
||||
task = dict(
|
||||
id=task_id,
|
||||
doc_id=doc_id,
|
||||
from_page=0,
|
||||
to_page=100000000,
|
||||
task_type="dataflow",
|
||||
task_type="dataflow" if not rerun else "dataflow_rerun",
|
||||
priority=priority,
|
||||
)
|
||||
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import base64
|
||||
import logging
|
||||
from functools import partial
|
||||
from io import BytesIO
|
||||
|
||||
@ -8,7 +9,7 @@ test_image_base64 = "iVBORw0KGgoAAAANSUhEUgAAAGQAAABkCAIAAAD/gAIDAAAA6ElEQVR4nO3
|
||||
test_image = base64.b64decode(test_image_base64)
|
||||
|
||||
|
||||
async def image2id(d: dict, storage_put_func: partial, objname:str, bucket:str="IMAGETEMPS"):
|
||||
async def image2id(d: dict, storage_put_func: partial, objname:str, bucket:str="imagetemps"):
|
||||
import logging
|
||||
from io import BytesIO
|
||||
import trio
|
||||
@ -46,7 +47,10 @@ def id2image(image_id:str|None, storage_get_func: partial):
|
||||
if len(arr) != 2:
|
||||
return
|
||||
bkt, nm = image_id.split("-")
|
||||
blob = storage_get_func(bucket=bkt, filename=nm)
|
||||
if not blob:
|
||||
return
|
||||
return Image.open(BytesIO(blob))
|
||||
try:
|
||||
blob = storage_get_func(bucket=bkt, filename=nm)
|
||||
if not blob:
|
||||
return
|
||||
return Image.open(BytesIO(blob))
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
|
||||
@ -21,6 +21,7 @@ import networkx as nx
|
||||
import trio
|
||||
|
||||
from api import settings
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.utils import get_uuid
|
||||
from api.utils.api_utils import timeout
|
||||
from graphrag.entity_resolution import EntityResolution
|
||||
@ -54,7 +55,7 @@ async def run_graphrag(
|
||||
start = trio.current_time()
|
||||
tenant_id, kb_id, doc_id = row["tenant_id"], str(row["kb_id"]), row["doc_id"]
|
||||
chunks = []
|
||||
for d in settings.retrievaler.chunk_list(doc_id, tenant_id, [kb_id], fields=["content_with_weight", "doc_id"]):
|
||||
for d in settings.retrievaler.chunk_list(doc_id, tenant_id, [kb_id], fields=["content_with_weight", "doc_id"], sort_by_position=True):
|
||||
chunks.append(d["content_with_weight"])
|
||||
|
||||
with trio.fail_after(max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000):
|
||||
@ -125,6 +126,212 @@ async def run_graphrag(
|
||||
return
|
||||
|
||||
|
||||
async def run_graphrag_for_kb(
|
||||
row: dict,
|
||||
doc_ids: list[str],
|
||||
language: str,
|
||||
kb_parser_config: dict,
|
||||
chat_model,
|
||||
embedding_model,
|
||||
callback,
|
||||
*,
|
||||
with_resolution: bool = True,
|
||||
with_community: bool = True,
|
||||
max_parallel_docs: int = 4,
|
||||
) -> dict:
|
||||
tenant_id, kb_id = row["tenant_id"], row["kb_id"]
|
||||
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
|
||||
start = trio.current_time()
|
||||
fields_for_chunks = ["content_with_weight", "doc_id"]
|
||||
|
||||
if not doc_ids:
|
||||
logging.info(f"Fetching all docs for {kb_id}")
|
||||
docs, _ = DocumentService.get_by_kb_id(
|
||||
kb_id=kb_id,
|
||||
page_number=0,
|
||||
items_per_page=0,
|
||||
orderby="create_time",
|
||||
desc=False,
|
||||
keywords="",
|
||||
run_status=[],
|
||||
types=[],
|
||||
suffix=[],
|
||||
)
|
||||
doc_ids = [doc["id"] for doc in docs]
|
||||
|
||||
doc_ids = list(dict.fromkeys(doc_ids))
|
||||
if not doc_ids:
|
||||
callback(msg=f"[GraphRAG] kb:{kb_id} has no processable doc_id.")
|
||||
return {"ok_docs": [], "failed_docs": [], "total_docs": 0, "total_chunks": 0, "seconds": 0.0}
|
||||
|
||||
def load_doc_chunks(doc_id: str) -> list[str]:
|
||||
from rag.utils import num_tokens_from_string
|
||||
|
||||
chunks = []
|
||||
current_chunk = ""
|
||||
|
||||
for d in settings.retrievaler.chunk_list(
|
||||
doc_id,
|
||||
tenant_id,
|
||||
[kb_id],
|
||||
fields=fields_for_chunks,
|
||||
sort_by_position=True,
|
||||
):
|
||||
content = d["content_with_weight"]
|
||||
if num_tokens_from_string(current_chunk + content) < 1024:
|
||||
current_chunk += content
|
||||
else:
|
||||
if current_chunk:
|
||||
chunks.append(current_chunk)
|
||||
current_chunk = content
|
||||
|
||||
if current_chunk:
|
||||
chunks.append(current_chunk)
|
||||
|
||||
return chunks
|
||||
|
||||
all_doc_chunks: dict[str, list[str]] = {}
|
||||
total_chunks = 0
|
||||
for doc_id in doc_ids:
|
||||
chunks = load_doc_chunks(doc_id)
|
||||
all_doc_chunks[doc_id] = chunks
|
||||
total_chunks += len(chunks)
|
||||
|
||||
if total_chunks == 0:
|
||||
callback(msg=f"[GraphRAG] kb:{kb_id} has no available chunks in all documents, skip.")
|
||||
return {"ok_docs": [], "failed_docs": doc_ids, "total_docs": len(doc_ids), "total_chunks": 0, "seconds": 0.0}
|
||||
|
||||
semaphore = trio.Semaphore(max_parallel_docs)
|
||||
|
||||
subgraphs: dict[str, object] = {}
|
||||
failed_docs: list[tuple[str, str]] = [] # (doc_id, error)
|
||||
|
||||
async def build_one(doc_id: str):
|
||||
chunks = all_doc_chunks.get(doc_id, [])
|
||||
if not chunks:
|
||||
callback(msg=f"[GraphRAG] doc:{doc_id} has no available chunks, skip generation.")
|
||||
return
|
||||
|
||||
kg_extractor = LightKGExt if ("method" not in kb_parser_config.get("graphrag", {}) or kb_parser_config["graphrag"]["method"] != "general") else GeneralKGExt
|
||||
|
||||
deadline = max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000
|
||||
|
||||
async with semaphore:
|
||||
try:
|
||||
msg = f"[GraphRAG] build_subgraph doc:{doc_id}"
|
||||
callback(msg=f"{msg} start (chunks={len(chunks)}, timeout={deadline}s)")
|
||||
with trio.fail_after(deadline):
|
||||
sg = await generate_subgraph(
|
||||
kg_extractor,
|
||||
tenant_id,
|
||||
kb_id,
|
||||
doc_id,
|
||||
chunks,
|
||||
language,
|
||||
kb_parser_config.get("graphrag", {}).get("entity_types", []),
|
||||
chat_model,
|
||||
embedding_model,
|
||||
callback,
|
||||
)
|
||||
if sg:
|
||||
subgraphs[doc_id] = sg
|
||||
callback(msg=f"{msg} done")
|
||||
else:
|
||||
failed_docs.append((doc_id, "subgraph is empty"))
|
||||
callback(msg=f"{msg} empty")
|
||||
except Exception as e:
|
||||
failed_docs.append((doc_id, repr(e)))
|
||||
callback(msg=f"[GraphRAG] build_subgraph doc:{doc_id} FAILED: {e!r}")
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
for doc_id in doc_ids:
|
||||
nursery.start_soon(build_one, doc_id)
|
||||
|
||||
ok_docs = [d for d in doc_ids if d in subgraphs]
|
||||
if not ok_docs:
|
||||
callback(msg=f"[GraphRAG] kb:{kb_id} no subgraphs generated successfully, end.")
|
||||
now = trio.current_time()
|
||||
return {"ok_docs": [], "failed_docs": failed_docs, "total_docs": len(doc_ids), "total_chunks": total_chunks, "seconds": now - start}
|
||||
|
||||
kb_lock = RedisDistributedLock(f"graphrag_task_{kb_id}", lock_value="batch_merge", timeout=1200)
|
||||
await kb_lock.spin_acquire()
|
||||
callback(msg=f"[GraphRAG] kb:{kb_id} merge lock acquired")
|
||||
|
||||
try:
|
||||
union_nodes: set = set()
|
||||
final_graph = None
|
||||
|
||||
for doc_id in ok_docs:
|
||||
sg = subgraphs[doc_id]
|
||||
union_nodes.update(set(sg.nodes()))
|
||||
|
||||
new_graph = await merge_subgraph(
|
||||
tenant_id,
|
||||
kb_id,
|
||||
doc_id,
|
||||
sg,
|
||||
embedding_model,
|
||||
callback,
|
||||
)
|
||||
if new_graph is not None:
|
||||
final_graph = new_graph
|
||||
|
||||
if final_graph is None:
|
||||
callback(msg=f"[GraphRAG] kb:{kb_id} merge finished (no in-memory graph returned).")
|
||||
else:
|
||||
callback(msg=f"[GraphRAG] kb:{kb_id} merge finished, graph ready.")
|
||||
finally:
|
||||
kb_lock.release()
|
||||
|
||||
if not with_resolution and not with_community:
|
||||
now = trio.current_time()
|
||||
callback(msg=f"[GraphRAG] KB merge only done in {now - start:.2f}s. ok={len(ok_docs)} / total={len(doc_ids)}")
|
||||
return {"ok_docs": ok_docs, "failed_docs": failed_docs, "total_docs": len(doc_ids), "total_chunks": total_chunks, "seconds": now - start}
|
||||
|
||||
await kb_lock.spin_acquire()
|
||||
callback(msg=f"[GraphRAG] kb:{kb_id} post-merge lock acquired for resolution/community")
|
||||
|
||||
try:
|
||||
subgraph_nodes = set()
|
||||
for sg in subgraphs.values():
|
||||
subgraph_nodes.update(set(sg.nodes()))
|
||||
|
||||
if with_resolution:
|
||||
await resolve_entities(
|
||||
final_graph,
|
||||
subgraph_nodes,
|
||||
tenant_id,
|
||||
kb_id,
|
||||
None,
|
||||
chat_model,
|
||||
embedding_model,
|
||||
callback,
|
||||
)
|
||||
|
||||
if with_community:
|
||||
await extract_community(
|
||||
final_graph,
|
||||
tenant_id,
|
||||
kb_id,
|
||||
None,
|
||||
chat_model,
|
||||
embedding_model,
|
||||
callback,
|
||||
)
|
||||
finally:
|
||||
kb_lock.release()
|
||||
|
||||
now = trio.current_time()
|
||||
callback(msg=f"[GraphRAG] GraphRAG for KB {kb_id} done in {now - start:.2f} seconds. ok={len(ok_docs)} failed={len(failed_docs)} total_docs={len(doc_ids)} total_chunks={total_chunks}")
|
||||
return {
|
||||
"ok_docs": ok_docs,
|
||||
"failed_docs": failed_docs, # [(doc_id, error), ...]
|
||||
"total_docs": len(doc_ids),
|
||||
"total_chunks": total_chunks,
|
||||
"seconds": now - start,
|
||||
}
|
||||
|
||||
|
||||
async def generate_subgraph(
|
||||
extractor: Extractor,
|
||||
tenant_id: str,
|
||||
|
||||
@ -13,13 +13,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
import trio
|
||||
|
||||
from agent.component.base import ComponentBase, ComponentParamBase
|
||||
from api.utils.api_utils import timeout
|
||||
|
||||
@ -43,17 +42,17 @@ class ProcessBase(ComponentBase):
|
||||
self.set_output("_created_time", time.perf_counter())
|
||||
for k, v in kwargs.items():
|
||||
self.set_output(k, v)
|
||||
#try:
|
||||
with trio.fail_after(self._param.timeout):
|
||||
await self._invoke(**kwargs)
|
||||
self.callback(1, "Done")
|
||||
#except Exception as e:
|
||||
# if self.get_exception_default_value():
|
||||
# self.set_exception_default_value()
|
||||
# else:
|
||||
# self.set_output("_ERROR", str(e))
|
||||
# logging.exception(e)
|
||||
# self.callback(-1, str(e))
|
||||
try:
|
||||
with trio.fail_after(self._param.timeout):
|
||||
await self._invoke(**kwargs)
|
||||
self.callback(1, "Done")
|
||||
except Exception as e:
|
||||
if self.get_exception_default_value():
|
||||
self.set_exception_default_value()
|
||||
else:
|
||||
self.set_output("_ERROR", str(e))
|
||||
logging.exception(e)
|
||||
self.callback(-1, str(e))
|
||||
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
|
||||
return self.output()
|
||||
|
||||
|
||||
15
rag/flow/extractor/__init__.py
Normal file
15
rag/flow/extractor/__init__.py
Normal file
@ -0,0 +1,15 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
59
rag/flow/extractor/extractor.py
Normal file
59
rag/flow/extractor/extractor.py
Normal file
@ -0,0 +1,59 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import random
|
||||
from agent.component.llm import LLMParam, LLM
|
||||
|
||||
|
||||
class ExtractorParam(LLMParam):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.field_name = ""
|
||||
|
||||
def check(self):
|
||||
super().check()
|
||||
self.check_empty(self.field_name, "Result Destination")
|
||||
|
||||
|
||||
class Extractor(LLM):
|
||||
component_name = "Extractor"
|
||||
|
||||
async def _invoke(self, **kwargs):
|
||||
self.callback(random.randint(1, 5) / 100.0, "Start to generate.")
|
||||
inputs = self.get_input_elements()
|
||||
chunks = []
|
||||
chunks_key = ""
|
||||
args = {}
|
||||
for k, v in inputs.items():
|
||||
args[k] = v["value"]
|
||||
if isinstance(args[k], list):
|
||||
chunks = args[k]
|
||||
chunks_key = k
|
||||
|
||||
if chunks:
|
||||
prog = 0
|
||||
for i, ck in enumerate(chunks):
|
||||
args[chunks_key] = ck["text"]
|
||||
msg, sys_prompt = self._sys_prompt_and_msg([], args)
|
||||
msg.insert(0, {"role": "system", "content": sys_prompt})
|
||||
ck[self._param.field_name] = self._generate(msg)
|
||||
prog += 1./len(chunks)
|
||||
self.callback(prog, f"{i+1} / {len(chunks)}")
|
||||
self.set_output("chunks", chunks)
|
||||
else:
|
||||
msg, sys_prompt = self._sys_prompt_and_msg([], args)
|
||||
msg.insert(0, {"role": "system", "content": sys_prompt})
|
||||
self.set_output("chunks", [{self._param.field_name: self._generate(msg)}])
|
||||
|
||||
|
||||
38
rag/flow/extractor/schema.py
Normal file
38
rag/flow/extractor/schema.py
Normal file
@ -0,0 +1,38 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class ExtractorFromUpstream(BaseModel):
|
||||
created_time: float | None = Field(default=None, alias="_created_time")
|
||||
elapsed_time: float | None = Field(default=None, alias="_elapsed_time")
|
||||
|
||||
name: str
|
||||
file: dict | None = Field(default=None)
|
||||
chunks: list[dict[str, Any]] | None = Field(default=None)
|
||||
|
||||
output_format: Literal["json", "markdown", "text", "html"] | None = Field(default=None)
|
||||
|
||||
json_result: list[dict[str, Any]] | None = Field(default=None, alias="json")
|
||||
markdown_result: str | None = Field(default=None, alias="markdown")
|
||||
text_result: str | None = Field(default=None, alias="text")
|
||||
html_result: list[str] | None = Field(default=None, alias="html")
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
|
||||
# def to_dict(self, *, exclude_none: bool = True) -> dict:
|
||||
# return self.model_dump(by_alias=True, exclude_none=exclude_none)
|
||||
@ -17,15 +17,11 @@ import datetime
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
import time
|
||||
from timeit import default_timer as timer
|
||||
import trio
|
||||
|
||||
from agent.canvas import Graph
|
||||
from api.db import PipelineTaskType
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.task_service import has_canceled
|
||||
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
|
||||
from api.db.services.task_service import has_canceled, TaskService, CANVAS_DEBUG_DOC_ID
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
|
||||
|
||||
@ -34,8 +30,8 @@ class Pipeline(Graph):
|
||||
if isinstance(dsl, dict):
|
||||
dsl = json.dumps(dsl, ensure_ascii=False)
|
||||
super().__init__(dsl, tenant_id, task_id)
|
||||
if self._doc_id == "x":
|
||||
self._doc_id = None
|
||||
if doc_id == CANVAS_DEBUG_DOC_ID:
|
||||
doc_id = None
|
||||
self._doc_id = doc_id
|
||||
self._flow_id = flow_id
|
||||
self._kb_id = None
|
||||
@ -80,7 +76,7 @@ class Pipeline(Graph):
|
||||
}
|
||||
]
|
||||
REDIS_CONN.set_obj(log_key, obj, 60 * 30)
|
||||
if self._doc_id:
|
||||
if self._doc_id and self.task_id:
|
||||
percentage = 1.0 / len(self.components.items())
|
||||
msg = ""
|
||||
finished = 0.0
|
||||
@ -96,7 +92,7 @@ class Pipeline(Graph):
|
||||
if finished < 0:
|
||||
break
|
||||
finished += o["trace"][-1]["progress"] * percentage
|
||||
DocumentService.update_by_id(self._doc_id, {"progress": finished, "progress_msg": msg})
|
||||
TaskService.update_progress(self.task_id, {"progress": finished, "progress_msg": msg})
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
|
||||
@ -113,34 +109,32 @@ class Pipeline(Graph):
|
||||
logging.exception(e)
|
||||
return []
|
||||
|
||||
def reset(self):
|
||||
super().reset()
|
||||
|
||||
async def run(self, **kwargs):
|
||||
log_key = f"{self._flow_id}-{self.task_id}-logs"
|
||||
try:
|
||||
REDIS_CONN.set_obj(log_key, [], 60 * 10)
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
|
||||
async def run(self, **kwargs):
|
||||
st = time.perf_counter()
|
||||
self.error = ""
|
||||
if not self.path:
|
||||
self.path.append("File")
|
||||
|
||||
if self._doc_id:
|
||||
DocumentService.update_by_id(
|
||||
self._doc_id, {"progress": random.randint(0, 5) / 100.0, "progress_msg": "Start the pipeline...", "process_begin_at": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
|
||||
)
|
||||
|
||||
self.error = ""
|
||||
idx = len(self.path) - 1
|
||||
if idx == 0:
|
||||
cpn_obj = self.get_component_obj(self.path[0])
|
||||
await cpn_obj.invoke(**kwargs)
|
||||
if cpn_obj.error():
|
||||
self.error = "[ERROR]" + cpn_obj.error()
|
||||
else:
|
||||
idx += 1
|
||||
self.path.extend(cpn_obj.get_downstream())
|
||||
self.callback(cpn_obj.component_name, -1, self.error)
|
||||
|
||||
if self._doc_id:
|
||||
TaskService.update_progress(self.task_id, {
|
||||
"progress": random.randint(0, 5) / 100.0,
|
||||
"progress_msg": "Start the pipeline...",
|
||||
"begin_at": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")})
|
||||
|
||||
idx = len(self.path) - 1
|
||||
cpn_obj = self.get_component_obj(self.path[idx])
|
||||
idx += 1
|
||||
self.path.extend(cpn_obj.get_downstream())
|
||||
|
||||
while idx < len(self.path) and not self.error:
|
||||
last_cpn = self.get_component_obj(self.path[idx - 1])
|
||||
@ -152,23 +146,21 @@ class Pipeline(Graph):
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(invoke)
|
||||
|
||||
if cpn_obj.error():
|
||||
self.error = "[ERROR]" + cpn_obj.error()
|
||||
self.callback(cpn_obj.component_name, -1, self.error)
|
||||
self.callback(cpn_obj._id, -1, self.error)
|
||||
break
|
||||
idx += 1
|
||||
self.path.extend(cpn_obj.get_downstream())
|
||||
|
||||
self.callback("END", 1, json.dumps(self.get_component_obj(self.path[-1]).output(), ensure_ascii=False))
|
||||
self.callback("END", 1 if not self.error else -1, json.dumps(self.get_component_obj(self.path[-1]).output(), ensure_ascii=False))
|
||||
|
||||
if self._doc_id:
|
||||
DocumentService.update_by_id(
|
||||
self._doc_id,
|
||||
{
|
||||
"progress": 1 if not self.error else -1,
|
||||
"progress_msg": "Pipeline finished...\n" + self.error,
|
||||
"process_duration": time.perf_counter() - st,
|
||||
},
|
||||
)
|
||||
if not self.error:
|
||||
return self.get_component_obj(self.path[-1]).output()
|
||||
|
||||
PipelineOperationLogService.create(document_id=self._doc_id, pipeline_id=self._flow_id, task_type=PipelineTaskType.PARSE)
|
||||
TaskService.update_progress(self.task_id, {
|
||||
"progress": -1,
|
||||
"progress_msg": f"[ERROR]: {self.error}"})
|
||||
|
||||
return {}
|
||||
|
||||
@ -99,7 +99,7 @@ class Splitter(ProcessBase):
|
||||
{
|
||||
"text": RAGFlowPdfParser.remove_tag(c),
|
||||
"image": img,
|
||||
"positions": RAGFlowPdfParser.extract_positions(c),
|
||||
"positions": [[pos[0][-1]+1, *pos[1:]] for pos in RAGFlowPdfParser.extract_positions(c)],
|
||||
}
|
||||
for c, img in zip(chunks, images)
|
||||
]
|
||||
|
||||
@ -120,8 +120,12 @@ class Tokenizer(ProcessBase):
|
||||
ck["question_tks"] = rag_tokenizer.tokenize("\n".join(ck["questions"]))
|
||||
if ck.get("keywords"):
|
||||
ck["important_tks"] = rag_tokenizer.tokenize("\n".join(ck["keywords"]))
|
||||
ck["content_ltks"] = rag_tokenizer.tokenize(ck["text"])
|
||||
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
|
||||
if ck.get("summary"):
|
||||
ck["content_ltks"] = rag_tokenizer.tokenize(ck["summary"])
|
||||
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
|
||||
else:
|
||||
ck["content_ltks"] = rag_tokenizer.tokenize(ck["text"])
|
||||
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
|
||||
if i % 100 == 99:
|
||||
self.callback(i * 1.0 / len(chunks) / parts)
|
||||
|
||||
|
||||
@ -285,6 +285,7 @@ def tokenize_chunks(chunks, doc, eng, pdf_parser=None):
|
||||
res.append(d)
|
||||
return res
|
||||
|
||||
|
||||
def tokenize_chunks_with_images(chunks, doc, eng, images):
|
||||
res = []
|
||||
# wrap up as es documents
|
||||
@ -299,6 +300,7 @@ def tokenize_chunks_with_images(chunks, doc, eng, images):
|
||||
res.append(d)
|
||||
return res
|
||||
|
||||
|
||||
def tokenize_table(tbls, doc, eng, batch_size=10):
|
||||
res = []
|
||||
# add tables
|
||||
|
||||
@ -383,7 +383,7 @@ class Dealer:
|
||||
vector_column = f"q_{dim}_vec"
|
||||
zero_vector = [0.0] * dim
|
||||
sim_np = np.array(sim)
|
||||
filtered_count = (sim_np >= similarity_threshold).sum()
|
||||
filtered_count = (sim_np >= similarity_threshold).sum()
|
||||
ranks["total"] = int(filtered_count) # Convert from np.int64 to Python int otherwise JSON serializable error
|
||||
for i in idx:
|
||||
if sim[i] < similarity_threshold:
|
||||
@ -444,12 +444,27 @@ class Dealer:
|
||||
def chunk_list(self, doc_id: str, tenant_id: str,
|
||||
kb_ids: list[str], max_count=1024,
|
||||
offset=0,
|
||||
fields=["docnm_kwd", "content_with_weight", "img_id"]):
|
||||
fields=["docnm_kwd", "content_with_weight", "img_id"],
|
||||
sort_by_position: bool = False):
|
||||
condition = {"doc_id": doc_id}
|
||||
|
||||
fields_set = set(fields or [])
|
||||
if sort_by_position:
|
||||
for need in ("page_num_int", "position_int", "top_int"):
|
||||
if need not in fields_set:
|
||||
fields_set.add(need)
|
||||
fields = list(fields_set)
|
||||
|
||||
orderBy = OrderByExpr()
|
||||
if sort_by_position:
|
||||
orderBy.asc("page_num_int")
|
||||
orderBy.asc("position_int")
|
||||
orderBy.asc("top_int")
|
||||
|
||||
res = []
|
||||
bs = 128
|
||||
for p in range(offset, max_count, bs):
|
||||
es_res = self.dataStore.search(fields, [], condition, [], OrderByExpr(), p, bs, index_name(tenant_id),
|
||||
es_res = self.dataStore.search(fields, [], condition, [], orderBy, p, bs, index_name(tenant_id),
|
||||
kb_ids)
|
||||
dict_chunks = self.dataStore.getFields(es_res, fields)
|
||||
for id, doc in dict_chunks.items():
|
||||
|
||||
@ -21,11 +21,12 @@ import sys
|
||||
import threading
|
||||
import time
|
||||
from api.db.services.canvas_service import UserCanvasService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
|
||||
from api.utils.api_utils import timeout
|
||||
from api.utils.base64_image import image2id
|
||||
from api.utils.log_utils import init_root_logger, get_project_base_directory
|
||||
from graphrag.general.index import run_graphrag
|
||||
from graphrag.general.index import run_graphrag_for_kb
|
||||
from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache
|
||||
from rag.flow.pipeline import Pipeline
|
||||
from rag.prompts import keyword_extraction, question_proposal, content_tagging
|
||||
@ -49,7 +50,7 @@ from peewee import DoesNotExist
|
||||
from api.db import LLMType, ParserType, PipelineTaskType
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.task_service import TaskService, has_canceled
|
||||
from api.db.services.task_service import TaskService, has_canceled, CANVAS_DEBUG_DOC_ID
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api import settings
|
||||
from api.versions import get_ragflow_version
|
||||
@ -85,6 +86,12 @@ FACTORY = {
|
||||
ParserType.TAG.value: tag
|
||||
}
|
||||
|
||||
TASK_TYPE_TO_PIPELINE_TASK_TYPE = {
|
||||
"dataflow" : PipelineTaskType.PARSE,
|
||||
"raptor": PipelineTaskType.RAPTOR,
|
||||
"graphrag": PipelineTaskType.GRAPH_RAG,
|
||||
}
|
||||
|
||||
UNACKED_ITERATOR = None
|
||||
|
||||
CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
|
||||
@ -140,6 +147,7 @@ def start_tracemalloc_and_snapshot(signum, frame):
|
||||
max_rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
|
||||
logging.info(f"taken snapshot {snapshot_file}. max RSS={max_rss / 1000:.2f} MB, current memory usage: {current / 10**6:.2f} MB, Peak memory usage: {peak / 10**6:.2f} MB")
|
||||
|
||||
|
||||
# SIGUSR2 handler: stop tracemalloc
|
||||
def stop_tracemalloc(signum, frame):
|
||||
if tracemalloc.is_tracing():
|
||||
@ -148,6 +156,7 @@ def stop_tracemalloc(signum, frame):
|
||||
else:
|
||||
logging.info("tracemalloc not running")
|
||||
|
||||
|
||||
class TaskCanceledException(Exception):
|
||||
def __init__(self, msg):
|
||||
self.msg = msg
|
||||
@ -215,6 +224,10 @@ async def collect():
|
||||
canceled = False
|
||||
if msg.get("doc_id", "") == "x":
|
||||
task = msg
|
||||
if task["task_type"] == "graphrag" and msg.get("doc_ids", []):
|
||||
print(f"hack {msg['doc_ids']=}=",flush=True)
|
||||
task = TaskService.get_task(msg["id"], msg["doc_ids"])
|
||||
task["doc_ids"] = msg["doc_ids"]
|
||||
else:
|
||||
task = TaskService.get_task(msg["id"])
|
||||
|
||||
@ -461,11 +474,97 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
|
||||
|
||||
|
||||
async def run_dataflow(task: dict):
|
||||
task_start_ts = timer()
|
||||
dataflow_id = task["dataflow_id"]
|
||||
e, cvs = UserCanvasService.get_by_id(dataflow_id)
|
||||
pipeline = Pipeline(cvs.dsl, tenant_id=task["tenant_id"], doc_id=task["doc_id"], task_id=task["id"], flow_id=dataflow_id)
|
||||
pipeline.reset()
|
||||
await pipeline.run(file=task.get("file"))
|
||||
doc_id = task["doc_id"]
|
||||
task_id = task["id"]
|
||||
if task["task_type"] == "dataflow":
|
||||
e, cvs = UserCanvasService.get_by_id(dataflow_id)
|
||||
assert e, "User pipeline not found."
|
||||
dsl = cvs.dsl
|
||||
else:
|
||||
e, pipeline_log = PipelineOperationLogService.get_by_id(dataflow_id)
|
||||
assert e, "Pipeline log not found."
|
||||
dsl = pipeline_log.dsl
|
||||
pipeline = Pipeline(dsl, tenant_id=task["tenant_id"], doc_id=doc_id, task_id=task_id, flow_id=dataflow_id)
|
||||
chunks = await pipeline.run(file=task["file"]) if task.get("file") else pipeline.run()
|
||||
if doc_id == CANVAS_DEBUG_DOC_ID:
|
||||
return
|
||||
|
||||
if not chunks:
|
||||
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE)
|
||||
return
|
||||
|
||||
embedding_token_consumption = chunks.get("embedding_token_consumption", 0)
|
||||
if chunks.get("chunks"):
|
||||
chunks = chunks["chunks"]
|
||||
elif chunks.get("json"):
|
||||
chunks = chunks["json"]
|
||||
elif chunks.get("markdown"):
|
||||
chunks = [{"text": [chunks["markdown"]]}]
|
||||
elif chunks.get("text"):
|
||||
chunks = [{"text": [chunks["text"]]}]
|
||||
elif chunks.get("html"):
|
||||
chunks = [{"text": [chunks["html"]]}]
|
||||
|
||||
keys = [k for o in chunks for k in list(o.keys())]
|
||||
if not any([re.match(r"q_[0-9]+_vec", k) for k in keys]):
|
||||
set_progress(task_id, prog=0.82, msg="Start to embedding...")
|
||||
e, kb = KnowledgebaseService.get_by_id(task["kb_id"])
|
||||
embedding_id = kb.embd_id
|
||||
embedding_model = LLMBundle(task["tenant_id"], LLMType.EMBEDDING, llm_name=embedding_id)
|
||||
@timeout(60)
|
||||
def batch_encode(txts):
|
||||
nonlocal embedding_model
|
||||
return embedding_model.encode([truncate(c, embedding_model.max_length - 10) for c in txts])
|
||||
vects = np.array([])
|
||||
texts = [o.get("questions", o.get("summary", o["text"])) for o in chunks]
|
||||
delta = 0.20/(len(texts)//EMBEDDING_BATCH_SIZE)
|
||||
prog = 0.8
|
||||
for i in range(0, len(texts), EMBEDDING_BATCH_SIZE):
|
||||
async with embed_limiter:
|
||||
vts, c = await trio.to_thread.run_sync(lambda: batch_encode(texts[i : i + EMBEDDING_BATCH_SIZE]))
|
||||
if len(vects) == 0:
|
||||
vects = vts
|
||||
else:
|
||||
vects = np.concatenate((vects, vts), axis=0)
|
||||
embedding_token_consumption += c
|
||||
prog += delta
|
||||
set_progress(task_id, prog=prog, msg=f"{i+1} / {len(texts)//EMBEDDING_BATCH_SIZE}")
|
||||
|
||||
assert len(vects) == len(chunks)
|
||||
for i, ck in enumerate(chunks):
|
||||
v = vects[i].tolist()
|
||||
ck["q_%d_vec" % len(v)] = v
|
||||
|
||||
for ck in chunks:
|
||||
ck["doc_id"] = task["doc_id"]
|
||||
ck["kb_id"] = [str(task["kb_id"])]
|
||||
ck["docnm_kwd"] = task["name"]
|
||||
ck["create_time"] = str(datetime.now()).replace("T", " ")[:19]
|
||||
ck["create_timestamp_flt"] = datetime.now().timestamp()
|
||||
ck["id"] = xxhash.xxh64((ck["text"] + str(ck["doc_id"])).encode("utf-8")).hexdigest()
|
||||
if "questions" in ck:
|
||||
del ck["questions"]
|
||||
if "keywords" in ck:
|
||||
del ck["keywords"]
|
||||
if "summary" in ck:
|
||||
del ck["summary"]
|
||||
del ck["text"]
|
||||
|
||||
start_ts = timer()
|
||||
set_progress(task_id, prog=0.82, msg="Start to index...")
|
||||
e = await insert_es(task_id, task["tenant_id"], task["kb_id"], chunks, partial(set_progress, task_id, 0, 100000000))
|
||||
if not e:
|
||||
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE)
|
||||
return
|
||||
|
||||
time_cost = timer() - start_ts
|
||||
task_time_cost = timer() - task_start_ts
|
||||
set_progress(task_id, prog=1., msg="Indexing done ({:.2f}s). Task done ({:.2f}s)".format(time_cost, task_time_cost))
|
||||
logging.info(
|
||||
"[Done], chunks({}), token({}), elapsed:{:.2f}".format(len(chunks), embedding_token_consumption, task_time_cost))
|
||||
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE)
|
||||
|
||||
|
||||
@timeout(3600)
|
||||
@ -510,11 +609,48 @@ async def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
|
||||
return res, tk_count
|
||||
|
||||
|
||||
async def delete_image(kb_id, chunk_id):
|
||||
try:
|
||||
async with minio_limiter:
|
||||
STORAGE_IMPL.delete(kb_id, chunk_id)
|
||||
except Exception:
|
||||
logging.exception(f"Deleting image of chunk {chunk_id} got exception")
|
||||
raise
|
||||
|
||||
|
||||
async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_callback):
|
||||
for b in range(0, len(chunks), DOC_BULK_SIZE):
|
||||
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + DOC_BULK_SIZE], search.index_name(task_tenant_id), task_dataset_id))
|
||||
task_canceled = has_canceled(task_id)
|
||||
if task_canceled:
|
||||
progress_callback(-1, msg="Task has been canceled.")
|
||||
return
|
||||
if b % 128 == 0:
|
||||
progress_callback(prog=0.8 + 0.1 * (b + 1) / len(chunks), msg="")
|
||||
if doc_store_result:
|
||||
error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
|
||||
progress_callback(-1, msg=error_message)
|
||||
raise Exception(error_message)
|
||||
chunk_ids = [chunk["id"] for chunk in chunks[:b + DOC_BULK_SIZE]]
|
||||
chunk_ids_str = " ".join(chunk_ids)
|
||||
try:
|
||||
TaskService.update_chunk_ids(task_id, chunk_ids_str)
|
||||
except DoesNotExist:
|
||||
logging.warning(f"do_handle_task update_chunk_ids failed since task {task_id} is unknown.")
|
||||
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id), task_dataset_id))
|
||||
async with trio.open_nursery() as nursery:
|
||||
for chunk_id in chunk_ids:
|
||||
nursery.start_soon(delete_image, task_dataset_id, chunk_id)
|
||||
progress_callback(-1, msg=f"Chunk updates failed since task {task_id} is unknown.")
|
||||
return
|
||||
return True
|
||||
|
||||
|
||||
@timeout(60*60*2, 1)
|
||||
async def do_handle_task(task):
|
||||
task_type = task.get("task_type", "")
|
||||
|
||||
if task_type == "dataflow" and task.get("doc_id", "") == "x":
|
||||
if task_type == "dataflow" and task.get("doc_id", "") == CANVAS_DEBUG_DOC_ID:
|
||||
await run_dataflow(task)
|
||||
return
|
||||
|
||||
@ -559,7 +695,7 @@ async def do_handle_task(task):
|
||||
|
||||
init_kb(task, vector_size)
|
||||
|
||||
if task_type == "dataflow":
|
||||
if task_type[:len("dataflow")] == "dataflow":
|
||||
await run_dataflow(task)
|
||||
return
|
||||
|
||||
@ -580,7 +716,19 @@ async def do_handle_task(task):
|
||||
with_resolution = graphrag_conf.get("resolution", False)
|
||||
with_community = graphrag_conf.get("community", False)
|
||||
async with kg_limiter:
|
||||
await run_graphrag(task, task_language, with_resolution, with_community, chat_model, embedding_model, progress_callback)
|
||||
# await run_graphrag(task, task_language, with_resolution, with_community, chat_model, embedding_model, progress_callback)
|
||||
result = await run_graphrag_for_kb(
|
||||
row=task,
|
||||
doc_ids=task.get("doc_ids", []),
|
||||
language=task_language,
|
||||
kb_parser_config=task_parser_config,
|
||||
chat_model=chat_model,
|
||||
embedding_model=embedding_model,
|
||||
callback=progress_callback,
|
||||
with_resolution=with_resolution,
|
||||
with_community=with_community,
|
||||
)
|
||||
logging.info(f"GraphRAG task result for task {task}:\n{result}")
|
||||
progress_callback(prog=1.0, msg="Knowledge Graph done ({:.2f}s)".format(timer() - start_ts))
|
||||
return
|
||||
else:
|
||||
@ -609,48 +757,15 @@ async def do_handle_task(task):
|
||||
|
||||
chunk_count = len(set([chunk["id"] for chunk in chunks]))
|
||||
start_ts = timer()
|
||||
doc_store_result = ""
|
||||
|
||||
async def delete_image(kb_id, chunk_id):
|
||||
try:
|
||||
async with minio_limiter:
|
||||
STORAGE_IMPL.delete(kb_id, chunk_id)
|
||||
except Exception:
|
||||
logging.exception(
|
||||
"Deleting image of chunk {}/{}/{} got exception".format(task["location"], task["name"], chunk_id))
|
||||
raise
|
||||
|
||||
for b in range(0, len(chunks), DOC_BULK_SIZE):
|
||||
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + DOC_BULK_SIZE], search.index_name(task_tenant_id), task_dataset_id))
|
||||
task_canceled = has_canceled(task_id)
|
||||
if task_canceled:
|
||||
progress_callback(-1, msg="Task has been canceled.")
|
||||
return
|
||||
if b % 128 == 0:
|
||||
progress_callback(prog=0.8 + 0.1 * (b + 1) / len(chunks), msg="")
|
||||
if doc_store_result:
|
||||
error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
|
||||
progress_callback(-1, msg=error_message)
|
||||
raise Exception(error_message)
|
||||
chunk_ids = [chunk["id"] for chunk in chunks[:b + DOC_BULK_SIZE]]
|
||||
chunk_ids_str = " ".join(chunk_ids)
|
||||
try:
|
||||
TaskService.update_chunk_ids(task["id"], chunk_ids_str)
|
||||
except DoesNotExist:
|
||||
logging.warning(f"do_handle_task update_chunk_ids failed since task {task['id']} is unknown.")
|
||||
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id), task_dataset_id))
|
||||
async with trio.open_nursery() as nursery:
|
||||
for chunk_id in chunk_ids:
|
||||
nursery.start_soon(delete_image, task_dataset_id, chunk_id)
|
||||
progress_callback(-1, msg=f"Chunk updates failed since task {task['id']} is unknown.")
|
||||
return
|
||||
e = await insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_callback)
|
||||
if not e:
|
||||
return
|
||||
|
||||
logging.info("Indexing doc({}), page({}-{}), chunks({}), elapsed: {:.2f}".format(task_document_name, task_from_page,
|
||||
task_to_page, len(chunks),
|
||||
timer() - start_ts))
|
||||
|
||||
DocumentService.increment_chunk_num(task_doc_id, task_dataset_id, token_count, chunk_count, 0)
|
||||
PipelineOperationLogService.record_pipeline_operation(document_id=task_doc_id, pipeline_id="", task_type=PipelineTaskType.PARSE)
|
||||
|
||||
time_cost = timer() - start_ts
|
||||
task_time_cost = timer() - task_start_ts
|
||||
@ -667,6 +782,10 @@ async def handle_task():
|
||||
if not task:
|
||||
await trio.sleep(5)
|
||||
return
|
||||
|
||||
task_type = task["task_type"]
|
||||
pipeline_task_type = TASK_TYPE_TO_PIPELINE_TASK_TYPE.get(task_type, PipelineTaskType.PARSE) or PipelineTaskType.PARSE
|
||||
|
||||
try:
|
||||
logging.info(f"handle_task begin for task {json.dumps(task)}")
|
||||
CURRENT_TASKS[task["id"]] = copy.deepcopy(task)
|
||||
@ -686,7 +805,13 @@ async def handle_task():
|
||||
except Exception:
|
||||
pass
|
||||
logging.exception(f"handle_task got exception for task {json.dumps(task)}")
|
||||
PipelineOperationLogService.record_pipeline_operation(document_id=task["doc_id"], pipeline_id=task.get("dataflow_id", "") or "", task_type=PipelineTaskType.PARSE)
|
||||
finally:
|
||||
task_document_ids = []
|
||||
if task_type in ["graphrag"]:
|
||||
task_document_ids = task["doc_ids"]
|
||||
if task["doc_id"] != CANVAS_DEBUG_DOC_ID:
|
||||
PipelineOperationLogService.record_pipeline_operation(document_id=task["doc_id"], pipeline_id=task.get("dataflow_id", "") or "", task_type=pipeline_task_type, fake_document_ids=task_document_ids)
|
||||
|
||||
redis_msg.ack()
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user