Refactor graphrag to remove redis lock (#5828)

### What problem does this PR solve?

Refactor graphrag to remove redis lock

### Type of change

- [x] Refactoring
This commit is contained in:
Zhichang Yu
2025-03-10 15:15:06 +08:00
committed by GitHub
parent 1163e9e409
commit 6ec6ca6971
9 changed files with 602 additions and 332 deletions

View File

@ -15,18 +15,25 @@
#
import logging
import re
from threading import Lock
import umap
import numpy as np
from sklearn.mixture import GaussianMixture
import trio
from graphrag.utils import get_llm_cache, get_embed_cache, set_embed_cache, set_llm_cache, chat_limiter
from graphrag.utils import (
get_llm_cache,
get_embed_cache,
set_embed_cache,
set_llm_cache,
chat_limiter,
)
from rag.utils import truncate
class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
def __init__(self, max_cluster, llm_model, embd_model, prompt, max_token=512, threshold=0.1):
def __init__(
self, max_cluster, llm_model, embd_model, prompt, max_token=512, threshold=0.1
):
self._max_cluster = max_cluster
self._llm_model = llm_model
self._embd_model = embd_model
@ -34,22 +41,24 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
self._prompt = prompt
self._max_token = max_token
def _chat(self, system, history, gen_conf):
async def _chat(self, system, history, gen_conf):
response = get_llm_cache(self._llm_model.llm_name, system, history, gen_conf)
if response:
return response
response = self._llm_model.chat(system, history, gen_conf)
response = await trio.to_thread.run_sync(
lambda: self._llm_model.chat(system, history, gen_conf)
)
response = re.sub(r"<think>.*</think>", "", response, flags=re.DOTALL)
if response.find("**ERROR**") >= 0:
raise Exception(response)
set_llm_cache(self._llm_model.llm_name, system, response, history, gen_conf)
return response
def _embedding_encode(self, txt):
async def _embedding_encode(self, txt):
response = get_embed_cache(self._embd_model.llm_name, txt)
if response is not None:
return response
embds, _ = self._embd_model.encode([txt])
embds, _ = await trio.to_thread.run_sync(lambda: self._embd_model.encode([txt]))
if len(embds) < 1 or len(embds[0]) < 1:
raise Exception("Embedding error: ")
embds = embds[0]
@ -74,36 +83,48 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
return []
chunks = [(s, a) for s, a in chunks if s and len(a) > 0]
async def summarize(ck_idx, lock):
async def summarize(ck_idx: list[int]):
nonlocal chunks
try:
texts = [chunks[i][0] for i in ck_idx]
len_per_chunk = int((self._llm_model.max_length - self._max_token) / len(texts))
cluster_content = "\n".join([truncate(t, max(1, len_per_chunk)) for t in texts])
async with chat_limiter:
cnt = await trio.to_thread.run_sync(lambda: self._chat("You're a helpful assistant.",
[{"role": "user",
"content": self._prompt.format(cluster_content=cluster_content)}],
{"temperature": 0.3, "max_tokens": self._max_token}
))
cnt = re.sub("(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)", "",
cnt)
logging.debug(f"SUM: {cnt}")
embds, _ = self._embd_model.encode([cnt])
with lock:
chunks.append((cnt, self._embedding_encode(cnt)))
except Exception as e:
logging.exception("summarize got exception")
return e
texts = [chunks[i][0] for i in ck_idx]
len_per_chunk = int(
(self._llm_model.max_length - self._max_token) / len(texts)
)
cluster_content = "\n".join(
[truncate(t, max(1, len_per_chunk)) for t in texts]
)
async with chat_limiter:
cnt = await self._chat(
"You're a helpful assistant.",
[
{
"role": "user",
"content": self._prompt.format(
cluster_content=cluster_content
),
}
],
{"temperature": 0.3, "max_tokens": self._max_token},
)
cnt = re.sub(
"(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)",
"",
cnt,
)
logging.debug(f"SUM: {cnt}")
embds = await self._embedding_encode(cnt)
chunks.append((cnt, embds))
labels = []
lock = Lock()
while end - start > 1:
embeddings = [embd for _, embd in chunks[start: end]]
embeddings = [embd for _, embd in chunks[start:end]]
if len(embeddings) == 2:
await summarize([start, start + 1], lock)
await summarize([start, start + 1])
if callback:
callback(msg="Cluster one layer: {} -> {}".format(end - start, len(chunks) - end))
callback(
msg="Cluster one layer: {} -> {}".format(
end - start, len(chunks) - end
)
)
labels.extend([0, 0])
layers.append((end, len(chunks)))
start = end
@ -112,7 +133,9 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
n_neighbors = int((len(embeddings) - 1) ** 0.8)
reduced_embeddings = umap.UMAP(
n_neighbors=max(2, n_neighbors), n_components=min(12, len(embeddings) - 2), metric="cosine"
n_neighbors=max(2, n_neighbors),
n_components=min(12, len(embeddings) - 2),
metric="cosine",
).fit_transform(embeddings)
n_clusters = self._get_optimal_clusters(reduced_embeddings, random_state)
if n_clusters == 1:
@ -127,18 +150,22 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
async with trio.open_nursery() as nursery:
for c in range(n_clusters):
ck_idx = [i + start for i in range(len(lbls)) if lbls[i] == c]
if not ck_idx:
continue
assert len(ck_idx) > 0
async with chat_limiter:
nursery.start_soon(lambda: summarize(ck_idx, lock))
nursery.start_soon(lambda: summarize(ck_idx))
assert len(chunks) - end == n_clusters, "{} vs. {}".format(len(chunks) - end, n_clusters)
assert len(chunks) - end == n_clusters, "{} vs. {}".format(
len(chunks) - end, n_clusters
)
labels.extend(lbls)
layers.append((end, len(chunks)))
if callback:
callback(msg="Cluster one layer: {} -> {}".format(end - start, len(chunks) - end))
callback(
msg="Cluster one layer: {} -> {}".format(
end - start, len(chunks) - end
)
)
start = end
end = len(chunks)
return chunks

View File

@ -20,9 +20,7 @@ import random
import sys
from api.utils.log_utils import initRootLogger, get_project_base_directory
from graphrag.general.index import WithCommunity, WithResolution, Dealer
from graphrag.light.graph_extractor import GraphExtractor as LightKGExt
from graphrag.general.graph_extractor import GraphExtractor as GeneralKGExt
from graphrag.general.index import run_graphrag
from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache
from rag.prompts import keyword_extraction, question_proposal, content_tagging
@ -45,6 +43,7 @@ import tracemalloc
import resource
import signal
import trio
import exceptiongroup
import numpy as np
from peewee import DoesNotExist
@ -453,24 +452,6 @@ async def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
return res, tk_count
async def run_graphrag(row, chat_model, language, embedding_model, callback=None):
chunks = []
for d in settings.retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])],
fields=["content_with_weight", "doc_id"]):
chunks.append((d["doc_id"], d["content_with_weight"]))
dealer = Dealer(LightKGExt if row["parser_config"]["graphrag"]["method"] != 'general' else GeneralKGExt,
row["tenant_id"],
str(row["kb_id"]),
chat_model,
chunks=chunks,
language=language,
entity_types=row["parser_config"]["graphrag"]["entity_types"],
embed_bdl=embedding_model,
callback=callback)
await dealer()
async def do_handle_task(task):
task_id = task["id"]
task_from_page = task["from_page"]
@ -526,24 +507,10 @@ async def do_handle_task(task):
return
start_ts = timer()
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
await run_graphrag(task, chat_model, task_language, embedding_model, progress_callback)
progress_callback(prog=1.0, msg="Knowledge Graph basic is done ({:.2f}s)".format(timer() - start_ts))
if graphrag_conf.get("resolution", False):
start_ts = timer()
with_res = WithResolution(
task["tenant_id"], str(task["kb_id"]), chat_model, embedding_model,
progress_callback
)
await with_res()
progress_callback(prog=1.0, msg="Knowledge Graph resolution is done ({:.2f}s)".format(timer() - start_ts))
if graphrag_conf.get("community", False):
start_ts = timer()
with_comm = WithCommunity(
task["tenant_id"], str(task["kb_id"]), chat_model, embedding_model,
progress_callback
)
await with_comm()
progress_callback(prog=1.0, msg="Knowledge Graph community is done ({:.2f}s)".format(timer() - start_ts))
with_resolution = graphrag_conf.get("resolution", False)
with_community = graphrag_conf.get("community", False)
await run_graphrag(task, task_language, with_resolution, with_community, chat_model, embedding_model, progress_callback)
progress_callback(prog=1.0, msg="Knowledge Graph done ({:.2f}s)".format(timer() - start_ts))
return
else:
# Standard chunking methods
@ -622,7 +589,11 @@ async def handle_task():
FAILED_TASKS += 1
CURRENT_TASKS.pop(task["id"], None)
try:
set_progress(task["id"], prog=-1, msg=f"[Exception]: {e}")
err_msg = str(e)
while isinstance(e, exceptiongroup.ExceptionGroup):
e = e.exceptions[0]
err_msg += ' -- ' + str(e)
set_progress(task["id"], prog=-1, msg=f"[Exception]: {err_msg}")
except Exception:
pass
logging.exception(f"handle_task got exception for task {json.dumps(task)}")

