mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Cache the result from llm for graphrag and raptor (#4051)
### What problem does this PR solve? #4045 ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -21,6 +21,7 @@ import umap
|
||||
import numpy as np
|
||||
from sklearn.mixture import GaussianMixture
|
||||
|
||||
from graphrag.utils import get_llm_cache, get_embed_cache, set_embed_cache, set_llm_cache
|
||||
from rag.utils import truncate
|
||||
|
||||
|
||||
@ -33,6 +34,27 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
||||
self._prompt = prompt
|
||||
self._max_token = max_token
|
||||
|
||||
def _chat(self, system, history, gen_conf):
|
||||
response = get_llm_cache(self._llm_model.llm_name, system, history, gen_conf)
|
||||
if response:
|
||||
return response
|
||||
response = self._llm_model.chat(system, history, gen_conf)
|
||||
if response.find("**ERROR**") >= 0:
|
||||
raise Exception(response)
|
||||
set_llm_cache(self._llm_model.llm_name, system, response, history, gen_conf)
|
||||
return response
|
||||
|
||||
def _embedding_encode(self, txt):
|
||||
response = get_embed_cache(self._embd_model.llm_name, txt)
|
||||
if response:
|
||||
return response
|
||||
embds, _ = self._embd_model.encode([txt])
|
||||
if len(embds) < 1 or len(embds[0]) < 1:
|
||||
raise Exception("Embedding error: ")
|
||||
embds = embds[0]
|
||||
set_embed_cache(self._embd_model.llm_name, txt, embds)
|
||||
return embds
|
||||
|
||||
def _get_optimal_clusters(self, embeddings: np.ndarray, random_state: int):
|
||||
max_clusters = min(self._max_cluster, len(embeddings))
|
||||
n_clusters = np.arange(1, max_clusters)
|
||||
@ -57,7 +79,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
||||
texts = [chunks[i][0] for i in ck_idx]
|
||||
len_per_chunk = int((self._llm_model.max_length - self._max_token) / len(texts))
|
||||
cluster_content = "\n".join([truncate(t, max(1, len_per_chunk)) for t in texts])
|
||||
cnt = self._llm_model.chat("You're a helpful assistant.",
|
||||
cnt = self._chat("You're a helpful assistant.",
|
||||
[{"role": "user",
|
||||
"content": self._prompt.format(cluster_content=cluster_content)}],
|
||||
{"temperature": 0.3, "max_tokens": self._max_token}
|
||||
@ -67,9 +89,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
||||
logging.debug(f"SUM: {cnt}")
|
||||
embds, _ = self._embd_model.encode([cnt])
|
||||
with lock:
|
||||
if not len(embds[0]):
|
||||
return
|
||||
chunks.append((cnt, embds[0]))
|
||||
chunks.append((cnt, self._embedding_encode(cnt)))
|
||||
except Exception as e:
|
||||
logging.exception("summarize got exception")
|
||||
return e
|
||||
|
||||
@ -19,6 +19,8 @@
|
||||
|
||||
import sys
|
||||
from api.utils.log_utils import initRootLogger
|
||||
from graphrag.utils import get_llm_cache, set_llm_cache
|
||||
|
||||
CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
|
||||
CONSUMER_NAME = "task_executor_" + CONSUMER_NO
|
||||
initRootLogger(CONSUMER_NAME)
|
||||
@ -232,9 +234,6 @@ def build_chunks(task, progress_callback):
|
||||
if not d.get("image"):
|
||||
_ = d.pop("image", None)
|
||||
d["img_id"] = ""
|
||||
d["page_num_int"] = []
|
||||
d["position_int"] = []
|
||||
d["top_int"] = []
|
||||
docs.append(d)
|
||||
continue
|
||||
|
||||
@ -262,8 +261,16 @@ def build_chunks(task, progress_callback):
|
||||
progress_callback(msg="Start to generate keywords for every chunk ...")
|
||||
chat_mdl = LLMBundle(task["tenant_id"], LLMType.CHAT, llm_name=task["llm_id"], lang=task["language"])
|
||||
for d in docs:
|
||||
d["important_kwd"] = keyword_extraction(chat_mdl, d["content_with_weight"],
|
||||
task["parser_config"]["auto_keywords"]).split(",")
|
||||
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "keywords",
|
||||
{"topn": task["parser_config"]["auto_keywords"]})
|
||||
if not cached:
|
||||
cached = keyword_extraction(chat_mdl, d["content_with_weight"],
|
||||
task["parser_config"]["auto_keywords"])
|
||||
if cached:
|
||||
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "keywords",
|
||||
{"topn": task["parser_config"]["auto_keywords"]})
|
||||
|
||||
d["important_kwd"] = cached.split(",")
|
||||
d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"]))
|
||||
progress_callback(msg="Keywords generation completed in {:.2f}s".format(timer() - st))
|
||||
|
||||
@ -272,7 +279,15 @@ def build_chunks(task, progress_callback):
|
||||
progress_callback(msg="Start to generate questions for every chunk ...")
|
||||
chat_mdl = LLMBundle(task["tenant_id"], LLMType.CHAT, llm_name=task["llm_id"], lang=task["language"])
|
||||
for d in docs:
|
||||
d["question_kwd"] = question_proposal(chat_mdl, d["content_with_weight"], task["parser_config"]["auto_questions"]).split("\n")
|
||||
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "question",
|
||||
{"topn": task["parser_config"]["auto_questions"]})
|
||||
if not cached:
|
||||
cached = question_proposal(chat_mdl, d["content_with_weight"], task["parser_config"]["auto_questions"])
|
||||
if cached:
|
||||
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "question",
|
||||
{"topn": task["parser_config"]["auto_questions"]})
|
||||
|
||||
d["question_kwd"] = cached.split("\n")
|
||||
d["question_tks"] = rag_tokenizer.tokenize("\n".join(d["question_kwd"]))
|
||||
progress_callback(msg="Question generation completed in {:.2f}s".format(timer() - st))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user