mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-01-03 19:15:30 +08:00
Refa: PARALLEL_DEVICES is a static parameter. (#6168)
### What problem does this PR solve? ### Type of change - [x] Refactoring
This commit is contained in:
@ -22,6 +22,7 @@ import os
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from api.utils.file_utils import get_project_base_directory
|
||||
from rag.settings import PARALLEL_DEVICES
|
||||
from .operators import * # noqa: F403
|
||||
from . import operators
|
||||
import math
|
||||
@ -509,7 +510,7 @@ class TextDetector:
|
||||
|
||||
|
||||
class OCR:
|
||||
def __init__(self, model_dir=None, parallel_devices: int | None = None):
|
||||
def __init__(self, model_dir=None):
|
||||
"""
|
||||
If you have trouble downloading HuggingFace models, -_^ this might help!!
|
||||
|
||||
@ -528,10 +529,10 @@ class OCR:
|
||||
"rag/res/deepdoc")
|
||||
|
||||
# Append muti-gpus task to the list
|
||||
if parallel_devices is not None and parallel_devices > 0:
|
||||
if PARALLEL_DEVICES is not None and PARALLEL_DEVICES > 0:
|
||||
self.text_detector = []
|
||||
self.text_recognizer = []
|
||||
for device_id in range(parallel_devices):
|
||||
for device_id in range(PARALLEL_DEVICES):
|
||||
self.text_detector.append(TextDetector(model_dir, device_id))
|
||||
self.text_recognizer.append(TextRecognizer(model_dir, device_id))
|
||||
else:
|
||||
@ -543,11 +544,11 @@ class OCR:
|
||||
local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"),
|
||||
local_dir_use_symlinks=False)
|
||||
|
||||
if parallel_devices is not None:
|
||||
assert parallel_devices > 0 , "Number of devices must be >= 1"
|
||||
if PARALLEL_DEVICES is not None:
|
||||
assert PARALLEL_DEVICES > 0, "Number of devices must be >= 1"
|
||||
self.text_detector = []
|
||||
self.text_recognizer = []
|
||||
for device_id in range(parallel_devices):
|
||||
for device_id in range(PARALLEL_DEVICES):
|
||||
self.text_detector.append(TextDetector(model_dir, device_id))
|
||||
self.text_recognizer.append(TextRecognizer(model_dir, device_id))
|
||||
else:
|
||||
|
||||
@ -34,15 +34,15 @@ import trio
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = '0' #1 gpu
|
||||
# os.environ['CUDA_VISIBLE_DEVICES'] = '' #cpu
|
||||
|
||||
|
||||
def main(args):
|
||||
import torch.cuda
|
||||
|
||||
cuda_devices = torch.cuda.device_count()
|
||||
limiter = [trio.CapacityLimiter(1) for _ in range(cuda_devices)] if cuda_devices > 1 else None
|
||||
ocr = OCR(parallel_devices = cuda_devices)
|
||||
ocr = OCR()
|
||||
images, outputs = init_in_out(args)
|
||||
|
||||
|
||||
def __ocr(i, id, img):
|
||||
print("Task {} start".format(i))
|
||||
bxs = ocr(np.array(img), id)
|
||||
|
||||
Reference in New Issue
Block a user