mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Light GraphRAG (#4585)
### What problem does this PR solve? #4543 ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -3,16 +3,26 @@
|
||||
"""
|
||||
Reference:
|
||||
- [graphrag](https://github.com/microsoft/graphrag)
|
||||
- [LightRag](https://github.com/HKUDS/LightRAG)
|
||||
"""
|
||||
|
||||
import html
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from copy import deepcopy
|
||||
from hashlib import md5
|
||||
from typing import Any, Callable
|
||||
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
import xxhash
|
||||
from networkx.readwrite import json_graph
|
||||
|
||||
from api import settings
|
||||
from rag.nlp import search, rag_tokenizer
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
|
||||
ErrorHandlerFn = Callable[[BaseException | None, str | None, dict | None], None]
|
||||
@ -131,3 +141,379 @@ def set_tags_to_cache(kb_ids, tags):
|
||||
|
||||
k = hasher.hexdigest()
|
||||
REDIS_CONN.set(k, json.dumps(tags).encode("utf-8"), 600)
|
||||
|
||||
|
||||
def graph_merge(g1, g2):
|
||||
g = g2.copy()
|
||||
for n, attr in g1.nodes(data=True):
|
||||
if n not in g2.nodes():
|
||||
g.add_node(n, **attr)
|
||||
continue
|
||||
|
||||
for source, target, attr in g1.edges(data=True):
|
||||
if g.has_edge(source, target):
|
||||
g[source][target].update({"weight": attr.get("weight", 0)+1})
|
||||
continue
|
||||
g.add_edge(source, target)#, **attr)
|
||||
|
||||
for node_degree in g.degree:
|
||||
g.nodes[str(node_degree[0])]["rank"] = int(node_degree[1])
|
||||
return g
|
||||
|
||||
|
||||
def compute_args_hash(*args):
|
||||
return md5(str(args).encode()).hexdigest()
|
||||
|
||||
|
||||
def handle_single_entity_extraction(
|
||||
record_attributes: list[str],
|
||||
chunk_key: str,
|
||||
):
|
||||
if len(record_attributes) < 4 or record_attributes[0] != '"entity"':
|
||||
return None
|
||||
# add this record as a node in the G
|
||||
entity_name = clean_str(record_attributes[1].upper())
|
||||
if not entity_name.strip():
|
||||
return None
|
||||
entity_type = clean_str(record_attributes[2].upper())
|
||||
entity_description = clean_str(record_attributes[3])
|
||||
entity_source_id = chunk_key
|
||||
return dict(
|
||||
entity_name=entity_name.upper(),
|
||||
entity_type=entity_type.upper(),
|
||||
description=entity_description,
|
||||
source_id=entity_source_id,
|
||||
)
|
||||
|
||||
|
||||
def handle_single_relationship_extraction(record_attributes: list[str], chunk_key: str):
|
||||
if len(record_attributes) < 5 or record_attributes[0] != '"relationship"':
|
||||
return None
|
||||
# add this record as edge
|
||||
source = clean_str(record_attributes[1].upper())
|
||||
target = clean_str(record_attributes[2].upper())
|
||||
edge_description = clean_str(record_attributes[3])
|
||||
|
||||
edge_keywords = clean_str(record_attributes[4])
|
||||
edge_source_id = chunk_key
|
||||
weight = (
|
||||
float(record_attributes[-1]) if is_float_regex(record_attributes[-1]) else 1.0
|
||||
)
|
||||
pair = sorted([source.upper(), target.upper()])
|
||||
return dict(
|
||||
src_id=pair[0],
|
||||
tgt_id=pair[1],
|
||||
weight=weight,
|
||||
description=edge_description,
|
||||
keywords=edge_keywords,
|
||||
source_id=edge_source_id,
|
||||
metadata={"created_at": time.time()},
|
||||
)
|
||||
|
||||
|
||||
def pack_user_ass_to_openai_messages(*args: str):
|
||||
roles = ["user", "assistant"]
|
||||
return [
|
||||
{"role": roles[i % 2], "content": content} for i, content in enumerate(args)
|
||||
]
|
||||
|
||||
|
||||
def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]:
|
||||
"""Split a string by multiple markers"""
|
||||
if not markers:
|
||||
return [content]
|
||||
results = re.split("|".join(re.escape(marker) for marker in markers), content)
|
||||
return [r.strip() for r in results if r.strip()]
|
||||
|
||||
|
||||
def is_float_regex(value):
|
||||
return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value))
|
||||
|
||||
|
||||
def chunk_id(chunk):
|
||||
return xxhash.xxh64((chunk["content_with_weight"] + chunk["kb_id"]).encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def get_entity(tenant_id, kb_id, ent_name):
|
||||
conds = {
|
||||
"fields": ["content_with_weight"],
|
||||
"entity_kwd": ent_name,
|
||||
"size": 10000,
|
||||
"knowledge_graph_kwd": ["entity"]
|
||||
}
|
||||
res = []
|
||||
es_res = settings.retrievaler.search(conds, search.index_name(tenant_id), [kb_id])
|
||||
for id in es_res.ids:
|
||||
try:
|
||||
if isinstance(ent_name, str):
|
||||
return json.loads(es_res.field[id]["content_with_weight"])
|
||||
res.append(json.loads(es_res.field[id]["content_with_weight"]))
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def set_entity(tenant_id, kb_id, embd_mdl, ent_name, meta):
|
||||
chunk = {
|
||||
"important_kwd": [ent_name],
|
||||
"title_tks": rag_tokenizer.tokenize(ent_name),
|
||||
"entity_kwd": ent_name,
|
||||
"knowledge_graph_kwd": "entity",
|
||||
"entity_type_kwd": meta["entity_type"],
|
||||
"content_with_weight": json.dumps(meta, ensure_ascii=False),
|
||||
"content_ltks": rag_tokenizer.tokenize(meta["description"]),
|
||||
"source_id": list(set(meta["source_id"])),
|
||||
"kb_id": kb_id,
|
||||
"available_int": 0
|
||||
}
|
||||
chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
|
||||
res = settings.retrievaler.search({"entity_kwd": ent_name, "size": 1, "fields": []},
|
||||
search.index_name(tenant_id), [kb_id])
|
||||
if res.ids:
|
||||
settings.docStoreConn.update({"entity_kwd": ent_name}, chunk, search.index_name(tenant_id), kb_id)
|
||||
else:
|
||||
ebd = get_embed_cache(embd_mdl.llm_name, ent_name)
|
||||
if ebd is None:
|
||||
try:
|
||||
ebd, _ = embd_mdl.encode([ent_name])
|
||||
ebd = ebd[0]
|
||||
set_embed_cache(embd_mdl.llm_name, ent_name, ebd)
|
||||
except Exception as e:
|
||||
logging.exception(f"Fail to embed entity: {e}")
|
||||
if ebd is not None:
|
||||
chunk["q_%d_vec" % len(ebd)] = ebd
|
||||
settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id))
|
||||
|
||||
|
||||
def get_relation(tenant_id, kb_id, from_ent_name, to_ent_name, size=1):
|
||||
ents = from_ent_name
|
||||
if isinstance(ents, str):
|
||||
ents = [from_ent_name]
|
||||
if isinstance(to_ent_name, str):
|
||||
to_ent_name = [to_ent_name]
|
||||
ents.extend(to_ent_name)
|
||||
ents = list(set(ents))
|
||||
conds = {
|
||||
"fields": ["content_with_weight"],
|
||||
"size": size,
|
||||
"from_entity_kwd": ents,
|
||||
"to_entity_kwd": ents,
|
||||
"knowledge_graph_kwd": ["relation"]
|
||||
}
|
||||
res = []
|
||||
es_res = settings.retrievaler.search(conds, search.index_name(tenant_id), [kb_id] if isinstance(kb_id, str) else kb_id)
|
||||
for id in es_res.ids:
|
||||
try:
|
||||
if size == 1:
|
||||
return json.loads(es_res.field[id]["content_with_weight"])
|
||||
res.append(json.loads(es_res.field[id]["content_with_weight"]))
|
||||
except Exception:
|
||||
continue
|
||||
return res
|
||||
|
||||
|
||||
def set_relation(tenant_id, kb_id, embd_mdl, from_ent_name, to_ent_name, meta):
|
||||
chunk = {
|
||||
"from_entity_kwd": from_ent_name,
|
||||
"to_entity_kwd": to_ent_name,
|
||||
"knowledge_graph_kwd": "relation",
|
||||
"content_with_weight": json.dumps(meta, ensure_ascii=False),
|
||||
"content_ltks": rag_tokenizer.tokenize(meta["description"]),
|
||||
"important_kwd": meta["keywords"],
|
||||
"source_id": list(set(meta["source_id"])),
|
||||
"weight_int": int(meta["weight"]),
|
||||
"kb_id": kb_id,
|
||||
"available_int": 0
|
||||
}
|
||||
chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
|
||||
res = settings.retrievaler.search({"from_entity_kwd": to_ent_name, "to_entity_kwd": to_ent_name, "size": 1, "fields": []},
|
||||
search.index_name(tenant_id), [kb_id])
|
||||
|
||||
if res.ids:
|
||||
settings.docStoreConn.update({"from_entity_kwd": from_ent_name, "to_entity_kwd": to_ent_name},
|
||||
chunk,
|
||||
search.index_name(tenant_id), kb_id)
|
||||
else:
|
||||
txt = f"{from_ent_name}->{to_ent_name}"
|
||||
ebd = get_embed_cache(embd_mdl.llm_name, txt)
|
||||
if ebd is None:
|
||||
try:
|
||||
ebd, _ = embd_mdl.encode([txt+f": {meta['description']}"])
|
||||
ebd = ebd[0]
|
||||
set_embed_cache(embd_mdl.llm_name, txt, ebd)
|
||||
except Exception as e:
|
||||
logging.exception(f"Fail to embed entity relation: {e}")
|
||||
if ebd is not None:
|
||||
chunk["q_%d_vec" % len(ebd)] = ebd
|
||||
settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id))
|
||||
|
||||
|
||||
def get_graph(tenant_id, kb_id):
|
||||
conds = {
|
||||
"fields": ["content_with_weight", "source_id"],
|
||||
"removed_kwd": "N",
|
||||
"size": 1,
|
||||
"knowledge_graph_kwd": ["graph"]
|
||||
}
|
||||
res = settings.retrievaler.search(conds, search.index_name(tenant_id), [kb_id])
|
||||
for id in res.ids:
|
||||
try:
|
||||
return json_graph.node_link_graph(json.loads(res.field[id]["content_with_weight"]), edges="edges"), \
|
||||
res.field[id]["source_id"]
|
||||
except Exception:
|
||||
continue
|
||||
return None, None
|
||||
|
||||
|
||||
def set_graph(tenant_id, kb_id, graph, docids):
|
||||
chunk = {
|
||||
"content_with_weight": json.dumps(nx.node_link_data(graph, edges="edges"), ensure_ascii=False,
|
||||
indent=2),
|
||||
"knowledge_graph_kwd": "graph",
|
||||
"kb_id": kb_id,
|
||||
"source_id": list(docids),
|
||||
"available_int": 0,
|
||||
"removed_kwd": "N"
|
||||
}
|
||||
res = settings.retrievaler.search({"knowledge_graph_kwd": "graph", "size": 1, "fields": []}, search.index_name(tenant_id), [kb_id])
|
||||
if res.ids:
|
||||
settings.docStoreConn.update({"knowledge_graph_kwd": "graph"}, chunk,
|
||||
search.index_name(tenant_id), kb_id)
|
||||
else:
|
||||
settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id))
|
||||
|
||||
|
||||
def is_continuous_subsequence(subseq, seq):
|
||||
def find_all_indexes(tup, value):
|
||||
indexes = []
|
||||
start = 0
|
||||
while True:
|
||||
try:
|
||||
index = tup.index(value, start)
|
||||
indexes.append(index)
|
||||
start = index + 1
|
||||
except ValueError:
|
||||
break
|
||||
return indexes
|
||||
|
||||
index_list = find_all_indexes(seq,subseq[0])
|
||||
for idx in index_list:
|
||||
if idx!=len(seq)-1:
|
||||
if seq[idx+1]==subseq[-1]:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def merge_tuples(list1, list2):
|
||||
result = []
|
||||
for tup in list1:
|
||||
last_element = tup[-1]
|
||||
if last_element in tup[:-1]:
|
||||
result.append(tup)
|
||||
else:
|
||||
matching_tuples = [t for t in list2 if t[0] == last_element]
|
||||
already_match_flag = 0
|
||||
for match in matching_tuples:
|
||||
matchh = (match[1], match[0])
|
||||
if is_continuous_subsequence(match, tup) or is_continuous_subsequence(matchh, tup):
|
||||
continue
|
||||
already_match_flag = 1
|
||||
merged_tuple = tup + match[1:]
|
||||
result.append(merged_tuple)
|
||||
if not already_match_flag:
|
||||
result.append(tup)
|
||||
return result
|
||||
|
||||
|
||||
def update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, graph, n_hop):
|
||||
def n_neighbor(id):
|
||||
nonlocal graph, n_hop
|
||||
count = 0
|
||||
source_edge = list(graph.edges(id))
|
||||
if not source_edge:
|
||||
return []
|
||||
count = count + 1
|
||||
while count < n_hop:
|
||||
count = count + 1
|
||||
sc_edge = deepcopy(source_edge)
|
||||
source_edge = []
|
||||
for pair in sc_edge:
|
||||
append_edge = list(graph.edges(pair[-1]))
|
||||
for tuples in merge_tuples([pair], append_edge):
|
||||
source_edge.append(tuples)
|
||||
nbrs = []
|
||||
for path in source_edge:
|
||||
n = {"path": path, "weights": []}
|
||||
wts = nx.get_edge_attributes(graph, 'weight')
|
||||
for i in range(len(path)-1):
|
||||
f, t = path[i], path[i+1]
|
||||
n["weights"].append(wts.get((f, t), 0))
|
||||
nbrs.append(n)
|
||||
return nbrs
|
||||
|
||||
pr = nx.pagerank(graph)
|
||||
for n, p in pr.items():
|
||||
graph.nodes[n]["pagerank"] = p
|
||||
try:
|
||||
settings.docStoreConn.update({"entity_kwd": n, "kb_id": kb_id},
|
||||
{"rank_flt": p,
|
||||
"n_hop_with_weight": json.dumps(n_neighbor(n), ensure_ascii=False)},
|
||||
search.index_name(tenant_id), kb_id)
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
|
||||
ty2ents = defaultdict(list)
|
||||
for p, r in sorted(pr.items(), key=lambda x: x[1], reverse=True):
|
||||
ty = graph.nodes[p].get("entity_type")
|
||||
if not ty or len(ty2ents[ty]) > 12:
|
||||
continue
|
||||
ty2ents[ty].append(p)
|
||||
|
||||
chunk = {
|
||||
"content_with_weight": json.dumps(ty2ents, ensure_ascii=False),
|
||||
"kb_id": kb_id,
|
||||
"knowledge_graph_kwd": "ty2ents",
|
||||
"available_int": 0
|
||||
}
|
||||
res = settings.retrievaler.search({"knowledge_graph_kwd": "ty2ents", "size": 1, "fields": []},
|
||||
search.index_name(tenant_id), [kb_id])
|
||||
if res.ids:
|
||||
settings.docStoreConn.update({"knowledge_graph_kwd": "ty2ents"},
|
||||
chunk,
|
||||
search.index_name(tenant_id), kb_id)
|
||||
else:
|
||||
settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id))
|
||||
|
||||
|
||||
def get_entity_type2sampels(idxnms, kb_ids: list):
|
||||
es_res = settings.retrievaler.search({"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids,
|
||||
"size": 10000,
|
||||
"fields": ["content_with_weight"]},
|
||||
idxnms, kb_ids)
|
||||
|
||||
res = defaultdict(list)
|
||||
for id in es_res.ids:
|
||||
smp = es_res.field[id].get("content_with_weight")
|
||||
if not smp:
|
||||
continue
|
||||
try:
|
||||
smp = json.loads(smp)
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
|
||||
for ty, ents in smp.items():
|
||||
res[ty].extend(ents)
|
||||
return res
|
||||
|
||||
|
||||
def flat_uniq_list(arr, key):
|
||||
res = []
|
||||
for a in arr:
|
||||
a = a[key]
|
||||
if isinstance(a, list):
|
||||
res.extend(a)
|
||||
else:
|
||||
res.append(a)
|
||||
return list(set(res))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user