diff --git a/api/utils/common.py b/api/utils/common.py index 142d7cfe3..958cf20ff 100644 --- a/api/utils/common.py +++ b/api/utils/common.py @@ -14,12 +14,6 @@ # limitations under the License. # -import threading -import subprocess -import sys -import os -import logging - def string_to_bytes(string): return string if isinstance( string, bytes) else string.encode(encoding="utf-8") @@ -28,70 +22,3 @@ def string_to_bytes(string): def bytes_to_string(byte): return byte.decode(encoding="utf-8") - -def convert_bytes(size_in_bytes: int) -> str: - """ - Format size in bytes. - """ - if size_in_bytes == 0: - return "0 B" - - units = ['B', 'KB', 'MB', 'GB', 'TB', 'PB'] - i = 0 - size = float(size_in_bytes) - - while size >= 1024 and i < len(units) - 1: - size /= 1024 - i += 1 - - if i == 0 or size >= 100: - return f"{size:.0f} {units[i]}" - elif size >= 10: - return f"{size:.1f} {units[i]}" - else: - return f"{size:.2f} {units[i]}" - - -def once(func): - """ - A thread-safe decorator that ensures the decorated function runs exactly once, - caching and returning its result for all subsequent calls. This prevents - race conditions in multi-threaded environments by using a lock to protect - the execution state. - - Args: - func (callable): The function to be executed only once. - - Returns: - callable: A wrapper function that executes `func` on the first call - and returns the cached result thereafter. - - Example: - @once - def compute_expensive_value(): - print("Computing...") - return 42 - - # First call: executes and prints - # Subsequent calls: return 42 without executing - """ - executed = False - result = None - lock = threading.Lock() - def wrapper(*args, **kwargs): - nonlocal executed, result - with lock: - if not executed: - executed = True - result = func(*args, **kwargs) - return result - return wrapper - -@once -def pip_install_torch(): - device = os.getenv("DEVICE", "cpu") - if device=="cpu": - return - logging.info("Installing pytorch") - pkg_names = ["torch>=2.5.0,<3.0.0"] - subprocess.check_call([sys.executable, "-m", "pip", "install", *pkg_names]) diff --git a/common/misc_utils.py b/common/misc_utils.py index 07594c145..ae56fe5c4 100644 --- a/common/misc_utils.py +++ b/common/misc_utils.py @@ -18,6 +18,11 @@ import base64 import hashlib import uuid import requests +import threading +import subprocess +import sys +import os +import logging def get_uuid(): return uuid.uuid1().hex @@ -33,4 +38,71 @@ def download_img(url): def hash_str2int(line: str, mod: int = 10 ** 8) -> int: - return int(hashlib.sha1(line.encode("utf-8")).hexdigest(), 16) % mod \ No newline at end of file + return int(hashlib.sha1(line.encode("utf-8")).hexdigest(), 16) % mod + +def convert_bytes(size_in_bytes: int) -> str: + """ + Format size in bytes. + """ + if size_in_bytes == 0: + return "0 B" + + units = ['B', 'KB', 'MB', 'GB', 'TB', 'PB'] + i = 0 + size = float(size_in_bytes) + + while size >= 1024 and i < len(units) - 1: + size /= 1024 + i += 1 + + if i == 0 or size >= 100: + return f"{size:.0f} {units[i]}" + elif size >= 10: + return f"{size:.1f} {units[i]}" + else: + return f"{size:.2f} {units[i]}" + + +def once(func): + """ + A thread-safe decorator that ensures the decorated function runs exactly once, + caching and returning its result for all subsequent calls. This prevents + race conditions in multi-thread environments by using a lock to protect + the execution state. + + Args: + func (callable): The function to be executed only once. + + Returns: + callable: A wrapper function that executes `func` on the first call + and returns the cached result thereafter. + + Example: + @once + def compute_expensive_value(): + print("Computing...") + return 42 + + # First call: executes and prints + # Subsequent calls: return 42 without executing + """ + executed = False + result = None + lock = threading.Lock() + def wrapper(*args, **kwargs): + nonlocal executed, result + with lock: + if not executed: + executed = True + result = func(*args, **kwargs) + return result + return wrapper + +@once +def pip_install_torch(): + device = os.getenv("DEVICE", "cpu") + if device=="cpu": + return + logging.info("Installing pytorch") + pkg_names = ["torch>=2.5.0,<3.0.0"] + subprocess.check_call([sys.executable, "-m", "pip", "install", *pkg_names]) diff --git a/deepdoc/parser/pdf_parser.py b/deepdoc/parser/pdf_parser.py index dba55bb13..0cf0c2949 100644 --- a/deepdoc/parser/pdf_parser.py +++ b/deepdoc/parser/pdf_parser.py @@ -35,7 +35,7 @@ from PIL import Image from pypdf import PdfReader as pdf2_read from common.file_utils import get_project_base_directory -from api.utils.common import pip_install_torch +from common.misc_utils import pip_install_torch from deepdoc.vision import OCR, AscendLayoutRecognizer, LayoutRecognizer, Recognizer, TableStructureRecognizer from rag.app.picture import vision_llm_chunk as picture_vision_llm_chunk from rag.nlp import rag_tokenizer diff --git a/deepdoc/vision/ocr.py b/deepdoc/vision/ocr.py index 4c4831d7c..a84b6c0a4 100644 --- a/deepdoc/vision/ocr.py +++ b/deepdoc/vision/ocr.py @@ -22,7 +22,7 @@ import os from huggingface_hub import snapshot_download from common.file_utils import get_project_base_directory -from api.utils.common import pip_install_torch +from common.misc_utils import pip_install_torch from rag.settings import PARALLEL_DEVICES from .operators import * # noqa: F403 from . import operators diff --git a/rag/settings.py b/rag/settings.py index 6167c08ae..1a8ee95b5 100644 --- a/rag/settings.py +++ b/rag/settings.py @@ -17,7 +17,7 @@ import os import logging from api.utils.configs import get_base_config, decrypt_database_config from common.file_utils import get_project_base_directory -from api.utils.common import pip_install_torch +from common.misc_utils import pip_install_torch # Server RAG_CONF_PATH = os.path.join(get_project_base_directory(), "conf") diff --git a/rag/utils/es_conn.py b/rag/utils/es_conn.py index e2ce8c9ed..6cb569406 100644 --- a/rag/utils/es_conn.py +++ b/rag/utils/es_conn.py @@ -28,7 +28,7 @@ from rag import settings from rag.settings import TAG_FLD, PAGERANK_FLD from common.decorator import singleton from common.file_utils import get_project_base_directory -from api.utils.common import convert_bytes +from common.misc_utils import convert_bytes from rag.utils.doc_store_conn import DocStoreConnection, MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, \ FusionExpr from rag.nlp import is_english, rag_tokenizer diff --git a/test/unit_test/common/test_misc_utils.py b/test/unit_test/common/test_misc_utils.py index 61d1d8dbd..b407c49b7 100644 --- a/test/unit_test/common/test_misc_utils.py +++ b/test/unit_test/common/test_misc_utils.py @@ -15,7 +15,7 @@ # import uuid import hashlib -from common.misc_utils import get_uuid, download_img, hash_str2int +from common.misc_utils import get_uuid, download_img, hash_str2int, convert_bytes class TestGetUuid: @@ -270,3 +270,86 @@ class TestHashStr2Int: result = hash_str2int(test_str) assert isinstance(result, int) assert 0 <= result < 10 ** 8 + + +class TestConvertBytes: + """Test suite for convert_bytes function""" + + def test_zero_bytes(self): + """Test that 0 bytes returns '0 B'""" + assert convert_bytes(0) == "0 B" + + def test_single_byte(self): + """Test single byte values""" + assert convert_bytes(1) == "1 B" + assert convert_bytes(999) == "999 B" + + def test_kilobyte_range(self): + """Test values in kilobyte range with different precisions""" + # Exactly 1 KB + assert convert_bytes(1024) == "1.00 KB" + + # Values that should show 1 decimal place (10-99.9 range) + assert convert_bytes(15360) == "15.0 KB" # 15 KB exactly + assert convert_bytes(10752) == "10.5 KB" # 10.5 KB + + # Values that should show 2 decimal places (1-9.99 range) + assert convert_bytes(2048) == "2.00 KB" # 2 KB exactly + assert convert_bytes(3072) == "3.00 KB" # 3 KB exactly + assert convert_bytes(5120) == "5.00 KB" # 5 KB exactly + + def test_megabyte_range(self): + """Test values in megabyte range""" + # Exactly 1 MB + assert convert_bytes(1048576) == "1.00 MB" + + # Values with different precision requirements + assert convert_bytes(15728640) == "15.0 MB" # 15.0 MB + assert convert_bytes(11010048) == "10.5 MB" # 10.5 MB + + def test_gigabyte_range(self): + """Test values in gigabyte range""" + # Exactly 1 GB + assert convert_bytes(1073741824) == "1.00 GB" + + # Large value that should show 0 decimal places + assert convert_bytes(3221225472) == "3.00 GB" # 3 GB exactly + + def test_terabyte_range(self): + """Test values in terabyte range""" + assert convert_bytes(1099511627776) == "1.00 TB" # 1 TB + + def test_petabyte_range(self): + """Test values in petabyte range""" + assert convert_bytes(1125899906842624) == "1.00 PB" # 1 PB + + def test_boundary_values(self): + """Test values at unit boundaries""" + # Just below 1 KB + assert convert_bytes(1023) == "1023 B" + + # Just above 1 KB + assert convert_bytes(1025) == "1.00 KB" + + # At 100 KB boundary (should switch to 0 decimal places) + assert convert_bytes(102400) == "100 KB" + assert convert_bytes(102300) == "99.9 KB" + + def test_precision_transitions(self): + """Test the precision formatting transitions""" + # Test transition from 2 decimal places to 1 decimal place + assert convert_bytes(9216) == "9.00 KB" # 9.00 KB (2 decimal places) + assert convert_bytes(10240) == "10.0 KB" # 10.0 KB (1 decimal place) + + # Test transition from 1 decimal place to 0 decimal places + assert convert_bytes(102400) == "100 KB" # 100 KB (0 decimal places) + + def test_large_values_no_overflow(self): + """Test that very large values don't cause issues""" + # Very large value that should use PB + large_value = 10 * 1125899906842624 # 10 PB + assert "PB" in convert_bytes(large_value) + + # Ensure we don't exceed available units + huge_value = 100 * 1125899906842624 # 100 PB (still within PB range) + assert "PB" in convert_bytes(huge_value)