mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Feat: add foundational support for GraphRAG dataset pipeline logs (#10264)
### What problem does this PR solve? Add foundational support for GraphRAG dataset pipeline logs ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -14,17 +14,19 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_login import login_required, current_user
|
from flask_login import login_required, current_user
|
||||||
|
|
||||||
from api.db.services import duplicate_name
|
from api.db.services import duplicate_name
|
||||||
from api.db.services.document_service import DocumentService
|
from api.db.services.document_service import DocumentService, queue_raptor_o_graphrag_tasks
|
||||||
from api.db.services.file2document_service import File2DocumentService
|
from api.db.services.file2document_service import File2DocumentService
|
||||||
from api.db.services.file_service import FileService
|
from api.db.services.file_service import FileService
|
||||||
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
|
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
|
||||||
|
from api.db.services.task_service import TaskService
|
||||||
from api.db.services.user_service import TenantService, UserTenantService
|
from api.db.services.user_service import TenantService, UserTenantService
|
||||||
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request, not_allowed_parameters
|
from api.utils.api_utils import get_error_data_result, server_error_response, get_data_error_result, validate_request, not_allowed_parameters
|
||||||
from api.utils import get_uuid
|
from api.utils import get_uuid
|
||||||
from api.db import StatusEnum, FileSource, VALID_FILE_TYPES
|
from api.db import StatusEnum, FileSource, VALID_FILE_TYPES
|
||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
@ -435,18 +437,60 @@ def list_pipeline_logs():
|
|||||||
suffix = req.get("suffix", [])
|
suffix = req.get("suffix", [])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
docs, tol = PipelineOperationLogService.get_by_kb_id(kb_id, page_number, items_per_page, orderby, desc, keywords, operation_status, types, suffix)
|
logs, tol = PipelineOperationLogService.get_file_logs_by_kb_id(kb_id, page_number, items_per_page, orderby, desc, keywords, operation_status, types, suffix)
|
||||||
|
|
||||||
if create_time_from or create_time_to:
|
if create_time_from or create_time_to:
|
||||||
filtered_docs = []
|
filtered_docs = []
|
||||||
for doc in docs:
|
for doc in logs:
|
||||||
doc_create_time = doc.get("create_time", 0)
|
doc_create_time = doc.get("create_time", 0)
|
||||||
if (create_time_from == 0 or doc_create_time >= create_time_from) and (create_time_to == 0 or doc_create_time <= create_time_to):
|
if (create_time_from == 0 or doc_create_time >= create_time_from) and (create_time_to == 0 or doc_create_time <= create_time_to):
|
||||||
filtered_docs.append(doc)
|
filtered_docs.append(doc)
|
||||||
docs = filtered_docs
|
logs = filtered_docs
|
||||||
|
|
||||||
|
|
||||||
return get_json_result(data={"total": tol, "docs": docs})
|
return get_json_result(data={"total": tol, "logs": logs})
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/list_pipeline_dataset_logs", methods=["POST"]) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
def list_pipeline_dataset_logs():
|
||||||
|
kb_id = request.args.get("kb_id")
|
||||||
|
if not kb_id:
|
||||||
|
return get_json_result(data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
|
||||||
|
|
||||||
|
page_number = int(request.args.get("page", 0))
|
||||||
|
items_per_page = int(request.args.get("page_size", 0))
|
||||||
|
orderby = request.args.get("orderby", "create_time")
|
||||||
|
if request.args.get("desc", "true").lower() == "false":
|
||||||
|
desc = False
|
||||||
|
else:
|
||||||
|
desc = True
|
||||||
|
create_time_from = int(request.args.get("create_time_from", 0))
|
||||||
|
create_time_to = int(request.args.get("create_time_to", 0))
|
||||||
|
|
||||||
|
req = request.get_json()
|
||||||
|
|
||||||
|
operation_status = req.get("operation_status", [])
|
||||||
|
if operation_status:
|
||||||
|
invalid_status = {s for s in operation_status if s not in ["success", "failed", "running", "pending"]}
|
||||||
|
if invalid_status:
|
||||||
|
return get_data_error_result(message=f"Invalid filter operation_status status conditions: {', '.join(invalid_status)}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
logs, tol = PipelineOperationLogService.get_dataset_logs_by_kb_id(kb_id, page_number, items_per_page, orderby, desc, operation_status)
|
||||||
|
|
||||||
|
if create_time_from or create_time_to:
|
||||||
|
filtered_docs = []
|
||||||
|
for doc in logs:
|
||||||
|
doc_create_time = doc.get("create_time", 0)
|
||||||
|
if (create_time_from == 0 or doc_create_time >= create_time_from) and (create_time_to == 0 or doc_create_time <= create_time_to):
|
||||||
|
filtered_docs.append(doc)
|
||||||
|
logs = filtered_docs
|
||||||
|
|
||||||
|
|
||||||
|
return get_json_result(data={"total": tol, "logs": logs})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
@ -478,3 +522,68 @@ def pipeline_log_detail():
|
|||||||
return get_data_error_result(message="Invalid pipeline log ID")
|
return get_data_error_result(message="Invalid pipeline log ID")
|
||||||
|
|
||||||
return get_json_result(data=log.to_dict())
|
return get_json_result(data=log.to_dict())
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/run_graphrag", methods=["POST"]) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
def run_graphrag():
|
||||||
|
req = request.json
|
||||||
|
|
||||||
|
kb_id = req.get("kb_id", "")
|
||||||
|
if not kb_id:
|
||||||
|
return get_error_data_result(message='Lack of "KB ID"')
|
||||||
|
|
||||||
|
doc_ids = req.get("doc_ids", [])
|
||||||
|
if not doc_ids:
|
||||||
|
return get_error_data_result(message="Need to specify document IDs to run Graph RAG")
|
||||||
|
|
||||||
|
ok, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||||
|
if not ok:
|
||||||
|
return get_error_data_result(message="Invalid Knowledgebase ID")
|
||||||
|
|
||||||
|
task_id = kb.graphrag_task_id
|
||||||
|
ok, task = TaskService.get_by_id(task_id)
|
||||||
|
if not ok:
|
||||||
|
logging.warning(f"A valid GraphRAG task id is expected for kb {kb_id}")
|
||||||
|
|
||||||
|
if task and task.progress not in [-1, 1]:
|
||||||
|
return get_error_data_result(message=f"Task in progress with status {task.progress}. A Graph Task is already running.")
|
||||||
|
|
||||||
|
document_ids = set()
|
||||||
|
sample_document = {}
|
||||||
|
for doc_id in doc_ids:
|
||||||
|
ok, document = DocumentService.get_by_id(doc_id)
|
||||||
|
if ok:
|
||||||
|
document_ids.add(document.id)
|
||||||
|
if not sample_document:
|
||||||
|
sample_document = document.to_dict()
|
||||||
|
|
||||||
|
task_id = queue_raptor_o_graphrag_tasks(doc=sample_document, ty="graphrag", priority=0, fake_doc_id="x", doc_ids=list(document_ids))
|
||||||
|
|
||||||
|
if not KnowledgebaseService.update_by_id(kb.id, {"graphrag_task_id": task_id}):
|
||||||
|
logging.warning(f"Cannot save graphrag_task_id for kb {kb_id}")
|
||||||
|
|
||||||
|
return get_json_result(data={"graphrag_task_id": task_id})
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/trace_graphrag", methods=["GET"]) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
def trace_graphrag():
|
||||||
|
kb_id = request.args.get("kb_id", "")
|
||||||
|
if not kb_id:
|
||||||
|
return get_error_data_result(message='Lack of "KB ID"')
|
||||||
|
|
||||||
|
ok, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||||
|
if not ok:
|
||||||
|
return get_error_data_result(message="Invalid Knowledgebase ID")
|
||||||
|
|
||||||
|
|
||||||
|
task_id = kb.graphrag_task_id
|
||||||
|
if not task_id:
|
||||||
|
return get_error_data_result(message="GraphRAG Task ID Not Found")
|
||||||
|
|
||||||
|
ok, task = TaskService.get_by_id(task_id)
|
||||||
|
if not ok:
|
||||||
|
return get_json_result(data=False, message="GraphRAG Task Not Found or Error Occurred", code=settings.RetCode.ARGUMENT_ERROR)
|
||||||
|
|
||||||
|
return get_json_result(data=task.to_dict())
|
||||||
|
|||||||
@ -124,10 +124,12 @@ VALID_MCP_SERVER_TYPES = {MCPServerType.SSE, MCPServerType.STREAMABLE_HTTP}
|
|||||||
|
|
||||||
class PipelineTaskType(StrEnum):
|
class PipelineTaskType(StrEnum):
|
||||||
PARSE = "Parse"
|
PARSE = "Parse"
|
||||||
DOWNLOAD = "DOWNLOAD"
|
DOWNLOAD = "Download"
|
||||||
|
RAPTOR = "RAPTOR"
|
||||||
|
GRAPH_RAG = "GraphRAG"
|
||||||
|
|
||||||
|
|
||||||
VALID_PIPELINE_TASK_TYPES = {PipelineTaskType.PARSE, PipelineTaskType.DOWNLOAD}
|
VALID_PIPELINE_TASK_TYPES = {PipelineTaskType.PARSE, PipelineTaskType.DOWNLOAD, PipelineTaskType.RAPTOR, PipelineTaskType.GRAPH_RAG}
|
||||||
|
|
||||||
|
|
||||||
KNOWLEDGEBASE_FOLDER_NAME=".knowledgebase"
|
KNOWLEDGEBASE_FOLDER_NAME=".knowledgebase"
|
||||||
|
|||||||
@ -649,6 +649,9 @@ class Knowledgebase(DataBaseModel):
|
|||||||
pipeline_id = CharField(max_length=32, null=True, help_text="Pipeline ID", index=True)
|
pipeline_id = CharField(max_length=32, null=True, help_text="Pipeline ID", index=True)
|
||||||
parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]})
|
parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]})
|
||||||
pagerank = IntegerField(default=0, index=False)
|
pagerank = IntegerField(default=0, index=False)
|
||||||
|
|
||||||
|
graphrag_task_id = CharField(max_length=32, null=True, help_text="Graph RAG task ID", index=True)
|
||||||
|
|
||||||
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True)
|
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True)
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
@ -1065,11 +1068,15 @@ def migrate_db():
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
try:
|
try:
|
||||||
migrate(migrator.add_column("knowledgebase", "pipeline_id", CharField(max_length=32, null=True, help_text="default parser ID", index=True)))
|
migrate(migrator.add_column("knowledgebase", "pipeline_id", CharField(max_length=32, null=True, help_text="Pipeline ID", index=True)))
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
try:
|
try:
|
||||||
migrate(migrator.add_column("document", "pipeline_id", CharField(max_length=32, null=True, help_text="default parser ID", index=True)))
|
migrate(migrator.add_column("document", "pipeline_id", CharField(max_length=32, null=True, help_text="Pipeline ID", index=True)))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
migrate(migrator.add_column("knowledgebase", "graphrag_task_id", CharField(max_length=32, null=True, help_text="Gragh RAG task ID", index=True)))
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
logging.disable(logging.NOTSET)
|
logging.disable(logging.NOTSET)
|
||||||
|
|||||||
@ -507,6 +507,9 @@ class DocumentService(CommonService):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_doc_id_by_doc_name(cls, doc_name):
|
def get_doc_id_by_doc_name(cls, doc_name):
|
||||||
|
"""
|
||||||
|
highly rely on the strict deduplication guarantee from Document
|
||||||
|
"""
|
||||||
fields = [cls.model.id]
|
fields = [cls.model.id]
|
||||||
doc_id = cls.model.select(*fields) \
|
doc_id = cls.model.select(*fields) \
|
||||||
.where(cls.model.name == doc_name)
|
.where(cls.model.name == doc_name)
|
||||||
@ -656,6 +659,7 @@ class DocumentService(CommonService):
|
|||||||
queue_raptor_o_graphrag_tasks(d, "graphrag", priority)
|
queue_raptor_o_graphrag_tasks(d, "graphrag", priority)
|
||||||
prg = 0.98 * len(tsks) / (len(tsks) + 1)
|
prg = 0.98 * len(tsks) / (len(tsks) + 1)
|
||||||
else:
|
else:
|
||||||
|
prg = 1
|
||||||
status = TaskStatus.DONE.value
|
status = TaskStatus.DONE.value
|
||||||
|
|
||||||
msg = "\n".join(sorted(msg))
|
msg = "\n".join(sorted(msg))
|
||||||
@ -741,7 +745,11 @@ class DocumentService(CommonService):
|
|||||||
"cancelled": int(cancelled),
|
"cancelled": int(cancelled),
|
||||||
}
|
}
|
||||||
|
|
||||||
def queue_raptor_o_graphrag_tasks(doc, ty, priority):
|
def queue_raptor_o_graphrag_tasks(doc, ty, priority, fake_doc_id="", doc_ids=[]):
|
||||||
|
"""
|
||||||
|
You can provide a fake_doc_id to bypass the restriction of tasks at the knowledgebase level.
|
||||||
|
Optionally, specify a list of doc_ids to determine which documents participate in the task.
|
||||||
|
"""
|
||||||
chunking_config = DocumentService.get_chunking_config(doc["id"])
|
chunking_config = DocumentService.get_chunking_config(doc["id"])
|
||||||
hasher = xxhash.xxh64()
|
hasher = xxhash.xxh64()
|
||||||
for field in sorted(chunking_config.keys()):
|
for field in sorted(chunking_config.keys()):
|
||||||
@ -751,7 +759,7 @@ def queue_raptor_o_graphrag_tasks(doc, ty, priority):
|
|||||||
nonlocal doc
|
nonlocal doc
|
||||||
return {
|
return {
|
||||||
"id": get_uuid(),
|
"id": get_uuid(),
|
||||||
"doc_id": doc["id"],
|
"doc_id": fake_doc_id if fake_doc_id else doc["id"],
|
||||||
"from_page": 100000000,
|
"from_page": 100000000,
|
||||||
"to_page": 100000000,
|
"to_page": 100000000,
|
||||||
"task_type": ty,
|
"task_type": ty,
|
||||||
@ -764,7 +772,11 @@ def queue_raptor_o_graphrag_tasks(doc, ty, priority):
|
|||||||
hasher.update(ty.encode("utf-8"))
|
hasher.update(ty.encode("utf-8"))
|
||||||
task["digest"] = hasher.hexdigest()
|
task["digest"] = hasher.hexdigest()
|
||||||
bulk_insert_into_db(Task, [task], True)
|
bulk_insert_into_db(Task, [task], True)
|
||||||
|
|
||||||
|
if ty == "graphrag":
|
||||||
|
task["doc_ids"] = doc_ids
|
||||||
assert REDIS_CONN.queue_product(get_svr_queue_name(priority), message=task), "Can't access Redis. Please check the Redis' status."
|
assert REDIS_CONN.queue_product(get_svr_queue_name(priority), message=task), "Can't access Redis. Please check the Redis' status."
|
||||||
|
return task["id"]
|
||||||
|
|
||||||
|
|
||||||
def get_queue_length(priority):
|
def get_queue_length(priority):
|
||||||
|
|||||||
@ -31,7 +31,7 @@ class PipelineOperationLogService(CommonService):
|
|||||||
model = PipelineOperationLog
|
model = PipelineOperationLog
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_cls_model_fields(cls):
|
def get_file_logs_fields(cls):
|
||||||
return [
|
return [
|
||||||
cls.model.id,
|
cls.model.id,
|
||||||
cls.model.document_id,
|
cls.model.document_id,
|
||||||
@ -59,9 +59,29 @@ class PipelineOperationLogService(CommonService):
|
|||||||
cls.model.update_date,
|
cls.model.update_date,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_dataset_logs_fields(cls):
|
||||||
|
return [
|
||||||
|
cls.model.id,
|
||||||
|
cls.model.tenant_id,
|
||||||
|
cls.model.kb_id,
|
||||||
|
cls.model.progress,
|
||||||
|
cls.model.progress_msg,
|
||||||
|
cls.model.process_begin_at,
|
||||||
|
cls.model.process_duration,
|
||||||
|
cls.model.task_type,
|
||||||
|
cls.model.operation_status,
|
||||||
|
cls.model.avatar,
|
||||||
|
cls.model.status,
|
||||||
|
cls.model.create_time,
|
||||||
|
cls.model.create_date,
|
||||||
|
cls.model.update_time,
|
||||||
|
cls.model.update_date,
|
||||||
|
]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def create(cls, document_id, pipeline_id, task_type):
|
def create(cls, document_id, pipeline_id, task_type, fake_document_ids=[]):
|
||||||
from rag.flow.pipeline import Pipeline
|
from rag.flow.pipeline import Pipeline
|
||||||
|
|
||||||
tenant_id = ""
|
tenant_id = ""
|
||||||
@ -69,14 +89,19 @@ class PipelineOperationLogService(CommonService):
|
|||||||
avatar = ""
|
avatar = ""
|
||||||
dsl = ""
|
dsl = ""
|
||||||
operation_status = ""
|
operation_status = ""
|
||||||
|
referred_document_id = document_id
|
||||||
|
|
||||||
ok, document = DocumentService.get_by_id(document_id)
|
if referred_document_id == "x" and fake_document_ids:
|
||||||
|
referred_document_id = fake_document_ids[0]
|
||||||
|
ok, document = DocumentService.get_by_id(referred_document_id)
|
||||||
if not ok:
|
if not ok:
|
||||||
raise RuntimeError(f"Document {document_id} not found")
|
raise RuntimeError(f"Document for referred_document_id {referred_document_id} not found")
|
||||||
DocumentService.update_progress_immediately([document.to_dict()])
|
DocumentService.update_progress_immediately([document.to_dict()])
|
||||||
ok, document = DocumentService.get_by_id(document_id)
|
ok, document = DocumentService.get_by_id(referred_document_id)
|
||||||
if not ok:
|
if not ok:
|
||||||
raise RuntimeError(f"Document {document_id} not found")
|
raise RuntimeError(f"Document for referred_document_id {referred_document_id} not found")
|
||||||
|
if document.progress not in [1, -1]:
|
||||||
|
return
|
||||||
operation_status = document.run
|
operation_status = document.run
|
||||||
|
|
||||||
if pipeline_id:
|
if pipeline_id:
|
||||||
@ -84,7 +109,7 @@ class PipelineOperationLogService(CommonService):
|
|||||||
if not ok:
|
if not ok:
|
||||||
raise RuntimeError(f"Pipeline {pipeline_id} not found")
|
raise RuntimeError(f"Pipeline {pipeline_id} not found")
|
||||||
|
|
||||||
pipeline = Pipeline(dsl=json.dumps(user_pipeline.dsl), tenant_id=user_pipeline.user_id, doc_id=document_id, task_id="", flow_id=pipeline_id)
|
pipeline = Pipeline(dsl=json.dumps(user_pipeline.dsl), tenant_id=user_pipeline.user_id, doc_id=referred_document_id, task_id="", flow_id=pipeline_id)
|
||||||
|
|
||||||
tenant_id = user_pipeline.user_id
|
tenant_id = user_pipeline.user_id
|
||||||
title = user_pipeline.title
|
title = user_pipeline.title
|
||||||
@ -93,7 +118,7 @@ class PipelineOperationLogService(CommonService):
|
|||||||
else:
|
else:
|
||||||
ok, kb_info = KnowledgebaseService.get_by_id(document.kb_id)
|
ok, kb_info = KnowledgebaseService.get_by_id(document.kb_id)
|
||||||
if not ok:
|
if not ok:
|
||||||
raise RuntimeError(f"Cannot find knowledge base {document.kb_id} for document {document_id}")
|
raise RuntimeError(f"Cannot find knowledge base {document.kb_id} for referred_document {referred_document_id}")
|
||||||
|
|
||||||
tenant_id = kb_info.tenant_id
|
tenant_id = kb_info.tenant_id
|
||||||
title = document.name
|
title = document.name
|
||||||
@ -104,7 +129,7 @@ class PipelineOperationLogService(CommonService):
|
|||||||
|
|
||||||
log = dict(
|
log = dict(
|
||||||
id=get_uuid(),
|
id=get_uuid(),
|
||||||
document_id=document_id,
|
document_id=document_id, # "x" or real document_id
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
kb_id=document.kb_id,
|
kb_id=document.kb_id,
|
||||||
pipeline_id=pipeline_id,
|
pipeline_id=pipeline_id,
|
||||||
@ -132,18 +157,20 @@ class PipelineOperationLogService(CommonService):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def record_pipeline_operation(cls, document_id, pipeline_id, task_type):
|
def record_pipeline_operation(cls, document_id, pipeline_id, task_type, fake_document_ids=[]):
|
||||||
return cls.create(document_id=document_id, pipeline_id=pipeline_id, task_type=task_type)
|
return cls.create(document_id=document_id, pipeline_id=pipeline_id, task_type=task_type, fake_document_ids=fake_document_ids)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, keywords, operation_status, types, suffix):
|
def get_file_logs_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, keywords, operation_status, types, suffix):
|
||||||
fields = cls.get_cls_model_fields()
|
fields = cls.get_file_logs_fields()
|
||||||
if keywords:
|
if keywords:
|
||||||
logs = cls.model.select(*fields).where((cls.model.kb_id == kb_id), (fn.LOWER(cls.model.document_name).contains(keywords.lower())))
|
logs = cls.model.select(*fields).where((cls.model.kb_id == kb_id), (fn.LOWER(cls.model.document_name).contains(keywords.lower())))
|
||||||
else:
|
else:
|
||||||
logs = cls.model.select(*fields).where(cls.model.kb_id == kb_id)
|
logs = cls.model.select(*fields).where(cls.model.kb_id == kb_id)
|
||||||
|
|
||||||
|
logs = logs.where(cls.model.document_id != "x")
|
||||||
|
|
||||||
if operation_status:
|
if operation_status:
|
||||||
logs = logs.where(cls.model.operation_status.in_(operation_status))
|
logs = logs.where(cls.model.operation_status.in_(operation_status))
|
||||||
if types:
|
if types:
|
||||||
@ -161,3 +188,23 @@ class PipelineOperationLogService(CommonService):
|
|||||||
logs = logs.paginate(page_number, items_per_page)
|
logs = logs.paginate(page_number, items_per_page)
|
||||||
|
|
||||||
return list(logs.dicts()), count
|
return list(logs.dicts()), count
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def get_dataset_logs_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, operation_status):
|
||||||
|
fields = cls.get_dataset_logs_fields()
|
||||||
|
logs = cls.model.select(*fields).where((cls.model.kb_id == kb_id), (cls.model.document_id == "x"))
|
||||||
|
|
||||||
|
if operation_status:
|
||||||
|
logs = logs.where(cls.model.operation_status.in_(operation_status))
|
||||||
|
|
||||||
|
count = logs.count()
|
||||||
|
if desc:
|
||||||
|
logs = logs.order_by(cls.model.getter_by(orderby).desc())
|
||||||
|
else:
|
||||||
|
logs = logs.order_by(cls.model.getter_by(orderby).asc())
|
||||||
|
|
||||||
|
if page_number and items_per_page:
|
||||||
|
logs = logs.paginate(page_number, items_per_page)
|
||||||
|
|
||||||
|
return list(logs.dicts()), count
|
||||||
|
|||||||
@ -70,7 +70,7 @@ class TaskService(CommonService):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_task(cls, task_id):
|
def get_task(cls, task_id, doc_ids=[]):
|
||||||
"""Retrieve detailed task information by task ID.
|
"""Retrieve detailed task information by task ID.
|
||||||
|
|
||||||
This method fetches comprehensive task details including associated document,
|
This method fetches comprehensive task details including associated document,
|
||||||
@ -84,6 +84,10 @@ class TaskService(CommonService):
|
|||||||
dict: Task details dictionary containing all task information and related metadata.
|
dict: Task details dictionary containing all task information and related metadata.
|
||||||
Returns None if task is not found or has exceeded retry limit.
|
Returns None if task is not found or has exceeded retry limit.
|
||||||
"""
|
"""
|
||||||
|
doc_id = cls.model.doc_id
|
||||||
|
if doc_id == "x" and doc_ids:
|
||||||
|
doc_id = doc_ids[0]
|
||||||
|
|
||||||
fields = [
|
fields = [
|
||||||
cls.model.id,
|
cls.model.id,
|
||||||
cls.model.doc_id,
|
cls.model.doc_id,
|
||||||
@ -109,7 +113,7 @@ class TaskService(CommonService):
|
|||||||
]
|
]
|
||||||
docs = (
|
docs = (
|
||||||
cls.model.select(*fields)
|
cls.model.select(*fields)
|
||||||
.join(Document, on=(cls.model.doc_id == Document.id))
|
.join(Document, on=(doc_id == Document.id))
|
||||||
.join(Knowledgebase, on=(Document.kb_id == Knowledgebase.id))
|
.join(Knowledgebase, on=(Document.kb_id == Knowledgebase.id))
|
||||||
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))
|
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))
|
||||||
.where(cls.model.id == task_id)
|
.where(cls.model.id == task_id)
|
||||||
|
|||||||
@ -21,6 +21,7 @@ import networkx as nx
|
|||||||
import trio
|
import trio
|
||||||
|
|
||||||
from api import settings
|
from api import settings
|
||||||
|
from api.db.services.document_service import DocumentService
|
||||||
from api.utils import get_uuid
|
from api.utils import get_uuid
|
||||||
from api.utils.api_utils import timeout
|
from api.utils.api_utils import timeout
|
||||||
from graphrag.entity_resolution import EntityResolution
|
from graphrag.entity_resolution import EntityResolution
|
||||||
@ -54,7 +55,7 @@ async def run_graphrag(
|
|||||||
start = trio.current_time()
|
start = trio.current_time()
|
||||||
tenant_id, kb_id, doc_id = row["tenant_id"], str(row["kb_id"]), row["doc_id"]
|
tenant_id, kb_id, doc_id = row["tenant_id"], str(row["kb_id"]), row["doc_id"]
|
||||||
chunks = []
|
chunks = []
|
||||||
for d in settings.retrievaler.chunk_list(doc_id, tenant_id, [kb_id], fields=["content_with_weight", "doc_id"]):
|
for d in settings.retrievaler.chunk_list(doc_id, tenant_id, [kb_id], fields=["content_with_weight", "doc_id"], sort_by_position=True):
|
||||||
chunks.append(d["content_with_weight"])
|
chunks.append(d["content_with_weight"])
|
||||||
|
|
||||||
with trio.fail_after(max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000):
|
with trio.fail_after(max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000):
|
||||||
@ -125,6 +126,212 @@ async def run_graphrag(
|
|||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
|
async def run_graphrag_for_kb(
|
||||||
|
row: dict,
|
||||||
|
doc_ids: list[str],
|
||||||
|
language: str,
|
||||||
|
kb_parser_config: dict,
|
||||||
|
chat_model,
|
||||||
|
embedding_model,
|
||||||
|
callback,
|
||||||
|
*,
|
||||||
|
with_resolution: bool = True,
|
||||||
|
with_community: bool = True,
|
||||||
|
max_parallel_docs: int = 4,
|
||||||
|
) -> dict:
|
||||||
|
tenant_id, kb_id = row["tenant_id"], row["kb_id"]
|
||||||
|
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
|
||||||
|
start = trio.current_time()
|
||||||
|
fields_for_chunks = ["content_with_weight", "doc_id"]
|
||||||
|
|
||||||
|
if not doc_ids:
|
||||||
|
logging.info(f"Fetching all docs for {kb_id}")
|
||||||
|
docs, _ = DocumentService.get_by_kb_id(
|
||||||
|
kb_id=kb_id,
|
||||||
|
page_number=0,
|
||||||
|
items_per_page=0,
|
||||||
|
orderby="create_time",
|
||||||
|
desc=False,
|
||||||
|
keywords="",
|
||||||
|
run_status=[],
|
||||||
|
types=[],
|
||||||
|
suffix=[],
|
||||||
|
)
|
||||||
|
doc_ids = [doc["id"] for doc in docs]
|
||||||
|
|
||||||
|
doc_ids = list(dict.fromkeys(doc_ids))
|
||||||
|
if not doc_ids:
|
||||||
|
callback(msg=f"[GraphRAG] kb:{kb_id} has no processable doc_id.")
|
||||||
|
return {"ok_docs": [], "failed_docs": [], "total_docs": 0, "total_chunks": 0, "seconds": 0.0}
|
||||||
|
|
||||||
|
def load_doc_chunks(doc_id: str) -> list[str]:
|
||||||
|
from rag.utils import num_tokens_from_string
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
current_chunk = ""
|
||||||
|
|
||||||
|
for d in settings.retrievaler.chunk_list(
|
||||||
|
doc_id,
|
||||||
|
tenant_id,
|
||||||
|
[kb_id],
|
||||||
|
fields=fields_for_chunks,
|
||||||
|
sort_by_position=True,
|
||||||
|
):
|
||||||
|
content = d["content_with_weight"]
|
||||||
|
if num_tokens_from_string(current_chunk + content) < 1024:
|
||||||
|
current_chunk += content
|
||||||
|
else:
|
||||||
|
if current_chunk:
|
||||||
|
chunks.append(current_chunk)
|
||||||
|
current_chunk = content
|
||||||
|
|
||||||
|
if current_chunk:
|
||||||
|
chunks.append(current_chunk)
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
all_doc_chunks: dict[str, list[str]] = {}
|
||||||
|
total_chunks = 0
|
||||||
|
for doc_id in doc_ids:
|
||||||
|
chunks = load_doc_chunks(doc_id)
|
||||||
|
all_doc_chunks[doc_id] = chunks
|
||||||
|
total_chunks += len(chunks)
|
||||||
|
|
||||||
|
if total_chunks == 0:
|
||||||
|
callback(msg=f"[GraphRAG] kb:{kb_id} has no available chunks in all documents, skip.")
|
||||||
|
return {"ok_docs": [], "failed_docs": doc_ids, "total_docs": len(doc_ids), "total_chunks": 0, "seconds": 0.0}
|
||||||
|
|
||||||
|
semaphore = trio.Semaphore(max_parallel_docs)
|
||||||
|
|
||||||
|
subgraphs: dict[str, object] = {}
|
||||||
|
failed_docs: list[tuple[str, str]] = [] # (doc_id, error)
|
||||||
|
|
||||||
|
async def build_one(doc_id: str):
|
||||||
|
chunks = all_doc_chunks.get(doc_id, [])
|
||||||
|
if not chunks:
|
||||||
|
callback(msg=f"[GraphRAG] doc:{doc_id} has no available chunks, skip generation.")
|
||||||
|
return
|
||||||
|
|
||||||
|
kg_extractor = LightKGExt if ("method" not in kb_parser_config.get("graphrag", {}) or kb_parser_config["graphrag"]["method"] != "general") else GeneralKGExt
|
||||||
|
|
||||||
|
deadline = max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000
|
||||||
|
|
||||||
|
async with semaphore:
|
||||||
|
try:
|
||||||
|
msg = f"[GraphRAG] build_subgraph doc:{doc_id}"
|
||||||
|
callback(msg=f"{msg} start (chunks={len(chunks)}, timeout={deadline}s)")
|
||||||
|
with trio.fail_after(deadline):
|
||||||
|
sg = await generate_subgraph(
|
||||||
|
kg_extractor,
|
||||||
|
tenant_id,
|
||||||
|
kb_id,
|
||||||
|
doc_id,
|
||||||
|
chunks,
|
||||||
|
language,
|
||||||
|
kb_parser_config.get("graphrag", {}).get("entity_types", []),
|
||||||
|
chat_model,
|
||||||
|
embedding_model,
|
||||||
|
callback,
|
||||||
|
)
|
||||||
|
if sg:
|
||||||
|
subgraphs[doc_id] = sg
|
||||||
|
callback(msg=f"{msg} done")
|
||||||
|
else:
|
||||||
|
failed_docs.append((doc_id, "subgraph is empty"))
|
||||||
|
callback(msg=f"{msg} empty")
|
||||||
|
except Exception as e:
|
||||||
|
failed_docs.append((doc_id, repr(e)))
|
||||||
|
callback(msg=f"[GraphRAG] build_subgraph doc:{doc_id} FAILED: {e!r}")
|
||||||
|
|
||||||
|
async with trio.open_nursery() as nursery:
|
||||||
|
for doc_id in doc_ids:
|
||||||
|
nursery.start_soon(build_one, doc_id)
|
||||||
|
|
||||||
|
ok_docs = [d for d in doc_ids if d in subgraphs]
|
||||||
|
if not ok_docs:
|
||||||
|
callback(msg=f"[GraphRAG] kb:{kb_id} no subgraphs generated successfully, end.")
|
||||||
|
now = trio.current_time()
|
||||||
|
return {"ok_docs": [], "failed_docs": failed_docs, "total_docs": len(doc_ids), "total_chunks": total_chunks, "seconds": now - start}
|
||||||
|
|
||||||
|
kb_lock = RedisDistributedLock(f"graphrag_task_{kb_id}", lock_value="batch_merge", timeout=1200)
|
||||||
|
await kb_lock.spin_acquire()
|
||||||
|
callback(msg=f"[GraphRAG] kb:{kb_id} merge lock acquired")
|
||||||
|
|
||||||
|
try:
|
||||||
|
union_nodes: set = set()
|
||||||
|
final_graph = None
|
||||||
|
|
||||||
|
for doc_id in ok_docs:
|
||||||
|
sg = subgraphs[doc_id]
|
||||||
|
union_nodes.update(set(sg.nodes()))
|
||||||
|
|
||||||
|
new_graph = await merge_subgraph(
|
||||||
|
tenant_id,
|
||||||
|
kb_id,
|
||||||
|
doc_id,
|
||||||
|
sg,
|
||||||
|
embedding_model,
|
||||||
|
callback,
|
||||||
|
)
|
||||||
|
if new_graph is not None:
|
||||||
|
final_graph = new_graph
|
||||||
|
|
||||||
|
if final_graph is None:
|
||||||
|
callback(msg=f"[GraphRAG] kb:{kb_id} merge finished (no in-memory graph returned).")
|
||||||
|
else:
|
||||||
|
callback(msg=f"[GraphRAG] kb:{kb_id} merge finished, graph ready.")
|
||||||
|
finally:
|
||||||
|
kb_lock.release()
|
||||||
|
|
||||||
|
if not with_resolution and not with_community:
|
||||||
|
now = trio.current_time()
|
||||||
|
callback(msg=f"[GraphRAG] KB merge only done in {now - start:.2f}s. ok={len(ok_docs)} / total={len(doc_ids)}")
|
||||||
|
return {"ok_docs": ok_docs, "failed_docs": failed_docs, "total_docs": len(doc_ids), "total_chunks": total_chunks, "seconds": now - start}
|
||||||
|
|
||||||
|
await kb_lock.spin_acquire()
|
||||||
|
callback(msg=f"[GraphRAG] kb:{kb_id} post-merge lock acquired for resolution/community")
|
||||||
|
|
||||||
|
try:
|
||||||
|
subgraph_nodes = set()
|
||||||
|
for sg in subgraphs.values():
|
||||||
|
subgraph_nodes.update(set(sg.nodes()))
|
||||||
|
|
||||||
|
if with_resolution:
|
||||||
|
await resolve_entities(
|
||||||
|
final_graph,
|
||||||
|
subgraph_nodes,
|
||||||
|
tenant_id,
|
||||||
|
kb_id,
|
||||||
|
None,
|
||||||
|
chat_model,
|
||||||
|
embedding_model,
|
||||||
|
callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
if with_community:
|
||||||
|
await extract_community(
|
||||||
|
final_graph,
|
||||||
|
tenant_id,
|
||||||
|
kb_id,
|
||||||
|
None,
|
||||||
|
chat_model,
|
||||||
|
embedding_model,
|
||||||
|
callback,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
kb_lock.release()
|
||||||
|
|
||||||
|
now = trio.current_time()
|
||||||
|
callback(msg=f"[GraphRAG] GraphRAG for KB {kb_id} done in {now - start:.2f} seconds. ok={len(ok_docs)} failed={len(failed_docs)} total_docs={len(doc_ids)} total_chunks={total_chunks}")
|
||||||
|
return {
|
||||||
|
"ok_docs": ok_docs,
|
||||||
|
"failed_docs": failed_docs, # [(doc_id, error), ...]
|
||||||
|
"total_docs": len(doc_ids),
|
||||||
|
"total_chunks": total_chunks,
|
||||||
|
"seconds": now - start,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
async def generate_subgraph(
|
async def generate_subgraph(
|
||||||
extractor: Extractor,
|
extractor: Extractor,
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
|
|||||||
@ -34,9 +34,9 @@ class Pipeline(Graph):
|
|||||||
if isinstance(dsl, dict):
|
if isinstance(dsl, dict):
|
||||||
dsl = json.dumps(dsl, ensure_ascii=False)
|
dsl = json.dumps(dsl, ensure_ascii=False)
|
||||||
super().__init__(dsl, tenant_id, task_id)
|
super().__init__(dsl, tenant_id, task_id)
|
||||||
|
self._doc_id = doc_id
|
||||||
if self._doc_id == "x":
|
if self._doc_id == "x":
|
||||||
self._doc_id = None
|
self._doc_id = None
|
||||||
self._doc_id = doc_id
|
|
||||||
self._flow_id = flow_id
|
self._flow_id = flow_id
|
||||||
self._kb_id = None
|
self._kb_id = None
|
||||||
if self._doc_id:
|
if self._doc_id:
|
||||||
|
|||||||
@ -383,7 +383,7 @@ class Dealer:
|
|||||||
vector_column = f"q_{dim}_vec"
|
vector_column = f"q_{dim}_vec"
|
||||||
zero_vector = [0.0] * dim
|
zero_vector = [0.0] * dim
|
||||||
sim_np = np.array(sim)
|
sim_np = np.array(sim)
|
||||||
filtered_count = (sim_np >= similarity_threshold).sum()
|
filtered_count = (sim_np >= similarity_threshold).sum()
|
||||||
ranks["total"] = int(filtered_count) # Convert from np.int64 to Python int otherwise JSON serializable error
|
ranks["total"] = int(filtered_count) # Convert from np.int64 to Python int otherwise JSON serializable error
|
||||||
for i in idx:
|
for i in idx:
|
||||||
if sim[i] < similarity_threshold:
|
if sim[i] < similarity_threshold:
|
||||||
@ -444,12 +444,27 @@ class Dealer:
|
|||||||
def chunk_list(self, doc_id: str, tenant_id: str,
|
def chunk_list(self, doc_id: str, tenant_id: str,
|
||||||
kb_ids: list[str], max_count=1024,
|
kb_ids: list[str], max_count=1024,
|
||||||
offset=0,
|
offset=0,
|
||||||
fields=["docnm_kwd", "content_with_weight", "img_id"]):
|
fields=["docnm_kwd", "content_with_weight", "img_id"],
|
||||||
|
sort_by_position: bool = False):
|
||||||
condition = {"doc_id": doc_id}
|
condition = {"doc_id": doc_id}
|
||||||
|
|
||||||
|
fields_set = set(fields or [])
|
||||||
|
if sort_by_position:
|
||||||
|
for need in ("page_num_int", "position_int", "top_int"):
|
||||||
|
if need not in fields_set:
|
||||||
|
fields_set.add(need)
|
||||||
|
fields = list(fields_set)
|
||||||
|
|
||||||
|
orderBy = OrderByExpr()
|
||||||
|
if sort_by_position:
|
||||||
|
orderBy.asc("page_num_int")
|
||||||
|
orderBy.asc("position_int")
|
||||||
|
orderBy.asc("top_int")
|
||||||
|
|
||||||
res = []
|
res = []
|
||||||
bs = 128
|
bs = 128
|
||||||
for p in range(offset, max_count, bs):
|
for p in range(offset, max_count, bs):
|
||||||
es_res = self.dataStore.search(fields, [], condition, [], OrderByExpr(), p, bs, index_name(tenant_id),
|
es_res = self.dataStore.search(fields, [], condition, [], orderBy, p, bs, index_name(tenant_id),
|
||||||
kb_ids)
|
kb_ids)
|
||||||
dict_chunks = self.dataStore.getFields(es_res, fields)
|
dict_chunks = self.dataStore.getFields(es_res, fields)
|
||||||
for id, doc in dict_chunks.items():
|
for id, doc in dict_chunks.items():
|
||||||
|
|||||||
@ -25,7 +25,7 @@ from api.db.services.pipeline_operation_log_service import PipelineOperationLogS
|
|||||||
from api.utils.api_utils import timeout
|
from api.utils.api_utils import timeout
|
||||||
from api.utils.base64_image import image2id
|
from api.utils.base64_image import image2id
|
||||||
from api.utils.log_utils import init_root_logger, get_project_base_directory
|
from api.utils.log_utils import init_root_logger, get_project_base_directory
|
||||||
from graphrag.general.index import run_graphrag
|
from graphrag.general.index import run_graphrag_for_kb
|
||||||
from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache
|
from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache
|
||||||
from rag.flow.pipeline import Pipeline
|
from rag.flow.pipeline import Pipeline
|
||||||
from rag.prompts import keyword_extraction, question_proposal, content_tagging
|
from rag.prompts import keyword_extraction, question_proposal, content_tagging
|
||||||
@ -85,6 +85,12 @@ FACTORY = {
|
|||||||
ParserType.TAG.value: tag
|
ParserType.TAG.value: tag
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TASK_TYPE_TO_PIPELINE_TASK_TYPE = {
|
||||||
|
"dataflow" : PipelineTaskType.PARSE,
|
||||||
|
"raptor": PipelineTaskType.RAPTOR,
|
||||||
|
"graphrag": PipelineTaskType.GRAPH_RAG,
|
||||||
|
}
|
||||||
|
|
||||||
UNACKED_ITERATOR = None
|
UNACKED_ITERATOR = None
|
||||||
|
|
||||||
CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
|
CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
|
||||||
@ -215,6 +221,10 @@ async def collect():
|
|||||||
canceled = False
|
canceled = False
|
||||||
if msg.get("doc_id", "") == "x":
|
if msg.get("doc_id", "") == "x":
|
||||||
task = msg
|
task = msg
|
||||||
|
if task["task_type"] == "graphrag" and msg.get("doc_ids", []):
|
||||||
|
print(f"hack {msg['doc_ids']=}=",flush=True)
|
||||||
|
task = TaskService.get_task(msg["id"], msg["doc_ids"])
|
||||||
|
task["doc_ids"] = msg["doc_ids"]
|
||||||
else:
|
else:
|
||||||
task = TaskService.get_task(msg["id"])
|
task = TaskService.get_task(msg["id"])
|
||||||
|
|
||||||
@ -580,7 +590,19 @@ async def do_handle_task(task):
|
|||||||
with_resolution = graphrag_conf.get("resolution", False)
|
with_resolution = graphrag_conf.get("resolution", False)
|
||||||
with_community = graphrag_conf.get("community", False)
|
with_community = graphrag_conf.get("community", False)
|
||||||
async with kg_limiter:
|
async with kg_limiter:
|
||||||
await run_graphrag(task, task_language, with_resolution, with_community, chat_model, embedding_model, progress_callback)
|
# await run_graphrag(task, task_language, with_resolution, with_community, chat_model, embedding_model, progress_callback)
|
||||||
|
result = await run_graphrag_for_kb(
|
||||||
|
row=task,
|
||||||
|
doc_ids=task.get("doc_ids", []),
|
||||||
|
language=task_language,
|
||||||
|
kb_parser_config=task_parser_config,
|
||||||
|
chat_model=chat_model,
|
||||||
|
embedding_model=embedding_model,
|
||||||
|
callback=progress_callback,
|
||||||
|
with_resolution=with_resolution,
|
||||||
|
with_community=with_community,
|
||||||
|
)
|
||||||
|
logging.info(f"GraphRAG task result for task {task}:\n{result}")
|
||||||
progress_callback(prog=1.0, msg="Knowledge Graph done ({:.2f}s)".format(timer() - start_ts))
|
progress_callback(prog=1.0, msg="Knowledge Graph done ({:.2f}s)".format(timer() - start_ts))
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
@ -650,7 +672,6 @@ async def do_handle_task(task):
|
|||||||
timer() - start_ts))
|
timer() - start_ts))
|
||||||
|
|
||||||
DocumentService.increment_chunk_num(task_doc_id, task_dataset_id, token_count, chunk_count, 0)
|
DocumentService.increment_chunk_num(task_doc_id, task_dataset_id, token_count, chunk_count, 0)
|
||||||
PipelineOperationLogService.record_pipeline_operation(document_id=task_doc_id, pipeline_id="", task_type=PipelineTaskType.PARSE)
|
|
||||||
|
|
||||||
time_cost = timer() - start_ts
|
time_cost = timer() - start_ts
|
||||||
task_time_cost = timer() - task_start_ts
|
task_time_cost = timer() - task_start_ts
|
||||||
@ -667,6 +688,10 @@ async def handle_task():
|
|||||||
if not task:
|
if not task:
|
||||||
await trio.sleep(5)
|
await trio.sleep(5)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
task_type = task["task_type"]
|
||||||
|
pipeline_task_type = TASK_TYPE_TO_PIPELINE_TASK_TYPE.get(task_type, PipelineTaskType.PARSE) or PipelineTaskType.PARSE
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logging.info(f"handle_task begin for task {json.dumps(task)}")
|
logging.info(f"handle_task begin for task {json.dumps(task)}")
|
||||||
CURRENT_TASKS[task["id"]] = copy.deepcopy(task)
|
CURRENT_TASKS[task["id"]] = copy.deepcopy(task)
|
||||||
@ -686,7 +711,12 @@ async def handle_task():
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
logging.exception(f"handle_task got exception for task {json.dumps(task)}")
|
logging.exception(f"handle_task got exception for task {json.dumps(task)}")
|
||||||
PipelineOperationLogService.record_pipeline_operation(document_id=task["doc_id"], pipeline_id=task.get("dataflow_id", "") or "", task_type=PipelineTaskType.PARSE)
|
finally:
|
||||||
|
task_document_ids = []
|
||||||
|
if task_type in ["graphrag"]:
|
||||||
|
task_document_ids = task["doc_ids"]
|
||||||
|
PipelineOperationLogService.record_pipeline_operation(document_id=task["doc_id"], pipeline_id=task.get("dataflow_id", "") or "", task_type=pipeline_task_type, fake_document_ids=task_document_ids)
|
||||||
|
|
||||||
redis_msg.ack()
|
redis_msg.ack()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user