From 312635cb13247a8ecc4c4e126081b72e5ab6bf8b Mon Sep 17 00:00:00 2001 From: Stephen Hu Date: Fri, 22 Aug 2025 10:58:02 +0800 Subject: [PATCH] Refactor: based on async await to handle Redis when raptor (#9576) ### What problem does this PR solve? based on async await to handle Redis when raptor ### Type of change - [x] Refactoring - [x] Performance Improvement --- rag/raptor.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/rag/raptor.py b/rag/raptor.py index 9e0a8ad97..6ce776a68 100644 --- a/rag/raptor.py +++ b/rag/raptor.py @@ -44,7 +44,10 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: @timeout(60*20) async def _chat(self, system, history, gen_conf): - response = get_llm_cache(self._llm_model.llm_name, system, history, gen_conf) + response = await trio.to_thread.run_sync( + lambda: get_llm_cache(self._llm_model.llm_name, system, history, gen_conf) + ) + if response: return response response = await trio.to_thread.run_sync( @@ -53,19 +56,23 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: response = re.sub(r"^.*", "", response, flags=re.DOTALL) if response.find("**ERROR**") >= 0: raise Exception(response) - set_llm_cache(self._llm_model.llm_name, system, response, history, gen_conf) + await trio.to_thread.run_sync( + lambda: set_llm_cache(self._llm_model.llm_name, system, response, history, gen_conf) + ) return response @timeout(20) async def _embedding_encode(self, txt): - response = get_embed_cache(self._embd_model.llm_name, txt) + response = await trio.to_thread.run_sync( + lambda: get_embed_cache(self._embd_model.llm_name, txt) + ) if response is not None: return response embds, _ = await trio.to_thread.run_sync(lambda: self._embd_model.encode([txt])) if len(embds) < 1 or len(embds[0]) < 1: raise Exception("Embedding error: ") embds = embds[0] - set_embed_cache(self._embd_model.llm_name, txt, embds) + await trio.to_thread.run_sync(lambda: set_embed_cache(self._embd_model.llm_name, txt, embds)) return embds def _get_optimal_clusters(self, embeddings: np.ndarray, random_state: int):