Refa:replace trio with asyncio (#11831)

### What problem does this PR solve?

change:
replace trio with asyncio

### Type of change
- [x] Refactoring
This commit is contained in:
buua436
2025-12-09 19:23:14 +08:00
committed by GitHub
parent ca2d6f3301
commit 65a5a56d95
31 changed files with 821 additions and 429 deletions

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import logging
import os
import re
@ -21,7 +22,6 @@ from copy import deepcopy
from typing import Callable
import networkx as nx
import trio
from api.db.services.task_service import has_canceled
from common.connection_utils import timeout
@ -109,14 +109,14 @@ class Extractor:
async def __call__(self, doc_id: str, chunks: list[str], callback: Callable | None = None, task_id: str = ""):
self.callback = callback
start_ts = trio.current_time()
start_ts = asyncio.get_running_loop().time()
async def extract_all(doc_id, chunks, max_concurrency=MAX_CONCURRENT_PROCESS_AND_EXTRACT_CHUNK, task_id=""):
out_results = []
error_count = 0
max_errors = int(os.environ.get("GRAPHRAG_MAX_ERRORS", 3))
limiter = trio.Semaphore(max_concurrency)
limiter = asyncio.Semaphore(max_concurrency)
async def worker(chunk_key_dp: tuple[str, str], idx: int, total: int, task_id=""):
nonlocal error_count
@ -137,9 +137,19 @@ class Extractor:
if error_count > max_errors:
raise Exception(f"Maximum error count ({max_errors}) reached. Last errors: {str(e)}")
async with trio.open_nursery() as nursery:
for i, ck in enumerate(chunks):
nursery.start_soon(worker, (doc_id, ck), i, len(chunks), task_id)
tasks = [
asyncio.create_task(worker((doc_id, ck), i, len(chunks), task_id))
for i, ck in enumerate(chunks)
]
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error(f"Error in worker: {str(e)}")
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
if error_count > 0:
warning_msg = f"Completed with {error_count} errors (out of {len(chunks)} chunks processed)"
@ -166,7 +176,7 @@ class Extractor:
for k, v in m_edges.items():
maybe_edges[tuple(sorted(k))].extend(v)
sum_token_count += token_count
now = trio.current_time()
now = asyncio.get_running_loop().time()
if self.callback:
self.callback(msg=f"Entities and relationships extraction done, {len(maybe_nodes)} nodes, {len(maybe_edges)} edges, {sum_token_count} tokens, {now - start_ts:.2f}s.")
start_ts = now
@ -176,14 +186,23 @@ class Extractor:
if task_id and has_canceled(task_id):
raise TaskCanceledException(f"Task {task_id} was cancelled before nodes merging")
async with trio.open_nursery() as nursery:
for en_nm, ents in maybe_nodes.items():
nursery.start_soon(self._merge_nodes, en_nm, ents, all_entities_data, task_id)
tasks = [
asyncio.create_task(self._merge_nodes(en_nm, ents, all_entities_data, task_id))
for en_nm, ents in maybe_nodes.items()
]
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error(f"Error merging nodes: {e}")
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
if task_id and has_canceled(task_id):
raise TaskCanceledException(f"Task {task_id} was cancelled after nodes merging")
now = trio.current_time()
now = asyncio.get_running_loop().time()
if self.callback:
self.callback(msg=f"Entities merging done, {now - start_ts:.2f}s.")
@ -194,14 +213,26 @@ class Extractor:
if task_id and has_canceled(task_id):
raise TaskCanceledException(f"Task {task_id} was cancelled before relationships merging")
async with trio.open_nursery() as nursery:
for (src, tgt), rels in maybe_edges.items():
nursery.start_soon(self._merge_edges, src, tgt, rels, all_relationships_data, task_id)
tasks = []
for (src, tgt), rels in maybe_edges.items():
tasks.append(
asyncio.create_task(
self._merge_edges(src, tgt, rels, all_relationships_data, task_id)
)
)
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error(f"Error during relationships merging: {e}")
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
if task_id and has_canceled(task_id):
raise TaskCanceledException(f"Task {task_id} was cancelled after relationships merging")
now = trio.current_time()
now = asyncio.get_running_loop().time()
if self.callback:
self.callback(msg=f"Relationships merging done, {now - start_ts:.2f}s.")
@ -309,5 +340,5 @@ class Extractor:
raise TaskCanceledException(f"Task {task_id} was cancelled during summary handling")
async with chat_limiter:
summary = await trio.to_thread.run_sync(self._chat, "", [{"role": "user", "content": use_prompt}], {}, task_id)
summary = await asyncio.to_thread(self._chat, "", [{"role": "user", "content": use_prompt}], {}, task_id)
return summary