Feat: add foundational support for GraphRAG dataset pipeline logs (#10264)

### What problem does this PR solve?

Add foundational support for GraphRAG dataset pipeline logs

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Yongteng Lei
2025-09-25 09:35:50 +08:00
committed by GitHub
parent a6039cf563
commit 840b2b5809
10 changed files with 469 additions and 36 deletions

View File

@ -14,17 +14,19 @@
# limitations under the License. # limitations under the License.
# #
import json import json
import logging
from flask import request from flask import request
from flask_login import login_required, current_user from flask_login import login_required, current_user
from api.db.services import duplicate_name from api.db.services import duplicate_name
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService, queue_raptor_o_graphrag_tasks
from api.db.services.file2document_service import File2DocumentService from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService from api.db.services.file_service import FileService
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
from api.db.services.task_service import TaskService
from api.db.services.user_service import TenantService, UserTenantService from api.db.services.user_service import TenantService, UserTenantService
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request, not_allowed_parameters from api.utils.api_utils import get_error_data_result, server_error_response, get_data_error_result, validate_request, not_allowed_parameters
from api.utils import get_uuid from api.utils import get_uuid
from api.db import StatusEnum, FileSource, VALID_FILE_TYPES from api.db import StatusEnum, FileSource, VALID_FILE_TYPES
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
@ -435,18 +437,60 @@ def list_pipeline_logs():
suffix = req.get("suffix", []) suffix = req.get("suffix", [])
try: try:
docs, tol = PipelineOperationLogService.get_by_kb_id(kb_id, page_number, items_per_page, orderby, desc, keywords, operation_status, types, suffix) logs, tol = PipelineOperationLogService.get_file_logs_by_kb_id(kb_id, page_number, items_per_page, orderby, desc, keywords, operation_status, types, suffix)
if create_time_from or create_time_to: if create_time_from or create_time_to:
filtered_docs = [] filtered_docs = []
for doc in docs: for doc in logs:
doc_create_time = doc.get("create_time", 0) doc_create_time = doc.get("create_time", 0)
if (create_time_from == 0 or doc_create_time >= create_time_from) and (create_time_to == 0 or doc_create_time <= create_time_to): if (create_time_from == 0 or doc_create_time >= create_time_from) and (create_time_to == 0 or doc_create_time <= create_time_to):
filtered_docs.append(doc) filtered_docs.append(doc)
docs = filtered_docs logs = filtered_docs
return get_json_result(data={"total": tol, "docs": docs}) return get_json_result(data={"total": tol, "logs": logs})
except Exception as e:
return server_error_response(e)
@manager.route("/list_pipeline_dataset_logs", methods=["POST"]) # noqa: F821
@login_required
def list_pipeline_dataset_logs():
kb_id = request.args.get("kb_id")
if not kb_id:
return get_json_result(data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
page_number = int(request.args.get("page", 0))
items_per_page = int(request.args.get("page_size", 0))
orderby = request.args.get("orderby", "create_time")
if request.args.get("desc", "true").lower() == "false":
desc = False
else:
desc = True
create_time_from = int(request.args.get("create_time_from", 0))
create_time_to = int(request.args.get("create_time_to", 0))
req = request.get_json()
operation_status = req.get("operation_status", [])
if operation_status:
invalid_status = {s for s in operation_status if s not in ["success", "failed", "running", "pending"]}
if invalid_status:
return get_data_error_result(message=f"Invalid filter operation_status status conditions: {', '.join(invalid_status)}")
try:
logs, tol = PipelineOperationLogService.get_dataset_logs_by_kb_id(kb_id, page_number, items_per_page, orderby, desc, operation_status)
if create_time_from or create_time_to:
filtered_docs = []
for doc in logs:
doc_create_time = doc.get("create_time", 0)
if (create_time_from == 0 or doc_create_time >= create_time_from) and (create_time_to == 0 or doc_create_time <= create_time_to):
filtered_docs.append(doc)
logs = filtered_docs
return get_json_result(data={"total": tol, "logs": logs})
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@ -478,3 +522,68 @@ def pipeline_log_detail():
return get_data_error_result(message="Invalid pipeline log ID") return get_data_error_result(message="Invalid pipeline log ID")
return get_json_result(data=log.to_dict()) return get_json_result(data=log.to_dict())
@manager.route("/run_graphrag", methods=["POST"]) # noqa: F821
@login_required
def run_graphrag():
req = request.json
kb_id = req.get("kb_id", "")
if not kb_id:
return get_error_data_result(message='Lack of "KB ID"')
doc_ids = req.get("doc_ids", [])
if not doc_ids:
return get_error_data_result(message="Need to specify document IDs to run Graph RAG")
ok, kb = KnowledgebaseService.get_by_id(kb_id)
if not ok:
return get_error_data_result(message="Invalid Knowledgebase ID")
task_id = kb.graphrag_task_id
ok, task = TaskService.get_by_id(task_id)
if not ok:
logging.warning(f"A valid GraphRAG task id is expected for kb {kb_id}")
if task and task.progress not in [-1, 1]:
return get_error_data_result(message=f"Task in progress with status {task.progress}. A Graph Task is already running.")
document_ids = set()
sample_document = {}
for doc_id in doc_ids:
ok, document = DocumentService.get_by_id(doc_id)
if ok:
document_ids.add(document.id)
if not sample_document:
sample_document = document.to_dict()
task_id = queue_raptor_o_graphrag_tasks(doc=sample_document, ty="graphrag", priority=0, fake_doc_id="x", doc_ids=list(document_ids))
if not KnowledgebaseService.update_by_id(kb.id, {"graphrag_task_id": task_id}):
logging.warning(f"Cannot save graphrag_task_id for kb {kb_id}")
return get_json_result(data={"graphrag_task_id": task_id})
@manager.route("/trace_graphrag", methods=["GET"]) # noqa: F821
@login_required
def trace_graphrag():
kb_id = request.args.get("kb_id", "")
if not kb_id:
return get_error_data_result(message='Lack of "KB ID"')
ok, kb = KnowledgebaseService.get_by_id(kb_id)
if not ok:
return get_error_data_result(message="Invalid Knowledgebase ID")
task_id = kb.graphrag_task_id
if not task_id:
return get_error_data_result(message="GraphRAG Task ID Not Found")
ok, task = TaskService.get_by_id(task_id)
if not ok:
return get_json_result(data=False, message="GraphRAG Task Not Found or Error Occurred", code=settings.RetCode.ARGUMENT_ERROR)
return get_json_result(data=task.to_dict())

View File

@ -124,10 +124,12 @@ VALID_MCP_SERVER_TYPES = {MCPServerType.SSE, MCPServerType.STREAMABLE_HTTP}
class PipelineTaskType(StrEnum): class PipelineTaskType(StrEnum):
PARSE = "Parse" PARSE = "Parse"
DOWNLOAD = "DOWNLOAD" DOWNLOAD = "Download"
RAPTOR = "RAPTOR"
GRAPH_RAG = "GraphRAG"
VALID_PIPELINE_TASK_TYPES = {PipelineTaskType.PARSE, PipelineTaskType.DOWNLOAD} VALID_PIPELINE_TASK_TYPES = {PipelineTaskType.PARSE, PipelineTaskType.DOWNLOAD, PipelineTaskType.RAPTOR, PipelineTaskType.GRAPH_RAG}
KNOWLEDGEBASE_FOLDER_NAME=".knowledgebase" KNOWLEDGEBASE_FOLDER_NAME=".knowledgebase"

View File

@ -649,6 +649,9 @@ class Knowledgebase(DataBaseModel):
pipeline_id = CharField(max_length=32, null=True, help_text="Pipeline ID", index=True) pipeline_id = CharField(max_length=32, null=True, help_text="Pipeline ID", index=True)
parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]}) parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]})
pagerank = IntegerField(default=0, index=False) pagerank = IntegerField(default=0, index=False)
graphrag_task_id = CharField(max_length=32, null=True, help_text="Graph RAG task ID", index=True)
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True) status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True)
def __str__(self): def __str__(self):
@ -1065,11 +1068,15 @@ def migrate_db():
except Exception: except Exception:
pass pass
try: try:
migrate(migrator.add_column("knowledgebase", "pipeline_id", CharField(max_length=32, null=True, help_text="default parser ID", index=True))) migrate(migrator.add_column("knowledgebase", "pipeline_id", CharField(max_length=32, null=True, help_text="Pipeline ID", index=True)))
except Exception: except Exception:
pass pass
try: try:
migrate(migrator.add_column("document", "pipeline_id", CharField(max_length=32, null=True, help_text="default parser ID", index=True))) migrate(migrator.add_column("document", "pipeline_id", CharField(max_length=32, null=True, help_text="Pipeline ID", index=True)))
except Exception:
pass
try:
migrate(migrator.add_column("knowledgebase", "graphrag_task_id", CharField(max_length=32, null=True, help_text="Gragh RAG task ID", index=True)))
except Exception: except Exception:
pass pass
logging.disable(logging.NOTSET) logging.disable(logging.NOTSET)