View File

@ -16,13 +16,12 @@
import logging
import json
import time
import uuid
import valkey as redis
from rag import settings
from rag.utils import singleton
from valkey.lock import Lock
class RedisMsg:
def __init__(self, consumer, queue_name, group_name, msg_id, message):
@ -281,29 +280,23 @@ REDIS_CONN = RedisDB()
class RedisDistributedLock:
def __init__(self, lock_key, timeout=10):
def __init__(self, lock_key, lock_value=None, timeout=10, blocking_timeout=1):
self.lock_key = lock_key
self.lock_value = str(uuid.uuid4())
if lock_value:
self.lock_value = lock_value
else:
self.lock_value = str(uuid.uuid4())
self.timeout = timeout
self.lock = Lock(REDIS_CONN.REDIS, lock_key, timeout=timeout, blocking_timeout=blocking_timeout)
@staticmethod
def clean_lock(lock_key):
REDIS_CONN.REDIS.delete(lock_key)
def acquire(self):
return self.lock.acquire()
def acquire_lock(self):
end_time = time.time() + self.timeout
while time.time() < end_time:
if REDIS_CONN.REDIS.setnx(self.lock_key, self.lock_value):
return True
time.sleep(1)
return False
def release_lock(self):
if REDIS_CONN.REDIS.get(self.lock_key) == self.lock_value:
REDIS_CONN.REDIS.delete(self.lock_key)
def release(self):
return self.lock.release()
def __enter__(self):
self.acquire_lock()
self.acquire()
def __exit__(self, exception_type, exception_value, exception_traceback):
self.release_lock()
self.release()