Optimized graphrag again (#5927)

### What problem does this PR solve?

Optimized graphrag again

### Type of change

- [x] Performance Improvement
This commit is contained in:
Zhichang Yu
2025-03-11 18:36:10 +08:00
committed by GitHub
parent 45318e7575
commit 939e668096
4 changed files with 117 additions and 101 deletions

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import itertools
import re
import time
@ -67,7 +68,7 @@ class EntityResolution(Extractor):
self._resolution_result_delimiter_key = "resolution_result_delimiter"
self._input_text_key = "input_text"
async def __call__(self, graph: nx.Graph, prompt_variables: dict[str, Any] | None = None) -> EntityResolutionResult:
async def __call__(self, graph: nx.Graph, prompt_variables: dict[str, Any] | None = None, callback: Callable | None = None) -> EntityResolutionResult:
"""Call method definition."""
if prompt_variables is None:
prompt_variables = {}
@ -93,6 +94,8 @@ class EntityResolution(Extractor):
candidate_resolution = {entity_type: [] for entity_type in entity_types}
for k, v in node_clusters.items():
candidate_resolution[k] = [(a, b) for a, b in itertools.combinations(v, 2) if self.is_similarity(a, b)]
num_candidates = sum([len(candidates) for _, candidates in candidate_resolution.items()])
callback(msg=f"Identified {num_candidates} candidate pairs")
resolution_result = set()
async with trio.open_nursery() as nursery:
@ -100,48 +103,52 @@ class EntityResolution(Extractor):
if not candidate_resolution_i[1]:
continue
nursery.start_soon(lambda: self._resolve_candidate(candidate_resolution_i, resolution_result))
callback(msg=f"Resolved {num_candidates} candidate pairs, {len(resolution_result)} of them are selected to merge.")
connect_graph = nx.Graph()
removed_entities = []
connect_graph.add_edges_from(resolution_result)
all_entities_data = []
all_relationships_data = []
all_remove_nodes = []
for sub_connect_graph in nx.connected_components(connect_graph):
sub_connect_graph = connect_graph.subgraph(sub_connect_graph)
remove_nodes = list(sub_connect_graph.nodes)
keep_node = remove_nodes.pop()
await self._merge_nodes(keep_node, self._get_entity_(remove_nodes), all_entities_data)
for remove_node in remove_nodes:
removed_entities.append(remove_node)
remove_node_neighbors = graph[remove_node]
remove_node_neighbors = list(remove_node_neighbors)
for remove_node_neighbor in remove_node_neighbors:
rel = self._get_relation_(remove_node, remove_node_neighbor)
if graph.has_edge(remove_node, remove_node_neighbor):
graph.remove_edge(remove_node, remove_node_neighbor)
if remove_node_neighbor == keep_node:
if graph.has_edge(keep_node, remove_node):
graph.remove_edge(keep_node, remove_node)
continue
if not rel:
continue
if graph.has_edge(keep_node, remove_node_neighbor):
await self._merge_edges(keep_node, remove_node_neighbor, [rel], all_relationships_data)
else:
pair = sorted([keep_node, remove_node_neighbor])
graph.add_edge(pair[0], pair[1], weight=rel['weight'])
self._set_relation_(pair[0], pair[1],
dict(
src_id=pair[0],
tgt_id=pair[1],
weight=rel['weight'],
description=rel['description'],
keywords=[],
source_id=rel.get("source_id", ""),
metadata={"created_at": time.time()}
))
graph.remove_node(remove_node)
async with trio.open_nursery() as nursery:
for sub_connect_graph in nx.connected_components(connect_graph):
sub_connect_graph = connect_graph.subgraph(sub_connect_graph)
remove_nodes = list(sub_connect_graph.nodes)
keep_node = remove_nodes.pop()
all_remove_nodes.append(remove_nodes)
nursery.start_soon(lambda: self._merge_nodes(keep_node, self._get_entity_(remove_nodes), all_entities_data))
for remove_node in remove_nodes:
removed_entities.append(remove_node)
remove_node_neighbors = graph[remove_node]
remove_node_neighbors = list(remove_node_neighbors)
for remove_node_neighbor in remove_node_neighbors:
rel = self._get_relation_(remove_node, remove_node_neighbor)
if graph.has_edge(remove_node, remove_node_neighbor):
graph.remove_edge(remove_node, remove_node_neighbor)
if remove_node_neighbor == keep_node:
if graph.has_edge(keep_node, remove_node):
graph.remove_edge(keep_node, remove_node)
continue
if not rel:
continue
if graph.has_edge(keep_node, remove_node_neighbor):
nursery.start_soon(lambda: self._merge_edges(keep_node, remove_node_neighbor, [rel], all_relationships_data))
else:
pair = sorted([keep_node, remove_node_neighbor])
graph.add_edge(pair[0], pair[1], weight=rel['weight'])
self._set_relation_(pair[0], pair[1],
dict(
src_id=pair[0],
tgt_id=pair[1],
weight=rel['weight'],
description=rel['description'],
keywords=[],
source_id=rel.get("source_id", ""),
metadata={"created_at": time.time()}
))
graph.remove_node(remove_node)
return EntityResolutionResult(
graph=graph,
@ -164,8 +171,10 @@ class EntityResolution(Extractor):
self._input_text_key: pair_prompt
}
text = perform_variable_replacements(self._resolution_prompt, variables=variables)
logging.info(f"Created resolution prompt {len(text)} bytes for {len(candidate_resolution_i[1])} entity pairs of type {candidate_resolution_i[0]}")
async with chat_limiter:
response = await trio.to_thread.run_sync(lambda: self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf))
logging.debug(f"_resolve_candidate chat prompt: {text}\nchat response: {response}")
result = self._process_results(len(candidate_resolution_i[1]), response,
self.prompt_variables.get(self._record_delimiter_key,
DEFAULT_RECORD_DELIMITER),