mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-25 08:06:48 +08:00
Refa:replace trio with asyncio (#11831)
### What problem does this PR solve? change: replace trio with asyncio ### Type of change - [x] Refactoring
This commit is contained in:
@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import inspect
|
||||
import json
|
||||
@ -25,7 +26,6 @@ from functools import wraps
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
import trio
|
||||
from quart import (
|
||||
Response,
|
||||
jsonify,
|
||||
@ -681,18 +681,37 @@ async def is_strong_enough(chat_model, embedding_model):
|
||||
async def _is_strong_enough():
|
||||
nonlocal chat_model, embedding_model
|
||||
if embedding_model:
|
||||
with trio.fail_after(10):
|
||||
_ = await trio.to_thread.run_sync(lambda: embedding_model.encode(["Are you strong enough!?"]))
|
||||
await asyncio.wait_for(
|
||||
asyncio.to_thread(embedding_model.encode, ["Are you strong enough!?"]),
|
||||
timeout=10
|
||||
)
|
||||
|
||||
if chat_model:
|
||||
with trio.fail_after(30):
|
||||
res = await trio.to_thread.run_sync(lambda: chat_model.chat("Nothing special.", [{"role": "user", "content": "Are you strong enough!?"}], {}))
|
||||
if res.find("**ERROR**") >= 0:
|
||||
res = await asyncio.wait_for(
|
||||
asyncio.to_thread(
|
||||
chat_model.chat,
|
||||
"Nothing special.",
|
||||
[{"role": "user", "content": "Are you strong enough!?"}],
|
||||
{}
|
||||
),
|
||||
timeout=30
|
||||
)
|
||||
if "**ERROR**" in res:
|
||||
raise Exception(res)
|
||||
|
||||
# Pressure test for GraphRAG task
|
||||
async with trio.open_nursery() as nursery:
|
||||
for _ in range(count):
|
||||
nursery.start_soon(_is_strong_enough)
|
||||
tasks = [
|
||||
asyncio.create_task(_is_strong_enough())
|
||||
for _ in range(count)
|
||||
]
|
||||
try:
|
||||
await asyncio.gather(*tasks, return_exceptions=False)
|
||||
except Exception as e:
|
||||
logging.error(f"Pressure test failed: {e}")
|
||||
for t in tasks:
|
||||
t.cancel()
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
raise
|
||||
|
||||
|
||||
def get_allowed_llm_factories() -> list:
|
||||
|
||||
Reference in New Issue
Block a user