diff --git a/rag/raptor.py b/rag/raptor.py index 9e0a8ad97..6ce776a68 100644 --- a/rag/raptor.py +++ b/rag/raptor.py @@ -44,7 +44,10 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: @timeout(60*20) async def _chat(self, system, history, gen_conf): - response = get_llm_cache(self._llm_model.llm_name, system, history, gen_conf) + response = await trio.to_thread.run_sync( + lambda: get_llm_cache(self._llm_model.llm_name, system, history, gen_conf) + ) + if response: return response response = await trio.to_thread.run_sync( @@ -53,19 +56,23 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: response = re.sub(r"^.*", "", response, flags=re.DOTALL) if response.find("**ERROR**") >= 0: raise Exception(response) - set_llm_cache(self._llm_model.llm_name, system, response, history, gen_conf) + await trio.to_thread.run_sync( + lambda: set_llm_cache(self._llm_model.llm_name, system, response, history, gen_conf) + ) return response @timeout(20) async def _embedding_encode(self, txt): - response = get_embed_cache(self._embd_model.llm_name, txt) + response = await trio.to_thread.run_sync( + lambda: get_embed_cache(self._embd_model.llm_name, txt) + ) if response is not None: return response embds, _ = await trio.to_thread.run_sync(lambda: self._embd_model.encode([txt])) if len(embds) < 1 or len(embds[0]) < 1: raise Exception("Embedding error: ") embds = embds[0] - set_embed_cache(self._embd_model.llm_name, txt, embds) + await trio.to_thread.run_sync(lambda: set_embed_cache(self._embd_model.llm_name, txt, embds)) return embds def _get_optimal_clusters(self, embeddings: np.ndarray, random_state: int):