mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-19 20:16:49 +08:00
Refa:replace trio with asyncio (#11831)
### What problem does this PR solve? change: replace trio with asyncio ### Type of change - [x] Refactoring
This commit is contained in:
@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
import logging
|
||||
import itertools
|
||||
import os
|
||||
@ -21,7 +22,6 @@ from dataclasses import dataclass
|
||||
from typing import Any, Callable
|
||||
|
||||
import networkx as nx
|
||||
import trio
|
||||
|
||||
from graphrag.general.extractor import Extractor
|
||||
from rag.nlp import is_english
|
||||
@ -101,35 +101,56 @@ class EntityResolution(Extractor):
|
||||
remain_candidates_to_resolve = num_candidates
|
||||
|
||||
resolution_result = set()
|
||||
resolution_result_lock = trio.Lock()
|
||||
resolution_result_lock = asyncio.Lock()
|
||||
resolution_batch_size = 100
|
||||
max_concurrent_tasks = 5
|
||||
semaphore = trio.Semaphore(max_concurrent_tasks)
|
||||
semaphore = asyncio.Semaphore(max_concurrent_tasks)
|
||||
|
||||
async def limited_resolve_candidate(candidate_batch, result_set, result_lock):
|
||||
nonlocal remain_candidates_to_resolve, callback
|
||||
async with semaphore:
|
||||
try:
|
||||
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
|
||||
with trio.move_on_after(280 if enable_timeout_assertion else 1000000000) as cancel_scope:
|
||||
await self._resolve_candidate(candidate_batch, result_set, result_lock, task_id)
|
||||
remain_candidates_to_resolve = remain_candidates_to_resolve - len(candidate_batch[1])
|
||||
callback(msg=f"Resolved {len(candidate_batch[1])} pairs, {remain_candidates_to_resolve} are remained to resolve. ")
|
||||
if cancel_scope.cancelled_caught:
|
||||
timeout_sec = 280 if enable_timeout_assertion else 1_000_000_000
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
self._resolve_candidate(candidate_batch, result_set, result_lock, task_id),
|
||||
timeout=timeout_sec
|
||||
)
|
||||
remain_candidates_to_resolve -= len(candidate_batch[1])
|
||||
callback(
|
||||
msg=f"Resolved {len(candidate_batch[1])} pairs, "
|
||||
f"{remain_candidates_to_resolve} remain."
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logging.warning(f"Timeout resolving {candidate_batch}, skipping...")
|
||||
remain_candidates_to_resolve = remain_candidates_to_resolve - len(candidate_batch[1])
|
||||
callback(msg=f"Fail to resolved {len(candidate_batch[1])} pairs due to timeout reason, skipped. {remain_candidates_to_resolve} are remained to resolve. ")
|
||||
remain_candidates_to_resolve -= len(candidate_batch[1])
|
||||
callback(
|
||||
msg=f"Failed to resolve {len(candidate_batch[1])} pairs due to timeout, skipped. "
|
||||
f"{remain_candidates_to_resolve} remain."
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error resolving candidate batch: {e}")
|
||||
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
for candidate_resolution_i in candidate_resolution.items():
|
||||
if not candidate_resolution_i[1]:
|
||||
continue
|
||||
for i in range(0, len(candidate_resolution_i[1]), resolution_batch_size):
|
||||
candidate_batch = candidate_resolution_i[0], candidate_resolution_i[1][i:i + resolution_batch_size]
|
||||
nursery.start_soon(limited_resolve_candidate, candidate_batch, resolution_result, resolution_result_lock)
|
||||
tasks = []
|
||||
for key, lst in candidate_resolution.items():
|
||||
if not lst:
|
||||
continue
|
||||
for i in range(0, len(lst), resolution_batch_size):
|
||||
batch = (key, lst[i:i + resolution_batch_size])
|
||||
tasks.append(limited_resolve_candidate(batch, resolution_result, resolution_result_lock))
|
||||
try:
|
||||
await asyncio.gather(*tasks, return_exceptions=False)
|
||||
except Exception as e:
|
||||
logging.error(f"Error resolving candidate pairs: {e}")
|
||||
for t in tasks:
|
||||
t.cancel()
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
raise
|
||||
|
||||
callback(msg=f"Resolved {num_candidates} candidate pairs, {len(resolution_result)} of them are selected to merge.")
|
||||
|
||||
@ -141,10 +162,19 @@ class EntityResolution(Extractor):
|
||||
async with semaphore:
|
||||
await self._merge_graph_nodes(graph, nodes, change, task_id)
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
for sub_connect_graph in nx.connected_components(connect_graph):
|
||||
merging_nodes = list(sub_connect_graph)
|
||||
nursery.start_soon(limited_merge_nodes, graph, merging_nodes, change)
|
||||
tasks = []
|
||||
for sub_connect_graph in nx.connected_components(connect_graph):
|
||||
merging_nodes = list(sub_connect_graph)
|
||||
tasks.append(asyncio.create_task(limited_merge_nodes(graph, merging_nodes, change))
|
||||
)
|
||||
try:
|
||||
await asyncio.gather(*tasks, return_exceptions=False)
|
||||
except Exception as e:
|
||||
logging.error(f"Error merging nodes: {e}")
|
||||
for t in tasks:
|
||||
t.cancel()
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
raise
|
||||
|
||||
# Update pagerank
|
||||
pr = nx.pagerank(graph)
|
||||
@ -156,7 +186,7 @@ class EntityResolution(Extractor):
|
||||
change=change,
|
||||
)
|
||||
|
||||
async def _resolve_candidate(self, candidate_resolution_i: tuple[str, list[tuple[str, str]]], resolution_result: set[str], resolution_result_lock: trio.Lock, task_id: str = ""):
|
||||
async def _resolve_candidate(self, candidate_resolution_i: tuple[str, list[tuple[str, str]]], resolution_result: set[str], resolution_result_lock: asyncio.Lock, task_id: str = ""):
|
||||
if task_id:
|
||||
if has_canceled(task_id):
|
||||
logging.info(f"Task {task_id} cancelled during entity resolution candidate processing.")
|
||||
@ -178,13 +208,22 @@ class EntityResolution(Extractor):
|
||||
text = perform_variable_replacements(self._resolution_prompt, variables=variables)
|
||||
logging.info(f"Created resolution prompt {len(text)} bytes for {len(candidate_resolution_i[1])} entity pairs of type {candidate_resolution_i[0]}")
|
||||
async with chat_limiter:
|
||||
timeout_seconds = 280 if os.environ.get("ENABLE_TIMEOUT_ASSERTION") else 1000000000
|
||||
try:
|
||||
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
|
||||
with trio.move_on_after(280 if enable_timeout_assertion else 1000000000) as cancel_scope:
|
||||
response = await trio.to_thread.run_sync(self._chat, text, [{"role": "user", "content": "Output:"}], {}, task_id)
|
||||
if cancel_scope.cancelled_caught:
|
||||
logging.warning("_resolve_candidate._chat timeout, skipping...")
|
||||
return
|
||||
response = await asyncio.wait_for(
|
||||
asyncio.to_thread(
|
||||
self._chat,
|
||||
text,
|
||||
[{"role": "user", "content": "Output:"}],
|
||||
{},
|
||||
task_id
|
||||
),
|
||||
timeout=timeout_seconds,
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logging.warning("_resolve_candidate._chat timeout, skipping...")
|
||||
return
|
||||
except Exception as e:
|
||||
logging.error(f"_resolve_candidate._chat failed: {e}")
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user