mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Move some functions out of 'api/utils/common.py' (#10948)
### What problem does this PR solve? as title. ### Type of change - [x] Refactoring Signed-off-by: Jin Hai <haijin.chn@gmail.com>
This commit is contained in:
@ -14,12 +14,6 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import threading
|
||||
import subprocess
|
||||
import sys
|
||||
import os
|
||||
import logging
|
||||
|
||||
def string_to_bytes(string):
|
||||
return string if isinstance(
|
||||
string, bytes) else string.encode(encoding="utf-8")
|
||||
@ -28,70 +22,3 @@ def string_to_bytes(string):
|
||||
def bytes_to_string(byte):
|
||||
return byte.decode(encoding="utf-8")
|
||||
|
||||
|
||||
def convert_bytes(size_in_bytes: int) -> str:
|
||||
"""
|
||||
Format size in bytes.
|
||||
"""
|
||||
if size_in_bytes == 0:
|
||||
return "0 B"
|
||||
|
||||
units = ['B', 'KB', 'MB', 'GB', 'TB', 'PB']
|
||||
i = 0
|
||||
size = float(size_in_bytes)
|
||||
|
||||
while size >= 1024 and i < len(units) - 1:
|
||||
size /= 1024
|
||||
i += 1
|
||||
|
||||
if i == 0 or size >= 100:
|
||||
return f"{size:.0f} {units[i]}"
|
||||
elif size >= 10:
|
||||
return f"{size:.1f} {units[i]}"
|
||||
else:
|
||||
return f"{size:.2f} {units[i]}"
|
||||
|
||||
|
||||
def once(func):
|
||||
"""
|
||||
A thread-safe decorator that ensures the decorated function runs exactly once,
|
||||
caching and returning its result for all subsequent calls. This prevents
|
||||
race conditions in multi-threaded environments by using a lock to protect
|
||||
the execution state.
|
||||
|
||||
Args:
|
||||
func (callable): The function to be executed only once.
|
||||
|
||||
Returns:
|
||||
callable: A wrapper function that executes `func` on the first call
|
||||
and returns the cached result thereafter.
|
||||
|
||||
Example:
|
||||
@once
|
||||
def compute_expensive_value():
|
||||
print("Computing...")
|
||||
return 42
|
||||
|
||||
# First call: executes and prints
|
||||
# Subsequent calls: return 42 without executing
|
||||
"""
|
||||
executed = False
|
||||
result = None
|
||||
lock = threading.Lock()
|
||||
def wrapper(*args, **kwargs):
|
||||
nonlocal executed, result
|
||||
with lock:
|
||||
if not executed:
|
||||
executed = True
|
||||
result = func(*args, **kwargs)
|
||||
return result
|
||||
return wrapper
|
||||
|
||||
@once
|
||||
def pip_install_torch():
|
||||
device = os.getenv("DEVICE", "cpu")
|
||||
if device=="cpu":
|
||||
return
|
||||
logging.info("Installing pytorch")
|
||||
pkg_names = ["torch>=2.5.0,<3.0.0"]
|
||||
subprocess.check_call([sys.executable, "-m", "pip", "install", *pkg_names])
|
||||
|
||||
@ -18,6 +18,11 @@ import base64
|
||||
import hashlib
|
||||
import uuid
|
||||
import requests
|
||||
import threading
|
||||
import subprocess
|
||||
import sys
|
||||
import os
|
||||
import logging
|
||||
|
||||
def get_uuid():
|
||||
return uuid.uuid1().hex
|
||||
@ -34,3 +39,70 @@ def download_img(url):
|
||||
|
||||
def hash_str2int(line: str, mod: int = 10 ** 8) -> int:
|
||||
return int(hashlib.sha1(line.encode("utf-8")).hexdigest(), 16) % mod
|
||||
|
||||
def convert_bytes(size_in_bytes: int) -> str:
|
||||
"""
|
||||
Format size in bytes.
|
||||
"""
|
||||
if size_in_bytes == 0:
|
||||
return "0 B"
|
||||
|
||||
units = ['B', 'KB', 'MB', 'GB', 'TB', 'PB']
|
||||
i = 0
|
||||
size = float(size_in_bytes)
|
||||
|
||||
while size >= 1024 and i < len(units) - 1:
|
||||
size /= 1024
|
||||
i += 1
|
||||
|
||||
if i == 0 or size >= 100:
|
||||
return f"{size:.0f} {units[i]}"
|
||||
elif size >= 10:
|
||||
return f"{size:.1f} {units[i]}"
|
||||
else:
|
||||
return f"{size:.2f} {units[i]}"
|
||||
|
||||
|
||||
def once(func):
|
||||
"""
|
||||
A thread-safe decorator that ensures the decorated function runs exactly once,
|
||||
caching and returning its result for all subsequent calls. This prevents
|
||||
race conditions in multi-thread environments by using a lock to protect
|
||||
the execution state.
|
||||
|
||||
Args:
|
||||
func (callable): The function to be executed only once.
|
||||
|
||||
Returns:
|
||||
callable: A wrapper function that executes `func` on the first call
|
||||
and returns the cached result thereafter.
|
||||
|
||||
Example:
|
||||
@once
|
||||
def compute_expensive_value():
|
||||
print("Computing...")
|
||||
return 42
|
||||
|
||||
# First call: executes and prints
|
||||
# Subsequent calls: return 42 without executing
|
||||
"""
|
||||
executed = False
|
||||
result = None
|
||||
lock = threading.Lock()
|
||||
def wrapper(*args, **kwargs):
|
||||
nonlocal executed, result
|
||||
with lock:
|
||||
if not executed:
|
||||
executed = True
|
||||
result = func(*args, **kwargs)
|
||||
return result
|
||||
return wrapper
|
||||
|
||||
@once
|
||||
def pip_install_torch():
|
||||
device = os.getenv("DEVICE", "cpu")
|
||||
if device=="cpu":
|
||||
return
|
||||
logging.info("Installing pytorch")
|
||||
pkg_names = ["torch>=2.5.0,<3.0.0"]
|
||||
subprocess.check_call([sys.executable, "-m", "pip", "install", *pkg_names])
|
||||
|
||||
@ -35,7 +35,7 @@ from PIL import Image
|
||||
from pypdf import PdfReader as pdf2_read
|
||||
|
||||
from common.file_utils import get_project_base_directory
|
||||
from api.utils.common import pip_install_torch
|
||||
from common.misc_utils import pip_install_torch
|
||||
from deepdoc.vision import OCR, AscendLayoutRecognizer, LayoutRecognizer, Recognizer, TableStructureRecognizer
|
||||
from rag.app.picture import vision_llm_chunk as picture_vision_llm_chunk
|
||||
from rag.nlp import rag_tokenizer
|
||||
|
||||
@ -22,7 +22,7 @@ import os
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from common.file_utils import get_project_base_directory
|
||||
from api.utils.common import pip_install_torch
|
||||
from common.misc_utils import pip_install_torch
|
||||
from rag.settings import PARALLEL_DEVICES
|
||||
from .operators import * # noqa: F403
|
||||
from . import operators
|
||||
|
||||
@ -17,7 +17,7 @@ import os
|
||||
import logging
|
||||
from api.utils.configs import get_base_config, decrypt_database_config
|
||||
from common.file_utils import get_project_base_directory
|
||||
from api.utils.common import pip_install_torch
|
||||
from common.misc_utils import pip_install_torch
|
||||
|
||||
# Server
|
||||
RAG_CONF_PATH = os.path.join(get_project_base_directory(), "conf")
|
||||
|
||||
@ -28,7 +28,7 @@ from rag import settings
|
||||
from rag.settings import TAG_FLD, PAGERANK_FLD
|
||||
from common.decorator import singleton
|
||||
from common.file_utils import get_project_base_directory
|
||||
from api.utils.common import convert_bytes
|
||||
from common.misc_utils import convert_bytes
|
||||
from rag.utils.doc_store_conn import DocStoreConnection, MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, \
|
||||
FusionExpr
|
||||
from rag.nlp import is_english, rag_tokenizer
|
||||
|
||||
@ -15,7 +15,7 @@
|
||||
#
|
||||
import uuid
|
||||
import hashlib
|
||||
from common.misc_utils import get_uuid, download_img, hash_str2int
|
||||
from common.misc_utils import get_uuid, download_img, hash_str2int, convert_bytes
|
||||
|
||||
|
||||
class TestGetUuid:
|
||||
@ -270,3 +270,86 @@ class TestHashStr2Int:
|
||||
result = hash_str2int(test_str)
|
||||
assert isinstance(result, int)
|
||||
assert 0 <= result < 10 ** 8
|
||||
|
||||
|
||||
class TestConvertBytes:
|
||||
"""Test suite for convert_bytes function"""
|
||||
|
||||
def test_zero_bytes(self):
|
||||
"""Test that 0 bytes returns '0 B'"""
|
||||
assert convert_bytes(0) == "0 B"
|
||||
|
||||
def test_single_byte(self):
|
||||
"""Test single byte values"""
|
||||
assert convert_bytes(1) == "1 B"
|
||||
assert convert_bytes(999) == "999 B"
|
||||
|
||||
def test_kilobyte_range(self):
|
||||
"""Test values in kilobyte range with different precisions"""
|
||||
# Exactly 1 KB
|
||||
assert convert_bytes(1024) == "1.00 KB"
|
||||
|
||||
# Values that should show 1 decimal place (10-99.9 range)
|
||||
assert convert_bytes(15360) == "15.0 KB" # 15 KB exactly
|
||||
assert convert_bytes(10752) == "10.5 KB" # 10.5 KB
|
||||
|
||||
# Values that should show 2 decimal places (1-9.99 range)
|
||||
assert convert_bytes(2048) == "2.00 KB" # 2 KB exactly
|
||||
assert convert_bytes(3072) == "3.00 KB" # 3 KB exactly
|
||||
assert convert_bytes(5120) == "5.00 KB" # 5 KB exactly
|
||||
|
||||
def test_megabyte_range(self):
|
||||
"""Test values in megabyte range"""
|
||||
# Exactly 1 MB
|
||||
assert convert_bytes(1048576) == "1.00 MB"
|
||||
|
||||
# Values with different precision requirements
|
||||
assert convert_bytes(15728640) == "15.0 MB" # 15.0 MB
|
||||
assert convert_bytes(11010048) == "10.5 MB" # 10.5 MB
|
||||
|
||||
def test_gigabyte_range(self):
|
||||
"""Test values in gigabyte range"""
|
||||
# Exactly 1 GB
|
||||
assert convert_bytes(1073741824) == "1.00 GB"
|
||||
|
||||
# Large value that should show 0 decimal places
|
||||
assert convert_bytes(3221225472) == "3.00 GB" # 3 GB exactly
|
||||
|
||||
def test_terabyte_range(self):
|
||||
"""Test values in terabyte range"""
|
||||
assert convert_bytes(1099511627776) == "1.00 TB" # 1 TB
|
||||
|
||||
def test_petabyte_range(self):
|
||||
"""Test values in petabyte range"""
|
||||
assert convert_bytes(1125899906842624) == "1.00 PB" # 1 PB
|
||||
|
||||
def test_boundary_values(self):
|
||||
"""Test values at unit boundaries"""
|
||||
# Just below 1 KB
|
||||
assert convert_bytes(1023) == "1023 B"
|
||||
|
||||
# Just above 1 KB
|
||||
assert convert_bytes(1025) == "1.00 KB"
|
||||
|
||||
# At 100 KB boundary (should switch to 0 decimal places)
|
||||
assert convert_bytes(102400) == "100 KB"
|
||||
assert convert_bytes(102300) == "99.9 KB"
|
||||
|
||||
def test_precision_transitions(self):
|
||||
"""Test the precision formatting transitions"""
|
||||
# Test transition from 2 decimal places to 1 decimal place
|
||||
assert convert_bytes(9216) == "9.00 KB" # 9.00 KB (2 decimal places)
|
||||
assert convert_bytes(10240) == "10.0 KB" # 10.0 KB (1 decimal place)
|
||||
|
||||
# Test transition from 1 decimal place to 0 decimal places
|
||||
assert convert_bytes(102400) == "100 KB" # 100 KB (0 decimal places)
|
||||
|
||||
def test_large_values_no_overflow(self):
|
||||
"""Test that very large values don't cause issues"""
|
||||
# Very large value that should use PB
|
||||
large_value = 10 * 1125899906842624 # 10 PB
|
||||
assert "PB" in convert_bytes(large_value)
|
||||
|
||||
# Ensure we don't exceed available units
|
||||
huge_value = 100 * 1125899906842624 # 100 PB (still within PB range)
|
||||
assert "PB" in convert_bytes(huge_value)
|
||||
|
||||
Reference in New Issue
Block a user