Refa:replace trio with asyncio (#11831)

### What problem does this PR solve?

change:
replace trio with asyncio

### Type of change
- [x] Refactoring
This commit is contained in:
buua436
2025-12-09 19:23:14 +08:00
committed by GitHub
parent ca2d6f3301
commit 65a5a56d95
31 changed files with 821 additions and 429 deletions

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import logging
import itertools
import os
@ -21,7 +22,6 @@ from dataclasses import dataclass
from typing import Any, Callable
import networkx as nx
import trio
from graphrag.general.extractor import Extractor
from rag.nlp import is_english
@ -101,35 +101,56 @@ class EntityResolution(Extractor):
remain_candidates_to_resolve = num_candidates
resolution_result = set()
resolution_result_lock = trio.Lock()
resolution_result_lock = asyncio.Lock()
resolution_batch_size = 100
max_concurrent_tasks = 5
semaphore = trio.Semaphore(max_concurrent_tasks)
semaphore = asyncio.Semaphore(max_concurrent_tasks)
async def limited_resolve_candidate(candidate_batch, result_set, result_lock):
nonlocal remain_candidates_to_resolve, callback
async with semaphore:
try:
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
with trio.move_on_after(280 if enable_timeout_assertion else 1000000000) as cancel_scope:
await self._resolve_candidate(candidate_batch, result_set, result_lock, task_id)
remain_candidates_to_resolve = remain_candidates_to_resolve - len(candidate_batch[1])
callback(msg=f"Resolved {len(candidate_batch[1])} pairs, {remain_candidates_to_resolve} are remained to resolve. ")
if cancel_scope.cancelled_caught:
timeout_sec = 280 if enable_timeout_assertion else 1_000_000_000
try:
await asyncio.wait_for(
self._resolve_candidate(candidate_batch, result_set, result_lock, task_id),
timeout=timeout_sec
)
remain_candidates_to_resolve -= len(candidate_batch[1])
callback(
msg=f"Resolved {len(candidate_batch[1])} pairs, "
f"{remain_candidates_to_resolve} remain."
)
except asyncio.TimeoutError:
logging.warning(f"Timeout resolving {candidate_batch}, skipping...")
remain_candidates_to_resolve = remain_candidates_to_resolve - len(candidate_batch[1])
callback(msg=f"Fail to resolved {len(candidate_batch[1])} pairs due to timeout reason, skipped. {remain_candidates_to_resolve} are remained to resolve. ")
remain_candidates_to_resolve -= len(candidate_batch[1])
callback(
msg=f"Failed to resolve {len(candidate_batch[1])} pairs due to timeout, skipped. "
f"{remain_candidates_to_resolve} remain."
)
except Exception as e:
logging.error(f"Error resolving candidate batch: {e}")
async with trio.open_nursery() as nursery:
for candidate_resolution_i in candidate_resolution.items():
if not candidate_resolution_i[1]:
continue
for i in range(0, len(candidate_resolution_i[1]), resolution_batch_size):
candidate_batch = candidate_resolution_i[0], candidate_resolution_i[1][i:i + resolution_batch_size]
nursery.start_soon(limited_resolve_candidate, candidate_batch, resolution_result, resolution_result_lock)
tasks = []
for key, lst in candidate_resolution.items():
if not lst:
continue
for i in range(0, len(lst), resolution_batch_size):
batch = (key, lst[i:i + resolution_batch_size])
tasks.append(limited_resolve_candidate(batch, resolution_result, resolution_result_lock))
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error(f"Error resolving candidate pairs: {e}")
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
callback(msg=f"Resolved {num_candidates} candidate pairs, {len(resolution_result)} of them are selected to merge.")
@ -141,10 +162,19 @@ class EntityResolution(Extractor):
async with semaphore:
await self._merge_graph_nodes(graph, nodes, change, task_id)
async with trio.open_nursery() as nursery:
for sub_connect_graph in nx.connected_components(connect_graph):
merging_nodes = list(sub_connect_graph)
nursery.start_soon(limited_merge_nodes, graph, merging_nodes, change)
tasks = []
for sub_connect_graph in nx.connected_components(connect_graph):
merging_nodes = list(sub_connect_graph)
tasks.append(asyncio.create_task(limited_merge_nodes(graph, merging_nodes, change))
)
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error(f"Error merging nodes: {e}")
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
# Update pagerank
pr = nx.pagerank(graph)
@ -156,7 +186,7 @@ class EntityResolution(Extractor):
change=change,
)
async def _resolve_candidate(self, candidate_resolution_i: tuple[str, list[tuple[str, str]]], resolution_result: set[str], resolution_result_lock: trio.Lock, task_id: str = ""):
async def _resolve_candidate(self, candidate_resolution_i: tuple[str, list[tuple[str, str]]], resolution_result: set[str], resolution_result_lock: asyncio.Lock, task_id: str = ""):
if task_id:
if has_canceled(task_id):
logging.info(f"Task {task_id} cancelled during entity resolution candidate processing.")
@ -178,13 +208,22 @@ class EntityResolution(Extractor):
text = perform_variable_replacements(self._resolution_prompt, variables=variables)
logging.info(f"Created resolution prompt {len(text)} bytes for {len(candidate_resolution_i[1])} entity pairs of type {candidate_resolution_i[0]}")
async with chat_limiter:
timeout_seconds = 280 if os.environ.get("ENABLE_TIMEOUT_ASSERTION") else 1000000000
try:
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
with trio.move_on_after(280 if enable_timeout_assertion else 1000000000) as cancel_scope:
response = await trio.to_thread.run_sync(self._chat, text, [{"role": "user", "content": "Output:"}], {}, task_id)
if cancel_scope.cancelled_caught:
logging.warning("_resolve_candidate._chat timeout, skipping...")
return
response = await asyncio.wait_for(
asyncio.to_thread(
self._chat,
text,
[{"role": "user", "content": "Output:"}],
{},
task_id
),
timeout=timeout_seconds,
)
except asyncio.TimeoutError:
logging.warning("_resolve_candidate._chat timeout, skipping...")
return
except Exception as e:
logging.error(f"_resolve_candidate._chat failed: {e}")
return