mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 12:32:30 +08:00
Move 'get_project_base_directory' to common directory (#10940)
### What problem does this PR solve? As title ### Type of change - [x] Refactoring --------- Signed-off-by: Jin Hai <haijin.chn@gmail.com>
This commit is contained in:
@ -42,7 +42,8 @@ from api.utils.api_utils import (
|
|||||||
server_error_response,
|
server_error_response,
|
||||||
validate_request,
|
validate_request,
|
||||||
)
|
)
|
||||||
from api.utils.file_utils import filename_type, get_project_base_directory, thumbnail
|
from api.utils.file_utils import filename_type, thumbnail
|
||||||
|
from common.file_utils import get_project_base_directory
|
||||||
from api.utils.web_utils import CONTENT_TYPE_MAP, html2pdf, is_valid_url
|
from api.utils.web_utils import CONTENT_TYPE_MAP, html2pdf, is_valid_url
|
||||||
from deepdoc.parser.html_parser import RAGFlowHtmlParser
|
from deepdoc.parser.html_parser import RAGFlowHtmlParser
|
||||||
from rag.nlp import search, rag_tokenizer
|
from rag.nlp import search, rag_tokenizer
|
||||||
|
|||||||
@ -30,7 +30,7 @@ from api.db.services.tenant_llm_service import LLMFactoriesService, TenantLLMSer
|
|||||||
from api.db.services.llm_service import LLMService, LLMBundle, get_init_tenant_llm
|
from api.db.services.llm_service import LLMService, LLMBundle, get_init_tenant_llm
|
||||||
from api.db.services.user_service import TenantService, UserTenantService
|
from api.db.services.user_service import TenantService, UserTenantService
|
||||||
from api import settings
|
from api import settings
|
||||||
from api.utils.file_utils import get_project_base_directory
|
from common.file_utils import get_project_base_directory
|
||||||
from api.common.base64 import encode_to_base64
|
from api.common.base64 import encode_to_base64
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -36,7 +36,7 @@ from api import settings
|
|||||||
from api.apps import app, smtp_mail_server
|
from api.apps import app, smtp_mail_server
|
||||||
from api.db.runtime_config import RuntimeConfig
|
from api.db.runtime_config import RuntimeConfig
|
||||||
from api.db.services.document_service import DocumentService
|
from api.db.services.document_service import DocumentService
|
||||||
from api import utils
|
from common.file_utils import get_project_base_directory
|
||||||
|
|
||||||
from api.db.db_models import init_database_tables as init_web_db
|
from api.db.db_models import init_database_tables as init_web_db
|
||||||
from api.db.init_data import init_web_data
|
from api.db.init_data import init_web_data
|
||||||
@ -88,7 +88,7 @@ if __name__ == '__main__':
|
|||||||
f'RAGFlow version: {get_ragflow_version()}'
|
f'RAGFlow version: {get_ragflow_version()}'
|
||||||
)
|
)
|
||||||
logging.info(
|
logging.info(
|
||||||
f'project base: {utils.file_utils.get_project_base_directory()}'
|
f'project base: {get_project_base_directory()}'
|
||||||
)
|
)
|
||||||
show_configs()
|
show_configs()
|
||||||
settings.init_settings()
|
settings.init_settings()
|
||||||
|
|||||||
@ -25,7 +25,7 @@ import rag.utils.infinity_conn
|
|||||||
import rag.utils.opensearch_conn
|
import rag.utils.opensearch_conn
|
||||||
from api.constants import RAG_FLOW_SERVICE_NAME
|
from api.constants import RAG_FLOW_SERVICE_NAME
|
||||||
from api.utils.configs import decrypt_database_config, get_base_config
|
from api.utils.configs import decrypt_database_config, get_base_config
|
||||||
from api.utils.file_utils import get_project_base_directory
|
from common.file_utils import get_project_base_directory
|
||||||
from rag.nlp import search
|
from rag.nlp import search
|
||||||
|
|
||||||
LLM = None
|
LLM = None
|
||||||
|
|||||||
@ -23,6 +23,7 @@ import pickle
|
|||||||
import importlib
|
import importlib
|
||||||
|
|
||||||
from api.utils import file_utils
|
from api.utils import file_utils
|
||||||
|
from common.file_utils import get_project_base_directory
|
||||||
from filelock import FileLock
|
from filelock import FileLock
|
||||||
from api.utils.common import bytes_to_string, string_to_bytes
|
from api.utils.common import bytes_to_string, string_to_bytes
|
||||||
from api.constants import SERVICE_CONF
|
from api.constants import SERVICE_CONF
|
||||||
@ -30,7 +31,7 @@ from api.constants import SERVICE_CONF
|
|||||||
|
|
||||||
def conf_realpath(conf_name):
|
def conf_realpath(conf_name):
|
||||||
conf_path = f"conf/{conf_name}"
|
conf_path = f"conf/{conf_name}"
|
||||||
return os.path.join(file_utils.get_project_base_directory(), conf_path)
|
return os.path.join(get_project_base_directory(), conf_path)
|
||||||
|
|
||||||
|
|
||||||
def read_config(conf_name=SERVICE_CONF):
|
def read_config(conf_name=SERVICE_CONF):
|
||||||
@ -129,8 +130,7 @@ def decrypt_database_config(
|
|||||||
def update_config(key, value, conf_name=SERVICE_CONF):
|
def update_config(key, value, conf_name=SERVICE_CONF):
|
||||||
conf_path = conf_realpath(conf_name=conf_name)
|
conf_path = conf_realpath(conf_name=conf_name)
|
||||||
if not os.path.isabs(conf_path):
|
if not os.path.isabs(conf_path):
|
||||||
conf_path = os.path.join(
|
conf_path = os.path.join(get_project_base_directory(), conf_path)
|
||||||
file_utils.get_project_base_directory(), conf_path)
|
|
||||||
|
|
||||||
with FileLock(os.path.join(os.path.dirname(conf_path), ".lock")):
|
with FileLock(os.path.join(os.path.dirname(conf_path), ".lock")):
|
||||||
config = file_utils.load_yaml_conf(conf_path=conf_path) or {}
|
config = file_utils.load_yaml_conf(conf_path=conf_path) or {}
|
||||||
|
|||||||
@ -19,14 +19,14 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
from Cryptodome.PublicKey import RSA
|
from Cryptodome.PublicKey import RSA
|
||||||
from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5
|
from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5
|
||||||
from api.utils import file_utils
|
from common.file_utils import get_project_base_directory
|
||||||
|
|
||||||
|
|
||||||
def crypt(line):
|
def crypt(line):
|
||||||
"""
|
"""
|
||||||
decrypt(crypt(input_string)) == base64(input_string), which frontend and admin_client use.
|
decrypt(crypt(input_string)) == base64(input_string), which frontend and admin_client use.
|
||||||
"""
|
"""
|
||||||
file_path = os.path.join(file_utils.get_project_base_directory(), "conf", "public.pem")
|
file_path = os.path.join(get_project_base_directory(), "conf", "public.pem")
|
||||||
rsa_key = RSA.importKey(open(file_path).read(), "Welcome")
|
rsa_key = RSA.importKey(open(file_path).read(), "Welcome")
|
||||||
cipher = Cipher_pkcs1_v1_5.new(rsa_key)
|
cipher = Cipher_pkcs1_v1_5.new(rsa_key)
|
||||||
password_base64 = base64.b64encode(line.encode('utf-8')).decode("utf-8")
|
password_base64 = base64.b64encode(line.encode('utf-8')).decode("utf-8")
|
||||||
@ -35,7 +35,7 @@ def crypt(line):
|
|||||||
|
|
||||||
|
|
||||||
def decrypt(line):
|
def decrypt(line):
|
||||||
file_path = os.path.join(file_utils.get_project_base_directory(), "conf", "private.pem")
|
file_path = os.path.join(get_project_base_directory(), "conf", "private.pem")
|
||||||
rsa_key = RSA.importKey(open(file_path).read(), "Welcome")
|
rsa_key = RSA.importKey(open(file_path).read(), "Welcome")
|
||||||
cipher = Cipher_pkcs1_v1_5.new(rsa_key)
|
cipher = Cipher_pkcs1_v1_5.new(rsa_key)
|
||||||
return cipher.decrypt(base64.b64decode(line), "Fail to decrypt password!").decode('utf-8')
|
return cipher.decrypt(base64.b64decode(line), "Fail to decrypt password!").decode('utf-8')
|
||||||
@ -50,7 +50,7 @@ def decrypt2(crypt_text):
|
|||||||
hex_fixed = '00' + decode_data.hex()
|
hex_fixed = '00' + decode_data.hex()
|
||||||
decode_data = b16decode(hex_fixed.upper())
|
decode_data = b16decode(hex_fixed.upper())
|
||||||
|
|
||||||
file_path = os.path.join(file_utils.get_project_base_directory(), "conf", "private.pem")
|
file_path = os.path.join(get_project_base_directory(), "conf", "private.pem")
|
||||||
pem = open(file_path).read()
|
pem = open(file_path).read()
|
||||||
rsa_key = RSA.importKey(pem, "Welcome")
|
rsa_key = RSA.importKey(pem, "Welcome")
|
||||||
cipher = Cipher_PKCS1_v1_5.new(rsa_key)
|
cipher = Cipher_PKCS1_v1_5.new(rsa_key)
|
||||||
|
|||||||
@ -43,6 +43,7 @@ from ruamel.yaml import YAML
|
|||||||
# Local imports
|
# Local imports
|
||||||
from api.constants import IMG_BASE64_PREFIX
|
from api.constants import IMG_BASE64_PREFIX
|
||||||
from api.db import FileType
|
from api.db import FileType
|
||||||
|
from common.file_utils import get_project_base_directory
|
||||||
|
|
||||||
PROJECT_BASE = os.getenv("RAG_PROJECT_BASE") or os.getenv("RAG_DEPLOY_BASE")
|
PROJECT_BASE = os.getenv("RAG_PROJECT_BASE") or os.getenv("RAG_DEPLOY_BASE")
|
||||||
|
|
||||||
@ -51,21 +52,6 @@ if LOCK_KEY_pdfplumber not in sys.modules:
|
|||||||
sys.modules[LOCK_KEY_pdfplumber] = threading.Lock()
|
sys.modules[LOCK_KEY_pdfplumber] = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
def get_project_base_directory(*args):
|
|
||||||
global PROJECT_BASE
|
|
||||||
if PROJECT_BASE is None:
|
|
||||||
PROJECT_BASE = os.path.abspath(
|
|
||||||
os.path.join(
|
|
||||||
os.path.dirname(os.path.realpath(__file__)),
|
|
||||||
os.pardir,
|
|
||||||
os.pardir,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if args:
|
|
||||||
return os.path.join(PROJECT_BASE, *args)
|
|
||||||
return PROJECT_BASE
|
|
||||||
|
|
||||||
@cached(cache=LRUCache(maxsize=10))
|
@cached(cache=LRUCache(maxsize=10))
|
||||||
def load_json_conf(conf_path):
|
def load_json_conf(conf_path):
|
||||||
if os.path.isabs(conf_path):
|
if os.path.isabs(conf_path):
|
||||||
|
|||||||
@ -17,19 +17,10 @@ import os
|
|||||||
import os.path
|
import os.path
|
||||||
import logging
|
import logging
|
||||||
from logging.handlers import RotatingFileHandler
|
from logging.handlers import RotatingFileHandler
|
||||||
|
from common.file_utils import get_project_base_directory
|
||||||
|
|
||||||
initialized_root_logger = False
|
initialized_root_logger = False
|
||||||
|
|
||||||
def get_project_base_directory():
|
|
||||||
PROJECT_BASE = os.path.abspath(
|
|
||||||
os.path.join(
|
|
||||||
os.path.dirname(os.path.realpath(__file__)),
|
|
||||||
os.pardir,
|
|
||||||
os.pardir,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return PROJECT_BASE
|
|
||||||
|
|
||||||
def init_root_logger(logfile_basename: str, log_format: str = "%(asctime)-15s %(levelname)-8s %(process)d %(message)s"):
|
def init_root_logger(logfile_basename: str, log_format: str = "%(asctime)-15s %(levelname)-8s %(process)d %(message)s"):
|
||||||
global initialized_root_logger
|
global initialized_root_logger
|
||||||
if initialized_root_logger:
|
if initialized_root_logger:
|
||||||
|
|||||||
33
common/file_utils.py
Normal file
33
common/file_utils.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
PROJECT_BASE = os.getenv("RAG_PROJECT_BASE") or os.getenv("RAG_DEPLOY_BASE")
|
||||||
|
|
||||||
|
def get_project_base_directory(*args):
|
||||||
|
global PROJECT_BASE
|
||||||
|
if PROJECT_BASE is None:
|
||||||
|
PROJECT_BASE = os.path.abspath(
|
||||||
|
os.path.join(
|
||||||
|
os.path.dirname(os.path.realpath(__file__)),
|
||||||
|
os.pardir,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if args:
|
||||||
|
return os.path.join(PROJECT_BASE, *args)
|
||||||
|
return PROJECT_BASE
|
||||||
@ -34,7 +34,7 @@ from huggingface_hub import snapshot_download
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
from pypdf import PdfReader as pdf2_read
|
from pypdf import PdfReader as pdf2_read
|
||||||
|
|
||||||
from api.utils.file_utils import get_project_base_directory
|
from common.file_utils import get_project_base_directory
|
||||||
from api.utils.common import pip_install_torch
|
from api.utils.common import pip_install_torch
|
||||||
from deepdoc.vision import OCR, AscendLayoutRecognizer, LayoutRecognizer, Recognizer, TableStructureRecognizer
|
from deepdoc.vision import OCR, AscendLayoutRecognizer, LayoutRecognizer, Recognizer, TableStructureRecognizer
|
||||||
from rag.app.picture import vision_llm_chunk as picture_vision_llm_chunk
|
from rag.app.picture import vision_llm_chunk as picture_vision_llm_chunk
|
||||||
|
|||||||
@ -25,7 +25,7 @@ import cv2
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
from api.utils.file_utils import get_project_base_directory
|
from common.file_utils import get_project_base_directory
|
||||||
from deepdoc.vision import Recognizer
|
from deepdoc.vision import Recognizer
|
||||||
from deepdoc.vision.operators import nms
|
from deepdoc.vision.operators import nms
|
||||||
|
|
||||||
|
|||||||
@ -21,7 +21,7 @@ import os
|
|||||||
|
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
from api.utils.file_utils import get_project_base_directory
|
from common.file_utils import get_project_base_directory
|
||||||
from api.utils.common import pip_install_torch
|
from api.utils.common import pip_install_torch
|
||||||
from rag.settings import PARALLEL_DEVICES
|
from rag.settings import PARALLEL_DEVICES
|
||||||
from .operators import * # noqa: F403
|
from .operators import * # noqa: F403
|
||||||
|
|||||||
@ -22,7 +22,7 @@ import cv2
|
|||||||
from functools import cmp_to_key
|
from functools import cmp_to_key
|
||||||
|
|
||||||
|
|
||||||
from api.utils.file_utils import get_project_base_directory
|
from common.file_utils import get_project_base_directory
|
||||||
from .operators import * # noqa: F403
|
from .operators import * # noqa: F403
|
||||||
from .operators import preprocess
|
from .operators import preprocess
|
||||||
from . import operators
|
from . import operators
|
||||||
|
|||||||
@ -21,7 +21,7 @@ from collections import Counter
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
from api.utils.file_utils import get_project_base_directory
|
from common.file_utils import get_project_base_directory
|
||||||
from rag.nlp import rag_tokenizer
|
from rag.nlp import rag_tokenizer
|
||||||
|
|
||||||
from .recognizer import Recognizer
|
from .recognizer import Recognizer
|
||||||
|
|||||||
@ -25,7 +25,7 @@ import sys
|
|||||||
from hanziconv import HanziConv
|
from hanziconv import HanziConv
|
||||||
from nltk import word_tokenize
|
from nltk import word_tokenize
|
||||||
from nltk.stem import PorterStemmer, WordNetLemmatizer
|
from nltk.stem import PorterStemmer, WordNetLemmatizer
|
||||||
from api.utils.file_utils import get_project_base_directory
|
from common.file_utils import get_project_base_directory
|
||||||
|
|
||||||
|
|
||||||
class RagTokenizer:
|
class RagTokenizer:
|
||||||
|
|||||||
@ -20,7 +20,7 @@ import os
|
|||||||
import time
|
import time
|
||||||
import re
|
import re
|
||||||
from nltk.corpus import wordnet
|
from nltk.corpus import wordnet
|
||||||
from api.utils.file_utils import get_project_base_directory
|
from common.file_utils import get_project_base_directory
|
||||||
|
|
||||||
|
|
||||||
class Dealer:
|
class Dealer:
|
||||||
|
|||||||
@ -21,7 +21,7 @@ import re
|
|||||||
import os
|
import os
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from rag.nlp import rag_tokenizer
|
from rag.nlp import rag_tokenizer
|
||||||
from api.utils.file_utils import get_project_base_directory
|
from common.file_utils import get_project_base_directory
|
||||||
|
|
||||||
|
|
||||||
class Dealer:
|
class Dealer:
|
||||||
|
|||||||
@ -16,7 +16,7 @@
|
|||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
from api.utils.configs import get_base_config, decrypt_database_config
|
from api.utils.configs import get_base_config, decrypt_database_config
|
||||||
from api.utils.file_utils import get_project_base_directory
|
from common.file_utils import get_project_base_directory
|
||||||
from api.utils.common import pip_install_torch
|
from api.utils.common import pip_install_torch
|
||||||
|
|
||||||
# Server
|
# Server
|
||||||
|
|||||||
@ -28,7 +28,8 @@ from api.db.services.knowledgebase_service import KnowledgebaseService
|
|||||||
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
|
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
|
||||||
from api.utils.api_utils import timeout
|
from api.utils.api_utils import timeout
|
||||||
from api.utils.base64_image import image2id
|
from api.utils.base64_image import image2id
|
||||||
from api.utils.log_utils import init_root_logger, get_project_base_directory
|
from api.utils.log_utils import init_root_logger
|
||||||
|
from common.file_utils import get_project_base_directory
|
||||||
from api.utils.configs import show_configs
|
from api.utils.configs import show_configs
|
||||||
from graphrag.general.index import run_graphrag_for_kb
|
from graphrag.general.index import run_graphrag_for_kb
|
||||||
from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache
|
from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache
|
||||||
|
|||||||
@ -17,7 +17,7 @@
|
|||||||
import os
|
import os
|
||||||
import tiktoken
|
import tiktoken
|
||||||
|
|
||||||
from api.utils.file_utils import get_project_base_directory
|
from common.file_utils import get_project_base_directory
|
||||||
|
|
||||||
tiktoken_cache_dir = get_project_base_directory()
|
tiktoken_cache_dir = get_project_base_directory()
|
||||||
os.environ["TIKTOKEN_CACHE_DIR"] = tiktoken_cache_dir
|
os.environ["TIKTOKEN_CACHE_DIR"] = tiktoken_cache_dir
|
||||||
|
|||||||
@ -27,7 +27,7 @@ from elastic_transport import ConnectionTimeout
|
|||||||
from rag import settings
|
from rag import settings
|
||||||
from rag.settings import TAG_FLD, PAGERANK_FLD
|
from rag.settings import TAG_FLD, PAGERANK_FLD
|
||||||
from common.decorator import singleton
|
from common.decorator import singleton
|
||||||
from api.utils.file_utils import get_project_base_directory
|
from common.file_utils import get_project_base_directory
|
||||||
from api.utils.common import convert_bytes
|
from api.utils.common import convert_bytes
|
||||||
from rag.utils.doc_store_conn import DocStoreConnection, MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, \
|
from rag.utils.doc_store_conn import DocStoreConnection, MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, \
|
||||||
FusionExpr
|
FusionExpr
|
||||||
|
|||||||
@ -29,7 +29,7 @@ from rag import settings
|
|||||||
from rag.settings import PAGERANK_FLD, TAG_FLD
|
from rag.settings import PAGERANK_FLD, TAG_FLD
|
||||||
from common.decorator import singleton
|
from common.decorator import singleton
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from api.utils.file_utils import get_project_base_directory
|
from common.file_utils import get_project_base_directory
|
||||||
from rag.nlp import is_english
|
from rag.nlp import is_english
|
||||||
|
|
||||||
from rag.utils.doc_store_conn import (
|
from rag.utils.doc_store_conn import (
|
||||||
|
|||||||
@ -27,7 +27,7 @@ from opensearchpy import ConnectionTimeout
|
|||||||
from rag import settings
|
from rag import settings
|
||||||
from rag.settings import TAG_FLD, PAGERANK_FLD
|
from rag.settings import TAG_FLD, PAGERANK_FLD
|
||||||
from common.decorator import singleton
|
from common.decorator import singleton
|
||||||
from api.utils.file_utils import get_project_base_directory
|
from common.file_utils import get_project_base_directory
|
||||||
from rag.utils.doc_store_conn import DocStoreConnection, MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, \
|
from rag.utils.doc_store_conn import DocStoreConnection, MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, \
|
||||||
FusionExpr
|
FusionExpr
|
||||||
from rag.nlp import is_english, rag_tokenizer
|
from rag.nlp import is_english, rag_tokenizer
|
||||||
|
|||||||
123
test/unit_test/common/test_file_utils.py
Normal file
123
test/unit_test/common/test_file_utils.py
Normal file
@ -0,0 +1,123 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
import os
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import patch
|
||||||
|
from common import file_utils
|
||||||
|
from common.file_utils import get_project_base_directory
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetProjectBaseDirectory:
|
||||||
|
"""Test cases for get_project_base_directory function"""
|
||||||
|
|
||||||
|
def test_returns_project_base_when_no_args(self):
|
||||||
|
"""Test that function returns project base directory when no arguments provided"""
|
||||||
|
result = get_project_base_directory()
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert isinstance(result, str)
|
||||||
|
assert os.path.isabs(result) # Should return absolute path
|
||||||
|
|
||||||
|
def test_returns_path_with_single_argument(self):
|
||||||
|
"""Test that function joins project base with single additional path component"""
|
||||||
|
result = get_project_base_directory("subfolder")
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert "subfolder" in result
|
||||||
|
assert result.endswith("subfolder")
|
||||||
|
|
||||||
|
def test_returns_path_with_multiple_arguments(self):
|
||||||
|
"""Test that function joins project base with multiple path components"""
|
||||||
|
result = get_project_base_directory("folder1", "folder2", "file.txt")
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert "folder1" in result
|
||||||
|
assert "folder2" in result
|
||||||
|
assert "file.txt" in result
|
||||||
|
assert os.path.basename(result) == "file.txt"
|
||||||
|
|
||||||
|
def test_uses_environment_variable_when_available(self):
|
||||||
|
"""Test that function uses RAG_PROJECT_BASE environment variable when set"""
|
||||||
|
test_path = "/custom/project/path"
|
||||||
|
|
||||||
|
file_utils.PROJECT_BASE = test_path
|
||||||
|
|
||||||
|
result = get_project_base_directory()
|
||||||
|
assert result == test_path
|
||||||
|
|
||||||
|
def test_calculates_default_path_when_no_env_vars(self):
|
||||||
|
"""Test that function calculates default path when no environment variables are set"""
|
||||||
|
with patch.dict(os.environ, {}, clear=True): # Clear all environment variables
|
||||||
|
# Reset the global variable to force re-initialization
|
||||||
|
|
||||||
|
result = get_project_base_directory()
|
||||||
|
|
||||||
|
# Should return a valid absolute path
|
||||||
|
assert result is not None
|
||||||
|
assert os.path.isabs(result)
|
||||||
|
assert os.path.basename(result) != "" # Should not be root directory
|
||||||
|
|
||||||
|
def test_caches_project_base_value(self):
|
||||||
|
"""Test that PROJECT_BASE is cached after first calculation"""
|
||||||
|
# Reset the global variable
|
||||||
|
|
||||||
|
# First call should calculate the value
|
||||||
|
first_result = get_project_base_directory()
|
||||||
|
|
||||||
|
# Store the current value
|
||||||
|
cached_value = file_utils.PROJECT_BASE
|
||||||
|
|
||||||
|
# Second call should use cached value
|
||||||
|
second_result = get_project_base_directory()
|
||||||
|
|
||||||
|
assert first_result == second_result
|
||||||
|
assert file_utils.PROJECT_BASE == cached_value
|
||||||
|
|
||||||
|
def test_path_components_joined_correctly(self):
|
||||||
|
"""Test that path components are properly joined with the base directory"""
|
||||||
|
base_path = get_project_base_directory()
|
||||||
|
expected_path = os.path.join(base_path, "data", "files", "document.txt")
|
||||||
|
|
||||||
|
result = get_project_base_directory("data", "files", "document.txt")
|
||||||
|
|
||||||
|
assert result == expected_path
|
||||||
|
|
||||||
|
def test_handles_empty_string_arguments(self):
|
||||||
|
"""Test that function handles empty string arguments correctly"""
|
||||||
|
result = get_project_base_directory("")
|
||||||
|
|
||||||
|
# Should still return a valid path (base directory)
|
||||||
|
assert result is not None
|
||||||
|
assert os.path.isabs(result)
|
||||||
|
|
||||||
|
|
||||||
|
# Parameterized tests for different path combinations
|
||||||
|
@pytest.mark.parametrize("path_args,expected_suffix", [
|
||||||
|
((), ""), # No additional arguments
|
||||||
|
(("src",), "src"),
|
||||||
|
(("data", "models"), os.path.join("data", "models")),
|
||||||
|
(("config", "app", "settings.json"), os.path.join("config", "app", "settings.json")),
|
||||||
|
])
|
||||||
|
def test_various_path_combinations(path_args, expected_suffix):
|
||||||
|
"""Test various combinations of path arguments"""
|
||||||
|
base_path = get_project_base_directory()
|
||||||
|
result = get_project_base_directory(*path_args)
|
||||||
|
|
||||||
|
if expected_suffix:
|
||||||
|
assert result.endswith(expected_suffix)
|
||||||
|
else:
|
||||||
|
assert result == base_path
|
||||||
Reference in New Issue
Block a user