mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Feat: add OCR's muti-gpus and parallel processing support (#5972)
### What problem does this PR solve? Add OCR's muti-gpus and parallel processing support ### Type of change - [x] New Feature (non-breaking change which adds functionality) @yuzhichang I've tried to resolve the comments in #5697. OCR jobs can now be done on both CPU and GPU. ( By the way, I've encountered a “Generate embedding error” issue #5954 that might be due to my outdated GPUs? idk. ) Please review it and give me suggestions. GPU:   CPU: 
This commit is contained in:
@ -20,6 +20,7 @@ import random
|
||||
from timeit import default_timer as timer
|
||||
import sys
|
||||
import threading
|
||||
import trio
|
||||
|
||||
import xgboost as xgb
|
||||
from io import BytesIO
|
||||
@ -41,7 +42,7 @@ if LOCK_KEY_pdfplumber not in sys.modules:
|
||||
sys.modules[LOCK_KEY_pdfplumber] = threading.Lock()
|
||||
|
||||
class RAGFlowPdfParser:
|
||||
def __init__(self):
|
||||
def __init__(self, parallel_devices: int | None = None):
|
||||
"""
|
||||
If you have trouble downloading HuggingFace models, -_^ this might help!!
|
||||
|
||||
@ -53,7 +54,13 @@ class RAGFlowPdfParser:
|
||||
^_-
|
||||
|
||||
"""
|
||||
self.ocr = OCR()
|
||||
|
||||
self.ocr = OCR(parallel_devices = parallel_devices)
|
||||
self.parallel_devices = parallel_devices
|
||||
self.parallel_limiter = None
|
||||
if parallel_devices is not None and parallel_devices > 1:
|
||||
self.parallel_limiter = [trio.CapacityLimiter(1) for _ in range(parallel_devices)]
|
||||
|
||||
if hasattr(self, "model_speciess"):
|
||||
self.layouter = LayoutRecognizer("layout." + self.model_speciess)
|
||||
else:
|
||||
@ -63,7 +70,7 @@ class RAGFlowPdfParser:
|
||||
self.updown_cnt_mdl = xgb.Booster()
|
||||
if not settings.LIGHTEN:
|
||||
try:
|
||||
import torch
|
||||
import torch.cuda
|
||||
if torch.cuda.is_available():
|
||||
self.updown_cnt_mdl.set_param({"device": "cuda"})
|
||||
except Exception:
|
||||
@ -283,9 +290,9 @@ class RAGFlowPdfParser:
|
||||
b["H_right"] = spans[ii]["x1"]
|
||||
b["SP"] = ii
|
||||
|
||||
def __ocr(self, pagenum, img, chars, ZM=3):
|
||||
def __ocr(self, pagenum, img, chars, ZM=3, device_id: int | None = None):
|
||||
start = timer()
|
||||
bxs = self.ocr.detect(np.array(img))
|
||||
bxs = self.ocr.detect(np.array(img), device_id)
|
||||
logging.info(f"__ocr detecting boxes of a image cost ({timer() - start}s)")
|
||||
|
||||
start = timer()
|
||||
@ -330,7 +337,7 @@ class RAGFlowPdfParser:
|
||||
b["box_image"] = self.ocr.get_rotate_crop_image(img_np, np.array([[left, top], [right, top], [right, bott], [left, bott]], dtype=np.float32))
|
||||
boxes_to_reg.append(b)
|
||||
del b["txt"]
|
||||
texts = self.ocr.recognize_batch([b["box_image"] for b in boxes_to_reg])
|
||||
texts = self.ocr.recognize_batch([b["box_image"] for b in boxes_to_reg], device_id)
|
||||
for i in range(len(boxes_to_reg)):
|
||||
boxes_to_reg[i]["text"] = texts[i]
|
||||
del boxes_to_reg[i]["box_image"]
|
||||
@ -1022,28 +1029,54 @@ class RAGFlowPdfParser:
|
||||
else:
|
||||
self.is_english = False
|
||||
|
||||
start = timer()
|
||||
for i, img in enumerate(self.page_images):
|
||||
chars = self.page_chars[i] if not self.is_english else []
|
||||
self.mean_height.append(
|
||||
np.median(sorted([c["height"] for c in chars])) if chars else 0
|
||||
)
|
||||
self.mean_width.append(
|
||||
np.median(sorted([c["width"] for c in chars])) if chars else 8
|
||||
)
|
||||
self.page_cum_height.append(img.size[1] / zoomin)
|
||||
async def __img_ocr(i, id, img, chars, limiter):
|
||||
j = 0
|
||||
while j + 1 < len(chars):
|
||||
if chars[j]["text"] and chars[j + 1]["text"] \
|
||||
and re.match(r"[0-9a-zA-Z,.:;!%]+", chars[j]["text"] + chars[j + 1]["text"]) \
|
||||
and chars[j + 1]["x0"] - chars[j]["x1"] >= min(chars[j + 1]["width"],
|
||||
chars[j]["width"]) / 2:
|
||||
chars[j]["width"]) / 2:
|
||||
chars[j]["text"] += " "
|
||||
j += 1
|
||||
|
||||
self.__ocr(i + 1, img, chars, zoomin)
|
||||
if limiter:
|
||||
async with limiter:
|
||||
await trio.to_thread.run_sync(lambda: self.__ocr(i + 1, img, chars, zoomin, id))
|
||||
else:
|
||||
self.__ocr(i + 1, img, chars, zoomin, id)
|
||||
|
||||
if callback and i % 6 == 5:
|
||||
callback(prog=(i + 1) * 0.6 / len(self.page_images), msg="")
|
||||
|
||||
async def __img_ocr_launcher():
|
||||
def __ocr_preprocess():
|
||||
chars = self.page_chars[i] if not self.is_english else []
|
||||
self.mean_height.append(
|
||||
np.median(sorted([c["height"] for c in chars])) if chars else 0
|
||||
)
|
||||
self.mean_width.append(
|
||||
np.median(sorted([c["width"] for c in chars])) if chars else 8
|
||||
)
|
||||
self.page_cum_height.append(img.size[1] / zoomin)
|
||||
return chars
|
||||
|
||||
if self.parallel_limiter:
|
||||
async with trio.open_nursery() as nursery:
|
||||
for i, img in enumerate(self.page_images):
|
||||
chars = __ocr_preprocess()
|
||||
|
||||
nursery.start_soon(__img_ocr, i, i % self.parallel_devices, img, chars,
|
||||
self.parallel_limiter[i % self.parallel_devices])
|
||||
await trio.sleep(0.1)
|
||||
else:
|
||||
for i, img in enumerate(self.page_images):
|
||||
chars = __ocr_preprocess()
|
||||
await __img_ocr(i, 0, img, chars, None)
|
||||
|
||||
start = timer()
|
||||
|
||||
trio.run(__img_ocr_launcher)
|
||||
|
||||
logging.info(f"__images__ {len(self.page_images)} pages cost {timer() - start}s")
|
||||
|
||||
if not self.is_english and not any(
|
||||
|
||||
Reference in New Issue
Block a user