mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-24 07:26:47 +08:00
Refa:replace trio with asyncio (#11831)
### What problem does this PR solve? change: replace trio with asyncio ### Type of change - [x] Refactoring
This commit is contained in:
@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
import binascii
|
||||
import logging
|
||||
import re
|
||||
@ -21,7 +22,6 @@ from copy import deepcopy
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from timeit import default_timer as timer
|
||||
import trio
|
||||
from langfuse import Langfuse
|
||||
from peewee import fn
|
||||
from agentic_reasoning import DeepResearcher
|
||||
@ -931,5 +931,5 @@ def gen_mindmap(question, kb_ids, tenant_id, search_config={}):
|
||||
rank_feature=label_question(question, kbs),
|
||||
)
|
||||
mindmap = MindMapExtractor(chat_mdl)
|
||||
mind_map = trio.run(mindmap, [c["content_with_weight"] for c in ranks["chunks"]])
|
||||
mind_map = asyncio.run(mindmap([c["content_with_weight"] for c in ranks["chunks"]]))
|
||||
return mind_map.output
|
||||
|
||||
@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
@ -22,7 +23,6 @@ from copy import deepcopy
|
||||
from datetime import datetime
|
||||
from io import BytesIO
|
||||
|
||||
import trio
|
||||
import xxhash
|
||||
from peewee import fn, Case, JOIN
|
||||
|
||||
@ -999,7 +999,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
||||
from graphrag.general.mind_map_extractor import MindMapExtractor
|
||||
mindmap = MindMapExtractor(llm_bdl)
|
||||
try:
|
||||
mind_map = trio.run(mindmap, [c["content_with_weight"] for c in docs if c["doc_id"] == doc_id])
|
||||
mind_map = asyncio.run(mindmap([c["content_with_weight"] for c in docs if c["doc_id"] == doc_id]))
|
||||
mind_map = json.dumps(mind_map.output, ensure_ascii=False, indent=2)
|
||||
if len(mind_map) < 32:
|
||||
raise Exception("Few content: " + mind_map)
|
||||
|
||||
@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import inspect
|
||||
import json
|
||||
@ -25,7 +26,6 @@ from functools import wraps
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
import trio
|
||||
from quart import (
|
||||
Response,
|
||||
jsonify,
|
||||
@ -681,18 +681,37 @@ async def is_strong_enough(chat_model, embedding_model):
|
||||
async def _is_strong_enough():
|
||||
nonlocal chat_model, embedding_model
|
||||
if embedding_model:
|
||||
with trio.fail_after(10):
|
||||
_ = await trio.to_thread.run_sync(lambda: embedding_model.encode(["Are you strong enough!?"]))
|
||||
await asyncio.wait_for(
|
||||
asyncio.to_thread(embedding_model.encode, ["Are you strong enough!?"]),
|
||||
timeout=10
|
||||
)
|
||||
|
||||
if chat_model:
|
||||
with trio.fail_after(30):
|
||||
res = await trio.to_thread.run_sync(lambda: chat_model.chat("Nothing special.", [{"role": "user", "content": "Are you strong enough!?"}], {}))
|
||||
if res.find("**ERROR**") >= 0:
|
||||
res = await asyncio.wait_for(
|
||||
asyncio.to_thread(
|
||||
chat_model.chat,
|
||||
"Nothing special.",
|
||||
[{"role": "user", "content": "Are you strong enough!?"}],
|
||||
{}
|
||||
),
|
||||
timeout=30
|
||||
)
|
||||
if "**ERROR**" in res:
|
||||
raise Exception(res)
|
||||
|
||||
# Pressure test for GraphRAG task
|
||||
async with trio.open_nursery() as nursery:
|
||||
for _ in range(count):
|
||||
nursery.start_soon(_is_strong_enough)
|
||||
tasks = [
|
||||
asyncio.create_task(_is_strong_enough())
|
||||
for _ in range(count)
|
||||
]
|
||||
try:
|
||||
await asyncio.gather(*tasks, return_exceptions=False)
|
||||
except Exception as e:
|
||||
logging.error(f"Pressure test failed: {e}")
|
||||
for t in tasks:
|
||||
t.cancel()
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
raise
|
||||
|
||||
|
||||
def get_allowed_llm_factories() -> list:
|
||||
|
||||
Reference in New Issue
Block a user