mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Move some functions out of 'api/utils/common.py' (#10948)
### What problem does this PR solve? as title. ### Type of change - [x] Refactoring Signed-off-by: Jin Hai <haijin.chn@gmail.com>
This commit is contained in:
@ -14,12 +14,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
import threading
|
|
||||||
import subprocess
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
import logging
|
|
||||||
|
|
||||||
def string_to_bytes(string):
|
def string_to_bytes(string):
|
||||||
return string if isinstance(
|
return string if isinstance(
|
||||||
string, bytes) else string.encode(encoding="utf-8")
|
string, bytes) else string.encode(encoding="utf-8")
|
||||||
@ -28,70 +22,3 @@ def string_to_bytes(string):
|
|||||||
def bytes_to_string(byte):
|
def bytes_to_string(byte):
|
||||||
return byte.decode(encoding="utf-8")
|
return byte.decode(encoding="utf-8")
|
||||||
|
|
||||||
|
|
||||||
def convert_bytes(size_in_bytes: int) -> str:
|
|
||||||
"""
|
|
||||||
Format size in bytes.
|
|
||||||
"""
|
|
||||||
if size_in_bytes == 0:
|
|
||||||
return "0 B"
|
|
||||||
|
|
||||||
units = ['B', 'KB', 'MB', 'GB', 'TB', 'PB']
|
|
||||||
i = 0
|
|
||||||
size = float(size_in_bytes)
|
|
||||||
|
|
||||||
while size >= 1024 and i < len(units) - 1:
|
|
||||||
size /= 1024
|
|
||||||
i += 1
|
|
||||||
|
|
||||||
if i == 0 or size >= 100:
|
|
||||||
return f"{size:.0f} {units[i]}"
|
|
||||||
elif size >= 10:
|
|
||||||
return f"{size:.1f} {units[i]}"
|
|
||||||
else:
|
|
||||||
return f"{size:.2f} {units[i]}"
|
|
||||||
|
|
||||||
|
|
||||||
def once(func):
|
|
||||||
"""
|
|
||||||
A thread-safe decorator that ensures the decorated function runs exactly once,
|
|
||||||
caching and returning its result for all subsequent calls. This prevents
|
|
||||||
race conditions in multi-threaded environments by using a lock to protect
|
|
||||||
the execution state.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
func (callable): The function to be executed only once.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
callable: A wrapper function that executes `func` on the first call
|
|
||||||
and returns the cached result thereafter.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
@once
|
|
||||||
def compute_expensive_value():
|
|
||||||
print("Computing...")
|
|
||||||
return 42
|
|
||||||
|
|
||||||
# First call: executes and prints
|
|
||||||
# Subsequent calls: return 42 without executing
|
|
||||||
"""
|
|
||||||
executed = False
|
|
||||||
result = None
|
|
||||||
lock = threading.Lock()
|
|
||||||
def wrapper(*args, **kwargs):
|
|
||||||
nonlocal executed, result
|
|
||||||
with lock:
|
|
||||||
if not executed:
|
|
||||||
executed = True
|
|
||||||
result = func(*args, **kwargs)
|
|
||||||
return result
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
@once
|
|
||||||
def pip_install_torch():
|
|
||||||
device = os.getenv("DEVICE", "cpu")
|
|
||||||
if device=="cpu":
|
|
||||||
return
|
|
||||||
logging.info("Installing pytorch")
|
|
||||||
pkg_names = ["torch>=2.5.0,<3.0.0"]
|
|
||||||
subprocess.check_call([sys.executable, "-m", "pip", "install", *pkg_names])
|
|
||||||
|
|||||||
@ -18,6 +18,11 @@ import base64
|
|||||||
import hashlib
|
import hashlib
|
||||||
import uuid
|
import uuid
|
||||||
import requests
|
import requests
|
||||||
|
import threading
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
|
||||||
def get_uuid():
|
def get_uuid():
|
||||||
return uuid.uuid1().hex
|
return uuid.uuid1().hex
|
||||||
@ -34,3 +39,70 @@ def download_img(url):
|
|||||||
|
|
||||||
def hash_str2int(line: str, mod: int = 10 ** 8) -> int:
|
def hash_str2int(line: str, mod: int = 10 ** 8) -> int:
|
||||||
return int(hashlib.sha1(line.encode("utf-8")).hexdigest(), 16) % mod
|
return int(hashlib.sha1(line.encode("utf-8")).hexdigest(), 16) % mod
|
||||||
|
|
||||||
|
def convert_bytes(size_in_bytes: int) -> str:
|
||||||
|
"""
|
||||||
|
Format size in bytes.
|
||||||
|
"""
|
||||||
|
if size_in_bytes == 0:
|
||||||
|
return "0 B"
|
||||||
|
|
||||||
|
units = ['B', 'KB', 'MB', 'GB', 'TB', 'PB']
|
||||||
|
i = 0
|
||||||
|
size = float(size_in_bytes)
|
||||||
|
|
||||||
|
while size >= 1024 and i < len(units) - 1:
|
||||||
|
size /= 1024
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
if i == 0 or size >= 100:
|
||||||
|
return f"{size:.0f} {units[i]}"
|
||||||
|
elif size >= 10:
|
||||||
|
return f"{size:.1f} {units[i]}"
|
||||||
|
else:
|
||||||
|
return f"{size:.2f} {units[i]}"
|
||||||
|
|
||||||
|
|
||||||
|
def once(func):
|
||||||
|
"""
|
||||||
|
A thread-safe decorator that ensures the decorated function runs exactly once,
|
||||||
|
caching and returning its result for all subsequent calls. This prevents
|
||||||
|
race conditions in multi-thread environments by using a lock to protect
|
||||||
|
the execution state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func (callable): The function to be executed only once.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
callable: A wrapper function that executes `func` on the first call
|
||||||
|
and returns the cached result thereafter.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
@once
|
||||||
|
def compute_expensive_value():
|
||||||
|
print("Computing...")
|
||||||
|
return 42
|
||||||
|
|
||||||
|
# First call: executes and prints
|
||||||
|
# Subsequent calls: return 42 without executing
|
||||||
|
"""
|
||||||
|
executed = False
|
||||||
|
result = None
|
||||||
|
lock = threading.Lock()
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
nonlocal executed, result
|
||||||
|
with lock:
|
||||||
|
if not executed:
|
||||||
|
executed = True
|
||||||
|
result = func(*args, **kwargs)
|
||||||
|
return result
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
@once
|
||||||
|
def pip_install_torch():
|
||||||
|
device = os.getenv("DEVICE", "cpu")
|
||||||
|
if device=="cpu":
|
||||||
|
return
|
||||||
|
logging.info("Installing pytorch")
|
||||||
|
pkg_names = ["torch>=2.5.0,<3.0.0"]
|
||||||
|
subprocess.check_call([sys.executable, "-m", "pip", "install", *pkg_names])
|
||||||
|
|||||||
@ -35,7 +35,7 @@ from PIL import Image
|
|||||||
from pypdf import PdfReader as pdf2_read
|
from pypdf import PdfReader as pdf2_read
|
||||||
|
|
||||||
from common.file_utils import get_project_base_directory
|
from common.file_utils import get_project_base_directory
|
||||||
from api.utils.common import pip_install_torch
|
from common.misc_utils import pip_install_torch
|
||||||
from deepdoc.vision import OCR, AscendLayoutRecognizer, LayoutRecognizer, Recognizer, TableStructureRecognizer
|
from deepdoc.vision import OCR, AscendLayoutRecognizer, LayoutRecognizer, Recognizer, TableStructureRecognizer
|
||||||
from rag.app.picture import vision_llm_chunk as picture_vision_llm_chunk
|
from rag.app.picture import vision_llm_chunk as picture_vision_llm_chunk
|
||||||
from rag.nlp import rag_tokenizer
|
from rag.nlp import rag_tokenizer
|
||||||
|
|||||||
@ -22,7 +22,7 @@ import os
|
|||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
from common.file_utils import get_project_base_directory
|
from common.file_utils import get_project_base_directory
|
||||||
from api.utils.common import pip_install_torch
|
from common.misc_utils import pip_install_torch
|
||||||
from rag.settings import PARALLEL_DEVICES
|
from rag.settings import PARALLEL_DEVICES
|
||||||
from .operators import * # noqa: F403
|
from .operators import * # noqa: F403
|
||||||
from . import operators
|
from . import operators
|
||||||
|
|||||||
@ -17,7 +17,7 @@ import os
|
|||||||
import logging
|
import logging
|
||||||
from api.utils.configs import get_base_config, decrypt_database_config
|
from api.utils.configs import get_base_config, decrypt_database_config
|
||||||
from common.file_utils import get_project_base_directory
|
from common.file_utils import get_project_base_directory
|
||||||
from api.utils.common import pip_install_torch
|
from common.misc_utils import pip_install_torch
|
||||||
|
|
||||||
# Server
|
# Server
|
||||||
RAG_CONF_PATH = os.path.join(get_project_base_directory(), "conf")
|
RAG_CONF_PATH = os.path.join(get_project_base_directory(), "conf")
|
||||||
|
|||||||
@ -28,7 +28,7 @@ from rag import settings
|
|||||||
from rag.settings import TAG_FLD, PAGERANK_FLD
|
from rag.settings import TAG_FLD, PAGERANK_FLD
|
||||||
from common.decorator import singleton
|
from common.decorator import singleton
|
||||||
from common.file_utils import get_project_base_directory
|
from common.file_utils import get_project_base_directory
|
||||||
from api.utils.common import convert_bytes
|
from common.misc_utils import convert_bytes
|
||||||
from rag.utils.doc_store_conn import DocStoreConnection, MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, \
|
from rag.utils.doc_store_conn import DocStoreConnection, MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, \
|
||||||
FusionExpr
|
FusionExpr
|
||||||
from rag.nlp import is_english, rag_tokenizer
|
from rag.nlp import is_english, rag_tokenizer
|
||||||
|
|||||||
@ -15,7 +15,7 @@
|
|||||||
#
|
#
|
||||||
import uuid
|
import uuid
|
||||||
import hashlib
|
import hashlib
|
||||||
from common.misc_utils import get_uuid, download_img, hash_str2int
|
from common.misc_utils import get_uuid, download_img, hash_str2int, convert_bytes
|
||||||
|
|
||||||
|
|
||||||
class TestGetUuid:
|
class TestGetUuid:
|
||||||
@ -270,3 +270,86 @@ class TestHashStr2Int:
|
|||||||
result = hash_str2int(test_str)
|
result = hash_str2int(test_str)
|
||||||
assert isinstance(result, int)
|
assert isinstance(result, int)
|
||||||
assert 0 <= result < 10 ** 8
|
assert 0 <= result < 10 ** 8
|
||||||
|
|
||||||
|
|
||||||
|
class TestConvertBytes:
|
||||||
|
"""Test suite for convert_bytes function"""
|
||||||
|
|
||||||
|
def test_zero_bytes(self):
|
||||||
|
"""Test that 0 bytes returns '0 B'"""
|
||||||
|
assert convert_bytes(0) == "0 B"
|
||||||
|
|
||||||
|
def test_single_byte(self):
|
||||||
|
"""Test single byte values"""
|
||||||
|
assert convert_bytes(1) == "1 B"
|
||||||
|
assert convert_bytes(999) == "999 B"
|
||||||
|
|
||||||
|
def test_kilobyte_range(self):
|
||||||
|
"""Test values in kilobyte range with different precisions"""
|
||||||
|
# Exactly 1 KB
|
||||||
|
assert convert_bytes(1024) == "1.00 KB"
|
||||||
|
|
||||||
|
# Values that should show 1 decimal place (10-99.9 range)
|
||||||
|
assert convert_bytes(15360) == "15.0 KB" # 15 KB exactly
|
||||||
|
assert convert_bytes(10752) == "10.5 KB" # 10.5 KB
|
||||||
|
|
||||||
|
# Values that should show 2 decimal places (1-9.99 range)
|
||||||
|
assert convert_bytes(2048) == "2.00 KB" # 2 KB exactly
|
||||||
|
assert convert_bytes(3072) == "3.00 KB" # 3 KB exactly
|
||||||
|
assert convert_bytes(5120) == "5.00 KB" # 5 KB exactly
|
||||||
|
|
||||||
|
def test_megabyte_range(self):
|
||||||
|
"""Test values in megabyte range"""
|
||||||
|
# Exactly 1 MB
|
||||||
|
assert convert_bytes(1048576) == "1.00 MB"
|
||||||
|
|
||||||
|
# Values with different precision requirements
|
||||||
|
assert convert_bytes(15728640) == "15.0 MB" # 15.0 MB
|
||||||
|
assert convert_bytes(11010048) == "10.5 MB" # 10.5 MB
|
||||||
|
|
||||||
|
def test_gigabyte_range(self):
|
||||||
|
"""Test values in gigabyte range"""
|
||||||
|
# Exactly 1 GB
|
||||||
|
assert convert_bytes(1073741824) == "1.00 GB"
|
||||||
|
|
||||||
|
# Large value that should show 0 decimal places
|
||||||
|
assert convert_bytes(3221225472) == "3.00 GB" # 3 GB exactly
|
||||||
|
|
||||||
|
def test_terabyte_range(self):
|
||||||
|
"""Test values in terabyte range"""
|
||||||
|
assert convert_bytes(1099511627776) == "1.00 TB" # 1 TB
|
||||||
|
|
||||||
|
def test_petabyte_range(self):
|
||||||
|
"""Test values in petabyte range"""
|
||||||
|
assert convert_bytes(1125899906842624) == "1.00 PB" # 1 PB
|
||||||
|
|
||||||
|
def test_boundary_values(self):
|
||||||
|
"""Test values at unit boundaries"""
|
||||||
|
# Just below 1 KB
|
||||||
|
assert convert_bytes(1023) == "1023 B"
|
||||||
|
|
||||||
|
# Just above 1 KB
|
||||||
|
assert convert_bytes(1025) == "1.00 KB"
|
||||||
|
|
||||||
|
# At 100 KB boundary (should switch to 0 decimal places)
|
||||||
|
assert convert_bytes(102400) == "100 KB"
|
||||||
|
assert convert_bytes(102300) == "99.9 KB"
|
||||||
|
|
||||||
|
def test_precision_transitions(self):
|
||||||
|
"""Test the precision formatting transitions"""
|
||||||
|
# Test transition from 2 decimal places to 1 decimal place
|
||||||
|
assert convert_bytes(9216) == "9.00 KB" # 9.00 KB (2 decimal places)
|
||||||
|
assert convert_bytes(10240) == "10.0 KB" # 10.0 KB (1 decimal place)
|
||||||
|
|
||||||
|
# Test transition from 1 decimal place to 0 decimal places
|
||||||
|
assert convert_bytes(102400) == "100 KB" # 100 KB (0 decimal places)
|
||||||
|
|
||||||
|
def test_large_values_no_overflow(self):
|
||||||
|
"""Test that very large values don't cause issues"""
|
||||||
|
# Very large value that should use PB
|
||||||
|
large_value = 10 * 1125899906842624 # 10 PB
|
||||||
|
assert "PB" in convert_bytes(large_value)
|
||||||
|
|
||||||
|
# Ensure we don't exceed available units
|
||||||
|
huge_value = 100 * 1125899906842624 # 100 PB (still within PB range)
|
||||||
|
assert "PB" in convert_bytes(huge_value)
|
||||||
|
|||||||
Reference in New Issue
Block a user