mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-23 15:06:50 +08:00
Refa:replace trio with asyncio (#11831)
### What problem does this PR solve? change: replace trio with asyncio ### Type of change - [x] Refactoring
This commit is contained in:
@ -13,12 +13,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
import trio
|
||||
from agent.component.base import ComponentBase, ComponentParamBase
|
||||
from common.connection_utils import timeout
|
||||
|
||||
@ -43,9 +43,11 @@ class ProcessBase(ComponentBase):
|
||||
for k, v in kwargs.items():
|
||||
self.set_output(k, v)
|
||||
try:
|
||||
with trio.fail_after(self._param.timeout):
|
||||
await self._invoke(**kwargs)
|
||||
self.callback(1, "Done")
|
||||
await asyncio.wait_for(
|
||||
self._invoke(**kwargs),
|
||||
timeout=self._param.timeout
|
||||
)
|
||||
self.callback(1, "Done")
|
||||
except Exception as e:
|
||||
if self.get_exception_default_value():
|
||||
self.set_exception_default_value()
|
||||
|
||||
@ -13,13 +13,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import random
|
||||
import re
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
|
||||
import trio
|
||||
|
||||
from common.misc_utils import get_uuid
|
||||
from rag.utils.base64_image import id2image, image2id
|
||||
from deepdoc.parser.pdf_parser import RAGFlowPdfParser
|
||||
@ -178,9 +178,18 @@ class HierarchicalMerger(ProcessBase):
|
||||
}
|
||||
for c, img in zip(cks, images)
|
||||
]
|
||||
async with trio.open_nursery() as nursery:
|
||||
for d in cks:
|
||||
nursery.start_soon(image2id, d, partial(settings.STORAGE_IMPL.put, tenant_id=self._canvas._tenant_id), get_uuid())
|
||||
tasks = []
|
||||
for d in cks:
|
||||
tasks.append(asyncio.create_task(image2id(d, partial(settings.STORAGE_IMPL.put, tenant_id=self._canvas._tenant_id), get_uuid())))
|
||||
try:
|
||||
await asyncio.gather(*tasks, return_exceptions=False)
|
||||
except Exception as e:
|
||||
logging.error(f"Error in image2id: {e}")
|
||||
for t in tasks:
|
||||
t.cancel()
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
raise
|
||||
|
||||
self.set_output("chunks", cks)
|
||||
|
||||
self.callback(1, "Done.")
|
||||
|
||||
@ -20,8 +20,8 @@ import random
|
||||
import re
|
||||
from functools import partial
|
||||
|
||||
from litellm import logging
|
||||
import numpy as np
|
||||
import trio
|
||||
from PIL import Image
|
||||
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
@ -834,7 +834,7 @@ class Parser(ProcessBase):
|
||||
for p_type, conf in self._param.setups.items():
|
||||
if from_upstream.name.split(".")[-1].lower() not in conf.get("suffix", []):
|
||||
continue
|
||||
await trio.to_thread.run_sync(function_map[p_type], name, blob)
|
||||
await asyncio.to_thread(function_map[p_type], name, blob)
|
||||
done = True
|
||||
break
|
||||
|
||||
@ -842,6 +842,15 @@ class Parser(ProcessBase):
|
||||
raise Exception("No suitable for file extension: `.%s`" % from_upstream.name.split(".")[-1].lower())
|
||||
|
||||
outs = self.output()
|
||||
async with trio.open_nursery() as nursery:
|
||||
for d in outs.get("json", []):
|
||||
nursery.start_soon(image2id, d, partial(settings.STORAGE_IMPL.put, tenant_id=self._canvas._tenant_id), get_uuid())
|
||||
tasks = []
|
||||
for d in outs.get("json", []):
|
||||
tasks.append(asyncio.create_task(image2id(d,partial(settings.STORAGE_IMPL.put, tenant_id=self._canvas._tenant_id),get_uuid())))
|
||||
|
||||
try:
|
||||
await asyncio.gather(*tasks, return_exceptions=False)
|
||||
except Exception as e:
|
||||
logging.error("Error while parsing: %s" % e)
|
||||
for t in tasks:
|
||||
t.cancel()
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
raise
|
||||
@ -13,12 +13,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
from timeit import default_timer as timer
|
||||
import trio
|
||||
from agent.canvas import Graph
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.task_service import has_canceled, TaskService, CANVAS_DEBUG_DOC_ID
|
||||
@ -152,8 +152,9 @@ class Pipeline(Graph):
|
||||
#else:
|
||||
# cpn_obj.invoke(**last_cpn.output())
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(invoke)
|
||||
tasks = []
|
||||
tasks.append(asyncio.create_task(invoke()))
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
if cpn_obj.error():
|
||||
self.error = "[ERROR]" + cpn_obj.error()
|
||||
|
||||
@ -12,11 +12,12 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import asyncio
|
||||
import logging
|
||||
import random
|
||||
import re
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
import trio
|
||||
from common.misc_utils import get_uuid
|
||||
from rag.utils.base64_image import id2image, image2id
|
||||
from deepdoc.parser.pdf_parser import RAGFlowPdfParser
|
||||
@ -129,9 +130,17 @@ class Splitter(ProcessBase):
|
||||
}
|
||||
for c, img in zip(chunks, images) if c.strip()
|
||||
]
|
||||
async with trio.open_nursery() as nursery:
|
||||
for d in cks:
|
||||
nursery.start_soon(image2id, d, partial(settings.STORAGE_IMPL.put, tenant_id=self._canvas._tenant_id), get_uuid())
|
||||
tasks = []
|
||||
for d in cks:
|
||||
tasks.append(asyncio.create_task(image2id(d, partial(settings.STORAGE_IMPL.put, tenant_id=self._canvas._tenant_id), get_uuid())))
|
||||
try:
|
||||
await asyncio.gather(*tasks, return_exceptions=False)
|
||||
except Exception as e:
|
||||
logging.error(f"error when splitting: {e}")
|
||||
for t in tasks:
|
||||
t.cancel()
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
raise
|
||||
|
||||
if custom_pattern:
|
||||
docs = []
|
||||
|
||||
@ -14,13 +14,12 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import trio
|
||||
|
||||
from common import settings
|
||||
from rag.flow.pipeline import Pipeline
|
||||
|
||||
@ -57,5 +56,5 @@ if __name__ == "__main__":
|
||||
|
||||
# queue_dataflow(dsl=open(args.dsl, "r").read(), tenant_id=args.tenant_id, doc_id=args.doc_id, task_id="xxxx", flow_id="xxx", priority=0)
|
||||
|
||||
trio.run(pipeline.run)
|
||||
asyncio.run(pipeline.run())
|
||||
thr.result()
|
||||
|
||||
@ -12,12 +12,12 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import asyncio
|
||||
import logging
|
||||
import random
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
import trio
|
||||
|
||||
from common.constants import LLMType
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
@ -84,7 +84,7 @@ class Tokenizer(ProcessBase):
|
||||
cnts_ = np.array([])
|
||||
for i in range(0, len(texts), settings.EMBEDDING_BATCH_SIZE):
|
||||
async with embed_limiter:
|
||||
vts, c = await trio.to_thread.run_sync(lambda: batch_encode(texts[i : i + settings.EMBEDDING_BATCH_SIZE]))
|
||||
vts, c = await asyncio.to_thread(batch_encode,texts[i : i + settings.EMBEDDING_BATCH_SIZE],)
|
||||
if len(cnts_) == 0:
|
||||
cnts_ = vts
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user