diff --git a/admin/server/admin_server.py b/admin/server/admin_server.py index c14c7d4cc..3524e6faa 100644 --- a/admin/server/admin_server.py +++ b/admin/server/admin_server.py @@ -25,6 +25,7 @@ from flask import Flask from routes import admin_bp from api.utils.log_utils import init_root_logger from common.contants import SERVICE_CONF +from common.config_utils import show_configs from api import settings from config import load_configurations, SERVICE_CONFIGS from auth import init_default_admin, setup_auth @@ -51,6 +52,7 @@ if __name__ == '__main__': os.environ.get("MAX_CONTENT_LENGTH", 1024 * 1024 * 1024) ) Session(app) + show_configs() login_manager = LoginManager() login_manager.init_app(app) settings.init_settings() diff --git a/admin/server/config.py b/admin/server/config.py index a14d95b1a..5d47f0d66 100644 --- a/admin/server/config.py +++ b/admin/server/config.py @@ -21,7 +21,7 @@ from enum import Enum from pydantic import BaseModel from typing import Any -from api.utils.configs import read_config +from common.config_utils import read_config from urllib.parse import urlparse diff --git a/api/ragflow_server.py b/api/ragflow_server.py index 8ce391649..2cdba51e6 100644 --- a/api/ragflow_server.py +++ b/api/ragflow_server.py @@ -41,7 +41,7 @@ from common.file_utils import get_project_base_directory from api.db.db_models import init_database_tables as init_web_db from api.db.init_data import init_web_data from api.versions import get_ragflow_version -from api.utils.configs import show_configs +from common.config_utils import show_configs from rag.settings import print_rag_settings from rag.utils.mcp_tool_call_conn import shutdown_all_mcp_sessions from rag.utils.redis_conn import RedisDistributedLock diff --git a/api/settings.py b/api/settings.py index 592753eb2..10d4e3b15 100644 --- a/api/settings.py +++ b/api/settings.py @@ -24,7 +24,7 @@ import rag.utils.es_conn import rag.utils.infinity_conn import rag.utils.opensearch_conn from api.constants import RAG_FLOW_SERVICE_NAME -from api.utils.configs import decrypt_database_config, get_base_config +from common.config_utils import decrypt_database_config, get_base_config from common.file_utils import get_project_base_directory from rag.nlp import search diff --git a/api/utils/configs.py b/api/utils/configs.py index c8efebf11..91baa28e3 100644 --- a/api/utils/configs.py +++ b/api/utils/configs.py @@ -14,129 +14,11 @@ # limitations under the License. # -import os import io -import copy -import logging import base64 import pickle -import importlib - -from api.utils import file_utils -from common.file_utils import get_project_base_directory -from filelock import FileLock from api.utils.common import bytes_to_string, string_to_bytes -from common.contants import SERVICE_CONF - - -def conf_realpath(conf_name): - conf_path = f"conf/{conf_name}" - return os.path.join(get_project_base_directory(), conf_path) - - -def read_config(conf_name=SERVICE_CONF): - local_config = {} - local_path = conf_realpath(f'local.{conf_name}') - - # load local config file - if os.path.exists(local_path): - local_config = file_utils.load_yaml_conf(local_path) - if not isinstance(local_config, dict): - raise ValueError(f'Invalid config file: "{local_path}".') - - global_config_path = conf_realpath(conf_name) - global_config = file_utils.load_yaml_conf(global_config_path) - - if not isinstance(global_config, dict): - raise ValueError(f'Invalid config file: "{global_config_path}".') - - global_config.update(local_config) - return global_config - - -CONFIGS = read_config() - - -def show_configs(): - msg = f"Current configs, from {conf_realpath(SERVICE_CONF)}:" - for k, v in CONFIGS.items(): - if isinstance(v, dict): - if "password" in v: - v = copy.deepcopy(v) - v["password"] = "*" * 8 - if "access_key" in v: - v = copy.deepcopy(v) - v["access_key"] = "*" * 8 - if "secret_key" in v: - v = copy.deepcopy(v) - v["secret_key"] = "*" * 8 - if "secret" in v: - v = copy.deepcopy(v) - v["secret"] = "*" * 8 - if "sas_token" in v: - v = copy.deepcopy(v) - v["sas_token"] = "*" * 8 - if "oauth" in k: - v = copy.deepcopy(v) - for key, val in v.items(): - if "client_secret" in val: - val["client_secret"] = "*" * 8 - if "authentication" in k: - v = copy.deepcopy(v) - for key, val in v.items(): - if "http_secret_key" in val: - val["http_secret_key"] = "*" * 8 - msg += f"\n\t{k}: {v}" - logging.info(msg) - - -def get_base_config(key, default=None): - if key is None: - return None - if default is None: - default = os.environ.get(key.upper()) - return CONFIGS.get(key, default) - - -def decrypt_database_password(password): - encrypt_password = get_base_config("encrypt_password", False) - encrypt_module = get_base_config("encrypt_module", False) - private_key = get_base_config("private_key", None) - - if not password or not encrypt_password: - return password - - if not private_key: - raise ValueError("No private key") - - module_fun = encrypt_module.split("#") - pwdecrypt_fun = getattr( - importlib.import_module( - module_fun[0]), - module_fun[1]) - - return pwdecrypt_fun(private_key, password) - - -def decrypt_database_config( - database=None, passwd_key="password", name="database"): - if not database: - database = get_base_config(name, {}) - - database[passwd_key] = decrypt_database_password(database[passwd_key]) - return database - - -def update_config(key, value, conf_name=SERVICE_CONF): - conf_path = conf_realpath(conf_name=conf_name) - if not os.path.isabs(conf_path): - conf_path = os.path.join(get_project_base_directory(), conf_path) - - with FileLock(os.path.join(os.path.dirname(conf_path), ".lock")): - config = file_utils.load_yaml_conf(conf_path=conf_path) or {} - config[key] = value - file_utils.rewrite_yaml_conf(conf_path=conf_path, config=config) - +from common.config_utils import get_base_config safe_module = { 'numpy', diff --git a/common/config_utils.py b/common/config_utils.py new file mode 100644 index 000000000..f368b9372 --- /dev/null +++ b/common/config_utils.py @@ -0,0 +1,155 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import copy +import logging +import importlib +from filelock import FileLock + +from common.file_utils import get_project_base_directory +from common.contants import SERVICE_CONF +from ruamel.yaml import YAML + + +def load_yaml_conf(conf_path): + if not os.path.isabs(conf_path): + conf_path = os.path.join(get_project_base_directory(), conf_path) + try: + with open(conf_path) as f: + yaml = YAML(typ="safe", pure=True) + return yaml.load(f) + except Exception as e: + raise EnvironmentError("loading yaml file config from {} failed:".format(conf_path), e) + + +def rewrite_yaml_conf(conf_path, config): + if not os.path.isabs(conf_path): + conf_path = os.path.join(get_project_base_directory(), conf_path) + try: + with open(conf_path, "w") as f: + yaml = YAML(typ="safe") + yaml.dump(config, f) + except Exception as e: + raise EnvironmentError("rewrite yaml file config {} failed:".format(conf_path), e) + + +def conf_realpath(conf_name): + conf_path = f"conf/{conf_name}" + return os.path.join(get_project_base_directory(), conf_path) + + +def read_config(conf_name=SERVICE_CONF): + local_config = {} + local_path = conf_realpath(f'local.{conf_name}') + + # load local config file + if os.path.exists(local_path): + local_config = load_yaml_conf(local_path) + if not isinstance(local_config, dict): + raise ValueError(f'Invalid config file: "{local_path}".') + + global_config_path = conf_realpath(conf_name) + global_config = load_yaml_conf(global_config_path) + + if not isinstance(global_config, dict): + raise ValueError(f'Invalid config file: "{global_config_path}".') + + global_config.update(local_config) + return global_config + + +CONFIGS = read_config() + + +def show_configs(): + msg = f"Current configs, from {conf_realpath(SERVICE_CONF)}:" + for k, v in CONFIGS.items(): + if isinstance(v, dict): + if "password" in v: + v = copy.deepcopy(v) + v["password"] = "*" * 8 + if "access_key" in v: + v = copy.deepcopy(v) + v["access_key"] = "*" * 8 + if "secret_key" in v: + v = copy.deepcopy(v) + v["secret_key"] = "*" * 8 + if "secret" in v: + v = copy.deepcopy(v) + v["secret"] = "*" * 8 + if "sas_token" in v: + v = copy.deepcopy(v) + v["sas_token"] = "*" * 8 + if "oauth" in k: + v = copy.deepcopy(v) + for key, val in v.items(): + if "client_secret" in val: + val["client_secret"] = "*" * 8 + if "authentication" in k: + v = copy.deepcopy(v) + for key, val in v.items(): + if "http_secret_key" in val: + val["http_secret_key"] = "*" * 8 + msg += f"\n\t{k}: {v}" + logging.info(msg) + + +def get_base_config(key, default=None): + if key is None: + return None + if default is None: + default = os.environ.get(key.upper()) + return CONFIGS.get(key, default) + + +def decrypt_database_password(password): + encrypt_password = get_base_config("encrypt_password", False) + encrypt_module = get_base_config("encrypt_module", False) + private_key = get_base_config("private_key", None) + + if not password or not encrypt_password: + return password + + if not private_key: + raise ValueError("No private key") + + module_fun = encrypt_module.split("#") + pwdecrypt_fun = getattr( + importlib.import_module( + module_fun[0]), + module_fun[1]) + + return pwdecrypt_fun(private_key, password) + + +def decrypt_database_config(database=None, passwd_key="password", name="database"): + if not database: + database = get_base_config(name, {}) + + database[passwd_key] = decrypt_database_password(database[passwd_key]) + return database + + +def update_config(key, value, conf_name=SERVICE_CONF): + conf_path = conf_realpath(conf_name=conf_name) + if not os.path.isabs(conf_path): + conf_path = os.path.join(get_project_base_directory(), conf_path) + + with FileLock(os.path.join(os.path.dirname(conf_path), ".lock")): + config = load_yaml_conf(conf_path=conf_path) or {} + config[key] = value + rewrite_yaml_conf(conf_path=conf_path, config=config) diff --git a/deepdoc/parser/tcadp_parser.py b/deepdoc/parser/tcadp_parser.py index f84a0e6a4..1b7a3e362 100644 --- a/deepdoc/parser/tcadp_parser.py +++ b/deepdoc/parser/tcadp_parser.py @@ -36,7 +36,7 @@ from tencentcloud.common.profile.http_profile import HttpProfile from tencentcloud.common.exception.tencent_cloud_sdk_exception import TencentCloudSDKException from tencentcloud.lkeap.v20240522 import lkeap_client, models -from api.utils.configs import get_base_config +from common.config_utils import get_base_config from deepdoc.parser.pdf_parser import RAGFlowPdfParser diff --git a/rag/settings.py b/rag/settings.py index 1a8ee95b5..57df2ee14 100644 --- a/rag/settings.py +++ b/rag/settings.py @@ -15,7 +15,7 @@ # import os import logging -from api.utils.configs import get_base_config, decrypt_database_config +from common.config_utils import get_base_config, decrypt_database_config from common.file_utils import get_project_base_directory from common.misc_utils import pip_install_torch diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index d786b13c1..994999a99 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -30,7 +30,7 @@ from api.utils.api_utils import timeout from common.base64_image import image2id from api.utils.log_utils import init_root_logger from common.file_utils import get_project_base_directory -from api.utils.configs import show_configs +from common.config_utils import show_configs from graphrag.general.index import run_graphrag_for_kb from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache from rag.flow.pipeline import Pipeline diff --git a/rag/utils/opendal_conn.py b/rag/utils/opendal_conn.py index 41abdf343..54650b54b 100644 --- a/rag/utils/opendal_conn.py +++ b/rag/utils/opendal_conn.py @@ -3,7 +3,7 @@ import logging import pymysql from urllib.parse import quote_plus -from api.utils.configs import get_base_config +from common.config_utils import get_base_config from common.decorator import singleton