mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-19 20:16:49 +08:00
Refa:replace trio with asyncio (#11831)
### What problem does this PR solve? change: replace trio with asyncio ### Type of change - [x] Refactoring
This commit is contained in:
@ -5,6 +5,7 @@ Reference:
|
||||
- [graphrag](https://github.com/microsoft/graphrag)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import json
|
||||
import os
|
||||
@ -24,7 +25,6 @@ from graphrag.general.leiden import add_community_info2graph
|
||||
from rag.llm.chat_model import Base as CompletionLLM
|
||||
from graphrag.utils import perform_variable_replacements, dict_has_keys_with_types, chat_limiter
|
||||
from common.token_utils import num_tokens_from_string
|
||||
import trio
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -101,14 +101,11 @@ class CommunityReportsExtractor(Extractor):
|
||||
text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables)
|
||||
async with chat_limiter:
|
||||
try:
|
||||
with trio.move_on_after(180 if enable_timeout_assertion else 1000000000) as cancel_scope:
|
||||
if task_id and has_canceled(task_id):
|
||||
logging.info(f"Task {task_id} cancelled before LLM call.")
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||
response = await trio.to_thread.run_sync( self._chat, text, [{"role": "user", "content": "Output:"}], {}, task_id)
|
||||
if cancel_scope.cancelled_caught:
|
||||
logging.warning("extract_community_report._chat timeout, skipping...")
|
||||
return
|
||||
timeout = 180 if enable_timeout_assertion else 1000000000
|
||||
response = await asyncio.wait_for(asyncio.to_thread(self._chat,text,[{"role": "user", "content": "Output:"}],{},task_id),timeout=timeout)
|
||||
except asyncio.TimeoutError:
|
||||
logging.warning("extract_community_report._chat timeout, skipping...")
|
||||
return
|
||||
except Exception as e:
|
||||
logging.error(f"extract_community_report._chat failed: {e}")
|
||||
return
|
||||
@ -141,17 +138,25 @@ class CommunityReportsExtractor(Extractor):
|
||||
if callback:
|
||||
callback(msg=f"Communities: {over}/{total}, used tokens: {token_count}")
|
||||
|
||||
st = trio.current_time()
|
||||
async with trio.open_nursery() as nursery:
|
||||
for level, comm in communities.items():
|
||||
logging.info(f"Level {level}: Community: {len(comm.keys())}")
|
||||
for community in comm.items():
|
||||
if task_id and has_canceled(task_id):
|
||||
logging.info(f"Task {task_id} cancelled before community processing.")
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||
nursery.start_soon(extract_community_report, community)
|
||||
st = asyncio.get_running_loop().time()
|
||||
tasks = []
|
||||
for level, comm in communities.items():
|
||||
logging.info(f"Level {level}: Community: {len(comm.keys())}")
|
||||
for community in comm.items():
|
||||
if task_id and has_canceled(task_id):
|
||||
logging.info(f"Task {task_id} cancelled before community processing.")
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||
tasks.append(asyncio.create_task(extract_community_report(community)))
|
||||
try:
|
||||
await asyncio.gather(*tasks, return_exceptions=False)
|
||||
except Exception as e:
|
||||
logging.error(f"Error in community processing: {e}")
|
||||
for t in tasks:
|
||||
t.cancel()
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
raise
|
||||
if callback:
|
||||
callback(msg=f"Community reports done in {trio.current_time() - st:.2f}s, used tokens: {token_count}")
|
||||
callback(msg=f"Community reports done in {asyncio.get_running_loop().time() - st:.2f}s, used tokens: {token_count}")
|
||||
|
||||
return CommunityReportsResult(
|
||||
structured_output=res_dict,
|
||||
|
||||
Reference in New Issue
Block a user