mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-19 20:16:49 +08:00
Refa:replace trio with asyncio (#11831)
### What problem does this PR solve? change: replace trio with asyncio ### Type of change - [x] Refactoring
This commit is contained in:
@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
@ -28,7 +29,6 @@ from timeit import default_timer as timer
|
||||
|
||||
import numpy as np
|
||||
import pdfplumber
|
||||
import trio
|
||||
import xgboost as xgb
|
||||
from huggingface_hub import snapshot_download
|
||||
from PIL import Image
|
||||
@ -65,7 +65,7 @@ class RAGFlowPdfParser:
|
||||
self.ocr = OCR()
|
||||
self.parallel_limiter = None
|
||||
if settings.PARALLEL_DEVICES > 1:
|
||||
self.parallel_limiter = [trio.CapacityLimiter(1) for _ in range(settings.PARALLEL_DEVICES)]
|
||||
self.parallel_limiter = [asyncio.Semaphore(1) for _ in range(settings.PARALLEL_DEVICES)]
|
||||
|
||||
layout_recognizer_type = os.getenv("LAYOUT_RECOGNIZER_TYPE", "onnx").lower()
|
||||
if layout_recognizer_type not in ["onnx", "ascend"]:
|
||||
@ -382,7 +382,7 @@ class RAGFlowPdfParser:
|
||||
else:
|
||||
x0s.append([x])
|
||||
x0s = np.array(x0s, dtype=float)
|
||||
|
||||
|
||||
max_try = min(4, len(bxs))
|
||||
if max_try < 2:
|
||||
max_try = 1
|
||||
@ -416,7 +416,7 @@ class RAGFlowPdfParser:
|
||||
for pg, bxs in by_page.items():
|
||||
if not bxs:
|
||||
continue
|
||||
k = page_cols[pg]
|
||||
k = page_cols[pg]
|
||||
if len(bxs) < k:
|
||||
k = 1
|
||||
x0s = np.array([[b["x0"]] for b in bxs], dtype=float)
|
||||
@ -430,7 +430,7 @@ class RAGFlowPdfParser:
|
||||
|
||||
for b, lb in zip(bxs, labels):
|
||||
b["col_id"] = remap[lb]
|
||||
|
||||
|
||||
grouped = defaultdict(list)
|
||||
for b in bxs:
|
||||
grouped[b["col_id"]].append(b)
|
||||
@ -1111,7 +1111,7 @@ class RAGFlowPdfParser:
|
||||
|
||||
if limiter:
|
||||
async with limiter:
|
||||
await trio.to_thread.run_sync(lambda: self.__ocr(i + 1, img, chars, zoomin, id))
|
||||
await asyncio.to_thread(self.__ocr, i + 1, img, chars, zoomin, id)
|
||||
else:
|
||||
self.__ocr(i + 1, img, chars, zoomin, id)
|
||||
|
||||
@ -1127,12 +1127,34 @@ class RAGFlowPdfParser:
|
||||
return chars
|
||||
|
||||
if self.parallel_limiter:
|
||||
async with trio.open_nursery() as nursery:
|
||||
for i, img in enumerate(self.page_images):
|
||||
chars = __ocr_preprocess()
|
||||
tasks = []
|
||||
|
||||
for i, img in enumerate(self.page_images):
|
||||
chars = __ocr_preprocess()
|
||||
|
||||
semaphore = self.parallel_limiter[i % settings.PARALLEL_DEVICES]
|
||||
|
||||
async def wrapper(i=i, img=img, chars=chars, semaphore=semaphore):
|
||||
await __img_ocr(
|
||||
i,
|
||||
i % settings.PARALLEL_DEVICES,
|
||||
img,
|
||||
chars,
|
||||
semaphore,
|
||||
)
|
||||
|
||||
tasks.append(asyncio.create_task(wrapper()))
|
||||
await asyncio.sleep(0)
|
||||
|
||||
try:
|
||||
await asyncio.gather(*tasks, return_exceptions=False)
|
||||
except Exception as e:
|
||||
logging.error(f"Error in OCR: {e}")
|
||||
for t in tasks:
|
||||
t.cancel()
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
raise
|
||||
|
||||
nursery.start_soon(__img_ocr, i, i % settings.PARALLEL_DEVICES, img, chars, self.parallel_limiter[i % settings.PARALLEL_DEVICES])
|
||||
await trio.sleep(0.1)
|
||||
else:
|
||||
for i, img in enumerate(self.page_images):
|
||||
chars = __ocr_preprocess()
|
||||
@ -1140,7 +1162,7 @@ class RAGFlowPdfParser:
|
||||
|
||||
start = timer()
|
||||
|
||||
trio.run(__img_ocr_launcher)
|
||||
asyncio.run(__img_ocr_launcher())
|
||||
|
||||
logging.info(f"__images__ {len(self.page_images)} pages cost {timer() - start}s")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user