diff --git a/api/apps/kb_app.py b/api/apps/kb_app.py index 9ee62d1e0..f644c9001 100644 --- a/api/apps/kb_app.py +++ b/api/apps/kb_app.py @@ -14,17 +14,19 @@ # limitations under the License. # import json +import logging from flask import request from flask_login import login_required, current_user from api.db.services import duplicate_name -from api.db.services.document_service import DocumentService +from api.db.services.document_service import DocumentService, queue_raptor_o_graphrag_tasks from api.db.services.file2document_service import File2DocumentService from api.db.services.file_service import FileService from api.db.services.pipeline_operation_log_service import PipelineOperationLogService +from api.db.services.task_service import TaskService from api.db.services.user_service import TenantService, UserTenantService -from api.utils.api_utils import server_error_response, get_data_error_result, validate_request, not_allowed_parameters +from api.utils.api_utils import get_error_data_result, server_error_response, get_data_error_result, validate_request, not_allowed_parameters from api.utils import get_uuid from api.db import StatusEnum, FileSource, VALID_FILE_TYPES from api.db.services.knowledgebase_service import KnowledgebaseService @@ -435,18 +437,60 @@ def list_pipeline_logs(): suffix = req.get("suffix", []) try: - docs, tol = PipelineOperationLogService.get_by_kb_id(kb_id, page_number, items_per_page, orderby, desc, keywords, operation_status, types, suffix) + logs, tol = PipelineOperationLogService.get_file_logs_by_kb_id(kb_id, page_number, items_per_page, orderby, desc, keywords, operation_status, types, suffix) if create_time_from or create_time_to: filtered_docs = [] - for doc in docs: + for doc in logs: doc_create_time = doc.get("create_time", 0) if (create_time_from == 0 or doc_create_time >= create_time_from) and (create_time_to == 0 or doc_create_time <= create_time_to): filtered_docs.append(doc) - docs = filtered_docs + logs = filtered_docs - return get_json_result(data={"total": tol, "docs": docs}) + return get_json_result(data={"total": tol, "logs": logs}) + except Exception as e: + return server_error_response(e) + + +@manager.route("/list_pipeline_dataset_logs", methods=["POST"]) # noqa: F821 +@login_required +def list_pipeline_dataset_logs(): + kb_id = request.args.get("kb_id") + if not kb_id: + return get_json_result(data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR) + + page_number = int(request.args.get("page", 0)) + items_per_page = int(request.args.get("page_size", 0)) + orderby = request.args.get("orderby", "create_time") + if request.args.get("desc", "true").lower() == "false": + desc = False + else: + desc = True + create_time_from = int(request.args.get("create_time_from", 0)) + create_time_to = int(request.args.get("create_time_to", 0)) + + req = request.get_json() + + operation_status = req.get("operation_status", []) + if operation_status: + invalid_status = {s for s in operation_status if s not in ["success", "failed", "running", "pending"]} + if invalid_status: + return get_data_error_result(message=f"Invalid filter operation_status status conditions: {', '.join(invalid_status)}") + + try: + logs, tol = PipelineOperationLogService.get_dataset_logs_by_kb_id(kb_id, page_number, items_per_page, orderby, desc, operation_status) + + if create_time_from or create_time_to: + filtered_docs = [] + for doc in logs: + doc_create_time = doc.get("create_time", 0) + if (create_time_from == 0 or doc_create_time >= create_time_from) and (create_time_to == 0 or doc_create_time <= create_time_to): + filtered_docs.append(doc) + logs = filtered_docs + + + return get_json_result(data={"total": tol, "logs": logs}) except Exception as e: return server_error_response(e) @@ -478,3 +522,68 @@ def pipeline_log_detail(): return get_data_error_result(message="Invalid pipeline log ID") return get_json_result(data=log.to_dict()) + + +@manager.route("/run_graphrag", methods=["POST"]) # noqa: F821 +@login_required +def run_graphrag(): + req = request.json + + kb_id = req.get("kb_id", "") + if not kb_id: + return get_error_data_result(message='Lack of "KB ID"') + + doc_ids = req.get("doc_ids", []) + if not doc_ids: + return get_error_data_result(message="Need to specify document IDs to run Graph RAG") + + ok, kb = KnowledgebaseService.get_by_id(kb_id) + if not ok: + return get_error_data_result(message="Invalid Knowledgebase ID") + + task_id = kb.graphrag_task_id + ok, task = TaskService.get_by_id(task_id) + if not ok: + logging.warning(f"A valid GraphRAG task id is expected for kb {kb_id}") + + if task and task.progress not in [-1, 1]: + return get_error_data_result(message=f"Task in progress with status {task.progress}. A Graph Task is already running.") + + document_ids = set() + sample_document = {} + for doc_id in doc_ids: + ok, document = DocumentService.get_by_id(doc_id) + if ok: + document_ids.add(document.id) + if not sample_document: + sample_document = document.to_dict() + + task_id = queue_raptor_o_graphrag_tasks(doc=sample_document, ty="graphrag", priority=0, fake_doc_id="x", doc_ids=list(document_ids)) + + if not KnowledgebaseService.update_by_id(kb.id, {"graphrag_task_id": task_id}): + logging.warning(f"Cannot save graphrag_task_id for kb {kb_id}") + + return get_json_result(data={"graphrag_task_id": task_id}) + + +@manager.route("/trace_graphrag", methods=["GET"]) # noqa: F821 +@login_required +def trace_graphrag(): + kb_id = request.args.get("kb_id", "") + if not kb_id: + return get_error_data_result(message='Lack of "KB ID"') + + ok, kb = KnowledgebaseService.get_by_id(kb_id) + if not ok: + return get_error_data_result(message="Invalid Knowledgebase ID") + + + task_id = kb.graphrag_task_id + if not task_id: + return get_error_data_result(message="GraphRAG Task ID Not Found") + + ok, task = TaskService.get_by_id(task_id) + if not ok: + return get_json_result(data=False, message="GraphRAG Task Not Found or Error Occurred", code=settings.RetCode.ARGUMENT_ERROR) + + return get_json_result(data=task.to_dict()) diff --git a/api/db/__init__.py b/api/db/__init__.py index 9bdfc726d..8f2806419 100644 --- a/api/db/__init__.py +++ b/api/db/__init__.py @@ -124,10 +124,12 @@ VALID_MCP_SERVER_TYPES = {MCPServerType.SSE, MCPServerType.STREAMABLE_HTTP} class PipelineTaskType(StrEnum): PARSE = "Parse" - DOWNLOAD = "DOWNLOAD" + DOWNLOAD = "Download" + RAPTOR = "RAPTOR" + GRAPH_RAG = "GraphRAG" -VALID_PIPELINE_TASK_TYPES = {PipelineTaskType.PARSE, PipelineTaskType.DOWNLOAD} +VALID_PIPELINE_TASK_TYPES = {PipelineTaskType.PARSE, PipelineTaskType.DOWNLOAD, PipelineTaskType.RAPTOR, PipelineTaskType.GRAPH_RAG} KNOWLEDGEBASE_FOLDER_NAME=".knowledgebase" diff --git a/api/db/db_models.py b/api/db/db_models.py index 07702c2e8..269fffdb5 100644 --- a/api/db/db_models.py +++ b/api/db/db_models.py @@ -649,6 +649,9 @@ class Knowledgebase(DataBaseModel): pipeline_id = CharField(max_length=32, null=True, help_text="Pipeline ID", index=True) parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]}) pagerank = IntegerField(default=0, index=False) + + graphrag_task_id = CharField(max_length=32, null=True, help_text="Graph RAG task ID", index=True) + status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True) def __str__(self): @@ -1065,11 +1068,15 @@ def migrate_db(): except Exception: pass try: - migrate(migrator.add_column("knowledgebase", "pipeline_id", CharField(max_length=32, null=True, help_text="default parser ID", index=True))) + migrate(migrator.add_column("knowledgebase", "pipeline_id", CharField(max_length=32, null=True, help_text="Pipeline ID", index=True))) except Exception: pass try: - migrate(migrator.add_column("document", "pipeline_id", CharField(max_length=32, null=True, help_text="default parser ID", index=True))) + migrate(migrator.add_column("document", "pipeline_id", CharField(max_length=32, null=True, help_text="Pipeline ID", index=True))) + except Exception: + pass + try: + migrate(migrator.add_column("knowledgebase", "graphrag_task_id", CharField(max_length=32, null=True, help_text="Gragh RAG task ID", index=True))) except Exception: pass logging.disable(logging.NOTSET) diff --git a/api/db/services/document_service.py b/api/db/services/document_service.py index 605f14049..56d08b230 100644 --- a/api/db/services/document_service.py +++ b/api/db/services/document_service.py @@ -507,6 +507,9 @@ class DocumentService(CommonService): @classmethod @DB.connection_context() def get_doc_id_by_doc_name(cls, doc_name): + """ + highly rely on the strict deduplication guarantee from Document + """ fields = [cls.model.id] doc_id = cls.model.select(*fields) \ .where(cls.model.name == doc_name) @@ -656,6 +659,7 @@ class DocumentService(CommonService): queue_raptor_o_graphrag_tasks(d, "graphrag", priority) prg = 0.98 * len(tsks) / (len(tsks) + 1) else: + prg = 1 status = TaskStatus.DONE.value msg = "\n".join(sorted(msg)) @@ -741,7 +745,11 @@ class DocumentService(CommonService): "cancelled": int(cancelled), } -def queue_raptor_o_graphrag_tasks(doc, ty, priority): +def queue_raptor_o_graphrag_tasks(doc, ty, priority, fake_doc_id="", doc_ids=[]): + """ + You can provide a fake_doc_id to bypass the restriction of tasks at the knowledgebase level. + Optionally, specify a list of doc_ids to determine which documents participate in the task. + """ chunking_config = DocumentService.get_chunking_config(doc["id"]) hasher = xxhash.xxh64() for field in sorted(chunking_config.keys()): @@ -751,7 +759,7 @@ def queue_raptor_o_graphrag_tasks(doc, ty, priority): nonlocal doc return { "id": get_uuid(), - "doc_id": doc["id"], + "doc_id": fake_doc_id if fake_doc_id else doc["id"], "from_page": 100000000, "to_page": 100000000, "task_type": ty, @@ -764,7 +772,11 @@ def queue_raptor_o_graphrag_tasks(doc, ty, priority): hasher.update(ty.encode("utf-8")) task["digest"] = hasher.hexdigest() bulk_insert_into_db(Task, [task], True) + + if ty == "graphrag": + task["doc_ids"] = doc_ids assert REDIS_CONN.queue_product(get_svr_queue_name(priority), message=task), "Can't access Redis. Please check the Redis' status." + return task["id"] def get_queue_length(priority): diff --git a/api/db/services/pipeline_operation_log_service.py b/api/db/services/pipeline_operation_log_service.py index c547a6a06..d316cb46b 100644 --- a/api/db/services/pipeline_operation_log_service.py +++ b/api/db/services/pipeline_operation_log_service.py @@ -31,7 +31,7 @@ class PipelineOperationLogService(CommonService): model = PipelineOperationLog @classmethod - def get_cls_model_fields(cls): + def get_file_logs_fields(cls): return [ cls.model.id, cls.model.document_id, @@ -59,9 +59,29 @@ class PipelineOperationLogService(CommonService): cls.model.update_date, ] + @classmethod + def get_dataset_logs_fields(cls): + return [ + cls.model.id, + cls.model.tenant_id, + cls.model.kb_id, + cls.model.progress, + cls.model.progress_msg, + cls.model.process_begin_at, + cls.model.process_duration, + cls.model.task_type, + cls.model.operation_status, + cls.model.avatar, + cls.model.status, + cls.model.create_time, + cls.model.create_date, + cls.model.update_time, + cls.model.update_date, + ] + @classmethod @DB.connection_context() - def create(cls, document_id, pipeline_id, task_type): + def create(cls, document_id, pipeline_id, task_type, fake_document_ids=[]): from rag.flow.pipeline import Pipeline tenant_id = "" @@ -69,14 +89,19 @@ class PipelineOperationLogService(CommonService): avatar = "" dsl = "" operation_status = "" + referred_document_id = document_id - ok, document = DocumentService.get_by_id(document_id) + if referred_document_id == "x" and fake_document_ids: + referred_document_id = fake_document_ids[0] + ok, document = DocumentService.get_by_id(referred_document_id) if not ok: - raise RuntimeError(f"Document {document_id} not found") + raise RuntimeError(f"Document for referred_document_id {referred_document_id} not found") DocumentService.update_progress_immediately([document.to_dict()]) - ok, document = DocumentService.get_by_id(document_id) + ok, document = DocumentService.get_by_id(referred_document_id) if not ok: - raise RuntimeError(f"Document {document_id} not found") + raise RuntimeError(f"Document for referred_document_id {referred_document_id} not found") + if document.progress not in [1, -1]: + return operation_status = document.run if pipeline_id: @@ -84,7 +109,7 @@ class PipelineOperationLogService(CommonService): if not ok: raise RuntimeError(f"Pipeline {pipeline_id} not found") - pipeline = Pipeline(dsl=json.dumps(user_pipeline.dsl), tenant_id=user_pipeline.user_id, doc_id=document_id, task_id="", flow_id=pipeline_id) + pipeline = Pipeline(dsl=json.dumps(user_pipeline.dsl), tenant_id=user_pipeline.user_id, doc_id=referred_document_id, task_id="", flow_id=pipeline_id) tenant_id = user_pipeline.user_id title = user_pipeline.title @@ -93,7 +118,7 @@ class PipelineOperationLogService(CommonService): else: ok, kb_info = KnowledgebaseService.get_by_id(document.kb_id) if not ok: - raise RuntimeError(f"Cannot find knowledge base {document.kb_id} for document {document_id}") + raise RuntimeError(f"Cannot find knowledge base {document.kb_id} for referred_document {referred_document_id}") tenant_id = kb_info.tenant_id title = document.name @@ -104,7 +129,7 @@ class PipelineOperationLogService(CommonService): log = dict( id=get_uuid(), - document_id=document_id, + document_id=document_id, # "x" or real document_id tenant_id=tenant_id, kb_id=document.kb_id, pipeline_id=pipeline_id, @@ -132,18 +157,20 @@ class PipelineOperationLogService(CommonService): @classmethod @DB.connection_context() - def record_pipeline_operation(cls, document_id, pipeline_id, task_type): - return cls.create(document_id=document_id, pipeline_id=pipeline_id, task_type=task_type) + def record_pipeline_operation(cls, document_id, pipeline_id, task_type, fake_document_ids=[]): + return cls.create(document_id=document_id, pipeline_id=pipeline_id, task_type=task_type, fake_document_ids=fake_document_ids) @classmethod @DB.connection_context() - def get_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, keywords, operation_status, types, suffix): - fields = cls.get_cls_model_fields() + def get_file_logs_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, keywords, operation_status, types, suffix): + fields = cls.get_file_logs_fields() if keywords: logs = cls.model.select(*fields).where((cls.model.kb_id == kb_id), (fn.LOWER(cls.model.document_name).contains(keywords.lower()))) else: logs = cls.model.select(*fields).where(cls.model.kb_id == kb_id) + logs = logs.where(cls.model.document_id != "x") + if operation_status: logs = logs.where(cls.model.operation_status.in_(operation_status)) if types: @@ -161,3 +188,23 @@ class PipelineOperationLogService(CommonService): logs = logs.paginate(page_number, items_per_page) return list(logs.dicts()), count + + @classmethod + @DB.connection_context() + def get_dataset_logs_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, operation_status): + fields = cls.get_dataset_logs_fields() + logs = cls.model.select(*fields).where((cls.model.kb_id == kb_id), (cls.model.document_id == "x")) + + if operation_status: + logs = logs.where(cls.model.operation_status.in_(operation_status)) + + count = logs.count() + if desc: + logs = logs.order_by(cls.model.getter_by(orderby).desc()) + else: + logs = logs.order_by(cls.model.getter_by(orderby).asc()) + + if page_number and items_per_page: + logs = logs.paginate(page_number, items_per_page) + + return list(logs.dicts()), count diff --git a/api/db/services/task_service.py b/api/db/services/task_service.py index 9d29cbff9..215e5c724 100644 --- a/api/db/services/task_service.py +++ b/api/db/services/task_service.py @@ -70,7 +70,7 @@ class TaskService(CommonService): @classmethod @DB.connection_context() - def get_task(cls, task_id): + def get_task(cls, task_id, doc_ids=[]): """Retrieve detailed task information by task ID. This method fetches comprehensive task details including associated document, @@ -84,6 +84,10 @@ class TaskService(CommonService): dict: Task details dictionary containing all task information and related metadata. Returns None if task is not found or has exceeded retry limit. """ + doc_id = cls.model.doc_id + if doc_id == "x" and doc_ids: + doc_id = doc_ids[0] + fields = [ cls.model.id, cls.model.doc_id, @@ -109,7 +113,7 @@ class TaskService(CommonService): ] docs = ( cls.model.select(*fields) - .join(Document, on=(cls.model.doc_id == Document.id)) + .join(Document, on=(doc_id == Document.id)) .join(Knowledgebase, on=(Document.kb_id == Knowledgebase.id)) .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id)) .where(cls.model.id == task_id) diff --git a/graphrag/general/index.py b/graphrag/general/index.py index 9e80309f2..edb25c9ae 100644 --- a/graphrag/general/index.py +++ b/graphrag/general/index.py @@ -21,6 +21,7 @@ import networkx as nx import trio from api import settings +from api.db.services.document_service import DocumentService from api.utils import get_uuid from api.utils.api_utils import timeout from graphrag.entity_resolution import EntityResolution @@ -54,7 +55,7 @@ async def run_graphrag( start = trio.current_time() tenant_id, kb_id, doc_id = row["tenant_id"], str(row["kb_id"]), row["doc_id"] chunks = [] - for d in settings.retrievaler.chunk_list(doc_id, tenant_id, [kb_id], fields=["content_with_weight", "doc_id"]): + for d in settings.retrievaler.chunk_list(doc_id, tenant_id, [kb_id], fields=["content_with_weight", "doc_id"], sort_by_position=True): chunks.append(d["content_with_weight"]) with trio.fail_after(max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000): @@ -125,6 +126,212 @@ async def run_graphrag( return +async def run_graphrag_for_kb( + row: dict, + doc_ids: list[str], + language: str, + kb_parser_config: dict, + chat_model, + embedding_model, + callback, + *, + with_resolution: bool = True, + with_community: bool = True, + max_parallel_docs: int = 4, +) -> dict: + tenant_id, kb_id = row["tenant_id"], row["kb_id"] + enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION") + start = trio.current_time() + fields_for_chunks = ["content_with_weight", "doc_id"] + + if not doc_ids: + logging.info(f"Fetching all docs for {kb_id}") + docs, _ = DocumentService.get_by_kb_id( + kb_id=kb_id, + page_number=0, + items_per_page=0, + orderby="create_time", + desc=False, + keywords="", + run_status=[], + types=[], + suffix=[], + ) + doc_ids = [doc["id"] for doc in docs] + + doc_ids = list(dict.fromkeys(doc_ids)) + if not doc_ids: + callback(msg=f"[GraphRAG] kb:{kb_id} has no processable doc_id.") + return {"ok_docs": [], "failed_docs": [], "total_docs": 0, "total_chunks": 0, "seconds": 0.0} + + def load_doc_chunks(doc_id: str) -> list[str]: + from rag.utils import num_tokens_from_string + + chunks = [] + current_chunk = "" + + for d in settings.retrievaler.chunk_list( + doc_id, + tenant_id, + [kb_id], + fields=fields_for_chunks, + sort_by_position=True, + ): + content = d["content_with_weight"] + if num_tokens_from_string(current_chunk + content) < 1024: + current_chunk += content + else: + if current_chunk: + chunks.append(current_chunk) + current_chunk = content + + if current_chunk: + chunks.append(current_chunk) + + return chunks + + all_doc_chunks: dict[str, list[str]] = {} + total_chunks = 0 + for doc_id in doc_ids: + chunks = load_doc_chunks(doc_id) + all_doc_chunks[doc_id] = chunks + total_chunks += len(chunks) + + if total_chunks == 0: + callback(msg=f"[GraphRAG] kb:{kb_id} has no available chunks in all documents, skip.") + return {"ok_docs": [], "failed_docs": doc_ids, "total_docs": len(doc_ids), "total_chunks": 0, "seconds": 0.0} + + semaphore = trio.Semaphore(max_parallel_docs) + + subgraphs: dict[str, object] = {} + failed_docs: list[tuple[str, str]] = [] # (doc_id, error) + + async def build_one(doc_id: str): + chunks = all_doc_chunks.get(doc_id, []) + if not chunks: + callback(msg=f"[GraphRAG] doc:{doc_id} has no available chunks, skip generation.") + return + + kg_extractor = LightKGExt if ("method" not in kb_parser_config.get("graphrag", {}) or kb_parser_config["graphrag"]["method"] != "general") else GeneralKGExt + + deadline = max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000 + + async with semaphore: + try: + msg = f"[GraphRAG] build_subgraph doc:{doc_id}" + callback(msg=f"{msg} start (chunks={len(chunks)}, timeout={deadline}s)") + with trio.fail_after(deadline): + sg = await generate_subgraph( + kg_extractor, + tenant_id, + kb_id, + doc_id, + chunks, + language, + kb_parser_config.get("graphrag", {}).get("entity_types", []), + chat_model, + embedding_model, + callback, + ) + if sg: + subgraphs[doc_id] = sg + callback(msg=f"{msg} done") + else: + failed_docs.append((doc_id, "subgraph is empty")) + callback(msg=f"{msg} empty") + except Exception as e: + failed_docs.append((doc_id, repr(e))) + callback(msg=f"[GraphRAG] build_subgraph doc:{doc_id} FAILED: {e!r}") + + async with trio.open_nursery() as nursery: + for doc_id in doc_ids: + nursery.start_soon(build_one, doc_id) + + ok_docs = [d for d in doc_ids if d in subgraphs] + if not ok_docs: + callback(msg=f"[GraphRAG] kb:{kb_id} no subgraphs generated successfully, end.") + now = trio.current_time() + return {"ok_docs": [], "failed_docs": failed_docs, "total_docs": len(doc_ids), "total_chunks": total_chunks, "seconds": now - start} + + kb_lock = RedisDistributedLock(f"graphrag_task_{kb_id}", lock_value="batch_merge", timeout=1200) + await kb_lock.spin_acquire() + callback(msg=f"[GraphRAG] kb:{kb_id} merge lock acquired") + + try: + union_nodes: set = set() + final_graph = None + + for doc_id in ok_docs: + sg = subgraphs[doc_id] + union_nodes.update(set(sg.nodes())) + + new_graph = await merge_subgraph( + tenant_id, + kb_id, + doc_id, + sg, + embedding_model, + callback, + ) + if new_graph is not None: + final_graph = new_graph + + if final_graph is None: + callback(msg=f"[GraphRAG] kb:{kb_id} merge finished (no in-memory graph returned).") + else: + callback(msg=f"[GraphRAG] kb:{kb_id} merge finished, graph ready.") + finally: + kb_lock.release() + + if not with_resolution and not with_community: + now = trio.current_time() + callback(msg=f"[GraphRAG] KB merge only done in {now - start:.2f}s. ok={len(ok_docs)} / total={len(doc_ids)}") + return {"ok_docs": ok_docs, "failed_docs": failed_docs, "total_docs": len(doc_ids), "total_chunks": total_chunks, "seconds": now - start} + + await kb_lock.spin_acquire() + callback(msg=f"[GraphRAG] kb:{kb_id} post-merge lock acquired for resolution/community") + + try: + subgraph_nodes = set() + for sg in subgraphs.values(): + subgraph_nodes.update(set(sg.nodes())) + + if with_resolution: + await resolve_entities( + final_graph, + subgraph_nodes, + tenant_id, + kb_id, + None, + chat_model, + embedding_model, + callback, + ) + + if with_community: + await extract_community( + final_graph, + tenant_id, + kb_id, + None, + chat_model, + embedding_model, + callback, + ) + finally: + kb_lock.release() + + now = trio.current_time() + callback(msg=f"[GraphRAG] GraphRAG for KB {kb_id} done in {now - start:.2f} seconds. ok={len(ok_docs)} failed={len(failed_docs)} total_docs={len(doc_ids)} total_chunks={total_chunks}") + return { + "ok_docs": ok_docs, + "failed_docs": failed_docs, # [(doc_id, error), ...] + "total_docs": len(doc_ids), + "total_chunks": total_chunks, + "seconds": now - start, + } + + async def generate_subgraph( extractor: Extractor, tenant_id: str, diff --git a/rag/flow/pipeline.py b/rag/flow/pipeline.py index 36ac531dc..4f2211df0 100644 --- a/rag/flow/pipeline.py +++ b/rag/flow/pipeline.py @@ -34,9 +34,9 @@ class Pipeline(Graph): if isinstance(dsl, dict): dsl = json.dumps(dsl, ensure_ascii=False) super().__init__(dsl, tenant_id, task_id) + self._doc_id = doc_id if self._doc_id == "x": self._doc_id = None - self._doc_id = doc_id self._flow_id = flow_id self._kb_id = None if self._doc_id: diff --git a/rag/nlp/search.py b/rag/nlp/search.py index b1617b9a7..db1423095 100644 --- a/rag/nlp/search.py +++ b/rag/nlp/search.py @@ -383,7 +383,7 @@ class Dealer: vector_column = f"q_{dim}_vec" zero_vector = [0.0] * dim sim_np = np.array(sim) - filtered_count = (sim_np >= similarity_threshold).sum() + filtered_count = (sim_np >= similarity_threshold).sum() ranks["total"] = int(filtered_count) # Convert from np.int64 to Python int otherwise JSON serializable error for i in idx: if sim[i] < similarity_threshold: @@ -444,12 +444,27 @@ class Dealer: def chunk_list(self, doc_id: str, tenant_id: str, kb_ids: list[str], max_count=1024, offset=0, - fields=["docnm_kwd", "content_with_weight", "img_id"]): + fields=["docnm_kwd", "content_with_weight", "img_id"], + sort_by_position: bool = False): condition = {"doc_id": doc_id} + + fields_set = set(fields or []) + if sort_by_position: + for need in ("page_num_int", "position_int", "top_int"): + if need not in fields_set: + fields_set.add(need) + fields = list(fields_set) + + orderBy = OrderByExpr() + if sort_by_position: + orderBy.asc("page_num_int") + orderBy.asc("position_int") + orderBy.asc("top_int") + res = [] bs = 128 for p in range(offset, max_count, bs): - es_res = self.dataStore.search(fields, [], condition, [], OrderByExpr(), p, bs, index_name(tenant_id), + es_res = self.dataStore.search(fields, [], condition, [], orderBy, p, bs, index_name(tenant_id), kb_ids) dict_chunks = self.dataStore.getFields(es_res, fields) for id, doc in dict_chunks.items(): diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index b6ccce655..d60c7e155 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -25,7 +25,7 @@ from api.db.services.pipeline_operation_log_service import PipelineOperationLogS from api.utils.api_utils import timeout from api.utils.base64_image import image2id from api.utils.log_utils import init_root_logger, get_project_base_directory -from graphrag.general.index import run_graphrag +from graphrag.general.index import run_graphrag_for_kb from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache from rag.flow.pipeline import Pipeline from rag.prompts import keyword_extraction, question_proposal, content_tagging @@ -85,6 +85,12 @@ FACTORY = { ParserType.TAG.value: tag } +TASK_TYPE_TO_PIPELINE_TASK_TYPE = { + "dataflow" : PipelineTaskType.PARSE, + "raptor": PipelineTaskType.RAPTOR, + "graphrag": PipelineTaskType.GRAPH_RAG, +} + UNACKED_ITERATOR = None CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1] @@ -215,6 +221,10 @@ async def collect(): canceled = False if msg.get("doc_id", "") == "x": task = msg + if task["task_type"] == "graphrag" and msg.get("doc_ids", []): + print(f"hack {msg['doc_ids']=}=",flush=True) + task = TaskService.get_task(msg["id"], msg["doc_ids"]) + task["doc_ids"] = msg["doc_ids"] else: task = TaskService.get_task(msg["id"]) @@ -580,7 +590,19 @@ async def do_handle_task(task): with_resolution = graphrag_conf.get("resolution", False) with_community = graphrag_conf.get("community", False) async with kg_limiter: - await run_graphrag(task, task_language, with_resolution, with_community, chat_model, embedding_model, progress_callback) + # await run_graphrag(task, task_language, with_resolution, with_community, chat_model, embedding_model, progress_callback) + result = await run_graphrag_for_kb( + row=task, + doc_ids=task.get("doc_ids", []), + language=task_language, + kb_parser_config=task_parser_config, + chat_model=chat_model, + embedding_model=embedding_model, + callback=progress_callback, + with_resolution=with_resolution, + with_community=with_community, + ) + logging.info(f"GraphRAG task result for task {task}:\n{result}") progress_callback(prog=1.0, msg="Knowledge Graph done ({:.2f}s)".format(timer() - start_ts)) return else: @@ -650,7 +672,6 @@ async def do_handle_task(task): timer() - start_ts)) DocumentService.increment_chunk_num(task_doc_id, task_dataset_id, token_count, chunk_count, 0) - PipelineOperationLogService.record_pipeline_operation(document_id=task_doc_id, pipeline_id="", task_type=PipelineTaskType.PARSE) time_cost = timer() - start_ts task_time_cost = timer() - task_start_ts @@ -667,6 +688,10 @@ async def handle_task(): if not task: await trio.sleep(5) return + + task_type = task["task_type"] + pipeline_task_type = TASK_TYPE_TO_PIPELINE_TASK_TYPE.get(task_type, PipelineTaskType.PARSE) or PipelineTaskType.PARSE + try: logging.info(f"handle_task begin for task {json.dumps(task)}") CURRENT_TASKS[task["id"]] = copy.deepcopy(task) @@ -686,7 +711,12 @@ async def handle_task(): except Exception: pass logging.exception(f"handle_task got exception for task {json.dumps(task)}") - PipelineOperationLogService.record_pipeline_operation(document_id=task["doc_id"], pipeline_id=task.get("dataflow_id", "") or "", task_type=PipelineTaskType.PARSE) + finally: + task_document_ids = [] + if task_type in ["graphrag"]: + task_document_ids = task["doc_ids"] + PipelineOperationLogService.record_pipeline_operation(document_id=task["doc_id"], pipeline_id=task.get("dataflow_id", "") or "", task_type=pipeline_task_type, fake_document_ids=task_document_ids) + redis_msg.ack()