mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-02-02 08:35:08 +08:00
Refa: asyncio.to_thread to ThreadPoolExecutor to break thread limitat… (#12716)
### Type of change - [x] Refactoring
This commit is contained in:
@ -32,6 +32,7 @@ from graphrag.utils import (
|
||||
set_embed_cache,
|
||||
set_llm_cache,
|
||||
)
|
||||
from common.misc_utils import thread_pool_exec
|
||||
|
||||
|
||||
class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
||||
@ -56,7 +57,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
||||
|
||||
@timeout(60 * 20)
|
||||
async def _chat(self, system, history, gen_conf):
|
||||
cached = await asyncio.to_thread(get_llm_cache, self._llm_model.llm_name, system, history, gen_conf)
|
||||
cached = await thread_pool_exec(get_llm_cache, self._llm_model.llm_name, system, history, gen_conf)
|
||||
if cached:
|
||||
return cached
|
||||
|
||||
@ -67,7 +68,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
||||
response = re.sub(r"^.*</think>", "", response, flags=re.DOTALL)
|
||||
if response.find("**ERROR**") >= 0:
|
||||
raise Exception(response)
|
||||
await asyncio.to_thread(set_llm_cache,self._llm_model.llm_name,system,response,history,gen_conf)
|
||||
await thread_pool_exec(set_llm_cache,self._llm_model.llm_name,system,response,history,gen_conf)
|
||||
return response
|
||||
except Exception as exc:
|
||||
last_exc = exc
|
||||
@ -79,14 +80,14 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
||||
|
||||
@timeout(20)
|
||||
async def _embedding_encode(self, txt):
|
||||
response = await asyncio.to_thread(get_embed_cache, self._embd_model.llm_name, txt)
|
||||
response = await thread_pool_exec(get_embed_cache, self._embd_model.llm_name, txt)
|
||||
if response is not None:
|
||||
return response
|
||||
embds, _ = await asyncio.to_thread(self._embd_model.encode, [txt])
|
||||
embds, _ = await thread_pool_exec(self._embd_model.encode, [txt])
|
||||
if len(embds) < 1 or len(embds[0]) < 1:
|
||||
raise Exception("Embedding error: ")
|
||||
embds = embds[0]
|
||||
await asyncio.to_thread(set_embed_cache, self._embd_model.llm_name, txt, embds)
|
||||
await thread_pool_exec(set_embed_cache, self._embd_model.llm_name, txt, embds)
|
||||
return embds
|
||||
|
||||
def _get_optimal_clusters(self, embeddings: np.ndarray, random_state: int, task_id: str = ""):
|
||||
|
||||
Reference in New Issue
Block a user