Optimize graphrag again (#6513)

### What problem does this PR solve?

Removed set_entity and set_relation to avoid accessing doc engine during
graph computation.
Introduced GraphChange to avoid writing unchanged chunks.

### Type of change

- [x] Performance Improvement
This commit is contained in:
Zhichang Yu
2025-03-26 15:34:42 +08:00
committed by GitHub
parent 7a677cb095
commit 6bf26e2a81
19 changed files with 466 additions and 530 deletions

View File

@ -16,7 +16,6 @@
import logging
import itertools
import re
import time
from dataclasses import dataclass
from typing import Any, Callable
@ -28,7 +27,7 @@ from rag.nlp import is_english
import editdistance
from graphrag.entity_resolution_prompt import ENTITY_RESOLUTION_PROMPT
from rag.llm.chat_model import Base as CompletionLLM
from graphrag.utils import perform_variable_replacements, chat_limiter
from graphrag.utils import perform_variable_replacements, chat_limiter, GraphChange
DEFAULT_RECORD_DELIMITER = "##"
DEFAULT_ENTITY_INDEX_DELIMITER = "<|>"
@ -39,7 +38,7 @@ DEFAULT_RESOLUTION_RESULT_DELIMITER = "&&"
class EntityResolutionResult:
"""Entity resolution result class definition."""
graph: nx.Graph
removed_entities: list
change: GraphChange
class EntityResolution(Extractor):
@ -54,12 +53,8 @@ class EntityResolution(Extractor):
def __init__(
self,
llm_invoker: CompletionLLM,
get_entity: Callable | None = None,
set_entity: Callable | None = None,
get_relation: Callable | None = None,
set_relation: Callable | None = None
):
super().__init__(llm_invoker, get_entity=get_entity, set_entity=set_entity, get_relation=get_relation, set_relation=set_relation)
super().__init__(llm_invoker)
"""Init method definition."""
self._llm = llm_invoker
self._resolution_prompt = ENTITY_RESOLUTION_PROMPT
@ -84,8 +79,8 @@ class EntityResolution(Extractor):
or DEFAULT_RESOLUTION_RESULT_DELIMITER,
}
nodes = graph.nodes
entity_types = list(set(graph.nodes[node].get('entity_type', '-') for node in nodes))
nodes = sorted(graph.nodes())
entity_types = sorted(set(graph.nodes[node].get('entity_type', '-') for node in nodes))
node_clusters = {entity_type: [] for entity_type in entity_types}
for node in nodes:
@ -105,54 +100,22 @@ class EntityResolution(Extractor):
nursery.start_soon(lambda: self._resolve_candidate(candidate_resolution_i, resolution_result))
callback(msg=f"Resolved {num_candidates} candidate pairs, {len(resolution_result)} of them are selected to merge.")
change = GraphChange()
connect_graph = nx.Graph()
removed_entities = []
connect_graph.add_edges_from(resolution_result)
all_entities_data = []
all_relationships_data = []
all_remove_nodes = []
async with trio.open_nursery() as nursery:
for sub_connect_graph in nx.connected_components(connect_graph):
sub_connect_graph = connect_graph.subgraph(sub_connect_graph)
remove_nodes = list(sub_connect_graph.nodes)
keep_node = remove_nodes.pop()
all_remove_nodes.append(remove_nodes)
nursery.start_soon(lambda: self._merge_nodes(keep_node, self._get_entity_(remove_nodes), all_entities_data))
for remove_node in remove_nodes:
removed_entities.append(remove_node)
remove_node_neighbors = graph[remove_node]
remove_node_neighbors = list(remove_node_neighbors)
for remove_node_neighbor in remove_node_neighbors:
rel = self._get_relation_(remove_node, remove_node_neighbor)
if graph.has_edge(remove_node, remove_node_neighbor):
graph.remove_edge(remove_node, remove_node_neighbor)
if remove_node_neighbor == keep_node:
if graph.has_edge(keep_node, remove_node):
graph.remove_edge(keep_node, remove_node)
continue
if not rel:
continue
if graph.has_edge(keep_node, remove_node_neighbor):
nursery.start_soon(lambda: self._merge_edges(keep_node, remove_node_neighbor, [rel], all_relationships_data))
else:
pair = sorted([keep_node, remove_node_neighbor])
graph.add_edge(pair[0], pair[1], weight=rel['weight'])
self._set_relation_(pair[0], pair[1],
dict(
src_id=pair[0],
tgt_id=pair[1],
weight=rel['weight'],
description=rel['description'],
keywords=[],
source_id=rel.get("source_id", ""),
metadata={"created_at": time.time()}
))
graph.remove_node(remove_node)
merging_nodes = list(sub_connect_graph.nodes)
nursery.start_soon(lambda: self._merge_graph_nodes(graph, merging_nodes, change))
# Update pagerank
pr = nx.pagerank(graph)
for node_name, pagerank in pr.items():
graph.nodes[node_name]["pagerank"] = pagerank
return EntityResolutionResult(
graph=graph,
removed_entities=removed_entities
change=change,
)
async def _resolve_candidate(self, candidate_resolution_i, resolution_result):