mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Add graphrag (#1793)
### What problem does this PR solve? #1594 ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
160
graphrag/leiden.py
Normal file
160
graphrag/leiden.py
Normal file
@ -0,0 +1,160 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""
|
||||
Reference:
|
||||
- [graphrag](https://github.com/microsoft/graphrag)
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, cast, List
|
||||
import html
|
||||
from graspologic.partition import hierarchical_leiden
|
||||
from graspologic.utils import largest_connected_component
|
||||
|
||||
import networkx as nx
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _stabilize_graph(graph: nx.Graph) -> nx.Graph:
|
||||
"""Ensure an undirected graph with the same relationships will always be read the same way."""
|
||||
fixed_graph = nx.DiGraph() if graph.is_directed() else nx.Graph()
|
||||
|
||||
sorted_nodes = graph.nodes(data=True)
|
||||
sorted_nodes = sorted(sorted_nodes, key=lambda x: x[0])
|
||||
|
||||
fixed_graph.add_nodes_from(sorted_nodes)
|
||||
edges = list(graph.edges(data=True))
|
||||
|
||||
# If the graph is undirected, we create the edges in a stable way, so we get the same results
|
||||
# for example:
|
||||
# A -> B
|
||||
# in graph theory is the same as
|
||||
# B -> A
|
||||
# in an undirected graph
|
||||
# however, this can lead to downstream issues because sometimes
|
||||
# consumers read graph.nodes() which ends up being [A, B] and sometimes it's [B, A]
|
||||
# but they base some of their logic on the order of the nodes, so the order ends up being important
|
||||
# so we sort the nodes in the edge in a stable way, so that we always get the same order
|
||||
if not graph.is_directed():
|
||||
|
||||
def _sort_source_target(edge):
|
||||
source, target, edge_data = edge
|
||||
if source > target:
|
||||
temp = source
|
||||
source = target
|
||||
target = temp
|
||||
return source, target, edge_data
|
||||
|
||||
edges = [_sort_source_target(edge) for edge in edges]
|
||||
|
||||
def _get_edge_key(source: Any, target: Any) -> str:
|
||||
return f"{source} -> {target}"
|
||||
|
||||
edges = sorted(edges, key=lambda x: _get_edge_key(x[0], x[1]))
|
||||
|
||||
fixed_graph.add_edges_from(edges)
|
||||
return fixed_graph
|
||||
|
||||
|
||||
def normalize_node_names(graph: nx.Graph | nx.DiGraph) -> nx.Graph | nx.DiGraph:
|
||||
"""Normalize node names."""
|
||||
node_mapping = {node: html.unescape(node.upper().strip()) for node in graph.nodes()} # type: ignore
|
||||
return nx.relabel_nodes(graph, node_mapping)
|
||||
|
||||
|
||||
def stable_largest_connected_component(graph: nx.Graph) -> nx.Graph:
|
||||
"""Return the largest connected component of the graph, with nodes and edges sorted in a stable way."""
|
||||
graph = graph.copy()
|
||||
graph = cast(nx.Graph, largest_connected_component(graph))
|
||||
graph = normalize_node_names(graph)
|
||||
return _stabilize_graph(graph)
|
||||
|
||||
|
||||
def _compute_leiden_communities(
|
||||
graph: nx.Graph | nx.DiGraph,
|
||||
max_cluster_size: int,
|
||||
use_lcc: bool,
|
||||
seed=0xDEADBEEF,
|
||||
) -> dict[int, dict[str, int]]:
|
||||
"""Return Leiden root communities."""
|
||||
if use_lcc:
|
||||
graph = stable_largest_connected_component(graph)
|
||||
|
||||
community_mapping = hierarchical_leiden(
|
||||
graph, max_cluster_size=max_cluster_size, random_seed=seed
|
||||
)
|
||||
results: dict[int, dict[str, int]] = {}
|
||||
for partition in community_mapping:
|
||||
results[partition.level] = results.get(partition.level, {})
|
||||
results[partition.level][partition.node] = partition.cluster
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def run(graph: nx.Graph, args: dict[str, Any]) -> dict[int, dict[str, dict]]:
|
||||
"""Run method definition."""
|
||||
max_cluster_size = args.get("max_cluster_size", 12)
|
||||
use_lcc = args.get("use_lcc", True)
|
||||
if args.get("verbose", False):
|
||||
log.info(
|
||||
"Running leiden with max_cluster_size=%s, lcc=%s", max_cluster_size, use_lcc
|
||||
)
|
||||
if not graph.nodes(): return {}
|
||||
|
||||
node_id_to_community_map = _compute_leiden_communities(
|
||||
graph=graph,
|
||||
max_cluster_size=max_cluster_size,
|
||||
use_lcc=use_lcc,
|
||||
seed=args.get("seed", 0xDEADBEEF),
|
||||
)
|
||||
levels = args.get("levels")
|
||||
|
||||
# If they don't pass in levels, use them all
|
||||
if levels is None:
|
||||
levels = sorted(node_id_to_community_map.keys())
|
||||
|
||||
results_by_level: dict[int, dict[str, list[str]]] = {}
|
||||
for level in levels:
|
||||
result = {}
|
||||
results_by_level[level] = result
|
||||
for node_id, raw_community_id in node_id_to_community_map[level].items():
|
||||
community_id = str(raw_community_id)
|
||||
if community_id not in result:
|
||||
result[community_id] = {"weight": 0, "nodes": []}
|
||||
result[community_id]["nodes"].append(node_id)
|
||||
result[community_id]["weight"] += graph.nodes[node_id].get("rank", 0) * graph.nodes[node_id].get("weight", 1)
|
||||
weights = [comm["weight"] for _, comm in result.items()]
|
||||
if not weights:continue
|
||||
max_weight = max(weights)
|
||||
for _, comm in result.items(): comm["weight"] /= max_weight
|
||||
|
||||
return results_by_level
|
||||
|
||||
|
||||
def add_community_info2graph(graph: nx.Graph, commu_info: dict[str, dict[str, dict]]):
|
||||
for lev, cluster_info in commu_info.items():
|
||||
for cid, nodes in cluster_info.items():
|
||||
for n in nodes["nodes"]:
|
||||
if "community" not in graph.nodes[n]: graph.nodes[n]["community"] = {}
|
||||
graph.nodes[n]["community"].update({lev: cid})
|
||||
|
||||
|
||||
def add_community_info2graph(graph: nx.Graph, nodes: List[str], community_title):
|
||||
for n in nodes:
|
||||
if "communities" not in graph.nodes[n]:
|
||||
graph.nodes[n]["communities"] = []
|
||||
graph.nodes[n]["communities"].append(community_title)
|
||||
Reference in New Issue
Block a user