Cache the result from llm for graphrag and raptor (#4051)

### What problem does this PR solve?

#4045

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Kevin Hu
2024-12-17 09:48:03 +08:00
committed by GitHub
parent 8ea631a2a0
commit cb6e9ce164
12 changed files with 161 additions and 38 deletions

View File

@ -21,6 +21,7 @@ import umap
import numpy as np
from sklearn.mixture import GaussianMixture
from graphrag.utils import get_llm_cache, get_embed_cache, set_embed_cache, set_llm_cache
from rag.utils import truncate
@ -33,6 +34,27 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
self._prompt = prompt
self._max_token = max_token
def _chat(self, system, history, gen_conf):
response = get_llm_cache(self._llm_model.llm_name, system, history, gen_conf)
if response:
return response
response = self._llm_model.chat(system, history, gen_conf)
if response.find("**ERROR**") >= 0:
raise Exception(response)
set_llm_cache(self._llm_model.llm_name, system, response, history, gen_conf)
return response
def _embedding_encode(self, txt):
response = get_embed_cache(self._embd_model.llm_name, txt)
if response:
return response
embds, _ = self._embd_model.encode([txt])
if len(embds) < 1 or len(embds[0]) < 1:
raise Exception("Embedding error: ")
embds = embds[0]
set_embed_cache(self._embd_model.llm_name, txt, embds)
return embds
def _get_optimal_clusters(self, embeddings: np.ndarray, random_state: int):
max_clusters = min(self._max_cluster, len(embeddings))
n_clusters = np.arange(1, max_clusters)
@ -57,7 +79,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
texts = [chunks[i][0] for i in ck_idx]
len_per_chunk = int((self._llm_model.max_length - self._max_token) / len(texts))
cluster_content = "\n".join([truncate(t, max(1, len_per_chunk)) for t in texts])
cnt = self._llm_model.chat("You're a helpful assistant.",
cnt = self._chat("You're a helpful assistant.",
[{"role": "user",
"content": self._prompt.format(cluster_content=cluster_content)}],
{"temperature": 0.3, "max_tokens": self._max_token}
@ -67,9 +89,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
logging.debug(f"SUM: {cnt}")
embds, _ = self._embd_model.encode([cnt])
with lock:
if not len(embds[0]):
return
chunks.append((cnt, embds[0]))
chunks.append((cnt, self._embedding_encode(cnt)))
except Exception as e:
logging.exception("summarize got exception")
return e