Files
ragflow/graphrag/general/index.py
Jin Hai 1a9215bc6f Move some vars to globals (#11017)
### What problem does this PR solve?

As title.

### Type of change

- [x] Refactoring

---------

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2025-11-05 14:14:38 +08:00

515 lines
17 KiB
Python

#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import logging
import os
import networkx as nx
import trio
from api.db.services.document_service import DocumentService
from common.misc_utils import get_uuid
from common.connection_utils import timeout
from graphrag.entity_resolution import EntityResolution
from graphrag.general.community_reports_extractor import CommunityReportsExtractor
from graphrag.general.extractor import Extractor
from graphrag.general.graph_extractor import GraphExtractor as GeneralKGExt
from graphrag.light.graph_extractor import GraphExtractor as LightKGExt
from graphrag.utils import (
GraphChange,
chunk_id,
does_graph_contains,
get_graph,
graph_merge,
set_graph,
tidy_graph,
)
from rag.nlp import rag_tokenizer, search
from rag.utils.redis_conn import RedisDistributedLock
from common import globals
async def run_graphrag(
row: dict,
language,
with_resolution: bool,
with_community: bool,
chat_model,
embedding_model,
callback,
):
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
start = trio.current_time()
tenant_id, kb_id, doc_id = row["tenant_id"], str(row["kb_id"]), row["doc_id"]
chunks = []
for d in globals.retriever.chunk_list(doc_id, tenant_id, [kb_id], fields=["content_with_weight", "doc_id"], sort_by_position=True):
chunks.append(d["content_with_weight"])
with trio.fail_after(max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000):
subgraph = await generate_subgraph(
LightKGExt if "method" not in row["kb_parser_config"].get("graphrag", {}) or row["kb_parser_config"]["graphrag"]["method"] != "general" else GeneralKGExt,
tenant_id,
kb_id,
doc_id,
chunks,
language,
row["kb_parser_config"]["graphrag"].get("entity_types", []),
chat_model,
embedding_model,
callback,
)
if not subgraph:
return
graphrag_task_lock = RedisDistributedLock(f"graphrag_task_{kb_id}", lock_value=doc_id, timeout=1200)
await graphrag_task_lock.spin_acquire()
callback(msg=f"run_graphrag {doc_id} graphrag_task_lock acquired")
try:
subgraph_nodes = set(subgraph.nodes())
new_graph = await merge_subgraph(
tenant_id,
kb_id,
doc_id,
subgraph,
embedding_model,
callback,
)
assert new_graph is not None
if not with_resolution and not with_community:
return
if with_resolution:
await graphrag_task_lock.spin_acquire()
callback(msg=f"run_graphrag {doc_id} graphrag_task_lock acquired")
await resolve_entities(
new_graph,
subgraph_nodes,
tenant_id,
kb_id,
doc_id,
chat_model,
embedding_model,
callback,
)
if with_community:
await graphrag_task_lock.spin_acquire()
callback(msg=f"run_graphrag {doc_id} graphrag_task_lock acquired")
await extract_community(
new_graph,
tenant_id,
kb_id,
doc_id,
chat_model,
embedding_model,
callback,
)
finally:
graphrag_task_lock.release()
now = trio.current_time()
callback(msg=f"GraphRAG for doc {doc_id} done in {now - start:.2f} seconds.")
return
async def run_graphrag_for_kb(
row: dict,
doc_ids: list[str],
language: str,
kb_parser_config: dict,
chat_model,
embedding_model,
callback,
*,
with_resolution: bool = True,
with_community: bool = True,
max_parallel_docs: int = 4,
) -> dict:
tenant_id, kb_id = row["tenant_id"], row["kb_id"]
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
start = trio.current_time()
fields_for_chunks = ["content_with_weight", "doc_id"]
if not doc_ids:
logging.info(f"Fetching all docs for {kb_id}")
docs, _ = DocumentService.get_by_kb_id(
kb_id=kb_id,
page_number=0,
items_per_page=0,
orderby="create_time",
desc=False,
keywords="",
run_status=[],
types=[],
suffix=[],
)
doc_ids = [doc["id"] for doc in docs]
doc_ids = list(dict.fromkeys(doc_ids))
if not doc_ids:
callback(msg=f"[GraphRAG] kb:{kb_id} has no processable doc_id.")
return {"ok_docs": [], "failed_docs": [], "total_docs": 0, "total_chunks": 0, "seconds": 0.0}
def load_doc_chunks(doc_id: str) -> list[str]:
from common.token_utils import num_tokens_from_string
chunks = []
current_chunk = ""
for d in globals.retriever.chunk_list(
doc_id,
tenant_id,
[kb_id],
fields=fields_for_chunks,
sort_by_position=True,
):
content = d["content_with_weight"]
if num_tokens_from_string(current_chunk + content) < 1024:
current_chunk += content
else:
if current_chunk:
chunks.append(current_chunk)
current_chunk = content
if current_chunk:
chunks.append(current_chunk)
return chunks
all_doc_chunks: dict[str, list[str]] = {}
total_chunks = 0
for doc_id in doc_ids:
chunks = load_doc_chunks(doc_id)
all_doc_chunks[doc_id] = chunks
total_chunks += len(chunks)
if total_chunks == 0:
callback(msg=f"[GraphRAG] kb:{kb_id} has no available chunks in all documents, skip.")
return {"ok_docs": [], "failed_docs": doc_ids, "total_docs": len(doc_ids), "total_chunks": 0, "seconds": 0.0}
semaphore = trio.Semaphore(max_parallel_docs)
subgraphs: dict[str, object] = {}
failed_docs: list[tuple[str, str]] = [] # (doc_id, error)
async def build_one(doc_id: str):
chunks = all_doc_chunks.get(doc_id, [])
if not chunks:
callback(msg=f"[GraphRAG] doc:{doc_id} has no available chunks, skip generation.")
return
kg_extractor = LightKGExt if ("method" not in kb_parser_config.get("graphrag", {}) or kb_parser_config["graphrag"]["method"] != "general") else GeneralKGExt
deadline = max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000
async with semaphore:
try:
msg = f"[GraphRAG] build_subgraph doc:{doc_id}"
callback(msg=f"{msg} start (chunks={len(chunks)}, timeout={deadline}s)")
with trio.fail_after(deadline):
sg = await generate_subgraph(
kg_extractor,
tenant_id,
kb_id,
doc_id,
chunks,
language,
kb_parser_config.get("graphrag", {}).get("entity_types", []),
chat_model,
embedding_model,
callback,
)
if sg:
subgraphs[doc_id] = sg
callback(msg=f"{msg} done")
else:
failed_docs.append((doc_id, "subgraph is empty"))
callback(msg=f"{msg} empty")
except Exception as e:
failed_docs.append((doc_id, repr(e)))
callback(msg=f"[GraphRAG] build_subgraph doc:{doc_id} FAILED: {e!r}")
async with trio.open_nursery() as nursery:
for doc_id in doc_ids:
nursery.start_soon(build_one, doc_id)
ok_docs = [d for d in doc_ids if d in subgraphs]
if not ok_docs:
callback(msg=f"[GraphRAG] kb:{kb_id} no subgraphs generated successfully, end.")
now = trio.current_time()
return {"ok_docs": [], "failed_docs": failed_docs, "total_docs": len(doc_ids), "total_chunks": total_chunks, "seconds": now - start}
kb_lock = RedisDistributedLock(f"graphrag_task_{kb_id}", lock_value="batch_merge", timeout=1200)
await kb_lock.spin_acquire()
callback(msg=f"[GraphRAG] kb:{kb_id} merge lock acquired")
try:
union_nodes: set = set()
final_graph = None
for doc_id in ok_docs:
sg = subgraphs[doc_id]
union_nodes.update(set(sg.nodes()))
new_graph = await merge_subgraph(
tenant_id,
kb_id,
doc_id,
sg,
embedding_model,
callback,
)
if new_graph is not None:
final_graph = new_graph
if final_graph is None:
callback(msg=f"[GraphRAG] kb:{kb_id} merge finished (no in-memory graph returned).")
else:
callback(msg=f"[GraphRAG] kb:{kb_id} merge finished, graph ready.")
finally:
kb_lock.release()
if not with_resolution and not with_community:
now = trio.current_time()
callback(msg=f"[GraphRAG] KB merge done in {now - start:.2f}s. ok={len(ok_docs)} / total={len(doc_ids)}")
return {"ok_docs": ok_docs, "failed_docs": failed_docs, "total_docs": len(doc_ids), "total_chunks": total_chunks, "seconds": now - start}
await kb_lock.spin_acquire()
callback(msg=f"[GraphRAG] kb:{kb_id} post-merge lock acquired for resolution/community")
try:
subgraph_nodes = set()
for sg in subgraphs.values():
subgraph_nodes.update(set(sg.nodes()))
if with_resolution:
await resolve_entities(
final_graph,
subgraph_nodes,
tenant_id,
kb_id,
None,
chat_model,
embedding_model,
callback,
)
if with_community:
await extract_community(
final_graph,
tenant_id,
kb_id,
None,
chat_model,
embedding_model,
callback,
)
finally:
kb_lock.release()
now = trio.current_time()
callback(msg=f"[GraphRAG] GraphRAG for KB {kb_id} done in {now - start:.2f} seconds. ok={len(ok_docs)} failed={len(failed_docs)} total_docs={len(doc_ids)} total_chunks={total_chunks}")
return {
"ok_docs": ok_docs,
"failed_docs": failed_docs, # [(doc_id, error), ...]
"total_docs": len(doc_ids),
"total_chunks": total_chunks,
"seconds": now - start,
}
async def generate_subgraph(
extractor: Extractor,
tenant_id: str,
kb_id: str,
doc_id: str,
chunks: list[str],
language,
entity_types,
llm_bdl,
embed_bdl,
callback,
):
contains = await does_graph_contains(tenant_id, kb_id, doc_id)
if contains:
callback(msg=f"Graph already contains {doc_id}")
return None
start = trio.current_time()
ext = extractor(
llm_bdl,
language=language,
entity_types=entity_types,
)
ents, rels = await ext(doc_id, chunks, callback)
subgraph = nx.Graph()
for ent in ents:
assert "description" in ent, f"entity {ent} does not have description"
ent["source_id"] = [doc_id]
subgraph.add_node(ent["entity_name"], **ent)
ignored_rels = 0
for rel in rels:
assert "description" in rel, f"relation {rel} does not have description"
if not subgraph.has_node(rel["src_id"]) or not subgraph.has_node(rel["tgt_id"]):
ignored_rels += 1
continue
rel["source_id"] = [doc_id]
subgraph.add_edge(
rel["src_id"],
rel["tgt_id"],
**rel,
)
if ignored_rels:
callback(msg=f"ignored {ignored_rels} relations due to missing entities.")
tidy_graph(subgraph, callback, check_attribute=False)
subgraph.graph["source_id"] = [doc_id]
chunk = {
"content_with_weight": json.dumps(nx.node_link_data(subgraph, edges="edges"), ensure_ascii=False),
"knowledge_graph_kwd": "subgraph",
"kb_id": kb_id,
"source_id": [doc_id],
"available_int": 0,
"removed_kwd": "N",
}
cid = chunk_id(chunk)
await trio.to_thread.run_sync(globals.docStoreConn.delete, {"knowledge_graph_kwd": "subgraph", "source_id": doc_id}, search.index_name(tenant_id), kb_id)
await trio.to_thread.run_sync(globals.docStoreConn.insert, [{"id": cid, **chunk}], search.index_name(tenant_id), kb_id)
now = trio.current_time()
callback(msg=f"generated subgraph for doc {doc_id} in {now - start:.2f} seconds.")
return subgraph
@timeout(60 * 3)
async def merge_subgraph(
tenant_id: str,
kb_id: str,
doc_id: str,
subgraph: nx.Graph,
embedding_model,
callback,
):
start = trio.current_time()
change = GraphChange()
old_graph = await get_graph(tenant_id, kb_id, subgraph.graph["source_id"])
if old_graph is not None:
logging.info("Merge with an exiting graph...................")
tidy_graph(old_graph, callback)
new_graph = graph_merge(old_graph, subgraph, change)
else:
new_graph = subgraph
change.added_updated_nodes = set(new_graph.nodes())
change.added_updated_edges = set(new_graph.edges())
pr = nx.pagerank(new_graph)
for node_name, pagerank in pr.items():
new_graph.nodes[node_name]["pagerank"] = pagerank
await set_graph(tenant_id, kb_id, embedding_model, new_graph, change, callback)
now = trio.current_time()
callback(msg=f"merging subgraph for doc {doc_id} into the global graph done in {now - start:.2f} seconds.")
return new_graph
@timeout(60 * 30, 1)
async def resolve_entities(
graph,
subgraph_nodes: set[str],
tenant_id: str,
kb_id: str,
doc_id: str,
llm_bdl,
embed_bdl,
callback,
):
start = trio.current_time()
er = EntityResolution(
llm_bdl,
)
reso = await er(graph, subgraph_nodes, callback=callback)
graph = reso.graph
change = reso.change
callback(msg=f"Graph resolution removed {len(change.removed_nodes)} nodes and {len(change.removed_edges)} edges.")
callback(msg="Graph resolution updated pagerank.")
await set_graph(tenant_id, kb_id, embed_bdl, graph, change, callback)
now = trio.current_time()
callback(msg=f"Graph resolution done in {now - start:.2f}s.")
@timeout(60 * 30, 1)
async def extract_community(
graph,
tenant_id: str,
kb_id: str,
doc_id: str,
llm_bdl,
embed_bdl,
callback,
):
start = trio.current_time()
ext = CommunityReportsExtractor(
llm_bdl,
)
cr = await ext(graph, callback=callback)
community_structure = cr.structured_output
community_reports = cr.output
doc_ids = graph.graph["source_id"]
now = trio.current_time()
callback(msg=f"Graph extracted {len(cr.structured_output)} communities in {now - start:.2f}s.")
start = now
chunks = []
for stru, rep in zip(community_structure, community_reports):
obj = {
"report": rep,
"evidences": "\n".join([f.get("explanation", "") for f in stru["findings"]]),
}
chunk = {
"id": get_uuid(),
"docnm_kwd": stru["title"],
"title_tks": rag_tokenizer.tokenize(stru["title"]),
"content_with_weight": json.dumps(obj, ensure_ascii=False),
"content_ltks": rag_tokenizer.tokenize(obj["report"] + " " + obj["evidences"]),
"knowledge_graph_kwd": "community_report",
"weight_flt": stru["weight"],
"entities_kwd": stru["entities"],
"important_kwd": stru["entities"],
"kb_id": kb_id,
"source_id": list(doc_ids),
"available_int": 0,
}
chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
chunks.append(chunk)
await trio.to_thread.run_sync(
lambda: globals.docStoreConn.delete(
{"knowledge_graph_kwd": "community_report", "kb_id": kb_id},
search.index_name(tenant_id),
kb_id,
)
)
es_bulk_size = 4
for b in range(0, len(chunks), es_bulk_size):
doc_store_result = await trio.to_thread.run_sync(lambda: globals.docStoreConn.insert(chunks[b : b + es_bulk_size], search.index_name(tenant_id), kb_id))
if doc_store_result:
error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
raise Exception(error_message)
now = trio.current_time()
callback(msg=f"Graph indexed {len(cr.structured_output)} communities in {now - start:.2f}s.")
return community_structure, community_reports