From 36b62e0fab54042654ed29cc39603172b370be72 Mon Sep 17 00:00:00 2001 From: Zhichang Yu Date: Thu, 27 Mar 2025 16:40:36 +0800 Subject: [PATCH] EntityResolution batch. Close #6570 (#6602) ### What problem does this PR solve? EntityResolution batch ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- graphrag/entity_resolution.py | 14 ++++++++++---- graphrag/general/index.py | 34 ++++++++++++++++++---------------- 2 files changed, 28 insertions(+), 20 deletions(-) diff --git a/graphrag/entity_resolution.py b/graphrag/entity_resolution.py index 052298534..67d351a73 100644 --- a/graphrag/entity_resolution.py +++ b/graphrag/entity_resolution.py @@ -63,7 +63,10 @@ class EntityResolution(Extractor): self._resolution_result_delimiter_key = "resolution_result_delimiter" self._input_text_key = "input_text" - async def __call__(self, graph: nx.Graph, prompt_variables: dict[str, Any] | None = None, callback: Callable | None = None) -> EntityResolutionResult: + async def __call__(self, graph: nx.Graph, + subgraph_nodes: set[str], + prompt_variables: dict[str, Any] | None = None, + callback: Callable | None = None) -> EntityResolutionResult: """Call method definition.""" if prompt_variables is None: prompt_variables = {} @@ -88,16 +91,19 @@ class EntityResolution(Extractor): candidate_resolution = {entity_type: [] for entity_type in entity_types} for k, v in node_clusters.items(): - candidate_resolution[k] = [(a, b) for a, b in itertools.combinations(v, 2) if self.is_similarity(a, b)] + candidate_resolution[k] = [(a, b) for a, b in itertools.combinations(v, 2) if (a in subgraph_nodes or b in subgraph_nodes) and self.is_similarity(a, b)] num_candidates = sum([len(candidates) for _, candidates in candidate_resolution.items()]) callback(msg=f"Identified {num_candidates} candidate pairs") resolution_result = set() + resolution_batch_size = 100 async with trio.open_nursery() as nursery: for candidate_resolution_i in candidate_resolution.items(): if not candidate_resolution_i[1]: continue - nursery.start_soon(lambda: self._resolve_candidate(candidate_resolution_i, resolution_result)) + for i in range(0, len(candidate_resolution_i[1]), resolution_batch_size): + candidate_batch = candidate_resolution_i[0], candidate_resolution_i[1][i:i + resolution_batch_size] + nursery.start_soon(lambda: self._resolve_candidate(candidate_batch, resolution_result)) callback(msg=f"Resolved {num_candidates} candidate pairs, {len(resolution_result)} of them are selected to merge.") change = GraphChange() @@ -118,7 +124,7 @@ class EntityResolution(Extractor): change=change, ) - async def _resolve_candidate(self, candidate_resolution_i, resolution_result): + async def _resolve_candidate(self, candidate_resolution_i: tuple[str, list[tuple[str, str]]], resolution_result: set[str]): gen_conf = {"temperature": 0.5} pair_txt = [ f'When determining whether two {candidate_resolution_i[0]}s are the same, you should only focus on critical properties and overlook noisy factors.\n'] diff --git a/graphrag/general/index.py b/graphrag/general/index.py index 2b4a66db5..79b30a058 100644 --- a/graphrag/general/index.py +++ b/graphrag/general/index.py @@ -69,26 +69,27 @@ async def run_graphrag( embedding_model, callback, ) - new_graph = None - if subgraph: - new_graph = await merge_subgraph( - tenant_id, - kb_id, - doc_id, - subgraph, - embedding_model, - callback, - ) + if not subgraph: + return + + subgraph_nodes = set(subgraph.nodes()) + new_graph = await merge_subgraph( + tenant_id, + kb_id, + doc_id, + subgraph, + embedding_model, + callback, + ) + assert new_graph is not None if not with_resolution or not with_community: return - if new_graph is None: - new_graph = await get_graph(tenant_id, kb_id) - - if with_resolution and new_graph is not None: + if with_resolution: await resolve_entities( new_graph, + subgraph_nodes, tenant_id, kb_id, doc_id, @@ -96,7 +97,7 @@ async def run_graphrag( embedding_model, callback, ) - if with_community and new_graph is not None: + if with_community: await extract_community( new_graph, tenant_id, @@ -223,6 +224,7 @@ async def merge_subgraph( async def resolve_entities( graph, + subgraph_nodes: set[str], tenant_id: str, kb_id: str, doc_id: str, @@ -241,7 +243,7 @@ async def resolve_entities( er = EntityResolution( llm_bdl, ) - reso = await er(graph, callback=callback) + reso = await er(graph, subgraph_nodes, callback=callback) graph = reso.graph change = reso.change callback(msg=f"Graph resolution removed {len(change.removed_nodes)} nodes and {len(change.removed_edges)} edges.")