Refa: PARALLEL_DEVICES is a static parameter. (#6168)

### What problem does this PR solve?


### Type of change

- [x] Refactoring
This commit is contained in:
Kevin Hu
2025-03-17 16:49:54 +08:00
committed by GitHub
parent 45fe02c8b3
commit 3a99c2b5f4
6 changed files with 29 additions and 28 deletions

View File

@ -37,13 +37,15 @@ from rag.nlp import rag_tokenizer
from copy import deepcopy
from huggingface_hub import snapshot_download
from rag.settings import PARALLEL_DEVICES
LOCK_KEY_pdfplumber = "global_shared_lock_pdfplumber"
if LOCK_KEY_pdfplumber not in sys.modules:
sys.modules[LOCK_KEY_pdfplumber] = threading.Lock()
class RAGFlowPdfParser:
def __init__(self, parallel_devices: int | None = None):
def __init__(self):
"""
If you have trouble downloading HuggingFace models, -_^ this might help!!
@ -56,11 +58,10 @@ class RAGFlowPdfParser:
"""
self.ocr = OCR(parallel_devices = parallel_devices)
self.parallel_devices = parallel_devices
self.ocr = OCR()
self.parallel_limiter = None
if parallel_devices is not None and parallel_devices > 1:
self.parallel_limiter = [trio.CapacityLimiter(1) for _ in range(parallel_devices)]
if PARALLEL_DEVICES is not None and PARALLEL_DEVICES > 1:
self.parallel_limiter = [trio.CapacityLimiter(1) for _ in range(PARALLEL_DEVICES)]
if hasattr(self, "model_speciess"):
self.layouter = LayoutRecognizer("layout." + self.model_speciess)
@ -1018,7 +1019,6 @@ class RAGFlowPdfParser:
self.pdf.close()
if not self.outlines:
logging.warning("Miss outlines")
logging.debug("Images converted.")
self.is_english = [re.search(r"[a-zA-Z0-9,/¸;:'\[\]\(\)!@#$%^&*\"?<>._-]{30,}", "".join(
@ -1066,8 +1066,8 @@ class RAGFlowPdfParser:
for i, img in enumerate(self.page_images):
chars = __ocr_preprocess()
nursery.start_soon(__img_ocr, i, i % self.parallel_devices, img, chars,
self.parallel_limiter[i % self.parallel_devices])
nursery.start_soon(__img_ocr, i, i % PARALLEL_DEVICES, img, chars,
self.parallel_limiter[i % PARALLEL_DEVICES])
await trio.sleep(0.1)
else:
for i, img in enumerate(self.page_images):

View File

@ -22,6 +22,7 @@ import os
from huggingface_hub import snapshot_download
from api.utils.file_utils import get_project_base_directory
from rag.settings import PARALLEL_DEVICES
from .operators import * # noqa: F403
from . import operators
import math
@ -509,7 +510,7 @@ class TextDetector:
class OCR:
def __init__(self, model_dir=None, parallel_devices: int | None = None):
def __init__(self, model_dir=None):
"""
If you have trouble downloading HuggingFace models, -_^ this might help!!
@ -528,10 +529,10 @@ class OCR:
"rag/res/deepdoc")
# Append muti-gpus task to the list
if parallel_devices is not None and parallel_devices > 0:
if PARALLEL_DEVICES is not None and PARALLEL_DEVICES > 0:
self.text_detector = []
self.text_recognizer = []
for device_id in range(parallel_devices):
for device_id in range(PARALLEL_DEVICES):
self.text_detector.append(TextDetector(model_dir, device_id))
self.text_recognizer.append(TextRecognizer(model_dir, device_id))
else:
@ -543,11 +544,11 @@ class OCR:
local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"),
local_dir_use_symlinks=False)
if parallel_devices is not None:
assert parallel_devices > 0 , "Number of devices must be >= 1"
if PARALLEL_DEVICES is not None:
assert PARALLEL_DEVICES > 0, "Number of devices must be >= 1"
self.text_detector = []
self.text_recognizer = []
for device_id in range(parallel_devices):
for device_id in range(PARALLEL_DEVICES):
self.text_detector.append(TextDetector(model_dir, device_id))
self.text_recognizer.append(TextRecognizer(model_dir, device_id))
else:

View File

@ -34,15 +34,15 @@ import trio
os.environ['CUDA_VISIBLE_DEVICES'] = '0' #1 gpu
# os.environ['CUDA_VISIBLE_DEVICES'] = '' #cpu
def main(args):
import torch.cuda
cuda_devices = torch.cuda.device_count()
limiter = [trio.CapacityLimiter(1) for _ in range(cuda_devices)] if cuda_devices > 1 else None
ocr = OCR(parallel_devices = cuda_devices)
ocr = OCR()
images, outputs = init_in_out(args)
def __ocr(i, id, img):
print("Task {} start".format(i))
bxs = ocr(np.array(img), id)