mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-19 12:06:42 +08:00
Refa:replace trio with asyncio (#11831)
### What problem does this PR solve? change: replace trio with asyncio ### Type of change - [x] Refactoring
This commit is contained in:
@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
@ -21,7 +22,6 @@ from copy import deepcopy
|
||||
from typing import Callable
|
||||
|
||||
import networkx as nx
|
||||
import trio
|
||||
|
||||
from api.db.services.task_service import has_canceled
|
||||
from common.connection_utils import timeout
|
||||
@ -109,14 +109,14 @@ class Extractor:
|
||||
|
||||
async def __call__(self, doc_id: str, chunks: list[str], callback: Callable | None = None, task_id: str = ""):
|
||||
self.callback = callback
|
||||
start_ts = trio.current_time()
|
||||
start_ts = asyncio.get_running_loop().time()
|
||||
|
||||
async def extract_all(doc_id, chunks, max_concurrency=MAX_CONCURRENT_PROCESS_AND_EXTRACT_CHUNK, task_id=""):
|
||||
out_results = []
|
||||
error_count = 0
|
||||
max_errors = int(os.environ.get("GRAPHRAG_MAX_ERRORS", 3))
|
||||
|
||||
limiter = trio.Semaphore(max_concurrency)
|
||||
limiter = asyncio.Semaphore(max_concurrency)
|
||||
|
||||
async def worker(chunk_key_dp: tuple[str, str], idx: int, total: int, task_id=""):
|
||||
nonlocal error_count
|
||||
@ -137,9 +137,19 @@ class Extractor:
|
||||
if error_count > max_errors:
|
||||
raise Exception(f"Maximum error count ({max_errors}) reached. Last errors: {str(e)}")
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
for i, ck in enumerate(chunks):
|
||||
nursery.start_soon(worker, (doc_id, ck), i, len(chunks), task_id)
|
||||
tasks = [
|
||||
asyncio.create_task(worker((doc_id, ck), i, len(chunks), task_id))
|
||||
for i, ck in enumerate(chunks)
|
||||
]
|
||||
|
||||
try:
|
||||
await asyncio.gather(*tasks, return_exceptions=False)
|
||||
except Exception as e:
|
||||
logging.error(f"Error in worker: {str(e)}")
|
||||
for t in tasks:
|
||||
t.cancel()
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
raise
|
||||
|
||||
if error_count > 0:
|
||||
warning_msg = f"Completed with {error_count} errors (out of {len(chunks)} chunks processed)"
|
||||
@ -166,7 +176,7 @@ class Extractor:
|
||||
for k, v in m_edges.items():
|
||||
maybe_edges[tuple(sorted(k))].extend(v)
|
||||
sum_token_count += token_count
|
||||
now = trio.current_time()
|
||||
now = asyncio.get_running_loop().time()
|
||||
if self.callback:
|
||||
self.callback(msg=f"Entities and relationships extraction done, {len(maybe_nodes)} nodes, {len(maybe_edges)} edges, {sum_token_count} tokens, {now - start_ts:.2f}s.")
|
||||
start_ts = now
|
||||
@ -176,14 +186,23 @@ class Extractor:
|
||||
if task_id and has_canceled(task_id):
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled before nodes merging")
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
for en_nm, ents in maybe_nodes.items():
|
||||
nursery.start_soon(self._merge_nodes, en_nm, ents, all_entities_data, task_id)
|
||||
tasks = [
|
||||
asyncio.create_task(self._merge_nodes(en_nm, ents, all_entities_data, task_id))
|
||||
for en_nm, ents in maybe_nodes.items()
|
||||
]
|
||||
try:
|
||||
await asyncio.gather(*tasks, return_exceptions=False)
|
||||
except Exception as e:
|
||||
logging.error(f"Error merging nodes: {e}")
|
||||
for t in tasks:
|
||||
t.cancel()
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
raise
|
||||
|
||||
if task_id and has_canceled(task_id):
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled after nodes merging")
|
||||
|
||||
now = trio.current_time()
|
||||
now = asyncio.get_running_loop().time()
|
||||
if self.callback:
|
||||
self.callback(msg=f"Entities merging done, {now - start_ts:.2f}s.")
|
||||
|
||||
@ -194,14 +213,26 @@ class Extractor:
|
||||
if task_id and has_canceled(task_id):
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled before relationships merging")
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
for (src, tgt), rels in maybe_edges.items():
|
||||
nursery.start_soon(self._merge_edges, src, tgt, rels, all_relationships_data, task_id)
|
||||
tasks = []
|
||||
for (src, tgt), rels in maybe_edges.items():
|
||||
tasks.append(
|
||||
asyncio.create_task(
|
||||
self._merge_edges(src, tgt, rels, all_relationships_data, task_id)
|
||||
)
|
||||
)
|
||||
try:
|
||||
await asyncio.gather(*tasks, return_exceptions=False)
|
||||
except Exception as e:
|
||||
logging.error(f"Error during relationships merging: {e}")
|
||||
for t in tasks:
|
||||
t.cancel()
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
raise
|
||||
|
||||
if task_id and has_canceled(task_id):
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled after relationships merging")
|
||||
|
||||
now = trio.current_time()
|
||||
now = asyncio.get_running_loop().time()
|
||||
if self.callback:
|
||||
self.callback(msg=f"Relationships merging done, {now - start_ts:.2f}s.")
|
||||
|
||||
@ -309,5 +340,5 @@ class Extractor:
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled during summary handling")
|
||||
|
||||
async with chat_limiter:
|
||||
summary = await trio.to_thread.run_sync(self._chat, "", [{"role": "user", "content": use_prompt}], {}, task_id)
|
||||
summary = await asyncio.to_thread(self._chat, "", [{"role": "user", "content": use_prompt}], {}, task_id)
|
||||
return summary
|
||||
|
||||
Reference in New Issue
Block a user