Feat: add foundational support for GraphRAG dataset pipeline logs (#10264)

### What problem does this PR solve?

Add foundational support for GraphRAG dataset pipeline logs

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Yongteng Lei
2025-09-25 09:35:50 +08:00
committed by GitHub
parent a6039cf563
commit 840b2b5809
10 changed files with 469 additions and 36 deletions

View File

@ -124,10 +124,12 @@ VALID_MCP_SERVER_TYPES = {MCPServerType.SSE, MCPServerType.STREAMABLE_HTTP}
class PipelineTaskType(StrEnum):
PARSE = "Parse"
DOWNLOAD = "DOWNLOAD"
DOWNLOAD = "Download"
RAPTOR = "RAPTOR"
GRAPH_RAG = "GraphRAG"
VALID_PIPELINE_TASK_TYPES = {PipelineTaskType.PARSE, PipelineTaskType.DOWNLOAD}
VALID_PIPELINE_TASK_TYPES = {PipelineTaskType.PARSE, PipelineTaskType.DOWNLOAD, PipelineTaskType.RAPTOR, PipelineTaskType.GRAPH_RAG}
KNOWLEDGEBASE_FOLDER_NAME=".knowledgebase"

View File

@ -649,6 +649,9 @@ class Knowledgebase(DataBaseModel):
pipeline_id = CharField(max_length=32, null=True, help_text="Pipeline ID", index=True)
parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]})
pagerank = IntegerField(default=0, index=False)
graphrag_task_id = CharField(max_length=32, null=True, help_text="Graph RAG task ID", index=True)
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True)
def __str__(self):
@ -1065,11 +1068,15 @@ def migrate_db():
except Exception:
pass
try:
migrate(migrator.add_column("knowledgebase", "pipeline_id", CharField(max_length=32, null=True, help_text="default parser ID", index=True)))
migrate(migrator.add_column("knowledgebase", "pipeline_id", CharField(max_length=32, null=True, help_text="Pipeline ID", index=True)))
except Exception:
pass
try:
migrate(migrator.add_column("document", "pipeline_id", CharField(max_length=32, null=True, help_text="default parser ID", index=True)))
migrate(migrator.add_column("document", "pipeline_id", CharField(max_length=32, null=True, help_text="Pipeline ID", index=True)))
except Exception:
pass
try:
migrate(migrator.add_column("knowledgebase", "graphrag_task_id", CharField(max_length=32, null=True, help_text="Gragh RAG task ID", index=True)))
except Exception:
pass
logging.disable(logging.NOTSET)

View File

