Refa: asyncio.to_thread to ThreadPoolExecutor to break thread limitat… (#12716)

### Type of change

- [x] Refactoring
This commit is contained in:
Kevin Hu
2026-01-20 13:29:37 +08:00
committed by GitHub
parent 120648ac81
commit 927db0b373
30 changed files with 246 additions and 157 deletions

View File

@ -32,6 +32,7 @@ from graphrag.utils import (
set_embed_cache,
set_llm_cache,
)
from common.misc_utils import thread_pool_exec
class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
@ -56,7 +57,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
@timeout(60 * 20)
async def _chat(self, system, history, gen_conf):
cached = await asyncio.to_thread(get_llm_cache, self._llm_model.llm_name, system, history, gen_conf)
cached = await thread_pool_exec(get_llm_cache, self._llm_model.llm_name, system, history, gen_conf)
if cached:
return cached
@ -67,7 +68,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
response = re.sub(r"^.*</think>", "", response, flags=re.DOTALL)
if response.find("**ERROR**") >= 0:
raise Exception(response)
await asyncio.to_thread(set_llm_cache,self._llm_model.llm_name,system,response,history,gen_conf)
await thread_pool_exec(set_llm_cache,self._llm_model.llm_name,system,response,history,gen_conf)
return response
except Exception as exc:
last_exc = exc
@ -79,14 +80,14 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
@timeout(20)
async def _embedding_encode(self, txt):
response = await asyncio.to_thread(get_embed_cache, self._embd_model.llm_name, txt)
response = await thread_pool_exec(get_embed_cache, self._embd_model.llm_name, txt)
if response is not None:
return response
embds, _ = await asyncio.to_thread(self._embd_model.encode, [txt])
embds, _ = await thread_pool_exec(self._embd_model.encode, [txt])
if len(embds) < 1 or len(embds[0]) < 1:
raise Exception("Embedding error: ")
embds = embds[0]
await asyncio.to_thread(set_embed_cache, self._embd_model.llm_name, txt, embds)
await thread_pool_exec(set_embed_cache, self._embd_model.llm_name, txt, embds)
return embds
def _get_optimal_clusters(self, embeddings: np.ndarray, random_state: int, task_id: str = ""):