mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Light GraphRAG (#4585)
### What problem does this PR solve? #4543 ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -19,6 +19,9 @@
|
||||
import random
|
||||
import sys
|
||||
from api.utils.log_utils import initRootLogger
|
||||
from graphrag.general.index import WithCommunity, WithResolution, Dealer
|
||||
from graphrag.light.graph_extractor import GraphExtractor as LightKGExt
|
||||
from graphrag.general.graph_extractor import GraphExtractor as GeneralKGExt
|
||||
from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache
|
||||
|
||||
CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
|
||||
@ -53,7 +56,7 @@ from api import settings
|
||||
from api.versions import get_ragflow_version
|
||||
from api.db.db_models import close_connection
|
||||
from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, \
|
||||
knowledge_graph, email, tag
|
||||
email, tag
|
||||
from rag.nlp import search, rag_tokenizer
|
||||
from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
|
||||
from rag.settings import DOC_MAXIMUM_SIZE, SVR_QUEUE_NAME, print_rag_settings, TAG_FLD, PAGERANK_FLD
|
||||
@ -78,7 +81,7 @@ FACTORY = {
|
||||
ParserType.ONE.value: one,
|
||||
ParserType.AUDIO.value: audio,
|
||||
ParserType.EMAIL.value: email,
|
||||
ParserType.KG.value: knowledge_graph,
|
||||
ParserType.KG.value: naive,
|
||||
ParserType.TAG.value: tag
|
||||
}
|
||||
|
||||
@ -118,7 +121,8 @@ def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing...
|
||||
|
||||
if to_page > 0:
|
||||
if msg:
|
||||
msg = f"Page({from_page + 1}~{to_page + 1}): " + msg
|
||||
if from_page < to_page:
|
||||
msg = f"Page({from_page + 1}~{to_page + 1}): " + msg
|
||||
if msg:
|
||||
msg = datetime.now().strftime("%H:%M:%S") + " " + msg
|
||||
d = {"progress_msg": msg}
|
||||
@ -177,8 +181,7 @@ def collect():
|
||||
logging.info(f"collect task {msg['id']} {state}")
|
||||
return None
|
||||
|
||||
if msg.get("type", "") == "raptor":
|
||||
task["task_type"] = "raptor"
|
||||
task["task_type"] = msg.get("task_type", "")
|
||||
return task
|
||||
|
||||
|
||||
@ -382,11 +385,9 @@ def embedding(docs, mdl, parser_config=None, callback=None):
|
||||
return tk_count, vector_size
|
||||
|
||||
|
||||
def run_raptor(row, chat_mdl, embd_mdl, callback=None):
|
||||
vts, _ = embd_mdl.encode(["ok"])
|
||||
vector_size = len(vts[0])
|
||||
vctr_nm = "q_%d_vec" % vector_size
|
||||
def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
|
||||
chunks = []
|
||||
vctr_nm = "q_%d_vec"%vector_size
|
||||
for d in settings.retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])],
|
||||
fields=["content_with_weight", vctr_nm]):
|
||||
chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
|
||||
@ -422,7 +423,24 @@ def run_raptor(row, chat_mdl, embd_mdl, callback=None):
|
||||
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
|
||||
res.append(d)
|
||||
tk_count += num_tokens_from_string(content)
|
||||
return res, tk_count, vector_size
|
||||
return res, tk_count
|
||||
|
||||
|
||||
def run_graphrag(row, chat_model, language, embedding_model, callback=None):
|
||||
chunks = []
|
||||
for d in settings.retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])],
|
||||
fields=["content_with_weight", "doc_id"]):
|
||||
chunks.append((d["doc_id"], d["content_with_weight"]))
|
||||
|
||||
Dealer(LightKGExt if row["parser_config"]["graphrag"]["method"] != 'general' else GeneralKGExt,
|
||||
row["tenant_id"],
|
||||
str(row["kb_id"]),
|
||||
chat_model,
|
||||
chunks=chunks,
|
||||
language=language,
|
||||
entity_types=row["parser_config"]["graphrag"]["entity_types"],
|
||||
embed_bdl=embedding_model,
|
||||
callback=callback)
|
||||
|
||||
|
||||
def do_handle_task(task):
|
||||
@ -466,14 +484,17 @@ def do_handle_task(task):
|
||||
logging.exception(error_message)
|
||||
raise
|
||||
|
||||
vts, _ = embedding_model.encode(["ok"])
|
||||
vector_size = len(vts[0])
|
||||
init_kb(task, vector_size)
|
||||
|
||||
# Either using RAPTOR or Standard chunking methods
|
||||
if task.get("task_type", "") == "raptor":
|
||||
try:
|
||||
# bind LLM for raptor
|
||||
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
|
||||
|
||||
# run RAPTOR
|
||||
chunks, token_count, vector_size = run_raptor(task, chat_model, embedding_model, progress_callback)
|
||||
chunks, token_count = run_raptor(task, chat_model, embedding_model, vector_size, progress_callback)
|
||||
except TaskCanceledException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@ -481,6 +502,55 @@ def do_handle_task(task):
|
||||
progress_callback(-1, msg=error_message)
|
||||
logging.exception(error_message)
|
||||
raise
|
||||
# Either using graphrag or Standard chunking methods
|
||||
elif task.get("task_type", "") == "graphrag":
|
||||
start_ts = timer()
|
||||
try:
|
||||
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
|
||||
run_graphrag(task, chat_model, task_language, embedding_model, progress_callback)
|
||||
progress_callback(prog=1.0, msg="Knowledge Graph is done ({:.2f}s)".format(timer() - start_ts))
|
||||
except TaskCanceledException:
|
||||
raise
|
||||
except Exception as e:
|
||||
error_message = f'Fail to bind LLM used by Knowledge Graph: {str(e)}'
|
||||
progress_callback(-1, msg=error_message)
|
||||
logging.exception(error_message)
|
||||
raise
|
||||
return
|
||||
elif task.get("task_type", "") == "graph_resolution":
|
||||
start_ts = timer()
|
||||
try:
|
||||
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
|
||||
WithResolution(
|
||||
task["tenant_id"], str(task["kb_id"]),chat_model, embedding_model,
|
||||
progress_callback
|
||||
)
|
||||
progress_callback(prog=1.0, msg="Knowledge Graph resolution is done ({:.2f}s)".format(timer() - start_ts))
|
||||
except TaskCanceledException:
|
||||
raise
|
||||
except Exception as e:
|
||||
error_message = f'Fail to bind LLM used by Knowledge Graph resolution: {str(e)}'
|
||||
progress_callback(-1, msg=error_message)
|
||||
logging.exception(error_message)
|
||||
raise
|
||||
return
|
||||
elif task.get("task_type", "") == "graph_community":
|
||||
start_ts = timer()
|
||||
try:
|
||||
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
|
||||
WithCommunity(
|
||||
task["tenant_id"], str(task["kb_id"]), chat_model, embedding_model,
|
||||
progress_callback
|
||||
)
|
||||
progress_callback(prog=1.0, msg="GraphRAG community reports generation is done ({:.2f}s)".format(timer() - start_ts))
|
||||
except TaskCanceledException:
|
||||
raise
|
||||
except Exception as e:
|
||||
error_message = f'Fail to bind LLM used by GraphRAG community reports generation: {str(e)}'
|
||||
progress_callback(-1, msg=error_message)
|
||||
logging.exception(error_message)
|
||||
raise
|
||||
return
|
||||
else:
|
||||
# Standard chunking methods
|
||||
start_ts = timer()
|
||||
@ -507,8 +577,6 @@ def do_handle_task(task):
|
||||
logging.info(progress_message)
|
||||
progress_callback(msg=progress_message)
|
||||
|
||||
# logging.info(f"task_executor init_kb index {search.index_name(task_tenant_id)} embedding_model {embedding_model.llm_name} vector length {vector_size}")
|
||||
init_kb(task, vector_size)
|
||||
chunk_count = len(set([chunk["id"] for chunk in chunks]))
|
||||
start_ts = timer()
|
||||
doc_store_result = ""
|
||||
|
||||
Reference in New Issue
Block a user