View File

@ -507,6 +507,9 @@ class DocumentService(CommonService):
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def get_doc_id_by_doc_name(cls, doc_name): def get_doc_id_by_doc_name(cls, doc_name):
"""
highly rely on the strict deduplication guarantee from Document
"""
fields = [cls.model.id] fields = [cls.model.id]
doc_id = cls.model.select(*fields) \ doc_id = cls.model.select(*fields) \
.where(cls.model.name == doc_name) .where(cls.model.name == doc_name)
@ -656,6 +659,7 @@ class DocumentService(CommonService):
queue_raptor_o_graphrag_tasks(d, "graphrag", priority) queue_raptor_o_graphrag_tasks(d, "graphrag", priority)
prg = 0.98 * len(tsks) / (len(tsks) + 1) prg = 0.98 * len(tsks) / (len(tsks) + 1)
else: else:
prg = 1
status = TaskStatus.DONE.value status = TaskStatus.DONE.value
msg = "\n".join(sorted(msg)) msg = "\n".join(sorted(msg))
@ -741,7 +745,11 @@ class DocumentService(CommonService):
"cancelled": int(cancelled), "cancelled": int(cancelled),
} }
def queue_raptor_o_graphrag_tasks(doc, ty, priority): def queue_raptor_o_graphrag_tasks(doc, ty, priority, fake_doc_id="", doc_ids=[]):
"""
You can provide a fake_doc_id to bypass the restriction of tasks at the knowledgebase level.
Optionally, specify a list of doc_ids to determine which documents participate in the task.
"""
chunking_config = DocumentService.get_chunking_config(doc["id"]) chunking_config = DocumentService.get_chunking_config(doc["id"])
hasher = xxhash.xxh64() hasher = xxhash.xxh64()
for field in sorted(chunking_config.keys()): for field in sorted(chunking_config.keys()):
@ -751,7 +759,7 @@ def queue_raptor_o_graphrag_tasks(doc, ty, priority):
nonlocal doc nonlocal doc
return { return {
"id": get_uuid(), "id": get_uuid(),
"doc_id": doc["id"], "doc_id": fake_doc_id if fake_doc_id else doc["id"],
"from_page": 100000000, "from_page": 100000000,
"to_page": 100000000, "to_page": 100000000,
"task_type": ty, "task_type": ty,
@ -764,7 +772,11 @@ def queue_raptor_o_graphrag_tasks(doc, ty, priority):
hasher.update(ty.encode("utf-8")) hasher.update(ty.encode("utf-8"))
task["digest"] = hasher.hexdigest() task["digest"] = hasher.hexdigest()
bulk_insert_into_db(Task, [task], True) bulk_insert_into_db(Task, [task], True)
if ty == "graphrag":
task["doc_ids"] = doc_ids
assert REDIS_CONN.queue_product(get_svr_queue_name(priority), message=task), "Can't access Redis. Please check the Redis' status." assert REDIS_CONN.queue_product(get_svr_queue_name(priority), message=task), "Can't access Redis. Please check the Redis' status."
return task["id"]
def get_queue_length(priority): def get_queue_length(priority):

View File

@ -31,7 +31,7 @@ class PipelineOperationLogService(CommonService):
model = PipelineOperationLog model = PipelineOperationLog
@classmethod @classmethod
def get_cls_model_fields(cls): def get_file_logs_fields(cls):
return [ return [
cls.model.id, cls.model.id,
cls.model.document_id, cls.model.document_id,
@ -59,9 +59,29 @@ class PipelineOperationLogService(CommonService):
cls.model.update_date, cls.model.update_date,
] ]
@classmethod
def get_dataset_logs_fields(cls):
return [
cls.model.id,
cls.model.tenant_id,
cls.model.kb_id,
cls.model.progress,
cls.model.progress_msg,
cls.model.process_begin_at,
cls.model.process_duration,
cls.model.task_type,
cls.model.operation_status,
cls.model.avatar,
cls.model.status,
cls.model.create_time,
cls.model.create_date,
cls.model.update_time,
cls.model.update_date,
]
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def create(cls, document_id, pipeline_id, task_type): def create(cls, document_id, pipeline_id, task_type, fake_document_ids=[]):
from rag.flow.pipeline import Pipeline from rag.flow.pipeline import Pipeline
tenant_id = "" tenant_id = ""
@ -69,14 +89,19 @@ class PipelineOperationLogService(CommonService):
avatar = "" avatar = ""
dsl = "" dsl = ""
operation_status = "" operation_status = ""
referred_document_id = document_id
ok, document = DocumentService.get_by_id(document_id) if referred_document_id == "x" and fake_document_ids:
referred_document_id = fake_document_ids[0]
ok, document = DocumentService.get_by_id(referred_document_id)
if not ok: if not ok:
raise RuntimeError(f"Document {document_id} not found") raise RuntimeError(f"Document for referred_document_id {referred_document_id} not found")
DocumentService.update_progress_immediately([document.to_dict()]) DocumentService.update_progress_immediately([document.to_dict()])
ok, document = DocumentService.get_by_id(document_id) ok, document = DocumentService.get_by_id(referred_document_id)
if not ok: if not ok:
raise RuntimeError(f"Document {document_id} not found") raise RuntimeError(f"Document for referred_document_id {referred_document_id} not found")
if document.progress not in [1, -1]:
return
operation_status = document.run operation_status = document.run
if pipeline_id: if pipeline_id:
@ -84,7 +109,7 @@ class PipelineOperationLogService(CommonService):
if not ok: if not ok:
raise RuntimeError(f"Pipeline {pipeline_id} not found") raise RuntimeError(f"Pipeline {pipeline_id} not found")
pipeline = Pipeline(dsl=json.dumps(user_pipeline.dsl), tenant_id=user_pipeline.user_id, doc_id=document_id, task_id="", flow_id=pipeline_id) pipeline = Pipeline(dsl=json.dumps(user_pipeline.dsl), tenant_id=user_pipeline.user_id, doc_id=referred_document_id, task_id="", flow_id=pipeline_id)
tenant_id = user_pipeline.user_id tenant_id = user_pipeline.user_id
title = user_pipeline.title title = user_pipeline.title
@ -93,7 +118,7 @@ class PipelineOperationLogService(CommonService):
else: else:
ok, kb_info = KnowledgebaseService.get_by_id(document.kb_id) ok, kb_info = KnowledgebaseService.get_by_id(document.kb_id)
if not ok: if not ok:
raise RuntimeError(f"Cannot find knowledge base {document.kb_id} for document {document_id}") raise RuntimeError(f"Cannot find knowledge base {document.kb_id} for referred_document {referred_document_id}")
tenant_id = kb_info.tenant_id tenant_id = kb_info.tenant_id
title = document.name title = document.name
@ -104,7 +129,7 @@ class PipelineOperationLogService(CommonService):
log = dict( log = dict(
id=get_uuid(), id=get_uuid(),
document_id=document_id, document_id=document_id, # "x" or real document_id
tenant_id=tenant_id, tenant_id=tenant_id,
kb_id=document.kb_id, kb_id=document.kb_id,
pipeline_id=pipeline_id, pipeline_id=pipeline_id,
@ -132,18 +157,20 @@ class PipelineOperationLogService(CommonService):
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def record_pipeline_operation(cls, document_id, pipeline_id, task_type): def record_pipeline_operation(cls, document_id, pipeline_id, task_type, fake_document_ids=[]):
return cls.create(document_id=document_id, pipeline_id=pipeline_id, task_type=task_type) return cls.create(document_id=document_id, pipeline_id=pipeline_id, task_type=task_type, fake_document_ids=fake_document_ids)
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def get_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, keywords, operation_status, types, suffix): def get_file_logs_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, keywords, operation_status, types, suffix):
fields = cls.get_cls_model_fields() fields = cls.get_file_logs_fields()
if keywords: if keywords:
logs = cls.model.select(*fields).where((cls.model.kb_id == kb_id), (fn.LOWER(cls.model.document_name).contains(keywords.lower()))) logs = cls.model.select(*fields).where((cls.model.kb_id == kb_id), (fn.LOWER(cls.model.document_name).contains(keywords.lower())))
else: else:
logs = cls.model.select(*fields).where(cls.model.kb_id == kb_id) logs = cls.model.select(*fields).where(cls.model.kb_id == kb_id)
logs = logs.where(cls.model.document_id != "x")
if operation_status: if operation_status:
logs = logs.where(cls.model.operation_status.in_(operation_status)) logs = logs.where(cls.model.operation_status.in_(operation_status))
if types: if types:
@ -161,3 +188,23 @@ class PipelineOperationLogService(CommonService):
logs = logs.paginate(page_number, items_per_page) logs = logs.paginate(page_number, items_per_page)
return list(logs.dicts()), count return list(logs.dicts()), count
@classmethod
@DB.connection_context()
def get_dataset_logs_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, operation_status):
fields = cls.get_dataset_logs_fields()
logs = cls.model.select(*fields).where((cls.model.kb_id == kb_id), (cls.model.document_id == "x"))
if operation_status:
logs = logs.where(cls.model.operation_status.in_(operation_status))
count = logs.count()
if desc:
logs = logs.order_by(cls.model.getter_by(orderby).desc())
else:
logs = logs.order_by(cls.model.getter_by(orderby).asc())
if page_number and items_per_page:
logs = logs.paginate(page_number, items_per_page)
return list(logs.dicts()), count

View File

@ -70,7 +70,7 @@ class TaskService(CommonService):
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def get_task(cls, task_id): def get_task(cls, task_id, doc_ids=[]):
"""Retrieve detailed task information by task ID. """Retrieve detailed task information by task ID.
This method fetches comprehensive task details including associated document, This method fetches comprehensive task details including associated document,
@ -84,6 +84,10 @@ class TaskService(CommonService):
dict: Task details dictionary containing all task information and related metadata. dict: Task details dictionary containing all task information and related metadata.
Returns None if task is not found or has exceeded retry limit. Returns None if task is not found or has exceeded retry limit.
""" """
doc_id = cls.model.doc_id
if doc_id == "x" and doc_ids:
doc_id = doc_ids[0]
fields = [ fields = [
cls.model.id, cls.model.id,
cls.model.doc_id, cls.model.doc_id,
@ -109,7 +113,7 @@ class TaskService(CommonService):
] ]
docs = ( docs = (
cls.model.select(*fields) cls.model.select(*fields)
.join(Document, on=(cls.model.doc_id == Document.id)) .join(Document, on=(doc_id == Document.id))
.join(Knowledgebase, on=(Document.kb_id == Knowledgebase.id)) .join(Knowledgebase, on=(Document.kb_id == Knowledgebase.id))
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id)) .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))
.where(cls.model.id == task_id) .where(cls.model.id == task_id)

