mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-26 08:56:47 +08:00
Feat: add foundational support for GraphRAG dataset pipeline logs (#10264)
### What problem does this PR solve? Add foundational support for GraphRAG dataset pipeline logs ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -14,17 +14,19 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
import json
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
|
||||
from api.db.services import duplicate_name
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.document_service import DocumentService, queue_raptor_o_graphrag_tasks
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
|
||||
from api.db.services.task_service import TaskService
|
||||
from api.db.services.user_service import TenantService, UserTenantService
|
||||
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request, not_allowed_parameters
|
||||
from api.utils.api_utils import get_error_data_result, server_error_response, get_data_error_result, validate_request, not_allowed_parameters
|
||||
from api.utils import get_uuid
|
||||
from api.db import StatusEnum, FileSource, VALID_FILE_TYPES
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
@ -435,18 +437,60 @@ def list_pipeline_logs():
|
||||
suffix = req.get("suffix", [])
|
||||
|
||||
try:
|
||||
docs, tol = PipelineOperationLogService.get_by_kb_id(kb_id, page_number, items_per_page, orderby, desc, keywords, operation_status, types, suffix)
|
||||
logs, tol = PipelineOperationLogService.get_file_logs_by_kb_id(kb_id, page_number, items_per_page, orderby, desc, keywords, operation_status, types, suffix)
|
||||
|
||||
if create_time_from or create_time_to:
|
||||
filtered_docs = []
|
||||
for doc in docs:
|
||||
for doc in logs:
|
||||
doc_create_time = doc.get("create_time", 0)
|
||||
if (create_time_from == 0 or doc_create_time >= create_time_from) and (create_time_to == 0 or doc_create_time <= create_time_to):
|
||||
filtered_docs.append(doc)
|
||||
docs = filtered_docs
|
||||
logs = filtered_docs
|
||||
|
||||
|
||||
return get_json_result(data={"total": tol, "docs": docs})
|
||||
return get_json_result(data={"total": tol, "logs": logs})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/list_pipeline_dataset_logs", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def list_pipeline_dataset_logs():
|
||||
kb_id = request.args.get("kb_id")
|
||||
if not kb_id:
|
||||
return get_json_result(data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
|
||||
|
||||
page_number = int(request.args.get("page", 0))
|
||||
items_per_page = int(request.args.get("page_size", 0))
|
||||
orderby = request.args.get("orderby", "create_time")
|
||||
if request.args.get("desc", "true").lower() == "false":
|
||||
desc = False
|
||||
else:
|
||||
desc = True
|
||||
create_time_from = int(request.args.get("create_time_from", 0))
|
||||
create_time_to = int(request.args.get("create_time_to", 0))
|
||||
|
||||
req = request.get_json()
|
||||
|
||||
operation_status = req.get("operation_status", [])
|
||||
if operation_status:
|
||||
invalid_status = {s for s in operation_status if s not in ["success", "failed", "running", "pending"]}
|
||||
if invalid_status:
|
||||
return get_data_error_result(message=f"Invalid filter operation_status status conditions: {', '.join(invalid_status)}")
|
||||
|
||||
try:
|
||||
logs, tol = PipelineOperationLogService.get_dataset_logs_by_kb_id(kb_id, page_number, items_per_page, orderby, desc, operation_status)
|
||||
|
||||
if create_time_from or create_time_to:
|
||||
filtered_docs = []
|
||||
for doc in logs:
|
||||
doc_create_time = doc.get("create_time", 0)
|
||||
if (create_time_from == 0 or doc_create_time >= create_time_from) and (create_time_to == 0 or doc_create_time <= create_time_to):
|
||||
filtered_docs.append(doc)
|
||||
logs = filtered_docs
|
||||
|
||||
|
||||
return get_json_result(data={"total": tol, "logs": logs})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@ -478,3 +522,68 @@ def pipeline_log_detail():
|
||||
return get_data_error_result(message="Invalid pipeline log ID")
|
||||
|
||||
return get_json_result(data=log.to_dict())
|
||||
|
||||
|
||||
@manager.route("/run_graphrag", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def run_graphrag():
|
||||
req = request.json
|
||||
|
||||
kb_id = req.get("kb_id", "")
|
||||
if not kb_id:
|
||||
return get_error_data_result(message='Lack of "KB ID"')
|
||||
|
||||
doc_ids = req.get("doc_ids", [])
|
||||
if not doc_ids:
|
||||
return get_error_data_result(message="Need to specify document IDs to run Graph RAG")
|
||||
|
||||
ok, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
if not ok:
|
||||
return get_error_data_result(message="Invalid Knowledgebase ID")
|
||||
|
||||
task_id = kb.graphrag_task_id
|
||||
ok, task = TaskService.get_by_id(task_id)
|
||||
if not ok:
|
||||
logging.warning(f"A valid GraphRAG task id is expected for kb {kb_id}")
|
||||
|
||||
if task and task.progress not in [-1, 1]:
|
||||
return get_error_data_result(message=f"Task in progress with status {task.progress}. A Graph Task is already running.")
|
||||
|
||||
document_ids = set()
|
||||
sample_document = {}
|
||||
for doc_id in doc_ids:
|
||||
ok, document = DocumentService.get_by_id(doc_id)
|
||||
if ok:
|
||||
document_ids.add(document.id)
|
||||
if not sample_document:
|
||||
sample_document = document.to_dict()
|
||||
|
||||
task_id = queue_raptor_o_graphrag_tasks(doc=sample_document, ty="graphrag", priority=0, fake_doc_id="x", doc_ids=list(document_ids))
|
||||
|
||||
if not KnowledgebaseService.update_by_id(kb.id, {"graphrag_task_id": task_id}):
|
||||
logging.warning(f"Cannot save graphrag_task_id for kb {kb_id}")
|
||||
|
||||
return get_json_result(data={"graphrag_task_id": task_id})
|
||||
|
||||
|
||||
@manager.route("/trace_graphrag", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def trace_graphrag():
|
||||
kb_id = request.args.get("kb_id", "")
|
||||
if not kb_id:
|
||||
return get_error_data_result(message='Lack of "KB ID"')
|
||||
|
||||
ok, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
if not ok:
|
||||
return get_error_data_result(message="Invalid Knowledgebase ID")
|
||||
|
||||
|
||||
task_id = kb.graphrag_task_id
|
||||
if not task_id:
|
||||
return get_error_data_result(message="GraphRAG Task ID Not Found")
|
||||
|
||||
ok, task = TaskService.get_by_id(task_id)
|
||||
if not ok:
|
||||
return get_json_result(data=False, message="GraphRAG Task Not Found or Error Occurred", code=settings.RetCode.ARGUMENT_ERROR)
|
||||
|
||||
return get_json_result(data=task.to_dict())
|
||||
|
||||
@ -124,10 +124,12 @@ VALID_MCP_SERVER_TYPES = {MCPServerType.SSE, MCPServerType.STREAMABLE_HTTP}
|
||||
|
||||
class PipelineTaskType(StrEnum):
|
||||
PARSE = "Parse"
|
||||
DOWNLOAD = "DOWNLOAD"
|
||||
DOWNLOAD = "Download"
|
||||
RAPTOR = "RAPTOR"
|
||||
GRAPH_RAG = "GraphRAG"
|
||||
|
||||
|
||||
VALID_PIPELINE_TASK_TYPES = {PipelineTaskType.PARSE, PipelineTaskType.DOWNLOAD}
|
||||
VALID_PIPELINE_TASK_TYPES = {PipelineTaskType.PARSE, PipelineTaskType.DOWNLOAD, PipelineTaskType.RAPTOR, PipelineTaskType.GRAPH_RAG}
|
||||
|
||||
|
||||
KNOWLEDGEBASE_FOLDER_NAME=".knowledgebase"
|
||||
|
||||
@ -649,6 +649,9 @@ class Knowledgebase(DataBaseModel):
|
||||
pipeline_id = CharField(max_length=32, null=True, help_text="Pipeline ID", index=True)
|
||||
parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]})
|
||||
pagerank = IntegerField(default=0, index=False)
|
||||
|
||||
graphrag_task_id = CharField(max_length=32, null=True, help_text="Graph RAG task ID", index=True)
|
||||
|
||||
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True)
|
||||
|
||||
def __str__(self):
|
||||
@ -1065,11 +1068,15 @@ def migrate_db():
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("knowledgebase", "pipeline_id", CharField(max_length=32, null=True, help_text="default parser ID", index=True)))
|
||||
migrate(migrator.add_column("knowledgebase", "pipeline_id", CharField(max_length=32, null=True, help_text="Pipeline ID", index=True)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("document", "pipeline_id", CharField(max_length=32, null=True, help_text="default parser ID", index=True)))
|
||||
migrate(migrator.add_column("document", "pipeline_id", CharField(max_length=32, null=True, help_text="Pipeline ID", index=True)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("knowledgebase", "graphrag_task_id", CharField(max_length=32, null=True, help_text="Gragh RAG task ID", index=True)))
|
||||
except Exception:
|
||||
pass
|
||||
logging.disable(logging.NOTSET)
|
||||
|
||||
@ -507,6 +507,9 @@ class DocumentService(CommonService):
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_doc_id_by_doc_name(cls, doc_name):
|
||||
"""
|
||||
highly rely on the strict deduplication guarantee from Document
|
||||
"""
|
||||
fields = [cls.model.id]
|
||||
doc_id = cls.model.select(*fields) \
|
||||
.where(cls.model.name == doc_name)
|
||||
@ -656,6 +659,7 @@ class DocumentService(CommonService):
|
||||
queue_raptor_o_graphrag_tasks(d, "graphrag", priority)
|
||||
prg = 0.98 * len(tsks) / (len(tsks) + 1)
|
||||
else:
|
||||
prg = 1
|
||||
status = TaskStatus.DONE.value
|
||||
|
||||
msg = "\n".join(sorted(msg))
|
||||
@ -741,7 +745,11 @@ class DocumentService(CommonService):
|
||||
"cancelled": int(cancelled),
|
||||
}
|
||||
|
||||
def queue_raptor_o_graphrag_tasks(doc, ty, priority):
|
||||
def queue_raptor_o_graphrag_tasks(doc, ty, priority, fake_doc_id="", doc_ids=[]):
|
||||
"""
|
||||
You can provide a fake_doc_id to bypass the restriction of tasks at the knowledgebase level.
|
||||
Optionally, specify a list of doc_ids to determine which documents participate in the task.
|
||||
"""
|
||||
chunking_config = DocumentService.get_chunking_config(doc["id"])
|
||||
hasher = xxhash.xxh64()
|
||||
for field in sorted(chunking_config.keys()):
|
||||
@ -751,7 +759,7 @@ def queue_raptor_o_graphrag_tasks(doc, ty, priority):
|
||||
nonlocal doc
|
||||
return {
|
||||
"id": get_uuid(),
|
||||
"doc_id": doc["id"],
|
||||
"doc_id": fake_doc_id if fake_doc_id else doc["id"],
|
||||
"from_page": 100000000,
|
||||
"to_page": 100000000,
|
||||
"task_type": ty,
|
||||
@ -764,7 +772,11 @@ def queue_raptor_o_graphrag_tasks(doc, ty, priority):
|
||||
hasher.update(ty.encode("utf-8"))
|
||||
task["digest"] = hasher.hexdigest()
|
||||
bulk_insert_into_db(Task, [task], True)
|
||||
|
||||
if ty == "graphrag":
|
||||
task["doc_ids"] = doc_ids
|
||||
assert REDIS_CONN.queue_product(get_svr_queue_name(priority), message=task), "Can't access Redis. Please check the Redis' status."
|
||||
return task["id"]
|
||||
|
||||
|
||||
def get_queue_length(priority):
|
||||
|
||||
@ -31,7 +31,7 @@ class PipelineOperationLogService(CommonService):
|
||||
model = PipelineOperationLog
|
||||
|
||||
@classmethod
|
||||
def get_cls_model_fields(cls):
|
||||
def get_file_logs_fields(cls):
|
||||
return [
|
||||
cls.model.id,
|
||||
cls.model.document_id,
|
||||
@ -59,9 +59,29 @@ class PipelineOperationLogService(CommonService):
|
||||
cls.model.update_date,
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def get_dataset_logs_fields(cls):
|
||||
return [
|
||||
cls.model.id,
|
||||
cls.model.tenant_id,
|
||||
cls.model.kb_id,
|
||||
cls.model.progress,
|
||||
cls.model.progress_msg,
|
||||
cls.model.process_begin_at,
|
||||
cls.model.process_duration,
|
||||
cls.model.task_type,
|
||||
cls.model.operation_status,
|
||||
cls.model.avatar,
|
||||
cls.model.status,
|
||||
cls.model.create_time,
|
||||
cls.model.create_date,
|
||||
cls.model.update_time,
|
||||
cls.model.update_date,
|
||||
]
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def create(cls, document_id, pipeline_id, task_type):
|
||||
def create(cls, document_id, pipeline_id, task_type, fake_document_ids=[]):
|
||||
from rag.flow.pipeline import Pipeline
|
||||
|
||||
tenant_id = ""
|
||||
@ -69,14 +89,19 @@ class PipelineOperationLogService(CommonService):
|
||||
avatar = ""
|
||||
dsl = ""
|
||||
operation_status = ""
|
||||
referred_document_id = document_id
|
||||
|
||||
ok, document = DocumentService.get_by_id(document_id)
|
||||
if referred_document_id == "x" and fake_document_ids:
|
||||
referred_document_id = fake_document_ids[0]
|
||||
ok, document = DocumentService.get_by_id(referred_document_id)
|
||||
if not ok:
|
||||
raise RuntimeError(f"Document {document_id} not found")
|
||||
raise RuntimeError(f"Document for referred_document_id {referred_document_id} not found")
|
||||
DocumentService.update_progress_immediately([document.to_dict()])
|
||||
ok, document = DocumentService.get_by_id(document_id)
|
||||
ok, document = DocumentService.get_by_id(referred_document_id)
|
||||
if not ok:
|
||||
raise RuntimeError(f"Document {document_id} not found")
|
||||
raise RuntimeError(f"Document for referred_document_id {referred_document_id} not found")
|
||||
if document.progress not in [1, -1]:
|
||||
return
|
||||
operation_status = document.run
|
||||
|
||||
if pipeline_id:
|
||||
@ -84,7 +109,7 @@ class PipelineOperationLogService(CommonService):
|
||||
if not ok:
|
||||
raise RuntimeError(f"Pipeline {pipeline_id} not found")
|
||||
|
||||
pipeline = Pipeline(dsl=json.dumps(user_pipeline.dsl), tenant_id=user_pipeline.user_id, doc_id=document_id, task_id="", flow_id=pipeline_id)
|
||||
pipeline = Pipeline(dsl=json.dumps(user_pipeline.dsl), tenant_id=user_pipeline.user_id, doc_id=referred_document_id, task_id="", flow_id=pipeline_id)
|
||||
|
||||
tenant_id = user_pipeline.user_id
|
||||
title = user_pipeline.title
|
||||
@ -93,7 +118,7 @@ class PipelineOperationLogService(CommonService):
|
||||
else:
|
||||
ok, kb_info = KnowledgebaseService.get_by_id(document.kb_id)
|
||||
if not ok:
|
||||
raise RuntimeError(f"Cannot find knowledge base {document.kb_id} for document {document_id}")
|
||||
raise RuntimeError(f"Cannot find knowledge base {document.kb_id} for referred_document {referred_document_id}")
|
||||
|
||||
tenant_id = kb_info.tenant_id
|
||||
title = document.name
|
||||
@ -104,7 +129,7 @@ class PipelineOperationLogService(CommonService):
|
||||
|
||||
log = dict(
|
||||
id=get_uuid(),
|
||||
document_id=document_id,
|
||||
document_id=document_id, # "x" or real document_id
|
||||
tenant_id=tenant_id,
|
||||
kb_id=document.kb_id,
|
||||
pipeline_id=pipeline_id,
|
||||
@ -132,18 +157,20 @@ class PipelineOperationLogService(CommonService):
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def record_pipeline_operation(cls, document_id, pipeline_id, task_type):
|
||||
return cls.create(document_id=document_id, pipeline_id=pipeline_id, task_type=task_type)
|
||||
def record_pipeline_operation(cls, document_id, pipeline_id, task_type, fake_document_ids=[]):
|
||||
return cls.create(document_id=document_id, pipeline_id=pipeline_id, task_type=task_type, fake_document_ids=fake_document_ids)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, keywords, operation_status, types, suffix):
|
||||
fields = cls.get_cls_model_fields()
|
||||
def get_file_logs_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, keywords, operation_status, types, suffix):
|
||||
fields = cls.get_file_logs_fields()
|
||||
if keywords:
|
||||
logs = cls.model.select(*fields).where((cls.model.kb_id == kb_id), (fn.LOWER(cls.model.document_name).contains(keywords.lower())))
|
||||
else:
|
||||
logs = cls.model.select(*fields).where(cls.model.kb_id == kb_id)
|
||||
|
||||
logs = logs.where(cls.model.document_id != "x")
|
||||
|
||||
if operation_status:
|
||||
logs = logs.where(cls.model.operation_status.in_(operation_status))
|
||||
if types:
|
||||
@ -161,3 +188,23 @@ class PipelineOperationLogService(CommonService):
|
||||
logs = logs.paginate(page_number, items_per_page)
|
||||
|
||||
return list(logs.dicts()), count
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_dataset_logs_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, operation_status):
|
||||
fields = cls.get_dataset_logs_fields()
|
||||
logs = cls.model.select(*fields).where((cls.model.kb_id == kb_id), (cls.model.document_id == "x"))
|
||||
|
||||
if operation_status:
|
||||
logs = logs.where(cls.model.operation_status.in_(operation_status))
|
||||
|
||||
count = logs.count()
|
||||
if desc:
|
||||
logs = logs.order_by(cls.model.getter_by(orderby).desc())
|
||||
else:
|
||||
logs = logs.order_by(cls.model.getter_by(orderby).asc())
|
||||
|
||||
if page_number and items_per_page:
|
||||
logs = logs.paginate(page_number, items_per_page)
|
||||
|
||||
return list(logs.dicts()), count
|
||||
|
||||
@ -70,7 +70,7 @@ class TaskService(CommonService):
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_task(cls, task_id):
|
||||
def get_task(cls, task_id, doc_ids=[]):
|
||||
"""Retrieve detailed task information by task ID.
|
||||
|
||||
This method fetches comprehensive task details including associated document,
|
||||
@ -84,6 +84,10 @@ class TaskService(CommonService):
|
||||
dict: Task details dictionary containing all task information and related metadata.
|
||||
Returns None if task is not found or has exceeded retry limit.
|
||||
"""
|
||||
doc_id = cls.model.doc_id
|
||||
if doc_id == "x" and doc_ids:
|
||||
doc_id = doc_ids[0]
|
||||
|
||||
fields = [
|
||||
cls.model.id,
|
||||
cls.model.doc_id,
|
||||
@ -109,7 +113,7 @@ class TaskService(CommonService):
|
||||
]
|
||||
docs = (
|
||||
cls.model.select(*fields)
|
||||
.join(Document, on=(cls.model.doc_id == Document.id))
|
||||
.join(Document, on=(doc_id == Document.id))
|
||||
.join(Knowledgebase, on=(Document.kb_id == Knowledgebase.id))
|
||||
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))
|
||||
.where(cls.model.id == task_id)
|
||||
|
||||
Reference in New Issue
Block a user