From b0b866c8fdfffe84aea65cc0fdd7cd9c29920ef3 Mon Sep 17 00:00:00 2001 From: Jin Hai Date: Thu, 25 Sep 2025 18:04:49 +0800 Subject: [PATCH] Refactor: move some functions out of api/utils/__init__.py (#10216) ### What problem does this PR solve? Refactor import modules. ### Type of change - [x] Refactoring --------- Signed-off-by: jinhai Signed-off-by: Jin Hai --- api/apps/__init__.py | 3 +- api/db/db_models.py | 14 ++- api/ragflow_server.py | 2 +- api/settings.py | 2 +- api/utils/__init__.py | 255 +------------------------------------- api/utils/api_utils.py | 3 +- api/utils/common.py | 23 ++++ api/utils/configs.py | 179 ++++++++++++++++++++++++++ api/utils/json.py | 78 ++++++++++++ deepdoc/vision/ocr.py | 4 +- rag/settings.py | 2 +- rag/utils/opendal_conn.py | 2 +- 12 files changed, 302 insertions(+), 265 deletions(-) create mode 100644 api/utils/common.py create mode 100644 api/utils/configs.py create mode 100644 api/utils/json.py diff --git a/api/apps/__init__.py b/api/apps/__init__.py index fba5d20b2..db27dd509 100644 --- a/api/apps/__init__.py +++ b/api/apps/__init__.py @@ -27,7 +27,8 @@ from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer from api.db import StatusEnum from api.db.db_models import close_connection from api.db.services import UserService -from api.utils import CustomJSONEncoder, commands +from api.utils.json import CustomJSONEncoder +from api.utils import commands from flask_mail import Mail from flask_session import Session diff --git a/api/db/db_models.py b/api/db/db_models.py index c541587c3..1ff6dfc96 100644 --- a/api/db/db_models.py +++ b/api/db/db_models.py @@ -32,6 +32,8 @@ from playhouse.pool import PooledMySQLDatabase, PooledPostgresqlDatabase from api import settings, utils from api.db import ParserType, SerializedType +from api.utils.json import json_dumps, json_loads +from api.utils.configs import deserialize_b64, serialize_b64 def singleton(cls, *args, **kw): @@ -70,12 +72,12 @@ class JSONField(LongTextField): def db_value(self, value): if value is None: value = self.default_value - return utils.json_dumps(value) + return json_dumps(value) def python_value(self, value): if not value: return self.default_value - return utils.json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook) + return json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook) class ListField(JSONField): @@ -91,21 +93,21 @@ class SerializedField(LongTextField): def db_value(self, value): if self._serialized_type == SerializedType.PICKLE: - return utils.serialize_b64(value, to_str=True) + return serialize_b64(value, to_str=True) elif self._serialized_type == SerializedType.JSON: if value is None: return None - return utils.json_dumps(value, with_type=True) + return json_dumps(value, with_type=True) else: raise ValueError(f"the serialized type {self._serialized_type} is not supported") def python_value(self, value): if self._serialized_type == SerializedType.PICKLE: - return utils.deserialize_b64(value) + return deserialize_b64(value) elif self._serialized_type == SerializedType.JSON: if value is None: return {} - return utils.json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook) + return json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook) else: raise ValueError(f"the serialized type {self._serialized_type} is not supported") diff --git a/api/ragflow_server.py b/api/ragflow_server.py index 0dbeb771b..fb49f3d8b 100644 --- a/api/ragflow_server.py +++ b/api/ragflow_server.py @@ -41,7 +41,7 @@ from api import utils from api.db.db_models import init_database_tables as init_web_db from api.db.init_data import init_web_data from api.versions import get_ragflow_version -from api.utils import show_configs +from api.utils.configs import show_configs from rag.settings import print_rag_settings from rag.utils.mcp_tool_call_conn import shutdown_all_mcp_sessions from rag.utils.redis_conn import RedisDistributedLock diff --git a/api/settings.py b/api/settings.py index 3148633e6..e6763d8a2 100644 --- a/api/settings.py +++ b/api/settings.py @@ -24,7 +24,7 @@ import rag.utils.es_conn import rag.utils.infinity_conn import rag.utils.opensearch_conn from api.constants import RAG_FLOW_SERVICE_NAME -from api.utils import decrypt_database_config, get_base_config +from api.utils.configs import decrypt_database_config, get_base_config from api.utils.file_utils import get_project_base_directory from rag.nlp import search diff --git a/api/utils/__init__.py b/api/utils/__init__.py index 22161b52f..e0f8a5655 100644 --- a/api/utils/__init__.py +++ b/api/utils/__init__.py @@ -16,182 +16,15 @@ import base64 import datetime import hashlib -import io -import json import os -import pickle import socket import time import uuid import requests -import logging -import copy -from enum import Enum, IntEnum + import importlib -from filelock import FileLock -from api.constants import SERVICE_CONF -from . import file_utils - - -def conf_realpath(conf_name): - conf_path = f"conf/{conf_name}" - return os.path.join(file_utils.get_project_base_directory(), conf_path) - - -def read_config(conf_name=SERVICE_CONF): - local_config = {} - local_path = conf_realpath(f'local.{conf_name}') - - # load local config file - if os.path.exists(local_path): - local_config = file_utils.load_yaml_conf(local_path) - if not isinstance(local_config, dict): - raise ValueError(f'Invalid config file: "{local_path}".') - - global_config_path = conf_realpath(conf_name) - global_config = file_utils.load_yaml_conf(global_config_path) - - if not isinstance(global_config, dict): - raise ValueError(f'Invalid config file: "{global_config_path}".') - - global_config.update(local_config) - return global_config - - -CONFIGS = read_config() - - -def show_configs(): - msg = f"Current configs, from {conf_realpath(SERVICE_CONF)}:" - for k, v in CONFIGS.items(): - if isinstance(v, dict): - if "password" in v: - v = copy.deepcopy(v) - v["password"] = "*" * 8 - if "access_key" in v: - v = copy.deepcopy(v) - v["access_key"] = "*" * 8 - if "secret_key" in v: - v = copy.deepcopy(v) - v["secret_key"] = "*" * 8 - if "secret" in v: - v = copy.deepcopy(v) - v["secret"] = "*" * 8 - if "sas_token" in v: - v = copy.deepcopy(v) - v["sas_token"] = "*" * 8 - if "oauth" in k: - v = copy.deepcopy(v) - for key, val in v.items(): - if "client_secret" in val: - val["client_secret"] = "*" * 8 - if "authentication" in k: - v = copy.deepcopy(v) - for key, val in v.items(): - if "http_secret_key" in val: - val["http_secret_key"] = "*" * 8 - msg += f"\n\t{k}: {v}" - logging.info(msg) - - -def get_base_config(key, default=None): - if key is None: - return None - if default is None: - default = os.environ.get(key.upper()) - return CONFIGS.get(key, default) - - -use_deserialize_safe_module = get_base_config( - 'use_deserialize_safe_module', False) - - -class BaseType: - def to_dict(self): - return dict([(k.lstrip("_"), v) for k, v in self.__dict__.items()]) - - def to_dict_with_type(self): - def _dict(obj): - module = None - if issubclass(obj.__class__, BaseType): - data = {} - for attr, v in obj.__dict__.items(): - k = attr.lstrip("_") - data[k] = _dict(v) - module = obj.__module__ - elif isinstance(obj, (list, tuple)): - data = [] - for i, vv in enumerate(obj): - data.append(_dict(vv)) - elif isinstance(obj, dict): - data = {} - for _k, vv in obj.items(): - data[_k] = _dict(vv) - else: - data = obj - return {"type": obj.__class__.__name__, - "data": data, "module": module} - - return _dict(self) - - -class CustomJSONEncoder(json.JSONEncoder): - def __init__(self, **kwargs): - self._with_type = kwargs.pop("with_type", False) - super().__init__(**kwargs) - - def default(self, obj): - if isinstance(obj, datetime.datetime): - return obj.strftime('%Y-%m-%d %H:%M:%S') - elif isinstance(obj, datetime.date): - return obj.strftime('%Y-%m-%d') - elif isinstance(obj, datetime.timedelta): - return str(obj) - elif issubclass(type(obj), Enum) or issubclass(type(obj), IntEnum): - return obj.value - elif isinstance(obj, set): - return list(obj) - elif issubclass(type(obj), BaseType): - if not self._with_type: - return obj.to_dict() - else: - return obj.to_dict_with_type() - elif isinstance(obj, type): - return obj.__name__ - else: - return json.JSONEncoder.default(self, obj) - - -def rag_uuid(): - return uuid.uuid1().hex - - -def string_to_bytes(string): - return string if isinstance( - string, bytes) else string.encode(encoding="utf-8") - - -def bytes_to_string(byte): - return byte.decode(encoding="utf-8") - - -def json_dumps(src, byte=False, indent=None, with_type=False): - dest = json.dumps( - src, - indent=indent, - cls=CustomJSONEncoder, - with_type=with_type) - if byte: - dest = string_to_bytes(dest) - return dest - - -def json_loads(src, object_hook=None, object_pairs_hook=None): - if isinstance(src, bytes): - src = bytes_to_string(src) - return json.loads(src, object_hook=object_hook, - object_pairs_hook=object_pairs_hook) +from .common import string_to_bytes def current_timestamp(): @@ -213,45 +46,6 @@ def date_string_to_timestamp(time_str, format_string="%Y-%m-%d %H:%M:%S"): return time_stamp -def serialize_b64(src, to_str=False): - dest = base64.b64encode(pickle.dumps(src)) - if not to_str: - return dest - else: - return bytes_to_string(dest) - - -def deserialize_b64(src): - src = base64.b64decode( - string_to_bytes(src) if isinstance( - src, str) else src) - if use_deserialize_safe_module: - return restricted_loads(src) - return pickle.loads(src) - - -safe_module = { - 'numpy', - 'rag_flow' -} - - -class RestrictedUnpickler(pickle.Unpickler): - def find_class(self, module, name): - import importlib - if module.split('.')[0] in safe_module: - _module = importlib.import_module(module) - return getattr(_module, name) - # Forbid everything else. - raise pickle.UnpicklingError("global '%s.%s' is forbidden" % - (module, name)) - - -def restricted_loads(src): - """Helper function analogous to pickle.loads().""" - return RestrictedUnpickler(io.BytesIO(src)).load() - - def get_lan_ip(): if os.name != "nt": import fcntl @@ -296,47 +90,6 @@ def from_dict_hook(in_dict: dict): return in_dict -def decrypt_database_password(password): - encrypt_password = get_base_config("encrypt_password", False) - encrypt_module = get_base_config("encrypt_module", False) - private_key = get_base_config("private_key", None) - - if not password or not encrypt_password: - return password - - if not private_key: - raise ValueError("No private key") - - module_fun = encrypt_module.split("#") - pwdecrypt_fun = getattr( - importlib.import_module( - module_fun[0]), - module_fun[1]) - - return pwdecrypt_fun(private_key, password) - - -def decrypt_database_config( - database=None, passwd_key="password", name="database"): - if not database: - database = get_base_config(name, {}) - - database[passwd_key] = decrypt_database_password(database[passwd_key]) - return database - - -def update_config(key, value, conf_name=SERVICE_CONF): - conf_path = conf_realpath(conf_name=conf_name) - if not os.path.isabs(conf_path): - conf_path = os.path.join( - file_utils.get_project_base_directory(), conf_path) - - with FileLock(os.path.join(os.path.dirname(conf_path), ".lock")): - config = file_utils.load_yaml_conf(conf_path=conf_path) or {} - config[key] = value - file_utils.rewrite_yaml_conf(conf_path=conf_path, config=config) - - def get_uuid(): return uuid.uuid1().hex @@ -375,5 +128,5 @@ def delta_seconds(date_string: str): return (datetime.datetime.now() - dt).total_seconds() -def hash_str2int(line:str, mod: int=10 ** 8) -> int: - return int(hashlib.sha1(line.encode("utf-8")).hexdigest(), 16) % mod \ No newline at end of file +def hash_str2int(line: str, mod: int = 10 ** 8) -> int: + return int(hashlib.sha1(line.encode("utf-8")).hexdigest(), 16) % mod diff --git a/api/utils/api_utils.py b/api/utils/api_utils.py index f8f396767..1aaaee997 100644 --- a/api/utils/api_utils.py +++ b/api/utils/api_utils.py @@ -54,7 +54,8 @@ from api.db.db_models import APIToken from api.db.services import UserService from api.db.services.llm_service import LLMService from api.db.services.tenant_llm_service import TenantLLMService -from api.utils import CustomJSONEncoder, get_uuid, json_dumps +from api.utils.json import CustomJSONEncoder, json_dumps +from api.utils import get_uuid from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSONEncoder) diff --git a/api/utils/common.py b/api/utils/common.py new file mode 100644 index 000000000..ce7428507 --- /dev/null +++ b/api/utils/common.py @@ -0,0 +1,23 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +def string_to_bytes(string): + return string if isinstance( + string, bytes) else string.encode(encoding="utf-8") + + +def bytes_to_string(byte): + return byte.decode(encoding="utf-8") diff --git a/api/utils/configs.py b/api/utils/configs.py new file mode 100644 index 000000000..48e492246 --- /dev/null +++ b/api/utils/configs.py @@ -0,0 +1,179 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import io +import copy +import logging +import base64 +import pickle +import importlib + +from api.utils import file_utils +from filelock import FileLock +from api.utils.common import bytes_to_string, string_to_bytes +from api.constants import SERVICE_CONF + + +def conf_realpath(conf_name): + conf_path = f"conf/{conf_name}" + return os.path.join(file_utils.get_project_base_directory(), conf_path) + + +def read_config(conf_name=SERVICE_CONF): + local_config = {} + local_path = conf_realpath(f'local.{conf_name}') + + # load local config file + if os.path.exists(local_path): + local_config = file_utils.load_yaml_conf(local_path) + if not isinstance(local_config, dict): + raise ValueError(f'Invalid config file: "{local_path}".') + + global_config_path = conf_realpath(conf_name) + global_config = file_utils.load_yaml_conf(global_config_path) + + if not isinstance(global_config, dict): + raise ValueError(f'Invalid config file: "{global_config_path}".') + + global_config.update(local_config) + return global_config + + +CONFIGS = read_config() + + +def show_configs(): + msg = f"Current configs, from {conf_realpath(SERVICE_CONF)}:" + for k, v in CONFIGS.items(): + if isinstance(v, dict): + if "password" in v: + v = copy.deepcopy(v) + v["password"] = "*" * 8 + if "access_key" in v: + v = copy.deepcopy(v) + v["access_key"] = "*" * 8 + if "secret_key" in v: + v = copy.deepcopy(v) + v["secret_key"] = "*" * 8 + if "secret" in v: + v = copy.deepcopy(v) + v["secret"] = "*" * 8 + if "sas_token" in v: + v = copy.deepcopy(v) + v["sas_token"] = "*" * 8 + if "oauth" in k: + v = copy.deepcopy(v) + for key, val in v.items(): + if "client_secret" in val: + val["client_secret"] = "*" * 8 + if "authentication" in k: + v = copy.deepcopy(v) + for key, val in v.items(): + if "http_secret_key" in val: + val["http_secret_key"] = "*" * 8 + msg += f"\n\t{k}: {v}" + logging.info(msg) + + +def get_base_config(key, default=None): + if key is None: + return None + if default is None: + default = os.environ.get(key.upper()) + return CONFIGS.get(key, default) + + +def decrypt_database_password(password): + encrypt_password = get_base_config("encrypt_password", False) + encrypt_module = get_base_config("encrypt_module", False) + private_key = get_base_config("private_key", None) + + if not password or not encrypt_password: + return password + + if not private_key: + raise ValueError("No private key") + + module_fun = encrypt_module.split("#") + pwdecrypt_fun = getattr( + importlib.import_module( + module_fun[0]), + module_fun[1]) + + return pwdecrypt_fun(private_key, password) + + +def decrypt_database_config( + database=None, passwd_key="password", name="database"): + if not database: + database = get_base_config(name, {}) + + database[passwd_key] = decrypt_database_password(database[passwd_key]) + return database + + +def update_config(key, value, conf_name=SERVICE_CONF): + conf_path = conf_realpath(conf_name=conf_name) + if not os.path.isabs(conf_path): + conf_path = os.path.join( + file_utils.get_project_base_directory(), conf_path) + + with FileLock(os.path.join(os.path.dirname(conf_path), ".lock")): + config = file_utils.load_yaml_conf(conf_path=conf_path) or {} + config[key] = value + file_utils.rewrite_yaml_conf(conf_path=conf_path, config=config) + + +safe_module = { + 'numpy', + 'rag_flow' +} + + +class RestrictedUnpickler(pickle.Unpickler): + def find_class(self, module, name): + import importlib + if module.split('.')[0] in safe_module: + _module = importlib.import_module(module) + return getattr(_module, name) + # Forbid everything else. + raise pickle.UnpicklingError("global '%s.%s' is forbidden" % + (module, name)) + + +def restricted_loads(src): + """Helper function analogous to pickle.loads().""" + return RestrictedUnpickler(io.BytesIO(src)).load() + + +def serialize_b64(src, to_str=False): + dest = base64.b64encode(pickle.dumps(src)) + if not to_str: + return dest + else: + return bytes_to_string(dest) + + +def deserialize_b64(src): + src = base64.b64decode( + string_to_bytes(src) if isinstance( + src, str) else src) + use_deserialize_safe_module = get_base_config( + 'use_deserialize_safe_module', False) + if use_deserialize_safe_module: + return restricted_loads(src) + return pickle.loads(src) diff --git a/api/utils/json.py b/api/utils/json.py new file mode 100644 index 000000000..b21addd4f --- /dev/null +++ b/api/utils/json.py @@ -0,0 +1,78 @@ +import datetime +import json +from enum import Enum, IntEnum +from api.utils.common import string_to_bytes, bytes_to_string + + +class BaseType: + def to_dict(self): + return dict([(k.lstrip("_"), v) for k, v in self.__dict__.items()]) + + def to_dict_with_type(self): + def _dict(obj): + module = None + if issubclass(obj.__class__, BaseType): + data = {} + for attr, v in obj.__dict__.items(): + k = attr.lstrip("_") + data[k] = _dict(v) + module = obj.__module__ + elif isinstance(obj, (list, tuple)): + data = [] + for i, vv in enumerate(obj): + data.append(_dict(vv)) + elif isinstance(obj, dict): + data = {} + for _k, vv in obj.items(): + data[_k] = _dict(vv) + else: + data = obj + return {"type": obj.__class__.__name__, + "data": data, "module": module} + + return _dict(self) + + +class CustomJSONEncoder(json.JSONEncoder): + def __init__(self, **kwargs): + self._with_type = kwargs.pop("with_type", False) + super().__init__(**kwargs) + + def default(self, obj): + if isinstance(obj, datetime.datetime): + return obj.strftime('%Y-%m-%d %H:%M:%S') + elif isinstance(obj, datetime.date): + return obj.strftime('%Y-%m-%d') + elif isinstance(obj, datetime.timedelta): + return str(obj) + elif issubclass(type(obj), Enum) or issubclass(type(obj), IntEnum): + return obj.value + elif isinstance(obj, set): + return list(obj) + elif issubclass(type(obj), BaseType): + if not self._with_type: + return obj.to_dict() + else: + return obj.to_dict_with_type() + elif isinstance(obj, type): + return obj.__name__ + else: + return json.JSONEncoder.default(self, obj) + + +def json_dumps(src, byte=False, indent=None, with_type=False): + dest = json.dumps( + src, + indent=indent, + cls=CustomJSONEncoder, + with_type=with_type) + if byte: + dest = string_to_bytes(dest) + return dest + + +def json_loads(src, object_hook=None, object_pairs_hook=None): + if isinstance(src, bytes): + src = bytes_to_string(src) + return json.loads(src, object_hook=object_hook, + object_pairs_hook=object_pairs_hook) diff --git a/deepdoc/vision/ocr.py b/deepdoc/vision/ocr.py index d9f472aa1..d91de2ab8 100644 --- a/deepdoc/vision/ocr.py +++ b/deepdoc/vision/ocr.py @@ -350,7 +350,7 @@ class TextRecognizer: def close(self): # close session and release manually - logging.info('Close TextRecognizer.') + logging.info('Close text recognizer.') if hasattr(self, "predictor"): del self.predictor gc.collect() @@ -490,7 +490,7 @@ class TextDetector: return dt_boxes def close(self): - logging.info("Close TextDetector.") + logging.info("Close text detector.") if hasattr(self, "predictor"): del self.predictor gc.collect() diff --git a/rag/settings.py b/rag/settings.py index 70d1b6234..c78728783 100644 --- a/rag/settings.py +++ b/rag/settings.py @@ -15,7 +15,7 @@ # import os import logging -from api.utils import get_base_config, decrypt_database_config +from api.utils.configs import get_base_config, decrypt_database_config from api.utils.file_utils import get_project_base_directory # Server diff --git a/rag/utils/opendal_conn.py b/rag/utils/opendal_conn.py index c4fe92563..7642b33d4 100644 --- a/rag/utils/opendal_conn.py +++ b/rag/utils/opendal_conn.py @@ -3,7 +3,7 @@ import logging import pymysql from urllib.parse import quote_plus -from api.utils import get_base_config +from api.utils.configs import get_base_config from rag.utils import singleton