View File

@ -21,6 +21,7 @@ import networkx as nx
import trio import trio
from api import settings from api import settings
from api.db.services.document_service import DocumentService
from api.utils import get_uuid from api.utils import get_uuid
from api.utils.api_utils import timeout from api.utils.api_utils import timeout
from graphrag.entity_resolution import EntityResolution from graphrag.entity_resolution import EntityResolution
@ -54,7 +55,7 @@ async def run_graphrag(
start = trio.current_time() start = trio.current_time()
tenant_id, kb_id, doc_id = row["tenant_id"], str(row["kb_id"]), row["doc_id"] tenant_id, kb_id, doc_id = row["tenant_id"], str(row["kb_id"]), row["doc_id"]
chunks = [] chunks = []
for d in settings.retrievaler.chunk_list(doc_id, tenant_id, [kb_id], fields=["content_with_weight", "doc_id"]): for d in settings.retrievaler.chunk_list(doc_id, tenant_id, [kb_id], fields=["content_with_weight", "doc_id"], sort_by_position=True):
chunks.append(d["content_with_weight"]) chunks.append(d["content_with_weight"])
with trio.fail_after(max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000): with trio.fail_after(max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000):
@ -125,6 +126,212 @@ async def run_graphrag(
return return
async def run_graphrag_for_kb(
row: dict,
doc_ids: list[str],
language: str,
kb_parser_config: dict,
chat_model,
embedding_model,
callback,
*,
with_resolution: bool = True,
with_community: bool = True,
max_parallel_docs: int = 4,
) -> dict:
tenant_id, kb_id = row["tenant_id"], row["kb_id"]
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
start = trio.current_time()
fields_for_chunks = ["content_with_weight", "doc_id"]
if not doc_ids:
logging.info(f"Fetching all docs for {kb_id}")
docs, _ = DocumentService.get_by_kb_id(
kb_id=kb_id,
page_number=0,
items_per_page=0,
orderby="create_time",
desc=False,
keywords="",
run_status=[],
types=[],
suffix=[],
)
doc_ids = [doc["id"] for doc in docs]
doc_ids = list(dict.fromkeys(doc_ids))
if not doc_ids:
callback(msg=f"[GraphRAG] kb:{kb_id} has no processable doc_id.")
return {"ok_docs": [], "failed_docs": [], "total_docs": 0, "total_chunks": 0, "seconds": 0.0}
def load_doc_chunks(doc_id: str) -> list[str]:
from rag.utils import num_tokens_from_string
chunks = []
current_chunk = ""
for d in settings.retrievaler.chunk_list(
doc_id,
tenant_id,
[kb_id],
fields=fields_for_chunks,
sort_by_position=True,
):
content = d["content_with_weight"]
if num_tokens_from_string(current_chunk + content) < 1024:
current_chunk += content
else:
if current_chunk:
chunks.append(current_chunk)
current_chunk = content
if current_chunk:
chunks.append(current_chunk)
return chunks
all_doc_chunks: dict[str, list[str]] = {}
total_chunks = 0
for doc_id in doc_ids:
chunks = load_doc_chunks(doc_id)
all_doc_chunks[doc_id] = chunks
total_chunks += len(chunks)
if total_chunks == 0:
callback(msg=f"[GraphRAG] kb:{kb_id} has no available chunks in all documents, skip.")
return {"ok_docs": [], "failed_docs": doc_ids, "total_docs": len(doc_ids), "total_chunks": 0, "seconds": 0.0}
semaphore = trio.Semaphore(max_parallel_docs)
subgraphs: dict[str, object] = {}
failed_docs: list[tuple[str, str]] = [] # (doc_id, error)
async def build_one(doc_id: str):
chunks = all_doc_chunks.get(doc_id, [])
if not chunks:
callback(msg=f"[GraphRAG] doc:{doc_id} has no available chunks, skip generation.")
return
kg_extractor = LightKGExt if ("method" not in kb_parser_config.get("graphrag", {}) or kb_parser_config["graphrag"]["method"] != "general") else GeneralKGExt
deadline = max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000
async with semaphore:
try:
msg = f"[GraphRAG] build_subgraph doc:{doc_id}"
callback(msg=f"{msg} start (chunks={len(chunks)}, timeout={deadline}s)")
with trio.fail_after(deadline):
sg = await generate_subgraph(
kg_extractor,
tenant_id,
kb_id,
doc_id,
chunks,
language,
kb_parser_config.get("graphrag", {}).get("entity_types", []),
chat_model,
embedding_model,
callback,
)
if sg:
subgraphs[doc_id] = sg
callback(msg=f"{msg} done")
else:
failed_docs.append((doc_id, "subgraph is empty"))
callback(msg=f"{msg} empty")
except Exception as e:
failed_docs.append((doc_id, repr(e)))
callback(msg=f"[GraphRAG] build_subgraph doc:{doc_id} FAILED: {e!r}")
async with trio.open_nursery() as nursery:
for doc_id in doc_ids:
nursery.start_soon(build_one, doc_id)
ok_docs = [d for d in doc_ids if d in subgraphs]
if not ok_docs:
callback(msg=f"[GraphRAG] kb:{kb_id} no subgraphs generated successfully, end.")
now = trio.current_time()
return {"ok_docs": [], "failed_docs": failed_docs, "total_docs": len(doc_ids), "total_chunks": total_chunks, "seconds": now - start}
kb_lock = RedisDistributedLock(f"graphrag_task_{kb_id}", lock_value="batch_merge", timeout=1200)
await kb_lock.spin_acquire()
callback(msg=f"[GraphRAG] kb:{kb_id} merge lock acquired")
try:
union_nodes: set = set()
final_graph = None
for doc_id in ok_docs:
sg = subgraphs[doc_id]
union_nodes.update(set(sg.nodes()))
new_graph = await merge_subgraph(
tenant_id,
kb_id,
doc_id,
sg,
embedding_model,
callback,
)
if new_graph is not None:
final_graph = new_graph
if final_graph is None:
callback(msg=f"[GraphRAG] kb:{kb_id} merge finished (no in-memory graph returned).")
else:
callback(msg=f"[GraphRAG] kb:{kb_id} merge finished, graph ready.")
finally:
kb_lock.release()
if not with_resolution and not with_community:
now = trio.current_time()
callback(msg=f"[GraphRAG] KB merge only done in {now - start:.2f}s. ok={len(ok_docs)} / total={len(doc_ids)}")
return {"ok_docs": ok_docs, "failed_docs": failed_docs, "total_docs": len(doc_ids), "total_chunks": total_chunks, "seconds": now - start}
await kb_lock.spin_acquire()
callback(msg=f"[GraphRAG] kb:{kb_id} post-merge lock acquired for resolution/community")
try:
subgraph_nodes = set()
for sg in subgraphs.values():
subgraph_nodes.update(set(sg.nodes()))
if with_resolution:
await resolve_entities(
final_graph,
subgraph_nodes,
tenant_id,
kb_id,
None,
chat_model,
embedding_model,
callback,
)
if with_community:
await extract_community(
final_graph,
tenant_id,
kb_id,
None,
chat_model,
embedding_model,
callback,
)
finally:
kb_lock.release()
now = trio.current_time()
callback(msg=f"[GraphRAG] GraphRAG for KB {kb_id} done in {now - start:.2f} seconds. ok={len(ok_docs)} failed={len(failed_docs)} total_docs={len(doc_ids)} total_chunks={total_chunks}")
return {
"ok_docs": ok_docs,
"failed_docs": failed_docs, # [(doc_id, error), ...]
"total_docs": len(doc_ids),
"total_chunks": total_chunks,
"seconds": now - start,
}
async def generate_subgraph( async def generate_subgraph(
extractor: Extractor, extractor: Extractor,
tenant_id: str, tenant_id: str,

View File

@ -34,9 +34,9 @@ class Pipeline(Graph):
if isinstance(dsl, dict): if isinstance(dsl, dict):
dsl = json.dumps(dsl, ensure_ascii=False) dsl = json.dumps(dsl, ensure_ascii=False)
super().__init__(dsl, tenant_id, task_id) super().__init__(dsl, tenant_id, task_id)
self._doc_id = doc_id
if self._doc_id == "x": if self._doc_id == "x":
self._doc_id = None self._doc_id = None
self._doc_id = doc_id
self._flow_id = flow_id self._flow_id = flow_id
self._kb_id = None self._kb_id = None
if self._doc_id: if self._doc_id:

View File

@ -383,7 +383,7 @@ class Dealer:
vector_column = f"q_{dim}_vec" vector_column = f"q_{dim}_vec"
zero_vector = [0.0] * dim zero_vector = [0.0] * dim
sim_np = np.array(sim) sim_np = np.array(sim)
filtered_count = (sim_np >= similarity_threshold).sum() filtered_count = (sim_np >= similarity_threshold).sum()
ranks["total"] = int(filtered_count) # Convert from np.int64 to Python int otherwise JSON serializable error ranks["total"] = int(filtered_count) # Convert from np.int64 to Python int otherwise JSON serializable error
for i in idx: for i in idx:
if sim[i] < similarity_threshold: if sim[i] < similarity_threshold:
@ -444,12 +444,27 @@ class Dealer:
def chunk_list(self, doc_id: str, tenant_id: str, def chunk_list(self, doc_id: str, tenant_id: str,
kb_ids: list[str], max_count=1024, kb_ids: list[str], max_count=1024,
offset=0, offset=0,
fields=["docnm_kwd", "content_with_weight", "img_id"]): fields=["docnm_kwd", "content_with_weight", "img_id"],
sort_by_position: bool = False):
condition = {"doc_id": doc_id} condition = {"doc_id": doc_id}
fields_set = set(fields or [])
if sort_by_position:
for need in ("page_num_int", "position_int", "top_int"):
if need not in fields_set:
fields_set.add(need)
fields = list(fields_set)
orderBy = OrderByExpr()
if sort_by_position:
orderBy.asc("page_num_int")
orderBy.asc("position_int")
orderBy.asc("top_int")
res = [] res = []
bs = 128 bs = 128
for p in range(offset, max_count, bs): for p in range(offset, max_count, bs):
es_res = self.dataStore.search(fields, [], condition, [], OrderByExpr(), p, bs, index_name(tenant_id), es_res = self.dataStore.search(fields, [], condition, [], orderBy, p, bs, index_name(tenant_id),
kb_ids) kb_ids)
dict_chunks = self.dataStore.getFields(es_res, fields) dict_chunks = self.dataStore.getFields(es_res, fields)
for id, doc in dict_chunks.items(): for id, doc in dict_chunks.items():

View File

@ -25,7 +25,7 @@ from api.db.services.pipeline_operation_log_service import PipelineOperationLogS
from api.utils.api_utils import timeout from api.utils.api_utils import timeout
from api.utils.base64_image import image2id from api.utils.base64_image import image2id
from api.utils.log_utils import init_root_logger, get_project_base_directory from api.utils.log_utils import init_root_logger, get_project_base_directory
from graphrag.general.index import run_graphrag from graphrag.general.index import run_graphrag_for_kb
from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache
from rag.flow.pipeline import Pipeline from rag.flow.pipeline import Pipeline
from rag.prompts import keyword_extraction, question_proposal, content_tagging from rag.prompts import keyword_extraction, question_proposal, content_tagging
@ -85,6 +85,12 @@ FACTORY = {
ParserType.TAG.value: tag ParserType.TAG.value: tag
} }
TASK_TYPE_TO_PIPELINE_TASK_TYPE = {
"dataflow" : PipelineTaskType.PARSE,
"raptor": PipelineTaskType.RAPTOR,
"graphrag": PipelineTaskType.GRAPH_RAG,
}
UNACKED_ITERATOR = None UNACKED_ITERATOR = None
CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1] CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
@ -215,6 +221,10 @@ async def collect():
canceled = False canceled = False
if msg.get("doc_id", "") == "x": if msg.get("doc_id", "") == "x":
task = msg task = msg
if task["task_type"] == "graphrag" and msg.get("doc_ids", []):
print(f"hack {msg['doc_ids']=}=",flush=True)
task = TaskService.get_task(msg["id"], msg["doc_ids"])
task["doc_ids"] = msg["doc_ids"]
else: else:
task = TaskService.get_task(msg["id"]) task = TaskService.get_task(msg["id"])
@ -580,7 +590,19 @@ async def do_handle_task(task):
with_resolution = graphrag_conf.get("resolution", False) with_resolution = graphrag_conf.get("resolution", False)
with_community = graphrag_conf.get("community", False) with_community = graphrag_conf.get("community", False)
async with kg_limiter: async with kg_limiter:
await run_graphrag(task, task_language, with_resolution, with_community, chat_model, embedding_model, progress_callback) # await run_graphrag(task, task_language, with_resolution, with_community, chat_model, embedding_model, progress_callback)
result = await run_graphrag_for_kb(
row=task,
doc_ids=task.get("doc_ids", []),
language=task_language,
kb_parser_config=task_parser_config,
chat_model=chat_model,
embedding_model=embedding_model,
callback=progress_callback,
with_resolution=with_resolution,
with_community=with_community,
)
logging.info(f"GraphRAG task result for task {task}:\n{result}")
progress_callback(prog=1.0, msg="Knowledge Graph done ({:.2f}s)".format(timer() - start_ts)) progress_callback(prog=1.0, msg="Knowledge Graph done ({:.2f}s)".format(timer() - start_ts))
return return
else: else:
@ -650,7 +672,6 @@ async def do_handle_task(task):
timer() - start_ts)) timer() - start_ts))
DocumentService.increment_chunk_num(task_doc_id, task_dataset_id, token_count, chunk_count, 0) DocumentService.increment_chunk_num(task_doc_id, task_dataset_id, token_count, chunk_count, 0)
PipelineOperationLogService.record_pipeline_operation(document_id=task_doc_id, pipeline_id="", task_type=PipelineTaskType.PARSE)
time_cost = timer() - start_ts time_cost = timer() - start_ts
task_time_cost = timer() - task_start_ts task_time_cost = timer() - task_start_ts
@ -667,6 +688,10 @@ async def handle_task():
if not task: if not task:
await trio.sleep(5) await trio.sleep(5)
return return
task_type = task["task_type"]
pipeline_task_type = TASK_TYPE_TO_PIPELINE_TASK_TYPE.get(task_type, PipelineTaskType.PARSE) or PipelineTaskType.PARSE
try: try:
logging.info(f"handle_task begin for task {json.dumps(task)}") logging.info(f"handle_task begin for task {json.dumps(task)}")
CURRENT_TASKS[task["id"]] = copy.deepcopy(task) CURRENT_TASKS[task["id"]] = copy.deepcopy(task)
@ -686,7 +711,12 @@ async def handle_task():
except Exception: except Exception:
pass pass
logging.exception(f"handle_task got exception for task {json.dumps(task)}") logging.exception(f"handle_task got exception for task {json.dumps(task)}")
PipelineOperationLogService.record_pipeline_operation(document_id=task["doc_id"], pipeline_id=task.get("dataflow_id", "") or "", task_type=PipelineTaskType.PARSE) finally:
task_document_ids = []
if task_type in ["graphrag"]:
task_document_ids = task["doc_ids"]
PipelineOperationLogService.record_pipeline_operation(document_id=task["doc_id"], pipeline_id=task.get("dataflow_id", "") or "", task_type=pipeline_task_type, fake_document_ids=task_document_ids)
redis_msg.ack() redis_msg.ack()