# # Copyright 2025 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import json import logging import os import networkx as nx import trio from api import settings from api.db.services.document_service import DocumentService from api.utils import get_uuid from api.utils.api_utils import timeout from graphrag.entity_resolution import EntityResolution from graphrag.general.community_reports_extractor import CommunityReportsExtractor from graphrag.general.extractor import Extractor from graphrag.general.graph_extractor import GraphExtractor as GeneralKGExt from graphrag.light.graph_extractor import GraphExtractor as LightKGExt from graphrag.utils import ( GraphChange, chunk_id, does_graph_contains, get_graph, graph_merge, set_graph, tidy_graph, ) from rag.nlp import rag_tokenizer, search from rag.utils.redis_conn import RedisDistributedLock async def run_graphrag( row: dict, language, with_resolution: bool, with_community: bool, chat_model, embedding_model, callback, ): enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION") start = trio.current_time() tenant_id, kb_id, doc_id = row["tenant_id"], str(row["kb_id"]), row["doc_id"] chunks = [] for d in settings.retrievaler.chunk_list(doc_id, tenant_id, [kb_id], fields=["content_with_weight", "doc_id"], sort_by_position=True): chunks.append(d["content_with_weight"]) with trio.fail_after(max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000): subgraph = await generate_subgraph( LightKGExt if "method" not in row["kb_parser_config"].get("graphrag", {}) or row["kb_parser_config"]["graphrag"]["method"] != "general" else GeneralKGExt, tenant_id, kb_id, doc_id, chunks, language, row["kb_parser_config"]["graphrag"].get("entity_types", []), chat_model, embedding_model, callback, ) if not subgraph: return graphrag_task_lock = RedisDistributedLock(f"graphrag_task_{kb_id}", lock_value=doc_id, timeout=1200) await graphrag_task_lock.spin_acquire() callback(msg=f"run_graphrag {doc_id} graphrag_task_lock acquired") try: subgraph_nodes = set(subgraph.nodes()) new_graph = await merge_subgraph( tenant_id, kb_id, doc_id, subgraph, embedding_model, callback, ) assert new_graph is not None if not with_resolution and not with_community: return if with_resolution: await graphrag_task_lock.spin_acquire() callback(msg=f"run_graphrag {doc_id} graphrag_task_lock acquired") await resolve_entities( new_graph, subgraph_nodes, tenant_id, kb_id, doc_id, chat_model, embedding_model, callback, ) if with_community: await graphrag_task_lock.spin_acquire() callback(msg=f"run_graphrag {doc_id} graphrag_task_lock acquired") await extract_community( new_graph, tenant_id, kb_id, doc_id, chat_model, embedding_model, callback, ) finally: graphrag_task_lock.release() now = trio.current_time() callback(msg=f"GraphRAG for doc {doc_id} done in {now - start:.2f} seconds.") return async def run_graphrag_for_kb( row: dict, doc_ids: list[str], language: str, kb_parser_config: dict, chat_model, embedding_model, callback, *, with_resolution: bool = True, with_community: bool = True, max_parallel_docs: int = 4, ) -> dict: tenant_id, kb_id = row["tenant_id"], row["kb_id"] enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION") start = trio.current_time() fields_for_chunks = ["content_with_weight", "doc_id"] if not doc_ids: logging.info(f"Fetching all docs for {kb_id}") docs, _ = DocumentService.get_by_kb_id( kb_id=kb_id, page_number=0, items_per_page=0, orderby="create_time", desc=False, keywords="", run_status=[], types=[], suffix=[], ) doc_ids = [doc["id"] for doc in docs] doc_ids = list(dict.fromkeys(doc_ids)) if not doc_ids: callback(msg=f"[GraphRAG] kb:{kb_id} has no processable doc_id.") return {"ok_docs": [], "failed_docs": [], "total_docs": 0, "total_chunks": 0, "seconds": 0.0} def load_doc_chunks(doc_id: str) -> list[str]: from rag.utils import num_tokens_from_string chunks = [] current_chunk = "" for d in settings.retrievaler.chunk_list( doc_id, tenant_id, [kb_id], fields=fields_for_chunks, sort_by_position=True, ): content = d["content_with_weight"] if num_tokens_from_string(current_chunk + content) < 1024: current_chunk += content else: if current_chunk: chunks.append(current_chunk) current_chunk = content if current_chunk: chunks.append(current_chunk) return chunks all_doc_chunks: dict[str, list[str]] = {} total_chunks = 0 for doc_id in doc_ids: chunks = load_doc_chunks(doc_id) all_doc_chunks[doc_id] = chunks total_chunks += len(chunks) if total_chunks == 0: callback(msg=f"[GraphRAG] kb:{kb_id} has no available chunks in all documents, skip.") return {"ok_docs": [], "failed_docs": doc_ids, "total_docs": len(doc_ids), "total_chunks": 0, "seconds": 0.0} semaphore = trio.Semaphore(max_parallel_docs) subgraphs: dict[str, object] = {} failed_docs: list[tuple[str, str]] = [] # (doc_id, error) async def build_one(doc_id: str): chunks = all_doc_chunks.get(doc_id, []) if not chunks: callback(msg=f"[GraphRAG] doc:{doc_id} has no available chunks, skip generation.") return kg_extractor = LightKGExt if ("method" not in kb_parser_config.get("graphrag", {}) or kb_parser_config["graphrag"]["method"] != "general") else GeneralKGExt deadline = max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000 async with semaphore: try: msg = f"[GraphRAG] build_subgraph doc:{doc_id}" callback(msg=f"{msg} start (chunks={len(chunks)}, timeout={deadline}s)") with trio.fail_after(deadline): sg = await generate_subgraph( kg_extractor, tenant_id, kb_id, doc_id, chunks, language, kb_parser_config.get("graphrag", {}).get("entity_types", []), chat_model, embedding_model, callback, ) if sg: subgraphs[doc_id] = sg callback(msg=f"{msg} done") else: failed_docs.append((doc_id, "subgraph is empty")) callback(msg=f"{msg} empty") except Exception as e: failed_docs.append((doc_id, repr(e))) callback(msg=f"[GraphRAG] build_subgraph doc:{doc_id} FAILED: {e!r}") async with trio.open_nursery() as nursery: for doc_id in doc_ids: nursery.start_soon(build_one, doc_id) ok_docs = [d for d in doc_ids if d in subgraphs] if not ok_docs: callback(msg=f"[GraphRAG] kb:{kb_id} no subgraphs generated successfully, end.") now = trio.current_time() return {"ok_docs": [], "failed_docs": failed_docs, "total_docs": len(doc_ids), "total_chunks": total_chunks, "seconds": now - start} kb_lock = RedisDistributedLock(f"graphrag_task_{kb_id}", lock_value="batch_merge", timeout=1200) await kb_lock.spin_acquire() callback(msg=f"[GraphRAG] kb:{kb_id} merge lock acquired") try: union_nodes: set = set() final_graph = None for doc_id in ok_docs: sg = subgraphs[doc_id] union_nodes.update(set(sg.nodes())) new_graph = await merge_subgraph( tenant_id, kb_id, doc_id, sg, embedding_model, callback, ) if new_graph is not None: final_graph = new_graph if final_graph is None: callback(msg=f"[GraphRAG] kb:{kb_id} merge finished (no in-memory graph returned).") else: callback(msg=f"[GraphRAG] kb:{kb_id} merge finished, graph ready.") finally: kb_lock.release() if not with_resolution and not with_community: now = trio.current_time() callback(msg=f"[GraphRAG] KB merge done in {now - start:.2f}s. ok={len(ok_docs)} / total={len(doc_ids)}") return {"ok_docs": ok_docs, "failed_docs": failed_docs, "total_docs": len(doc_ids), "total_chunks": total_chunks, "seconds": now - start} await kb_lock.spin_acquire() callback(msg=f"[GraphRAG] kb:{kb_id} post-merge lock acquired for resolution/community") try: subgraph_nodes = set() for sg in subgraphs.values(): subgraph_nodes.update(set(sg.nodes())) if with_resolution: await resolve_entities( final_graph, subgraph_nodes, tenant_id, kb_id, None, chat_model, embedding_model, callback, ) if with_community: await extract_community( final_graph, tenant_id, kb_id, None, chat_model, embedding_model, callback, ) finally: kb_lock.release() now = trio.current_time() callback(msg=f"[GraphRAG] GraphRAG for KB {kb_id} done in {now - start:.2f} seconds. ok={len(ok_docs)} failed={len(failed_docs)} total_docs={len(doc_ids)} total_chunks={total_chunks}") return { "ok_docs": ok_docs, "failed_docs": failed_docs, # [(doc_id, error), ...] "total_docs": len(doc_ids), "total_chunks": total_chunks, "seconds": now - start, } async def generate_subgraph( extractor: Extractor, tenant_id: str, kb_id: str, doc_id: str, chunks: list[str], language, entity_types, llm_bdl, embed_bdl, callback, ): contains = await does_graph_contains(tenant_id, kb_id, doc_id) if contains: callback(msg=f"Graph already contains {doc_id}") return None start = trio.current_time() ext = extractor( llm_bdl, language=language, entity_types=entity_types, ) ents, rels = await ext(doc_id, chunks, callback) subgraph = nx.Graph() for ent in ents: assert "description" in ent, f"entity {ent} does not have description" ent["source_id"] = [doc_id] subgraph.add_node(ent["entity_name"], **ent) ignored_rels = 0 for rel in rels: assert "description" in rel, f"relation {rel} does not have description" if not subgraph.has_node(rel["src_id"]) or not subgraph.has_node(rel["tgt_id"]): ignored_rels += 1 continue rel["source_id"] = [doc_id] subgraph.add_edge( rel["src_id"], rel["tgt_id"], **rel, ) if ignored_rels: callback(msg=f"ignored {ignored_rels} relations due to missing entities.") tidy_graph(subgraph, callback, check_attribute=False) subgraph.graph["source_id"] = [doc_id] chunk = { "content_with_weight": json.dumps(nx.node_link_data(subgraph, edges="edges"), ensure_ascii=False), "knowledge_graph_kwd": "subgraph", "kb_id": kb_id, "source_id": [doc_id], "available_int": 0, "removed_kwd": "N", } cid = chunk_id(chunk) await trio.to_thread.run_sync(settings.docStoreConn.delete, {"knowledge_graph_kwd": "subgraph", "source_id": doc_id}, search.index_name(tenant_id), kb_id) await trio.to_thread.run_sync(settings.docStoreConn.insert, [{"id": cid, **chunk}], search.index_name(tenant_id), kb_id) now = trio.current_time() callback(msg=f"generated subgraph for doc {doc_id} in {now - start:.2f} seconds.") return subgraph @timeout(60 * 3) async def merge_subgraph( tenant_id: str, kb_id: str, doc_id: str, subgraph: nx.Graph, embedding_model, callback, ): start = trio.current_time() change = GraphChange() old_graph = await get_graph(tenant_id, kb_id, subgraph.graph["source_id"]) if old_graph is not None: logging.info("Merge with an exiting graph...................") tidy_graph(old_graph, callback) new_graph = graph_merge(old_graph, subgraph, change) else: new_graph = subgraph change.added_updated_nodes = set(new_graph.nodes()) change.added_updated_edges = set(new_graph.edges()) pr = nx.pagerank(new_graph) for node_name, pagerank in pr.items(): new_graph.nodes[node_name]["pagerank"] = pagerank await set_graph(tenant_id, kb_id, embedding_model, new_graph, change, callback) now = trio.current_time() callback(msg=f"merging subgraph for doc {doc_id} into the global graph done in {now - start:.2f} seconds.") return new_graph @timeout(60 * 30, 1) async def resolve_entities( graph, subgraph_nodes: set[str], tenant_id: str, kb_id: str, doc_id: str, llm_bdl, embed_bdl, callback, ): start = trio.current_time() er = EntityResolution( llm_bdl, ) reso = await er(graph, subgraph_nodes, callback=callback) graph = reso.graph change = reso.change callback(msg=f"Graph resolution removed {len(change.removed_nodes)} nodes and {len(change.removed_edges)} edges.") callback(msg="Graph resolution updated pagerank.") await set_graph(tenant_id, kb_id, embed_bdl, graph, change, callback) now = trio.current_time() callback(msg=f"Graph resolution done in {now - start:.2f}s.") @timeout(60 * 30, 1) async def extract_community( graph, tenant_id: str, kb_id: str, doc_id: str, llm_bdl, embed_bdl, callback, ): start = trio.current_time() ext = CommunityReportsExtractor( llm_bdl, ) cr = await ext(graph, callback=callback) community_structure = cr.structured_output community_reports = cr.output doc_ids = graph.graph["source_id"] now = trio.current_time() callback(msg=f"Graph extracted {len(cr.structured_output)} communities in {now - start:.2f}s.") start = now chunks = [] for stru, rep in zip(community_structure, community_reports): obj = { "report": rep, "evidences": "\n".join([f.get("explanation", "") for f in stru["findings"]]), } chunk = { "id": get_uuid(), "docnm_kwd": stru["title"], "title_tks": rag_tokenizer.tokenize(stru["title"]), "content_with_weight": json.dumps(obj, ensure_ascii=False), "content_ltks": rag_tokenizer.tokenize(obj["report"] + " " + obj["evidences"]), "knowledge_graph_kwd": "community_report", "weight_flt": stru["weight"], "entities_kwd": stru["entities"], "important_kwd": stru["entities"], "kb_id": kb_id, "source_id": list(doc_ids), "available_int": 0, } chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"]) chunks.append(chunk) await trio.to_thread.run_sync( lambda: settings.docStoreConn.delete( {"knowledge_graph_kwd": "community_report", "kb_id": kb_id}, search.index_name(tenant_id), kb_id, ) ) es_bulk_size = 4 for b in range(0, len(chunks), es_bulk_size): doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b : b + es_bulk_size], search.index_name(tenant_id), kb_id)) if doc_store_result: error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!" raise Exception(error_message) now = trio.current_time() callback(msg=f"Graph indexed {len(cr.structured_output)} communities in {now - start:.2f}s.") return community_structure, community_reports