diff --git a/api/apps/document_app.py b/api/apps/document_app.py index fb6f9fdd4..934da374c 100644 --- a/api/apps/document_app.py +++ b/api/apps/document_app.py @@ -42,7 +42,8 @@ from api.utils.api_utils import ( server_error_response, validate_request, ) -from api.utils.file_utils import filename_type, get_project_base_directory, thumbnail +from api.utils.file_utils import filename_type, thumbnail +from common.file_utils import get_project_base_directory from api.utils.web_utils import CONTENT_TYPE_MAP, html2pdf, is_valid_url from deepdoc.parser.html_parser import RAGFlowHtmlParser from rag.nlp import search, rag_tokenizer diff --git a/api/db/init_data.py b/api/db/init_data.py index 39b87d06f..78ef6cfc9 100644 --- a/api/db/init_data.py +++ b/api/db/init_data.py @@ -30,7 +30,7 @@ from api.db.services.tenant_llm_service import LLMFactoriesService, TenantLLMSer from api.db.services.llm_service import LLMService, LLMBundle, get_init_tenant_llm from api.db.services.user_service import TenantService, UserTenantService from api import settings -from api.utils.file_utils import get_project_base_directory +from common.file_utils import get_project_base_directory from api.common.base64 import encode_to_base64 diff --git a/api/ragflow_server.py b/api/ragflow_server.py index fb49f3d8b..8ce391649 100644 --- a/api/ragflow_server.py +++ b/api/ragflow_server.py @@ -36,7 +36,7 @@ from api import settings from api.apps import app, smtp_mail_server from api.db.runtime_config import RuntimeConfig from api.db.services.document_service import DocumentService -from api import utils +from common.file_utils import get_project_base_directory from api.db.db_models import init_database_tables as init_web_db from api.db.init_data import init_web_data @@ -88,7 +88,7 @@ if __name__ == '__main__': f'RAGFlow version: {get_ragflow_version()}' ) logging.info( - f'project base: {utils.file_utils.get_project_base_directory()}' + f'project base: {get_project_base_directory()}' ) show_configs() settings.init_settings() diff --git a/api/settings.py b/api/settings.py index c8573bf29..592753eb2 100644 --- a/api/settings.py +++ b/api/settings.py @@ -25,7 +25,7 @@ import rag.utils.infinity_conn import rag.utils.opensearch_conn from api.constants import RAG_FLOW_SERVICE_NAME from api.utils.configs import decrypt_database_config, get_base_config -from api.utils.file_utils import get_project_base_directory +from common.file_utils import get_project_base_directory from rag.nlp import search LLM = None diff --git a/api/utils/configs.py b/api/utils/configs.py index 48e492246..c586d9091 100644 --- a/api/utils/configs.py +++ b/api/utils/configs.py @@ -23,6 +23,7 @@ import pickle import importlib from api.utils import file_utils +from common.file_utils import get_project_base_directory from filelock import FileLock from api.utils.common import bytes_to_string, string_to_bytes from api.constants import SERVICE_CONF @@ -30,7 +31,7 @@ from api.constants import SERVICE_CONF def conf_realpath(conf_name): conf_path = f"conf/{conf_name}" - return os.path.join(file_utils.get_project_base_directory(), conf_path) + return os.path.join(get_project_base_directory(), conf_path) def read_config(conf_name=SERVICE_CONF): @@ -129,8 +130,7 @@ def decrypt_database_config( def update_config(key, value, conf_name=SERVICE_CONF): conf_path = conf_realpath(conf_name=conf_name) if not os.path.isabs(conf_path): - conf_path = os.path.join( - file_utils.get_project_base_directory(), conf_path) + conf_path = os.path.join(get_project_base_directory(), conf_path) with FileLock(os.path.join(os.path.dirname(conf_path), ".lock")): config = file_utils.load_yaml_conf(conf_path=conf_path) or {} diff --git a/api/utils/crypt.py b/api/utils/crypt.py index eb922a886..174ca3568 100644 --- a/api/utils/crypt.py +++ b/api/utils/crypt.py @@ -19,14 +19,14 @@ import os import sys from Cryptodome.PublicKey import RSA from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5 -from api.utils import file_utils +from common.file_utils import get_project_base_directory def crypt(line): """ decrypt(crypt(input_string)) == base64(input_string), which frontend and admin_client use. """ - file_path = os.path.join(file_utils.get_project_base_directory(), "conf", "public.pem") + file_path = os.path.join(get_project_base_directory(), "conf", "public.pem") rsa_key = RSA.importKey(open(file_path).read(), "Welcome") cipher = Cipher_pkcs1_v1_5.new(rsa_key) password_base64 = base64.b64encode(line.encode('utf-8')).decode("utf-8") @@ -35,7 +35,7 @@ def crypt(line): def decrypt(line): - file_path = os.path.join(file_utils.get_project_base_directory(), "conf", "private.pem") + file_path = os.path.join(get_project_base_directory(), "conf", "private.pem") rsa_key = RSA.importKey(open(file_path).read(), "Welcome") cipher = Cipher_pkcs1_v1_5.new(rsa_key) return cipher.decrypt(base64.b64decode(line), "Fail to decrypt password!").decode('utf-8') @@ -50,7 +50,7 @@ def decrypt2(crypt_text): hex_fixed = '00' + decode_data.hex() decode_data = b16decode(hex_fixed.upper()) - file_path = os.path.join(file_utils.get_project_base_directory(), "conf", "private.pem") + file_path = os.path.join(get_project_base_directory(), "conf", "private.pem") pem = open(file_path).read() rsa_key = RSA.importKey(pem, "Welcome") cipher = Cipher_PKCS1_v1_5.new(rsa_key) diff --git a/api/utils/file_utils.py b/api/utils/file_utils.py index 999a1147d..8eb84dcd2 100644 --- a/api/utils/file_utils.py +++ b/api/utils/file_utils.py @@ -43,6 +43,7 @@ from ruamel.yaml import YAML # Local imports from api.constants import IMG_BASE64_PREFIX from api.db import FileType +from common.file_utils import get_project_base_directory PROJECT_BASE = os.getenv("RAG_PROJECT_BASE") or os.getenv("RAG_DEPLOY_BASE") @@ -51,21 +52,6 @@ if LOCK_KEY_pdfplumber not in sys.modules: sys.modules[LOCK_KEY_pdfplumber] = threading.Lock() -def get_project_base_directory(*args): - global PROJECT_BASE - if PROJECT_BASE is None: - PROJECT_BASE = os.path.abspath( - os.path.join( - os.path.dirname(os.path.realpath(__file__)), - os.pardir, - os.pardir, - ) - ) - - if args: - return os.path.join(PROJECT_BASE, *args) - return PROJECT_BASE - @cached(cache=LRUCache(maxsize=10)) def load_json_conf(conf_path): if os.path.isabs(conf_path): diff --git a/api/utils/log_utils.py b/api/utils/log_utils.py index 0a4840e79..0348f9e09 100644 --- a/api/utils/log_utils.py +++ b/api/utils/log_utils.py @@ -17,19 +17,10 @@ import os import os.path import logging from logging.handlers import RotatingFileHandler +from common.file_utils import get_project_base_directory initialized_root_logger = False -def get_project_base_directory(): - PROJECT_BASE = os.path.abspath( - os.path.join( - os.path.dirname(os.path.realpath(__file__)), - os.pardir, - os.pardir, - ) - ) - return PROJECT_BASE - def init_root_logger(logfile_basename: str, log_format: str = "%(asctime)-15s %(levelname)-8s %(process)d %(message)s"): global initialized_root_logger if initialized_root_logger: diff --git a/common/file_utils.py b/common/file_utils.py new file mode 100644 index 000000000..6b07bf646 --- /dev/null +++ b/common/file_utils.py @@ -0,0 +1,33 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os + +PROJECT_BASE = os.getenv("RAG_PROJECT_BASE") or os.getenv("RAG_DEPLOY_BASE") + +def get_project_base_directory(*args): + global PROJECT_BASE + if PROJECT_BASE is None: + PROJECT_BASE = os.path.abspath( + os.path.join( + os.path.dirname(os.path.realpath(__file__)), + os.pardir, + ) + ) + + if args: + return os.path.join(PROJECT_BASE, *args) + return PROJECT_BASE diff --git a/deepdoc/parser/pdf_parser.py b/deepdoc/parser/pdf_parser.py index c9f2f34ca..dba55bb13 100644 --- a/deepdoc/parser/pdf_parser.py +++ b/deepdoc/parser/pdf_parser.py @@ -34,7 +34,7 @@ from huggingface_hub import snapshot_download from PIL import Image from pypdf import PdfReader as pdf2_read -from api.utils.file_utils import get_project_base_directory +from common.file_utils import get_project_base_directory from api.utils.common import pip_install_torch from deepdoc.vision import OCR, AscendLayoutRecognizer, LayoutRecognizer, Recognizer, TableStructureRecognizer from rag.app.picture import vision_llm_chunk as picture_vision_llm_chunk diff --git a/deepdoc/vision/layout_recognizer.py b/deepdoc/vision/layout_recognizer.py index 9cd0b5a5d..8068cbc07 100644 --- a/deepdoc/vision/layout_recognizer.py +++ b/deepdoc/vision/layout_recognizer.py @@ -25,7 +25,7 @@ import cv2 import numpy as np from huggingface_hub import snapshot_download -from api.utils.file_utils import get_project_base_directory +from common.file_utils import get_project_base_directory from deepdoc.vision import Recognizer from deepdoc.vision.operators import nms diff --git a/deepdoc/vision/ocr.py b/deepdoc/vision/ocr.py index 1a99e6ce0..4c4831d7c 100644 --- a/deepdoc/vision/ocr.py +++ b/deepdoc/vision/ocr.py @@ -21,7 +21,7 @@ import os from huggingface_hub import snapshot_download -from api.utils.file_utils import get_project_base_directory +from common.file_utils import get_project_base_directory from api.utils.common import pip_install_torch from rag.settings import PARALLEL_DEVICES from .operators import * # noqa: F403 diff --git a/deepdoc/vision/recognizer.py b/deepdoc/vision/recognizer.py index 65995a579..af259a384 100644 --- a/deepdoc/vision/recognizer.py +++ b/deepdoc/vision/recognizer.py @@ -22,7 +22,7 @@ import cv2 from functools import cmp_to_key -from api.utils.file_utils import get_project_base_directory +from common.file_utils import get_project_base_directory from .operators import * # noqa: F403 from .operators import preprocess from . import operators diff --git a/deepdoc/vision/table_structure_recognizer.py b/deepdoc/vision/table_structure_recognizer.py index 34b988e9a..cf1c79db1 100644 --- a/deepdoc/vision/table_structure_recognizer.py +++ b/deepdoc/vision/table_structure_recognizer.py @@ -21,7 +21,7 @@ from collections import Counter import numpy as np from huggingface_hub import snapshot_download -from api.utils.file_utils import get_project_base_directory +from common.file_utils import get_project_base_directory from rag.nlp import rag_tokenizer from .recognizer import Recognizer diff --git a/rag/nlp/rag_tokenizer.py b/rag/nlp/rag_tokenizer.py index c3394971e..3c4b97833 100644 --- a/rag/nlp/rag_tokenizer.py +++ b/rag/nlp/rag_tokenizer.py @@ -25,7 +25,7 @@ import sys from hanziconv import HanziConv from nltk import word_tokenize from nltk.stem import PorterStemmer, WordNetLemmatizer -from api.utils.file_utils import get_project_base_directory +from common.file_utils import get_project_base_directory class RagTokenizer: diff --git a/rag/nlp/synonym.py b/rag/nlp/synonym.py index b28560ce1..0956ee8e8 100644 --- a/rag/nlp/synonym.py +++ b/rag/nlp/synonym.py @@ -20,7 +20,7 @@ import os import time import re from nltk.corpus import wordnet -from api.utils.file_utils import get_project_base_directory +from common.file_utils import get_project_base_directory class Dealer: diff --git a/rag/nlp/term_weight.py b/rag/nlp/term_weight.py index 33ee62660..392117c18 100644 --- a/rag/nlp/term_weight.py +++ b/rag/nlp/term_weight.py @@ -21,7 +21,7 @@ import re import os import numpy as np from rag.nlp import rag_tokenizer -from api.utils.file_utils import get_project_base_directory +from common.file_utils import get_project_base_directory class Dealer: diff --git a/rag/settings.py b/rag/settings.py index 6c2017dc1..6167c08ae 100644 --- a/rag/settings.py +++ b/rag/settings.py @@ -16,7 +16,7 @@ import os import logging from api.utils.configs import get_base_config, decrypt_database_config -from api.utils.file_utils import get_project_base_directory +from common.file_utils import get_project_base_directory from api.utils.common import pip_install_torch # Server diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index 7bc577c63..8a0d464ba 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -28,7 +28,8 @@ from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.pipeline_operation_log_service import PipelineOperationLogService from api.utils.api_utils import timeout from api.utils.base64_image import image2id -from api.utils.log_utils import init_root_logger, get_project_base_directory +from api.utils.log_utils import init_root_logger +from common.file_utils import get_project_base_directory from api.utils.configs import show_configs from graphrag.general.index import run_graphrag_for_kb from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache diff --git a/rag/utils/__init__.py b/rag/utils/__init__.py index f3a696318..64db543b1 100644 --- a/rag/utils/__init__.py +++ b/rag/utils/__init__.py @@ -17,7 +17,7 @@ import os import tiktoken -from api.utils.file_utils import get_project_base_directory +from common.file_utils import get_project_base_directory tiktoken_cache_dir = get_project_base_directory() os.environ["TIKTOKEN_CACHE_DIR"] = tiktoken_cache_dir diff --git a/rag/utils/es_conn.py b/rag/utils/es_conn.py index e662428ba..e2ce8c9ed 100644 --- a/rag/utils/es_conn.py +++ b/rag/utils/es_conn.py @@ -27,7 +27,7 @@ from elastic_transport import ConnectionTimeout from rag import settings from rag.settings import TAG_FLD, PAGERANK_FLD from common.decorator import singleton -from api.utils.file_utils import get_project_base_directory +from common.file_utils import get_project_base_directory from api.utils.common import convert_bytes from rag.utils.doc_store_conn import DocStoreConnection, MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, \ FusionExpr diff --git a/rag/utils/infinity_conn.py b/rag/utils/infinity_conn.py index 1d6dbd091..9a92c8e86 100644 --- a/rag/utils/infinity_conn.py +++ b/rag/utils/infinity_conn.py @@ -29,7 +29,7 @@ from rag import settings from rag.settings import PAGERANK_FLD, TAG_FLD from common.decorator import singleton import pandas as pd -from api.utils.file_utils import get_project_base_directory +from common.file_utils import get_project_base_directory from rag.nlp import is_english from rag.utils.doc_store_conn import ( diff --git a/rag/utils/opensearch_conn.py b/rag/utils/opensearch_conn.py index 3c2cf376b..5c51be52f 100644 --- a/rag/utils/opensearch_conn.py +++ b/rag/utils/opensearch_conn.py @@ -27,7 +27,7 @@ from opensearchpy import ConnectionTimeout from rag import settings from rag.settings import TAG_FLD, PAGERANK_FLD from common.decorator import singleton -from api.utils.file_utils import get_project_base_directory +from common.file_utils import get_project_base_directory from rag.utils.doc_store_conn import DocStoreConnection, MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, \ FusionExpr from rag.nlp import is_english, rag_tokenizer diff --git a/test/unit_test/common/test_file_utils.py b/test/unit_test/common/test_file_utils.py new file mode 100644 index 000000000..616312cde --- /dev/null +++ b/test/unit_test/common/test_file_utils.py @@ -0,0 +1,123 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import pytest +from unittest.mock import patch +from common import file_utils +from common.file_utils import get_project_base_directory + + +class TestGetProjectBaseDirectory: + """Test cases for get_project_base_directory function""" + + def test_returns_project_base_when_no_args(self): + """Test that function returns project base directory when no arguments provided""" + result = get_project_base_directory() + + assert result is not None + assert isinstance(result, str) + assert os.path.isabs(result) # Should return absolute path + + def test_returns_path_with_single_argument(self): + """Test that function joins project base with single additional path component""" + result = get_project_base_directory("subfolder") + + assert result is not None + assert "subfolder" in result + assert result.endswith("subfolder") + + def test_returns_path_with_multiple_arguments(self): + """Test that function joins project base with multiple path components""" + result = get_project_base_directory("folder1", "folder2", "file.txt") + + assert result is not None + assert "folder1" in result + assert "folder2" in result + assert "file.txt" in result + assert os.path.basename(result) == "file.txt" + + def test_uses_environment_variable_when_available(self): + """Test that function uses RAG_PROJECT_BASE environment variable when set""" + test_path = "/custom/project/path" + + file_utils.PROJECT_BASE = test_path + + result = get_project_base_directory() + assert result == test_path + + def test_calculates_default_path_when_no_env_vars(self): + """Test that function calculates default path when no environment variables are set""" + with patch.dict(os.environ, {}, clear=True): # Clear all environment variables + # Reset the global variable to force re-initialization + + result = get_project_base_directory() + + # Should return a valid absolute path + assert result is not None + assert os.path.isabs(result) + assert os.path.basename(result) != "" # Should not be root directory + + def test_caches_project_base_value(self): + """Test that PROJECT_BASE is cached after first calculation""" + # Reset the global variable + + # First call should calculate the value + first_result = get_project_base_directory() + + # Store the current value + cached_value = file_utils.PROJECT_BASE + + # Second call should use cached value + second_result = get_project_base_directory() + + assert first_result == second_result + assert file_utils.PROJECT_BASE == cached_value + + def test_path_components_joined_correctly(self): + """Test that path components are properly joined with the base directory""" + base_path = get_project_base_directory() + expected_path = os.path.join(base_path, "data", "files", "document.txt") + + result = get_project_base_directory("data", "files", "document.txt") + + assert result == expected_path + + def test_handles_empty_string_arguments(self): + """Test that function handles empty string arguments correctly""" + result = get_project_base_directory("") + + # Should still return a valid path (base directory) + assert result is not None + assert os.path.isabs(result) + + +# Parameterized tests for different path combinations +@pytest.mark.parametrize("path_args,expected_suffix", [ + ((), ""), # No additional arguments + (("src",), "src"), + (("data", "models"), os.path.join("data", "models")), + (("config", "app", "settings.json"), os.path.join("config", "app", "settings.json")), +]) +def test_various_path_combinations(path_args, expected_suffix): + """Test various combinations of path arguments""" + base_path = get_project_base_directory() + result = get_project_base_directory(*path_args) + + if expected_suffix: + assert result.endswith(expected_suffix) + else: + assert result == base_path