Move some functions out of 'api/utils/common.py' (#10948)

### What problem does this PR solve?

as title.

### Type of change

- [x] Refactoring

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
This commit is contained in:
Jin Hai
2025-11-03 12:34:47 +08:00
committed by GitHub
parent 4117f41758
commit 78631a3fd3
7 changed files with 161 additions and 79 deletions

View File

@ -14,12 +14,6 @@
# limitations under the License. # limitations under the License.
# #
import threading
import subprocess
import sys
import os
import logging
def string_to_bytes(string): def string_to_bytes(string):
return string if isinstance( return string if isinstance(
string, bytes) else string.encode(encoding="utf-8") string, bytes) else string.encode(encoding="utf-8")
@ -28,70 +22,3 @@ def string_to_bytes(string):
def bytes_to_string(byte): def bytes_to_string(byte):
return byte.decode(encoding="utf-8") return byte.decode(encoding="utf-8")
def convert_bytes(size_in_bytes: int) -> str:
"""
Format size in bytes.
"""
if size_in_bytes == 0:
return "0 B"
units = ['B', 'KB', 'MB', 'GB', 'TB', 'PB']
i = 0
size = float(size_in_bytes)
while size >= 1024 and i < len(units) - 1:
size /= 1024
i += 1
if i == 0 or size >= 100:
return f"{size:.0f} {units[i]}"
elif size >= 10:
return f"{size:.1f} {units[i]}"
else:
return f"{size:.2f} {units[i]}"
def once(func):
"""
A thread-safe decorator that ensures the decorated function runs exactly once,
caching and returning its result for all subsequent calls. This prevents
race conditions in multi-threaded environments by using a lock to protect
the execution state.
Args:
func (callable): The function to be executed only once.
Returns:
callable: A wrapper function that executes `func` on the first call
and returns the cached result thereafter.
Example:
@once
def compute_expensive_value():
print("Computing...")
return 42
# First call: executes and prints
# Subsequent calls: return 42 without executing
"""
executed = False
result = None
lock = threading.Lock()
def wrapper(*args, **kwargs):
nonlocal executed, result
with lock:
if not executed:
executed = True
result = func(*args, **kwargs)
return result
return wrapper
@once
def pip_install_torch():
device = os.getenv("DEVICE", "cpu")
if device=="cpu":
return
logging.info("Installing pytorch")
pkg_names = ["torch>=2.5.0,<3.0.0"]
subprocess.check_call([sys.executable, "-m", "pip", "install", *pkg_names])

View File

@ -18,6 +18,11 @@ import base64
import hashlib import hashlib
import uuid import uuid
import requests import requests
import threading
import subprocess
import sys
import os
import logging
def get_uuid(): def get_uuid():
return uuid.uuid1().hex return uuid.uuid1().hex
@ -34,3 +39,70 @@ def download_img(url):
def hash_str2int(line: str, mod: int = 10 ** 8) -> int: def hash_str2int(line: str, mod: int = 10 ** 8) -> int:
return int(hashlib.sha1(line.encode("utf-8")).hexdigest(), 16) % mod return int(hashlib.sha1(line.encode("utf-8")).hexdigest(), 16) % mod
def convert_bytes(size_in_bytes: int) -> str:
"""
Format size in bytes.
"""
if size_in_bytes == 0:
return "0 B"
units = ['B', 'KB', 'MB', 'GB', 'TB', 'PB']
i = 0
size = float(size_in_bytes)
while size >= 1024 and i < len(units) - 1:
size /= 1024
i += 1
if i == 0 or size >= 100:
return f"{size:.0f} {units[i]}"
elif size >= 10:
return f"{size:.1f} {units[i]}"
else:
return f"{size:.2f} {units[i]}"
def once(func):
"""
A thread-safe decorator that ensures the decorated function runs exactly once,
caching and returning its result for all subsequent calls. This prevents
race conditions in multi-thread environments by using a lock to protect
the execution state.
Args:
func (callable): The function to be executed only once.
Returns:
callable: A wrapper function that executes `func` on the first call
and returns the cached result thereafter.
Example:
@once
def compute_expensive_value():
print("Computing...")
return 42
# First call: executes and prints
# Subsequent calls: return 42 without executing
"""
executed = False
result = None
lock = threading.Lock()
def wrapper(*args, **kwargs):
nonlocal executed, result
with lock:
if not executed:
executed = True
result = func(*args, **kwargs)
return result
return wrapper
@once
def pip_install_torch():
device = os.getenv("DEVICE", "cpu")
if device=="cpu":
return
logging.info("Installing pytorch")
pkg_names = ["torch>=2.5.0,<3.0.0"]
subprocess.check_call([sys.executable, "-m", "pip", "install", *pkg_names])

View File

@ -35,7 +35,7 @@ from PIL import Image
from pypdf import PdfReader as pdf2_read from pypdf import PdfReader as pdf2_read
from common.file_utils import get_project_base_directory from common.file_utils import get_project_base_directory
from api.utils.common import pip_install_torch from common.misc_utils import pip_install_torch
from deepdoc.vision import OCR, AscendLayoutRecognizer, LayoutRecognizer, Recognizer, TableStructureRecognizer from deepdoc.vision import OCR, AscendLayoutRecognizer, LayoutRecognizer, Recognizer, TableStructureRecognizer
from rag.app.picture import vision_llm_chunk as picture_vision_llm_chunk from rag.app.picture import vision_llm_chunk as picture_vision_llm_chunk
from rag.nlp import rag_tokenizer from rag.nlp import rag_tokenizer

View File

@ -22,7 +22,7 @@ import os
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from common.file_utils import get_project_base_directory from common.file_utils import get_project_base_directory
from api.utils.common import pip_install_torch from common.misc_utils import pip_install_torch
from rag.settings import PARALLEL_DEVICES from rag.settings import PARALLEL_DEVICES
from .operators import * # noqa: F403 from .operators import * # noqa: F403
from . import operators from . import operators

View File

@ -17,7 +17,7 @@ import os
import logging import logging
from api.utils.configs import get_base_config, decrypt_database_config from api.utils.configs import get_base_config, decrypt_database_config
from common.file_utils import get_project_base_directory from common.file_utils import get_project_base_directory
from api.utils.common import pip_install_torch from common.misc_utils import pip_install_torch
# Server # Server
RAG_CONF_PATH = os.path.join(get_project_base_directory(), "conf") RAG_CONF_PATH = os.path.join(get_project_base_directory(), "conf")

View File

@ -28,7 +28,7 @@ from rag import settings
from rag.settings import TAG_FLD, PAGERANK_FLD from rag.settings import TAG_FLD, PAGERANK_FLD
from common.decorator import singleton from common.decorator import singleton
from common.file_utils import get_project_base_directory from common.file_utils import get_project_base_directory
from api.utils.common import convert_bytes from common.misc_utils import convert_bytes
from rag.utils.doc_store_conn import DocStoreConnection, MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, \ from rag.utils.doc_store_conn import DocStoreConnection, MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, \
FusionExpr FusionExpr
from rag.nlp import is_english, rag_tokenizer from rag.nlp import is_english, rag_tokenizer

View File

@ -15,7 +15,7 @@
# #
import uuid import uuid
import hashlib import hashlib
from common.misc_utils import get_uuid, download_img, hash_str2int from common.misc_utils import get_uuid, download_img, hash_str2int, convert_bytes
class TestGetUuid: class TestGetUuid:
@ -270,3 +270,86 @@ class TestHashStr2Int:
result = hash_str2int(test_str) result = hash_str2int(test_str)
assert isinstance(result, int) assert isinstance(result, int)
assert 0 <= result < 10 ** 8 assert 0 <= result < 10 ** 8
class TestConvertBytes:
"""Test suite for convert_bytes function"""
def test_zero_bytes(self):
"""Test that 0 bytes returns '0 B'"""
assert convert_bytes(0) == "0 B"
def test_single_byte(self):
"""Test single byte values"""
assert convert_bytes(1) == "1 B"
assert convert_bytes(999) == "999 B"
def test_kilobyte_range(self):
"""Test values in kilobyte range with different precisions"""
# Exactly 1 KB
assert convert_bytes(1024) == "1.00 KB"
# Values that should show 1 decimal place (10-99.9 range)
assert convert_bytes(15360) == "15.0 KB" # 15 KB exactly
assert convert_bytes(10752) == "10.5 KB" # 10.5 KB
# Values that should show 2 decimal places (1-9.99 range)
assert convert_bytes(2048) == "2.00 KB" # 2 KB exactly
assert convert_bytes(3072) == "3.00 KB" # 3 KB exactly
assert convert_bytes(5120) == "5.00 KB" # 5 KB exactly
def test_megabyte_range(self):
"""Test values in megabyte range"""
# Exactly 1 MB
assert convert_bytes(1048576) == "1.00 MB"
# Values with different precision requirements
assert convert_bytes(15728640) == "15.0 MB" # 15.0 MB
assert convert_bytes(11010048) == "10.5 MB" # 10.5 MB
def test_gigabyte_range(self):
"""Test values in gigabyte range"""
# Exactly 1 GB
assert convert_bytes(1073741824) == "1.00 GB"
# Large value that should show 0 decimal places
assert convert_bytes(3221225472) == "3.00 GB" # 3 GB exactly
def test_terabyte_range(self):
"""Test values in terabyte range"""
assert convert_bytes(1099511627776) == "1.00 TB" # 1 TB
def test_petabyte_range(self):
"""Test values in petabyte range"""
assert convert_bytes(1125899906842624) == "1.00 PB" # 1 PB
def test_boundary_values(self):
"""Test values at unit boundaries"""
# Just below 1 KB
assert convert_bytes(1023) == "1023 B"
# Just above 1 KB
assert convert_bytes(1025) == "1.00 KB"
# At 100 KB boundary (should switch to 0 decimal places)
assert convert_bytes(102400) == "100 KB"
assert convert_bytes(102300) == "99.9 KB"
def test_precision_transitions(self):
"""Test the precision formatting transitions"""
# Test transition from 2 decimal places to 1 decimal place
assert convert_bytes(9216) == "9.00 KB" # 9.00 KB (2 decimal places)
assert convert_bytes(10240) == "10.0 KB" # 10.0 KB (1 decimal place)
# Test transition from 1 decimal place to 0 decimal places
assert convert_bytes(102400) == "100 KB" # 100 KB (0 decimal places)
def test_large_values_no_overflow(self):
"""Test that very large values don't cause issues"""
# Very large value that should use PB
large_value = 10 * 1125899906842624 # 10 PB
assert "PB" in convert_bytes(large_value)
# Ensure we don't exceed available units
huge_value = 100 * 1125899906842624 # 100 PB (still within PB range)
assert "PB" in convert_bytes(huge_value)