@ -507,6 +507,9 @@ class DocumentService(CommonService):
@classmethod
@DB.connection_context()
def get_doc_id_by_doc_name(cls, doc_name):
"""
highly rely on the strict deduplication guarantee from Document
"""
fields = [cls.model.id]
doc_id = cls.model.select(*fields) \
.where(cls.model.name == doc_name)
@ -656,6 +659,7 @@ class DocumentService(CommonService):
queue_raptor_o_graphrag_tasks(d, "graphrag", priority)
prg = 0.98 * len(tsks) / (len(tsks) + 1)
else:
prg = 1
status = TaskStatus.DONE.value
msg = "\n".join(sorted(msg))
@ -741,7 +745,11 @@ class DocumentService(CommonService):
"cancelled": int(cancelled),
}
def queue_raptor_o_graphrag_tasks(doc, ty, priority):
def queue_raptor_o_graphrag_tasks(doc, ty, priority, fake_doc_id="", doc_ids=[]):
"""
You can provide a fake_doc_id to bypass the restriction of tasks at the knowledgebase level.
Optionally, specify a list of doc_ids to determine which documents participate in the task.
"""
chunking_config = DocumentService.get_chunking_config(doc["id"])
hasher = xxhash.xxh64()
for field in sorted(chunking_config.keys()):
@ -751,7 +759,7 @@ def queue_raptor_o_graphrag_tasks(doc, ty, priority):
nonlocal doc
return {
"id": get_uuid(),
"doc_id": doc["id"],
"doc_id": fake_doc_id if fake_doc_id else doc["id"],
"from_page": 100000000,
"to_page": 100000000,
"task_type": ty,
@ -764,7 +772,11 @@ def queue_raptor_o_graphrag_tasks(doc, ty, priority):
hasher.update(ty.encode("utf-8"))
task["digest"] = hasher.hexdigest()
bulk_insert_into_db(Task, [task], True)
if ty == "graphrag":
task["doc_ids"] = doc_ids
assert REDIS_CONN.queue_product(get_svr_queue_name(priority), message=task), "Can't access Redis. Please check the Redis' status."
return task["id"]
def get_queue_length(priority):

View File

@ -31,7 +31,7 @@ class PipelineOperationLogService(CommonService):
model = PipelineOperationLog
@classmethod
def get_cls_model_fields(cls):
def get_file_logs_fields(cls):
return [
cls.model.id,
cls.model.document_id,
@ -59,9 +59,29 @@ class PipelineOperationLogService(CommonService):
cls.model.update_date,
]
@classmethod
def get_dataset_logs_fields(cls):
return [
cls.model.id,
cls.model.tenant_id,
cls.model.kb_id,
cls.model.progress,
cls.model.progress_msg,
cls.model.process_begin_at,
cls.model.process_duration,
cls.model.task_type,
cls.model.operation_status,
cls.model.avatar,
cls.model.status,
cls.model.create_time,
cls.model.create_date,
cls.model.update_time,
cls.model.update_date,
]
@classmethod
@DB.connection_context()
def create(cls, document_id, pipeline_id, task_type):
def create(cls, document_id, pipeline_id, task_type, fake_document_ids=[]):
from rag.flow.pipeline import Pipeline
tenant_id = ""
@ -69,14 +89,19 @@ class PipelineOperationLogService(CommonService):
avatar = ""
dsl = ""
operation_status = ""
referred_document_id = document_id
ok, document = DocumentService.get_by_id(document_id)
if referred_document_id == "x" and fake_document_ids:
referred_document_id = fake_document_ids[0]
ok, document = DocumentService.get_by_id(referred_document_id)
if not ok:
raise RuntimeError(f"Document {document_id} not found")
raise RuntimeError(f"Document for referred_document_id {referred_document_id} not found")
DocumentService.update_progress_immediately([document.to_dict()])
ok, document = DocumentService.get_by_id(document_id)
ok, document = DocumentService.get_by_id(referred_document_id)
if not ok:
raise RuntimeError(f"Document {document_id} not found")
raise RuntimeError(f"Document for referred_document_id {referred_document_id} not found")
if document.progress not in [1, -1]:
return
operation_status = document.run
if pipeline_id:
@ -84,7 +109,7 @@ class PipelineOperationLogService(CommonService):
if not ok:
raise RuntimeError(f"Pipeline {pipeline_id} not found")
pipeline = Pipeline(dsl=json.dumps(user_pipeline.dsl), tenant_id=user_pipeline.user_id, doc_id=document_id, task_id="", flow_id=pipeline_id)
pipeline = Pipeline(dsl=json.dumps(user_pipeline.dsl), tenant_id=user_pipeline.user_id, doc_id=referred_document_id, task_id="", flow_id=pipeline_id)
tenant_id = user_pipeline.user_id
title = user_pipeline.title
@ -93,7 +118,7 @@ class PipelineOperationLogService(CommonService):
else:
ok, kb_info = KnowledgebaseService.get_by_id(document.kb_id)
if not ok:
raise RuntimeError(f"Cannot find knowledge base {document.kb_id} for document {document_id}")
raise RuntimeError(f"Cannot find knowledge base {document.kb_id} for referred_document {referred_document_id}")
tenant_id = kb_info.tenant_id
title = document.name
@ -104,7 +129,7 @@ class PipelineOperationLogService(CommonService):
log = dict(
id=get_uuid(),
document_id=document_id,
document_id=document_id, # "x" or real document_id
tenant_id=tenant_id,
kb_id=document.kb_id,
pipeline_id=pipeline_id,
@ -132,18 +157,20 @@ class PipelineOperationLogService(CommonService):
@classmethod
@DB.connection_context()
def record_pipeline_operation(cls, document_id, pipeline_id, task_type):
return cls.create(document_id=document_id, pipeline_id=pipeline_id, task_type=task_type)
def record_pipeline_operation(cls, document_id, pipeline_id, task_type, fake_document_ids=[]):
return cls.create(document_id=document_id, pipeline_id=pipeline_id, task_type=task_type, fake_document_ids=fake_document_ids)
@classmethod
@DB.connection_context()
def get_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, keywords, operation_status, types, suffix):
fields = cls.get_cls_model_fields()
def get_file_logs_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, keywords, operation_status, types, suffix):
fields = cls.get_file_logs_fields()
if keywords:
logs = cls.model.select(*fields).where((cls.model.kb_id == kb_id), (fn.LOWER(cls.model.document_name).contains(keywords.lower())))
else:
logs = cls.model.select(*fields).where(cls.model.kb_id == kb_id)
logs = logs.where(cls.model.document_id != "x")
if operation_status:
logs = logs.where(cls.model.operation_status.in_(operation_status))
if types:
@ -161,3 +188,23 @@ class PipelineOperationLogService(CommonService):
logs = logs.paginate(page_number, items_per_page)
return list(logs.dicts()), count
@classmethod
@DB.connection_context()
def get_dataset_logs_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, operation_status):
fields = cls.get_dataset_logs_fields()
logs = cls.model.select(*fields).where((cls.model.kb_id == kb_id), (cls.model.document_id == "x"))
if operation_status:
logs = logs.where(cls.model.operation_status.in_(operation_status))
count = logs.count()
if desc:
logs = logs.order_by(cls.model.getter_by(orderby).desc())
else:
logs = logs.order_by(cls.model.getter_by(orderby).asc())
if page_number and items_per_page:
logs = logs.paginate(page_number, items_per_page)
return list(logs.dicts()), count

View File

@ -70,7 +70,7 @@ class TaskService(CommonService):
@classmethod
@DB.connection_context()
def get_task(cls, task_id):
def get_task(cls, task_id, doc_ids=[]):
"""Retrieve detailed task information by task ID.
This method fetches comprehensive task details including associated document,
@ -84,6 +84,10 @@ class TaskService(CommonService):
dict: Task details dictionary containing all task information and related metadata.
Returns None if task is not found or has exceeded retry limit.
"""
doc_id = cls.model.doc_id
if doc_id == "x" and doc_ids:
doc_id = doc_ids[0]
fields = [
cls.model.id,
cls.model.doc_id,
@ -109,7 +113,7 @@ class TaskService(CommonService):
]
docs = (
cls.model.select(*fields)
.join(Document, on=(cls.model.doc_id == Document.id))
.join(Document, on=(doc_id == Document.id))
.join(Knowledgebase, on=(Document.kb_id == Knowledgebase.id))
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))
.where(cls.model.id == task_id)