Light GraphRAG (#4585)

### What problem does this PR solve?

#4543

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Kevin Hu
2025-01-22 19:43:14 +08:00
committed by GitHub
parent 1a367664f1
commit dd0ebbea35
55 changed files with 5461 additions and 4000 deletions

View File

@ -16,18 +16,18 @@
import logging
import itertools
import re
import traceback
import time
from dataclasses import dataclass
from typing import Any
from typing import Any, Callable
import networkx as nx
from graphrag.extractor import Extractor
from graphrag.general.extractor import Extractor
from rag.nlp import is_english
import editdistance
from graphrag.entity_resolution_prompt import ENTITY_RESOLUTION_PROMPT
from rag.llm.chat_model import Base as CompletionLLM
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements
from graphrag.utils import perform_variable_replacements
DEFAULT_RECORD_DELIMITER = "##"
DEFAULT_ENTITY_INDEX_DELIMITER = "<|>"
@ -37,8 +37,8 @@ DEFAULT_RESOLUTION_RESULT_DELIMITER = "&&"
@dataclass
class EntityResolutionResult:
"""Entity resolution result class definition."""
output: nx.Graph
graph: nx.Graph
removed_entities: list
class EntityResolution(Extractor):
@ -46,7 +46,6 @@ class EntityResolution(Extractor):
_resolution_prompt: str
_output_formatter_prompt: str
_on_error: ErrorHandlerFn
_record_delimiter_key: str
_entity_index_delimiter_key: str
_resolution_result_delimiter_key: str
@ -54,21 +53,19 @@ class EntityResolution(Extractor):
def __init__(
self,
llm_invoker: CompletionLLM,
resolution_prompt: str | None = None,
on_error: ErrorHandlerFn | None = None,
record_delimiter_key: str | None = None,
entity_index_delimiter_key: str | None = None,
resolution_result_delimiter_key: str | None = None,
input_text_key: str | None = None
get_entity: Callable | None = None,
set_entity: Callable | None = None,
get_relation: Callable | None = None,
set_relation: Callable | None = None
):
super().__init__(llm_invoker, get_entity=get_entity, set_entity=set_entity, get_relation=get_relation, set_relation=set_relation)
"""Init method definition."""
self._llm = llm_invoker
self._resolution_prompt = resolution_prompt or ENTITY_RESOLUTION_PROMPT
self._on_error = on_error or (lambda _e, _s, _d: None)
self._record_delimiter_key = record_delimiter_key or "record_delimiter"
self._entity_index_dilimiter_key = entity_index_delimiter_key or "entity_index_delimiter"
self._resolution_result_delimiter_key = resolution_result_delimiter_key or "resolution_result_delimiter"
self._input_text_key = input_text_key or "input_text"
self._resolution_prompt = ENTITY_RESOLUTION_PROMPT
self._record_delimiter_key = "record_delimiter"
self._entity_index_dilimiter_key = "entity_index_delimiter"
self._resolution_result_delimiter_key = "resolution_result_delimiter"
self._input_text_key = "input_text"
def __call__(self, graph: nx.Graph, prompt_variables: dict[str, Any] | None = None) -> EntityResolutionResult:
"""Call method definition."""
@ -87,11 +84,11 @@ class EntityResolution(Extractor):
}
nodes = graph.nodes
entity_types = list(set(graph.nodes[node]['entity_type'] for node in nodes))
entity_types = list(set(graph.nodes[node].get('entity_type', '-') for node in nodes))
node_clusters = {entity_type: [] for entity_type in entity_types}
for node in nodes:
node_clusters[graph.nodes[node]['entity_type']].append(node)
node_clusters[graph.nodes[node].get('entity_type', '-')].append(node)
candidate_resolution = {entity_type: [] for entity_type in entity_types}
for k, v in node_clusters.items():
@ -128,44 +125,51 @@ class EntityResolution(Extractor):
DEFAULT_RESOLUTION_RESULT_DELIMITER))
for result_i in result:
resolution_result.add(candidate_resolution_i[1][result_i[0] - 1])
except Exception as e:
except Exception:
logging.exception("error entity resolution")
self._on_error(e, traceback.format_exc(), None)
connect_graph = nx.Graph()
removed_entities = []
connect_graph.add_edges_from(resolution_result)
for sub_connect_graph in nx.connected_components(connect_graph):
sub_connect_graph = connect_graph.subgraph(sub_connect_graph)
remove_nodes = list(sub_connect_graph.nodes)
keep_node = remove_nodes.pop()
self._merge_nodes(keep_node, self._get_entity_(remove_nodes))
for remove_node in remove_nodes:
removed_entities.append(remove_node)
remove_node_neighbors = graph[remove_node]
graph.nodes[keep_node]['description'] += graph.nodes[remove_node]['description']
graph.nodes[keep_node]['weight'] += graph.nodes[remove_node]['weight']
remove_node_neighbors = list(remove_node_neighbors)
for remove_node_neighbor in remove_node_neighbors:
rel = self._get_relation_(remove_node, remove_node_neighbor)
if graph.has_edge(remove_node, remove_node_neighbor):
graph.remove_edge(remove_node, remove_node_neighbor)
if remove_node_neighbor == keep_node:
graph.remove_edge(keep_node, remove_node)
if graph.has_edge(keep_node, remove_node):
graph.remove_edge(keep_node, remove_node)
continue
if not rel:
continue
if graph.has_edge(keep_node, remove_node_neighbor):
graph[keep_node][remove_node_neighbor]['weight'] += graph[remove_node][remove_node_neighbor][
'weight']
graph[keep_node][remove_node_neighbor]['description'] += \
graph[remove_node][remove_node_neighbor]['description']
graph.remove_edge(remove_node, remove_node_neighbor)
self._merge_edges(keep_node, remove_node_neighbor, [rel])
else:
graph.add_edge(keep_node, remove_node_neighbor,
weight=graph[remove_node][remove_node_neighbor]['weight'],
description=graph[remove_node][remove_node_neighbor]['description'],
source_id="")
graph.remove_edge(remove_node, remove_node_neighbor)
pair = sorted([keep_node, remove_node_neighbor])
graph.add_edge(pair[0], pair[1], weight=rel['weight'])
self._set_relation_(pair[0], pair[1],
dict(
src_id=pair[0],
tgt_id=pair[1],
weight=rel['weight'],
description=rel['description'],
keywords=[],
source_id=rel.get("source_id", ""),
metadata={"created_at": time.time()}
))
graph.remove_node(remove_node)
for node_degree in graph.degree:
graph.nodes[str(node_degree[0])]["rank"] = int(node_degree[1])
return EntityResolutionResult(
output=graph,
graph=graph,
removed_entities=removed_entities
)
def _process_results(