mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Light GraphRAG (#4585)
### What problem does this PR solve? #4543 ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -16,18 +16,18 @@
|
||||
import logging
|
||||
import itertools
|
||||
import re
|
||||
import traceback
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from typing import Any, Callable
|
||||
|
||||
import networkx as nx
|
||||
|
||||
from graphrag.extractor import Extractor
|
||||
from graphrag.general.extractor import Extractor
|
||||
from rag.nlp import is_english
|
||||
import editdistance
|
||||
from graphrag.entity_resolution_prompt import ENTITY_RESOLUTION_PROMPT
|
||||
from rag.llm.chat_model import Base as CompletionLLM
|
||||
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements
|
||||
from graphrag.utils import perform_variable_replacements
|
||||
|
||||
DEFAULT_RECORD_DELIMITER = "##"
|
||||
DEFAULT_ENTITY_INDEX_DELIMITER = "<|>"
|
||||
@ -37,8 +37,8 @@ DEFAULT_RESOLUTION_RESULT_DELIMITER = "&&"
|
||||
@dataclass
|
||||
class EntityResolutionResult:
|
||||
"""Entity resolution result class definition."""
|
||||
|
||||
output: nx.Graph
|
||||
graph: nx.Graph
|
||||
removed_entities: list
|
||||
|
||||
|
||||
class EntityResolution(Extractor):
|
||||
@ -46,7 +46,6 @@ class EntityResolution(Extractor):
|
||||
|
||||
_resolution_prompt: str
|
||||
_output_formatter_prompt: str
|
||||
_on_error: ErrorHandlerFn
|
||||
_record_delimiter_key: str
|
||||
_entity_index_delimiter_key: str
|
||||
_resolution_result_delimiter_key: str
|
||||
@ -54,21 +53,19 @@ class EntityResolution(Extractor):
|
||||
def __init__(
|
||||
self,
|
||||
llm_invoker: CompletionLLM,
|
||||
resolution_prompt: str | None = None,
|
||||
on_error: ErrorHandlerFn | None = None,
|
||||
record_delimiter_key: str | None = None,
|
||||
entity_index_delimiter_key: str | None = None,
|
||||
resolution_result_delimiter_key: str | None = None,
|
||||
input_text_key: str | None = None
|
||||
get_entity: Callable | None = None,
|
||||
set_entity: Callable | None = None,
|
||||
get_relation: Callable | None = None,
|
||||
set_relation: Callable | None = None
|
||||
):
|
||||
super().__init__(llm_invoker, get_entity=get_entity, set_entity=set_entity, get_relation=get_relation, set_relation=set_relation)
|
||||
"""Init method definition."""
|
||||
self._llm = llm_invoker
|
||||
self._resolution_prompt = resolution_prompt or ENTITY_RESOLUTION_PROMPT
|
||||
self._on_error = on_error or (lambda _e, _s, _d: None)
|
||||
self._record_delimiter_key = record_delimiter_key or "record_delimiter"
|
||||
self._entity_index_dilimiter_key = entity_index_delimiter_key or "entity_index_delimiter"
|
||||
self._resolution_result_delimiter_key = resolution_result_delimiter_key or "resolution_result_delimiter"
|
||||
self._input_text_key = input_text_key or "input_text"
|
||||
self._resolution_prompt = ENTITY_RESOLUTION_PROMPT
|
||||
self._record_delimiter_key = "record_delimiter"
|
||||
self._entity_index_dilimiter_key = "entity_index_delimiter"
|
||||
self._resolution_result_delimiter_key = "resolution_result_delimiter"
|
||||
self._input_text_key = "input_text"
|
||||
|
||||
def __call__(self, graph: nx.Graph, prompt_variables: dict[str, Any] | None = None) -> EntityResolutionResult:
|
||||
"""Call method definition."""
|
||||
@ -87,11 +84,11 @@ class EntityResolution(Extractor):
|
||||
}
|
||||
|
||||
nodes = graph.nodes
|
||||
entity_types = list(set(graph.nodes[node]['entity_type'] for node in nodes))
|
||||
entity_types = list(set(graph.nodes[node].get('entity_type', '-') for node in nodes))
|
||||
node_clusters = {entity_type: [] for entity_type in entity_types}
|
||||
|
||||
for node in nodes:
|
||||
node_clusters[graph.nodes[node]['entity_type']].append(node)
|
||||
node_clusters[graph.nodes[node].get('entity_type', '-')].append(node)
|
||||
|
||||
candidate_resolution = {entity_type: [] for entity_type in entity_types}
|
||||
for k, v in node_clusters.items():
|
||||
@ -128,44 +125,51 @@ class EntityResolution(Extractor):
|
||||
DEFAULT_RESOLUTION_RESULT_DELIMITER))
|
||||
for result_i in result:
|
||||
resolution_result.add(candidate_resolution_i[1][result_i[0] - 1])
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logging.exception("error entity resolution")
|
||||
self._on_error(e, traceback.format_exc(), None)
|
||||
|
||||
connect_graph = nx.Graph()
|
||||
removed_entities = []
|
||||
connect_graph.add_edges_from(resolution_result)
|
||||
for sub_connect_graph in nx.connected_components(connect_graph):
|
||||
sub_connect_graph = connect_graph.subgraph(sub_connect_graph)
|
||||
remove_nodes = list(sub_connect_graph.nodes)
|
||||
keep_node = remove_nodes.pop()
|
||||
self._merge_nodes(keep_node, self._get_entity_(remove_nodes))
|
||||
for remove_node in remove_nodes:
|
||||
removed_entities.append(remove_node)
|
||||
remove_node_neighbors = graph[remove_node]
|
||||
graph.nodes[keep_node]['description'] += graph.nodes[remove_node]['description']
|
||||
graph.nodes[keep_node]['weight'] += graph.nodes[remove_node]['weight']
|
||||
remove_node_neighbors = list(remove_node_neighbors)
|
||||
for remove_node_neighbor in remove_node_neighbors:
|
||||
rel = self._get_relation_(remove_node, remove_node_neighbor)
|
||||
if graph.has_edge(remove_node, remove_node_neighbor):
|
||||
graph.remove_edge(remove_node, remove_node_neighbor)
|
||||
if remove_node_neighbor == keep_node:
|
||||
graph.remove_edge(keep_node, remove_node)
|
||||
if graph.has_edge(keep_node, remove_node):
|
||||
graph.remove_edge(keep_node, remove_node)
|
||||
continue
|
||||
if not rel:
|
||||
continue
|
||||
if graph.has_edge(keep_node, remove_node_neighbor):
|
||||
graph[keep_node][remove_node_neighbor]['weight'] += graph[remove_node][remove_node_neighbor][
|
||||
'weight']
|
||||
graph[keep_node][remove_node_neighbor]['description'] += \
|
||||
graph[remove_node][remove_node_neighbor]['description']
|
||||
graph.remove_edge(remove_node, remove_node_neighbor)
|
||||
self._merge_edges(keep_node, remove_node_neighbor, [rel])
|
||||
else:
|
||||
graph.add_edge(keep_node, remove_node_neighbor,
|
||||
weight=graph[remove_node][remove_node_neighbor]['weight'],
|
||||
description=graph[remove_node][remove_node_neighbor]['description'],
|
||||
source_id="")
|
||||
graph.remove_edge(remove_node, remove_node_neighbor)
|
||||
pair = sorted([keep_node, remove_node_neighbor])
|
||||
graph.add_edge(pair[0], pair[1], weight=rel['weight'])
|
||||
self._set_relation_(pair[0], pair[1],
|
||||
dict(
|
||||
src_id=pair[0],
|
||||
tgt_id=pair[1],
|
||||
weight=rel['weight'],
|
||||
description=rel['description'],
|
||||
keywords=[],
|
||||
source_id=rel.get("source_id", ""),
|
||||
metadata={"created_at": time.time()}
|
||||
))
|
||||
graph.remove_node(remove_node)
|
||||
|
||||
for node_degree in graph.degree:
|
||||
graph.nodes[str(node_degree[0])]["rank"] = int(node_degree[1])
|
||||
|
||||
return EntityResolutionResult(
|
||||
output=graph,
|
||||
graph=graph,
|
||||
removed_entities=removed_entities
|
||||
)
|
||||
|
||||
def _process_results(
|
||||
|
||||
Reference in New Issue
Block a user