Refactor graphrag to remove redis lock (#5828)

### What problem does this PR solve?

Refactor graphrag to remove redis lock

### Type of change

- [x] Refactoring
This commit is contained in:
Zhichang Yu
2025-03-10 15:15:06 +08:00
committed by GitHub
parent 1163e9e409
commit 6ec6ca6971
9 changed files with 602 additions and 332 deletions

View File

@ -20,9 +20,7 @@ import random
import sys
from api.utils.log_utils import initRootLogger, get_project_base_directory
from graphrag.general.index import WithCommunity, WithResolution, Dealer
from graphrag.light.graph_extractor import GraphExtractor as LightKGExt
from graphrag.general.graph_extractor import GraphExtractor as GeneralKGExt
from graphrag.general.index import run_graphrag
from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache
from rag.prompts import keyword_extraction, question_proposal, content_tagging
@ -45,6 +43,7 @@ import tracemalloc
import resource
import signal
import trio
import exceptiongroup
import numpy as np
from peewee import DoesNotExist
@ -453,24 +452,6 @@ async def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
return res, tk_count
async def run_graphrag(row, chat_model, language, embedding_model, callback=None):
chunks = []
for d in settings.retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])],
fields=["content_with_weight", "doc_id"]):
chunks.append((d["doc_id"], d["content_with_weight"]))
dealer = Dealer(LightKGExt if row["parser_config"]["graphrag"]["method"] != 'general' else GeneralKGExt,
row["tenant_id"],
str(row["kb_id"]),
chat_model,
chunks=chunks,
language=language,
entity_types=row["parser_config"]["graphrag"]["entity_types"],
embed_bdl=embedding_model,
callback=callback)
await dealer()
async def do_handle_task(task):
task_id = task["id"]
task_from_page = task["from_page"]
@ -526,24 +507,10 @@ async def do_handle_task(task):
return
start_ts = timer()
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
await run_graphrag(task, chat_model, task_language, embedding_model, progress_callback)
progress_callback(prog=1.0, msg="Knowledge Graph basic is done ({:.2f}s)".format(timer() - start_ts))
if graphrag_conf.get("resolution", False):
start_ts = timer()
with_res = WithResolution(
task["tenant_id"], str(task["kb_id"]), chat_model, embedding_model,
progress_callback
)
await with_res()
progress_callback(prog=1.0, msg="Knowledge Graph resolution is done ({:.2f}s)".format(timer() - start_ts))
if graphrag_conf.get("community", False):
start_ts = timer()
with_comm = WithCommunity(
task["tenant_id"], str(task["kb_id"]), chat_model, embedding_model,
progress_callback
)
await with_comm()
progress_callback(prog=1.0, msg="Knowledge Graph community is done ({:.2f}s)".format(timer() - start_ts))
with_resolution = graphrag_conf.get("resolution", False)
with_community = graphrag_conf.get("community", False)
await run_graphrag(task, task_language, with_resolution, with_community, chat_model, embedding_model, progress_callback)
progress_callback(prog=1.0, msg="Knowledge Graph done ({:.2f}s)".format(timer() - start_ts))
return
else:
# Standard chunking methods
@ -622,7 +589,11 @@ async def handle_task():
FAILED_TASKS += 1
CURRENT_TASKS.pop(task["id"], None)
try:
set_progress(task["id"], prog=-1, msg=f"[Exception]: {e}")
err_msg = str(e)
while isinstance(e, exceptiongroup.ExceptionGroup):
e = e.exceptions[0]
err_msg += ' -- ' + str(e)
set_progress(task["id"], prog=-1, msg=f"[Exception]: {err_msg}")
except Exception:
pass
logging.exception(f"handle_task got exception for task {json.dumps(task)}")