Refa:replace trio with asyncio (#11831)

### What problem does this PR solve?

change:
replace trio with asyncio

### Type of change
- [x] Refactoring
This commit is contained in:
buua436
2025-12-09 19:23:14 +08:00
committed by GitHub
parent ca2d6f3301
commit 65a5a56d95
31 changed files with 821 additions and 429 deletions

View File

@ -13,11 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import logging
import re
import numpy as np
import trio
import umap
from sklearn.mixture import GaussianMixture
@ -56,37 +56,37 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
@timeout(60 * 20)
async def _chat(self, system, history, gen_conf):
cached = await trio.to_thread.run_sync(lambda: get_llm_cache(self._llm_model.llm_name, system, history, gen_conf))
cached = await asyncio.to_thread(get_llm_cache, self._llm_model.llm_name, system, history, gen_conf)
if cached:
return cached
last_exc = None
for attempt in range(3):
try:
response = await trio.to_thread.run_sync(lambda: self._llm_model.chat(system, history, gen_conf))
response = await asyncio.to_thread(self._llm_model.chat, system, history, gen_conf)
response = re.sub(r"^.*</think>", "", response, flags=re.DOTALL)
if response.find("**ERROR**") >= 0:
raise Exception(response)
await trio.to_thread.run_sync(lambda: set_llm_cache(self._llm_model.llm_name, system, response, history, gen_conf))
await asyncio.to_thread(set_llm_cache,self._llm_model.llm_name,system,response,history,gen_conf)
return response
except Exception as exc:
last_exc = exc
logging.warning("RAPTOR LLM call failed on attempt %d/3: %s", attempt + 1, exc)
if attempt < 2:
await trio.sleep(1 + attempt)
await asyncio.sleep(1 + attempt)
raise last_exc if last_exc else Exception("LLM chat failed without exception")
@timeout(20)
async def _embedding_encode(self, txt):
response = await trio.to_thread.run_sync(lambda: get_embed_cache(self._embd_model.llm_name, txt))
response = await asyncio.to_thread(get_embed_cache, self._embd_model.llm_name, txt)
if response is not None:
return response
embds, _ = await trio.to_thread.run_sync(lambda: self._embd_model.encode([txt]))
embds, _ = await asyncio.to_thread(self._embd_model.encode, [txt])
if len(embds) < 1 or len(embds[0]) < 1:
raise Exception("Embedding error: ")
embds = embds[0]
await trio.to_thread.run_sync(lambda: set_embed_cache(self._embd_model.llm_name, txt, embds))
await asyncio.to_thread(set_embed_cache, self._embd_model.llm_name, txt, embds)
return embds
def _get_optimal_clusters(self, embeddings: np.ndarray, random_state: int, task_id: str = ""):
@ -198,16 +198,22 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
lbls = [np.where(prob > self._threshold)[0] for prob in probs]
lbls = [lbl[0] if isinstance(lbl, np.ndarray) else lbl for lbl in lbls]
async with trio.open_nursery() as nursery:
for c in range(n_clusters):
ck_idx = [i + start for i in range(len(lbls)) if lbls[i] == c]
assert len(ck_idx) > 0
if task_id and has_canceled(task_id):
logging.info(f"Task {task_id} cancelled before RAPTOR cluster processing.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
nursery.start_soon(summarize, ck_idx)
tasks = []
for c in range(n_clusters):
ck_idx = [i + start for i in range(len(lbls)) if lbls[i] == c]
assert len(ck_idx) > 0
if task_id and has_canceled(task_id):
logging.info(f"Task {task_id} cancelled before RAPTOR cluster processing.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
tasks.append(asyncio.create_task(summarize(ck_idx)))
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error(f"Error in RAPTOR cluster processing: {e}")
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
assert len(chunks) - end == n_clusters, "{} vs. {}".format(len(chunks) - end, n_clusters)
labels.extend(lbls)