Refa:replace trio with asyncio (#11831)

### What problem does this PR solve?

change:
replace trio with asyncio

### Type of change
- [x] Refactoring
This commit is contained in:
buua436
2025-12-09 19:23:14 +08:00
committed by GitHub
parent ca2d6f3301
commit 65a5a56d95
31 changed files with 821 additions and 429 deletions

View File

@ -13,12 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import logging
import os
import time
from functools import partial
from typing import Any
import trio
from agent.component.base import ComponentBase, ComponentParamBase
from common.connection_utils import timeout
@ -43,9 +43,11 @@ class ProcessBase(ComponentBase):
for k, v in kwargs.items():
self.set_output(k, v)
try:
with trio.fail_after(self._param.timeout):
await self._invoke(**kwargs)
self.callback(1, "Done")
await asyncio.wait_for(
self._invoke(**kwargs),
timeout=self._param.timeout
)
self.callback(1, "Done")
except Exception as e:
if self.get_exception_default_value():
self.set_exception_default_value()

View File

@ -13,13 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
import random
import re
from copy import deepcopy
from functools import partial
import trio
from common.misc_utils import get_uuid
from rag.utils.base64_image import id2image, image2id
from deepdoc.parser.pdf_parser import RAGFlowPdfParser
@ -178,9 +178,18 @@ class HierarchicalMerger(ProcessBase):
}
for c, img in zip(cks, images)
]
async with trio.open_nursery() as nursery:
for d in cks:
nursery.start_soon(image2id, d, partial(settings.STORAGE_IMPL.put, tenant_id=self._canvas._tenant_id), get_uuid())
tasks = []
for d in cks:
tasks.append(asyncio.create_task(image2id(d, partial(settings.STORAGE_IMPL.put, tenant_id=self._canvas._tenant_id), get_uuid())))
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error(f"Error in image2id: {e}")
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
self.set_output("chunks", cks)
self.callback(1, "Done.")

View File

@ -20,8 +20,8 @@ import random
import re
from functools import partial
from litellm import logging
import numpy as np
import trio
from PIL import Image
from api.db.services.file2document_service import File2DocumentService
@ -834,7 +834,7 @@ class Parser(ProcessBase):
for p_type, conf in self._param.setups.items():
if from_upstream.name.split(".")[-1].lower() not in conf.get("suffix", []):
continue
await trio.to_thread.run_sync(function_map[p_type], name, blob)
await asyncio.to_thread(function_map[p_type], name, blob)
done = True
break
@ -842,6 +842,15 @@ class Parser(ProcessBase):
raise Exception("No suitable for file extension: `.%s`" % from_upstream.name.split(".")[-1].lower())
outs = self.output()
async with trio.open_nursery() as nursery:
for d in outs.get("json", []):
nursery.start_soon(image2id, d, partial(settings.STORAGE_IMPL.put, tenant_id=self._canvas._tenant_id), get_uuid())
tasks = []
for d in outs.get("json", []):
tasks.append(asyncio.create_task(image2id(d,partial(settings.STORAGE_IMPL.put, tenant_id=self._canvas._tenant_id),get_uuid())))
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error("Error while parsing: %s" % e)
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise

View File

@ -13,12 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import datetime
import json
import logging
import random
from timeit import default_timer as timer
import trio
from agent.canvas import Graph
from api.db.services.document_service import DocumentService
from api.db.services.task_service import has_canceled, TaskService, CANVAS_DEBUG_DOC_ID
@ -152,8 +152,9 @@ class Pipeline(Graph):
#else:
# cpn_obj.invoke(**last_cpn.output())
async with trio.open_nursery() as nursery:
nursery.start_soon(invoke)
tasks = []
tasks.append(asyncio.create_task(invoke()))
await asyncio.gather(*tasks)
if cpn_obj.error():
self.error = "[ERROR]" + cpn_obj.error()

View File

@ -12,11 +12,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
import random
import re
from copy import deepcopy
from functools import partial
import trio
from common.misc_utils import get_uuid
from rag.utils.base64_image import id2image, image2id
from deepdoc.parser.pdf_parser import RAGFlowPdfParser
@ -129,9 +130,17 @@ class Splitter(ProcessBase):
}
for c, img in zip(chunks, images) if c.strip()
]
async with trio.open_nursery() as nursery:
for d in cks:
nursery.start_soon(image2id, d, partial(settings.STORAGE_IMPL.put, tenant_id=self._canvas._tenant_id), get_uuid())
tasks = []
for d in cks:
tasks.append(asyncio.create_task(image2id(d, partial(settings.STORAGE_IMPL.put, tenant_id=self._canvas._tenant_id), get_uuid())))
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error(f"error when splitting: {e}")
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
if custom_pattern:
docs = []

View File

@ -14,13 +14,12 @@
# limitations under the License.
#
import argparse
import asyncio
import json
import os
import time
from concurrent.futures import ThreadPoolExecutor
import trio
from common import settings
from rag.flow.pipeline import Pipeline
@ -57,5 +56,5 @@ if __name__ == "__main__":
# queue_dataflow(dsl=open(args.dsl, "r").read(), tenant_id=args.tenant_id, doc_id=args.doc_id, task_id="xxxx", flow_id="xxx", priority=0)
trio.run(pipeline.run)
asyncio.run(pipeline.run())
thr.result()

View File

@ -12,12 +12,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
import random
import re
import numpy as np
import trio
from common.constants import LLMType
from api.db.services.knowledgebase_service import KnowledgebaseService
@ -84,7 +84,7 @@ class Tokenizer(ProcessBase):
cnts_ = np.array([])
for i in range(0, len(texts), settings.EMBEDDING_BATCH_SIZE):
async with embed_limiter:
vts, c = await trio.to_thread.run_sync(lambda: batch_encode(texts[i : i + settings.EMBEDDING_BATCH_SIZE]))
vts, c = await asyncio.to_thread(batch_encode,texts[i : i + settings.EMBEDDING_BATCH_SIZE],)
if len(cnts_) == 0:
cnts_ = vts
else: