Optimize graphrag again (#6513)

### What problem does this PR solve?

Removed set_entity and set_relation to avoid accessing doc engine during
graph computation.
Introduced GraphChange to avoid writing unchanged chunks.

### Type of change

- [x] Performance Improvement
This commit is contained in:
Zhichang Yu
2025-03-26 15:34:42 +08:00
committed by GitHub
parent 7a677cb095
commit 6bf26e2a81
19 changed files with 466 additions and 530 deletions

View File

@ -12,26 +12,37 @@ import logging
import re
import time
from collections import defaultdict
from copy import deepcopy
from hashlib import md5
from typing import Any, Callable
import os
import trio
from typing import Set, Tuple
import networkx as nx
import numpy as np
import xxhash
from networkx.readwrite import json_graph
import dataclasses
from api import settings
from api.utils import get_uuid
from rag.nlp import search, rag_tokenizer
from rag.utils.doc_store_conn import OrderByExpr
from rag.utils.redis_conn import REDIS_CONN
GRAPH_FIELD_SEP = "<SEP>"
ErrorHandlerFn = Callable[[BaseException | None, str | None, dict | None], None]
chat_limiter = trio.CapacityLimiter(int(os.environ.get('MAX_CONCURRENT_CHATS', 10)))
@dataclasses.dataclass
class GraphChange:
removed_nodes: Set[str] = dataclasses.field(default_factory=set)
added_updated_nodes: Set[str] = dataclasses.field(default_factory=set)
removed_edges: Set[Tuple[str, str]] = dataclasses.field(default_factory=set)
added_updated_edges: Set[Tuple[str, str]] = dataclasses.field(default_factory=set)
def perform_variable_replacements(
input: str, history: list[dict] | None = None, variables: dict | None = None
) -> str:
@ -146,24 +157,74 @@ def set_tags_to_cache(kb_ids, tags):
k = hasher.hexdigest()
REDIS_CONN.set(k, json.dumps(tags).encode("utf-8"), 600)
def tidy_graph(graph: nx.Graph, callback):
"""
Ensure all nodes and edges in the graph have some essential attribute.
"""
def is_valid_node(node_attrs: dict) -> bool:
valid_node = True
for attr in ["description", "source_id"]:
if attr not in node_attrs:
valid_node = False
break
return valid_node
purged_nodes = []
for node, node_attrs in graph.nodes(data=True):
if not is_valid_node(node_attrs):
purged_nodes.append(node)
for node in purged_nodes:
graph.remove_node(node)
if purged_nodes and callback:
callback(msg=f"Purged {len(purged_nodes)} nodes from graph due to missing essential attributes.")
def graph_merge(g1, g2):
g = g2.copy()
for n, attr in g1.nodes(data=True):
if n not in g2.nodes():
g.add_node(n, **attr)
purged_edges = []
for source, target, attr in graph.edges(data=True):
if not is_valid_node(attr):
purged_edges.append((source, target))
if "keywords" not in attr:
attr["keywords"] = []
for source, target in purged_edges:
graph.remove_edge(source, target)
if purged_edges and callback:
callback(msg=f"Purged {len(purged_edges)} edges from graph due to missing essential attributes.")
def get_from_to(node1, node2):
if node1 < node2:
return (node1, node2)
else:
return (node2, node1)
def graph_merge(g1: nx.Graph, g2: nx.Graph, change: GraphChange):
"""Merge graph g2 into g1 in place."""
for node_name, attr in g2.nodes(data=True):
change.added_updated_nodes.add(node_name)
if not g1.has_node(node_name):
g1.add_node(node_name, **attr)
continue
node = g1.nodes[node_name]
node["description"] += GRAPH_FIELD_SEP + attr["description"]
# A node's source_id indicates which chunks it came from.
node["source_id"] += attr["source_id"]
for source, target, attr in g1.edges(data=True):
if g.has_edge(source, target):
g[source][target].update({"weight": attr.get("weight", 0)+1})
for source, target, attr in g2.edges(data=True):
change.added_updated_edges.add(get_from_to(source, target))
edge = g1.get_edge_data(source, target)
if edge is None:
g1.add_edge(source, target, **attr)
continue
g.add_edge(source, target)#, **attr)
for node_degree in g.degree:
g.nodes[str(node_degree[0])]["rank"] = int(node_degree[1])
return g
edge["weight"] += attr.get("weight", 0)
edge["description"] += GRAPH_FIELD_SEP + attr["description"]
edge["keywords"] += attr["keywords"]
# A edge's source_id indicates which chunks it came from.
edge["source_id"] += attr["source_id"]
for node_degree in g1.degree:
g1.nodes[str(node_degree[0])]["rank"] = int(node_degree[1])
# A graph's source_id indicates which documents it came from.
if "source_id" not in g1.graph:
g1.graph["source_id"] = []
g1.graph["source_id"] += g2.graph.get("source_id", [])
return g1
def compute_args_hash(*args):
return md5(str(args).encode()).hexdigest()
@ -237,55 +298,10 @@ def is_float_regex(value):
def chunk_id(chunk):
return xxhash.xxh64((chunk["content_with_weight"] + chunk["kb_id"]).encode("utf-8")).hexdigest()
def get_entity_cache(tenant_id, kb_id, ent_name) -> str | list[str]:
hasher = xxhash.xxh64()
hasher.update(str(tenant_id).encode("utf-8"))
hasher.update(str(kb_id).encode("utf-8"))
hasher.update(str(ent_name).encode("utf-8"))
k = hasher.hexdigest()
bin = REDIS_CONN.get(k)
if not bin:
return
return json.loads(bin)
def set_entity_cache(tenant_id, kb_id, ent_name, content_with_weight):
hasher = xxhash.xxh64()
hasher.update(str(tenant_id).encode("utf-8"))
hasher.update(str(kb_id).encode("utf-8"))
hasher.update(str(ent_name).encode("utf-8"))
k = hasher.hexdigest()
REDIS_CONN.set(k, content_with_weight.encode("utf-8"), 3600)
def get_entity(tenant_id, kb_id, ent_name):
cache = get_entity_cache(tenant_id, kb_id, ent_name)
if cache:
return cache
conds = {
"fields": ["content_with_weight"],
"entity_kwd": ent_name,
"size": 10000,
"knowledge_graph_kwd": ["entity"]
}
res = []
es_res = settings.retrievaler.search(conds, search.index_name(tenant_id), [kb_id])
for id in es_res.ids:
try:
if isinstance(ent_name, str):
set_entity_cache(tenant_id, kb_id, ent_name, es_res.field[id]["content_with_weight"])
return json.loads(es_res.field[id]["content_with_weight"])
res.append(json.loads(es_res.field[id]["content_with_weight"]))
except Exception:
continue
return res
def set_entity(tenant_id, kb_id, embd_mdl, ent_name, meta):
async def graph_node_to_chunk(kb_id, embd_mdl, ent_name, meta, chunks):
chunk = {
"id": get_uuid(),
"important_kwd": [ent_name],
"title_tks": rag_tokenizer.tokenize(ent_name),
"entity_kwd": ent_name,
@ -293,28 +309,19 @@ def set_entity(tenant_id, kb_id, embd_mdl, ent_name, meta):
"entity_type_kwd": meta["entity_type"],
"content_with_weight": json.dumps(meta, ensure_ascii=False),
"content_ltks": rag_tokenizer.tokenize(meta["description"]),
"source_id": list(set(meta["source_id"])),
"source_id": meta["source_id"],
"kb_id": kb_id,
"available_int": 0
}
chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
set_entity_cache(tenant_id, kb_id, ent_name, chunk["content_with_weight"])
res = settings.retrievaler.search({"entity_kwd": ent_name, "size": 1, "fields": []},
search.index_name(tenant_id), [kb_id])
if res.ids:
settings.docStoreConn.update({"entity_kwd": ent_name}, chunk, search.index_name(tenant_id), kb_id)
else:
ebd = get_embed_cache(embd_mdl.llm_name, ent_name)
if ebd is None:
try:
ebd, _ = embd_mdl.encode([ent_name])
ebd = ebd[0]
set_embed_cache(embd_mdl.llm_name, ent_name, ebd)
except Exception as e:
logging.exception(f"Fail to embed entity: {e}")
if ebd is not None:
chunk["q_%d_vec" % len(ebd)] = ebd
settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id), kb_id)
ebd = get_embed_cache(embd_mdl.llm_name, ent_name)
if ebd is None:
ebd, _ = await trio.to_thread.run_sync(lambda: embd_mdl.encode([ent_name]))
ebd = ebd[0]
set_embed_cache(embd_mdl.llm_name, ent_name, ebd)
assert ebd is not None
chunk["q_%d_vec" % len(ebd)] = ebd
chunks.append(chunk)
def get_relation(tenant_id, kb_id, from_ent_name, to_ent_name, size=1):
@ -344,40 +351,30 @@ def get_relation(tenant_id, kb_id, from_ent_name, to_ent_name, size=1):
return res
def set_relation(tenant_id, kb_id, embd_mdl, from_ent_name, to_ent_name, meta):
async def graph_edge_to_chunk(kb_id, embd_mdl, from_ent_name, to_ent_name, meta, chunks):
chunk = {
"id": get_uuid(),
"from_entity_kwd": from_ent_name,
"to_entity_kwd": to_ent_name,
"knowledge_graph_kwd": "relation",
"content_with_weight": json.dumps(meta, ensure_ascii=False),
"content_ltks": rag_tokenizer.tokenize(meta["description"]),
"important_kwd": meta["keywords"],
"source_id": list(set(meta["source_id"])),
"source_id": meta["source_id"],
"weight_int": int(meta["weight"]),
"kb_id": kb_id,
"available_int": 0
}
chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
res = settings.retrievaler.search({"from_entity_kwd": to_ent_name, "to_entity_kwd": to_ent_name, "size": 1, "fields": []},
search.index_name(tenant_id), [kb_id])
if res.ids:
settings.docStoreConn.update({"from_entity_kwd": from_ent_name, "to_entity_kwd": to_ent_name},
chunk,
search.index_name(tenant_id), kb_id)
else:
txt = f"{from_ent_name}->{to_ent_name}"
ebd = get_embed_cache(embd_mdl.llm_name, txt)
if ebd is None:
try:
ebd, _ = embd_mdl.encode([txt+f": {meta['description']}"])
ebd = ebd[0]
set_embed_cache(embd_mdl.llm_name, txt, ebd)
except Exception as e:
logging.exception(f"Fail to embed entity relation: {e}")
if ebd is not None:
chunk["q_%d_vec" % len(ebd)] = ebd
settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id), kb_id)
txt = f"{from_ent_name}->{to_ent_name}"
ebd = get_embed_cache(embd_mdl.llm_name, txt)
if ebd is None:
ebd, _ = await trio.to_thread.run_sync(lambda: embd_mdl.encode([txt+f": {meta['description']}"]))
ebd = ebd[0]
set_embed_cache(embd_mdl.llm_name, txt, ebd)
assert ebd is not None
chunk["q_%d_vec" % len(ebd)] = ebd
chunks.append(chunk)
async def does_graph_contains(tenant_id, kb_id, doc_id):
# Get doc_ids of graph
@ -418,33 +415,68 @@ async def get_graph(tenant_id, kb_id):
}
res = await trio.to_thread.run_sync(lambda: settings.retrievaler.search(conds, search.index_name(tenant_id), [kb_id]))
if res.total == 0:
return None, []
return None
for id in res.ids:
try:
return json_graph.node_link_graph(json.loads(res.field[id]["content_with_weight"]), edges="edges"), \
res.field[id]["source_id"]
g = json_graph.node_link_graph(json.loads(res.field[id]["content_with_weight"]), edges="edges")
if "source_id" not in g.graph:
g.graph["source_id"] = res.field[id]["source_id"]
return g
except Exception:
continue
result = await rebuild_graph(tenant_id, kb_id)
return result
async def set_graph(tenant_id, kb_id, graph, docids):
chunk = {
"content_with_weight": json.dumps(nx.node_link_data(graph, edges="edges"), ensure_ascii=False,
indent=2),
async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, change: GraphChange, callback):
start = trio.current_time()
await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph"]}, search.index_name(tenant_id), kb_id))
if change.removed_nodes:
await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"knowledge_graph_kwd": ["entity"], "entity_kwd": sorted(change.removed_nodes)}, search.index_name(tenant_id), kb_id))
if change.removed_edges:
async with trio.open_nursery() as nursery:
for from_node, to_node in change.removed_edges:
nursery.start_soon(lambda: settings.docStoreConn.delete({"knowledge_graph_kwd": ["relation"], "from_entity_kwd": from_node, "to_entity_kwd": to_node}, search.index_name(tenant_id), kb_id))
now = trio.current_time()
if callback:
callback(msg=f"set_graph removed {len(change.removed_nodes)} nodes and {len(change.removed_edges)} edges from index in {now - start:.2f}s.")
start = now
chunks = [{
"id": get_uuid(),
"content_with_weight": json.dumps(nx.node_link_data(graph, edges="edges"), ensure_ascii=False),
"knowledge_graph_kwd": "graph",
"kb_id": kb_id,
"source_id": list(docids),
"source_id": graph.graph.get("source_id", []),
"available_int": 0,
"removed_kwd": "N"
}
res = await trio.to_thread.run_sync(lambda: settings.retrievaler.search({"knowledge_graph_kwd": "graph", "size": 1, "fields": []}, search.index_name(tenant_id), [kb_id]))
if res.ids:
await trio.to_thread.run_sync(lambda: settings.docStoreConn.update({"knowledge_graph_kwd": "graph"}, chunk,
search.index_name(tenant_id), kb_id))
else:
await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id), kb_id))
}]
async with trio.open_nursery() as nursery:
for node in change.added_updated_nodes:
node_attrs = graph.nodes[node]
nursery.start_soon(lambda: graph_node_to_chunk(kb_id, embd_mdl, node, node_attrs, chunks))
for from_node, to_node in change.added_updated_edges:
edge_attrs = graph.edges[from_node, to_node]
nursery.start_soon(lambda: graph_edge_to_chunk(kb_id, embd_mdl, from_node, to_node, edge_attrs, chunks))
now = trio.current_time()
if callback:
callback(msg=f"set_graph converted graph change to {len(chunks)} chunks in {now - start:.2f}s.")
start = now
await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "entity", "relation"]}, search.index_name(tenant_id), kb_id))
es_bulk_size = 4
for b in range(0, len(chunks), es_bulk_size):
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + es_bulk_size], search.index_name(tenant_id), kb_id))
if doc_store_result:
error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
raise Exception(error_message)
now = trio.current_time()
if callback:
callback(msg=f"set_graph added/updated {len(change.added_updated_nodes)} nodes and {len(change.added_updated_edges)} edges from index in {now - start:.2f}s.")
def is_continuous_subsequence(subseq, seq):
@ -489,67 +521,6 @@ def merge_tuples(list1, list2):
return result
async def update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, graph, n_hop):
def n_neighbor(id):
nonlocal graph, n_hop
count = 0
source_edge = list(graph.edges(id))
if not source_edge:
return []
count = count + 1
while count < n_hop:
count = count + 1
sc_edge = deepcopy(source_edge)
source_edge = []
for pair in sc_edge:
append_edge = list(graph.edges(pair[-1]))
for tuples in merge_tuples([pair], append_edge):
source_edge.append(tuples)
nbrs = []
for path in source_edge:
n = {"path": path, "weights": []}
wts = nx.get_edge_attributes(graph, 'weight')
for i in range(len(path)-1):
f, t = path[i], path[i+1]
n["weights"].append(wts.get((f, t), 0))
nbrs.append(n)
return nbrs
pr = nx.pagerank(graph)
try:
async with trio.open_nursery() as nursery:
for n, p in pr.items():
graph.nodes[n]["pagerank"] = p
nursery.start_soon(lambda: trio.to_thread.run_sync(lambda: settings.docStoreConn.update({"entity_kwd": n, "kb_id": kb_id},
{"rank_flt": p,
"n_hop_with_weight": json.dumps((n), ensure_ascii=False)},
search.index_name(tenant_id), kb_id)))
except Exception as e:
logging.exception(e)
ty2ents = defaultdict(list)
for p, r in sorted(pr.items(), key=lambda x: x[1], reverse=True):
ty = graph.nodes[p].get("entity_type")
if not ty or len(ty2ents[ty]) > 12:
continue
ty2ents[ty].append(p)
chunk = {
"content_with_weight": json.dumps(ty2ents, ensure_ascii=False),
"kb_id": kb_id,
"knowledge_graph_kwd": "ty2ents",
"available_int": 0
}
res = await trio.to_thread.run_sync(lambda: settings.retrievaler.search({"knowledge_graph_kwd": "ty2ents", "size": 1, "fields": []},
search.index_name(tenant_id), [kb_id]))
if res.ids:
await trio.to_thread.run_sync(lambda: settings.docStoreConn.update({"knowledge_graph_kwd": "ty2ents"},
chunk,
search.index_name(tenant_id), kb_id))
else:
await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id), kb_id))
async def get_entity_type2sampels(idxnms, kb_ids: list):
es_res = await trio.to_thread.run_sync(lambda: settings.retrievaler.search({"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids,
"size": 10000,
@ -584,33 +555,46 @@ def flat_uniq_list(arr, key):
async def rebuild_graph(tenant_id, kb_id):
graph = nx.Graph()
src_ids = []
flds = ["entity_kwd", "entity_type_kwd", "from_entity_kwd", "to_entity_kwd", "weight_int", "knowledge_graph_kwd", "source_id"]
src_ids = set()
flds = ["entity_kwd", "from_entity_kwd", "to_entity_kwd", "knowledge_graph_kwd", "content_with_weight", "source_id"]
bs = 256
for i in range(0, 39*bs, bs):
for i in range(0, 1024*bs, bs):
es_res = await trio.to_thread.run_sync(lambda: settings.docStoreConn.search(flds, [],
{"kb_id": kb_id, "knowledge_graph_kwd": ["entity", "relation"]},
{"kb_id": kb_id, "knowledge_graph_kwd": ["entity"]},
[],
OrderByExpr(),
i, bs, search.index_name(tenant_id), [kb_id]
))
tot = settings.docStoreConn.getTotal(es_res)
if tot == 0:
return None, None
break
es_res = settings.docStoreConn.getFields(es_res, flds)
for id, d in es_res.items():
src_ids.extend(d.get("source_id", []))
if d["knowledge_graph_kwd"] == "entity":
graph.add_node(d["entity_kwd"], entity_type=d["entity_type_kwd"])
elif "from_entity_kwd" in d and "to_entity_kwd" in d:
graph.add_edge(
d["from_entity_kwd"],
d["to_entity_kwd"],
weight=int(d["weight_int"])
)
assert d["knowledge_graph_kwd"] == "relation"
src_ids.update(d.get("source_id", []))
attrs = json.load(d["content_with_weight"])
graph.add_node(d["entity_kwd"], **attrs)
if len(es_res.keys()) < 128:
return graph, list(set(src_ids))
for i in range(0, 1024*bs, bs):
es_res = await trio.to_thread.run_sync(lambda: settings.docStoreConn.search(flds, [],
{"kb_id": kb_id, "knowledge_graph_kwd": ["relation"]},
[],
OrderByExpr(),
i, bs, search.index_name(tenant_id), [kb_id]
))
tot = settings.docStoreConn.getTotal(es_res)
if tot == 0:
return None
return graph, list(set(src_ids))
es_res = settings.docStoreConn.getFields(es_res, flds)
for id, d in es_res.items():
assert d["knowledge_graph_kwd"] == "relation"
src_ids.update(d.get("source_id", []))
if graph.has_node(d["from_entity_kwd"]) and graph.has_node(d["to_entity_kwd"]):
attrs = json.load(d["content_with_weight"])
graph.add_edge(d["from_entity_kwd"], d["to_entity_kwd"], **attrs)
src_ids = sorted(src_ids)
graph.graph["source_id"] = src_ids
return graph