Light GraphRAG (#4585)

### What problem does this PR solve?

#4543

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Kevin Hu
2025-01-22 19:43:14 +08:00
committed by GitHub
parent 1a367664f1
commit dd0ebbea35
55 changed files with 5461 additions and 4000 deletions

View File

@ -19,6 +19,9 @@
import random
import sys
from api.utils.log_utils import initRootLogger
from graphrag.general.index import WithCommunity, WithResolution, Dealer
from graphrag.light.graph_extractor import GraphExtractor as LightKGExt
from graphrag.general.graph_extractor import GraphExtractor as GeneralKGExt
from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache
CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
@ -53,7 +56,7 @@ from api import settings
from api.versions import get_ragflow_version
from api.db.db_models import close_connection
from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, \
knowledge_graph, email, tag
email, tag
from rag.nlp import search, rag_tokenizer
from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
from rag.settings import DOC_MAXIMUM_SIZE, SVR_QUEUE_NAME, print_rag_settings, TAG_FLD, PAGERANK_FLD
@ -78,7 +81,7 @@ FACTORY = {
ParserType.ONE.value: one,
ParserType.AUDIO.value: audio,
ParserType.EMAIL.value: email,
ParserType.KG.value: knowledge_graph,
ParserType.KG.value: naive,
ParserType.TAG.value: tag
}
@ -118,7 +121,8 @@ def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing...
if to_page > 0:
if msg:
msg = f"Page({from_page + 1}~{to_page + 1}): " + msg
if from_page < to_page:
msg = f"Page({from_page + 1}~{to_page + 1}): " + msg
if msg:
msg = datetime.now().strftime("%H:%M:%S") + " " + msg
d = {"progress_msg": msg}
@ -177,8 +181,7 @@ def collect():
logging.info(f"collect task {msg['id']} {state}")
return None
if msg.get("type", "") == "raptor":
task["task_type"] = "raptor"
task["task_type"] = msg.get("task_type", "")
return task
@ -382,11 +385,9 @@ def embedding(docs, mdl, parser_config=None, callback=None):
return tk_count, vector_size
def run_raptor(row, chat_mdl, embd_mdl, callback=None):
vts, _ = embd_mdl.encode(["ok"])
vector_size = len(vts[0])
vctr_nm = "q_%d_vec" % vector_size
def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
chunks = []
vctr_nm = "q_%d_vec"%vector_size
for d in settings.retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])],
fields=["content_with_weight", vctr_nm]):
chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
@ -422,7 +423,24 @@ def run_raptor(row, chat_mdl, embd_mdl, callback=None):
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
res.append(d)
tk_count += num_tokens_from_string(content)
return res, tk_count, vector_size
return res, tk_count
def run_graphrag(row, chat_model, language, embedding_model, callback=None):
chunks = []
for d in settings.retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])],
fields=["content_with_weight", "doc_id"]):
chunks.append((d["doc_id"], d["content_with_weight"]))
Dealer(LightKGExt if row["parser_config"]["graphrag"]["method"] != 'general' else GeneralKGExt,
row["tenant_id"],
str(row["kb_id"]),
chat_model,
chunks=chunks,
language=language,
entity_types=row["parser_config"]["graphrag"]["entity_types"],
embed_bdl=embedding_model,
callback=callback)
def do_handle_task(task):
@ -466,14 +484,17 @@ def do_handle_task(task):
logging.exception(error_message)
raise
vts, _ = embedding_model.encode(["ok"])
vector_size = len(vts[0])
init_kb(task, vector_size)
# Either using RAPTOR or Standard chunking methods
if task.get("task_type", "") == "raptor":
try:
# bind LLM for raptor
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
# run RAPTOR
chunks, token_count, vector_size = run_raptor(task, chat_model, embedding_model, progress_callback)
chunks, token_count = run_raptor(task, chat_model, embedding_model, vector_size, progress_callback)
except TaskCanceledException:
raise
except Exception as e:
@ -481,6 +502,55 @@ def do_handle_task(task):
progress_callback(-1, msg=error_message)
logging.exception(error_message)
raise
# Either using graphrag or Standard chunking methods
elif task.get("task_type", "") == "graphrag":
start_ts = timer()
try:
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
run_graphrag(task, chat_model, task_language, embedding_model, progress_callback)
progress_callback(prog=1.0, msg="Knowledge Graph is done ({:.2f}s)".format(timer() - start_ts))
except TaskCanceledException:
raise
except Exception as e:
error_message = f'Fail to bind LLM used by Knowledge Graph: {str(e)}'
progress_callback(-1, msg=error_message)
logging.exception(error_message)
raise
return
elif task.get("task_type", "") == "graph_resolution":
start_ts = timer()
try:
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
WithResolution(
task["tenant_id"], str(task["kb_id"]),chat_model, embedding_model,
progress_callback
)
progress_callback(prog=1.0, msg="Knowledge Graph resolution is done ({:.2f}s)".format(timer() - start_ts))
except TaskCanceledException:
raise
except Exception as e:
error_message = f'Fail to bind LLM used by Knowledge Graph resolution: {str(e)}'
progress_callback(-1, msg=error_message)
logging.exception(error_message)
raise
return
elif task.get("task_type", "") == "graph_community":
start_ts = timer()
try:
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
WithCommunity(
task["tenant_id"], str(task["kb_id"]), chat_model, embedding_model,
progress_callback
)
progress_callback(prog=1.0, msg="GraphRAG community reports generation is done ({:.2f}s)".format(timer() - start_ts))
except TaskCanceledException:
raise
except Exception as e:
error_message = f'Fail to bind LLM used by GraphRAG community reports generation: {str(e)}'
progress_callback(-1, msg=error_message)
logging.exception(error_message)
raise
return
else:
# Standard chunking methods
start_ts = timer()
@ -507,8 +577,6 @@ def do_handle_task(task):
logging.info(progress_message)
progress_callback(msg=progress_message)
# logging.info(f"task_executor init_kb index {search.index_name(task_tenant_id)} embedding_model {embedding_model.llm_name} vector length {vector_size}")
init_kb(task, vector_size)
chunk_count = len(set([chunk["id"] for chunk in chunks]))
start_ts = timer()
doc_store_result = ""