Refactor: move some functions out of api/utils/__init__.py (#10216)

### What problem does this PR solve?

Refactor import modules.

### Type of change

- [x] Refactoring

---------

Signed-off-by: jinhai <haijin.chn@gmail.com>
Signed-off-by: Jin Hai <haijin.chn@gmail.com>
This commit is contained in:
Jin Hai
2025-09-25 18:04:49 +08:00
committed by GitHub
parent 3a831d0c28
commit b0b866c8fd
12 changed files with 302 additions and 265 deletions

View File

@ -27,7 +27,8 @@ from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
from api.db import StatusEnum from api.db import StatusEnum
from api.db.db_models import close_connection from api.db.db_models import close_connection
from api.db.services import UserService from api.db.services import UserService
from api.utils import CustomJSONEncoder, commands from api.utils.json import CustomJSONEncoder
from api.utils import commands
from flask_mail import Mail from flask_mail import Mail
from flask_session import Session from flask_session import Session

View File

@ -32,6 +32,8 @@ from playhouse.pool import PooledMySQLDatabase, PooledPostgresqlDatabase
from api import settings, utils from api import settings, utils
from api.db import ParserType, SerializedType from api.db import ParserType, SerializedType
from api.utils.json import json_dumps, json_loads
from api.utils.configs import deserialize_b64, serialize_b64
def singleton(cls, *args, **kw): def singleton(cls, *args, **kw):
@ -70,12 +72,12 @@ class JSONField(LongTextField):
def db_value(self, value): def db_value(self, value):
if value is None: if value is None:
value = self.default_value value = self.default_value
return utils.json_dumps(value) return json_dumps(value)
def python_value(self, value): def python_value(self, value):
if not value: if not value:
return self.default_value return self.default_value
return utils.json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook) return json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
class ListField(JSONField): class ListField(JSONField):
@ -91,21 +93,21 @@ class SerializedField(LongTextField):
def db_value(self, value): def db_value(self, value):
if self._serialized_type == SerializedType.PICKLE: if self._serialized_type == SerializedType.PICKLE:
return utils.serialize_b64(value, to_str=True) return serialize_b64(value, to_str=True)
elif self._serialized_type == SerializedType.JSON: elif self._serialized_type == SerializedType.JSON:
if value is None: if value is None:
return None return None
return utils.json_dumps(value, with_type=True) return json_dumps(value, with_type=True)
else: else:
raise ValueError(f"the serialized type {self._serialized_type} is not supported") raise ValueError(f"the serialized type {self._serialized_type} is not supported")
def python_value(self, value): def python_value(self, value):
if self._serialized_type == SerializedType.PICKLE: if self._serialized_type == SerializedType.PICKLE:
return utils.deserialize_b64(value) return deserialize_b64(value)
elif self._serialized_type == SerializedType.JSON: elif self._serialized_type == SerializedType.JSON:
if value is None: if value is None:
return {} return {}
return utils.json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook) return json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
else: else:
raise ValueError(f"the serialized type {self._serialized_type} is not supported") raise ValueError(f"the serialized type {self._serialized_type} is not supported")

View File

@ -41,7 +41,7 @@ from api import utils
from api.db.db_models import init_database_tables as init_web_db from api.db.db_models import init_database_tables as init_web_db
from api.db.init_data import init_web_data from api.db.init_data import init_web_data
from api.versions import get_ragflow_version from api.versions import get_ragflow_version
from api.utils import show_configs from api.utils.configs import show_configs
from rag.settings import print_rag_settings from rag.settings import print_rag_settings
from rag.utils.mcp_tool_call_conn import shutdown_all_mcp_sessions from rag.utils.mcp_tool_call_conn import shutdown_all_mcp_sessions
from rag.utils.redis_conn import RedisDistributedLock from rag.utils.redis_conn import RedisDistributedLock

View File

@ -24,7 +24,7 @@ import rag.utils.es_conn
import rag.utils.infinity_conn import rag.utils.infinity_conn
import rag.utils.opensearch_conn import rag.utils.opensearch_conn
from api.constants import RAG_FLOW_SERVICE_NAME from api.constants import RAG_FLOW_SERVICE_NAME
from api.utils import decrypt_database_config, get_base_config from api.utils.configs import decrypt_database_config, get_base_config
from api.utils.file_utils import get_project_base_directory from api.utils.file_utils import get_project_base_directory
from rag.nlp import search from rag.nlp import search

View File

@ -16,182 +16,15 @@
import base64 import base64
import datetime import datetime
import hashlib import hashlib
import io
import json
import os import os
import pickle
import socket import socket
import time import time
import uuid import uuid
import requests import requests
import logging
import copy
from enum import Enum, IntEnum
import importlib import importlib
from filelock import FileLock
from api.constants import SERVICE_CONF
from . import file_utils from .common import string_to_bytes
def conf_realpath(conf_name):
conf_path = f"conf/{conf_name}"
return os.path.join(file_utils.get_project_base_directory(), conf_path)
def read_config(conf_name=SERVICE_CONF):
local_config = {}
local_path = conf_realpath(f'local.{conf_name}')
# load local config file
if os.path.exists(local_path):
local_config = file_utils.load_yaml_conf(local_path)
if not isinstance(local_config, dict):
raise ValueError(f'Invalid config file: "{local_path}".')
global_config_path = conf_realpath(conf_name)
global_config = file_utils.load_yaml_conf(global_config_path)
if not isinstance(global_config, dict):
raise ValueError(f'Invalid config file: "{global_config_path}".')
global_config.update(local_config)
return global_config
CONFIGS = read_config()
def show_configs():
msg = f"Current configs, from {conf_realpath(SERVICE_CONF)}:"
for k, v in CONFIGS.items():
if isinstance(v, dict):
if "password" in v:
v = copy.deepcopy(v)
v["password"] = "*" * 8
if "access_key" in v:
v = copy.deepcopy(v)
v["access_key"] = "*" * 8
if "secret_key" in v:
v = copy.deepcopy(v)
v["secret_key"] = "*" * 8
if "secret" in v:
v = copy.deepcopy(v)
v["secret"] = "*" * 8
if "sas_token" in v:
v = copy.deepcopy(v)
v["sas_token"] = "*" * 8
if "oauth" in k:
v = copy.deepcopy(v)
for key, val in v.items():
if "client_secret" in val:
val["client_secret"] = "*" * 8
if "authentication" in k:
v = copy.deepcopy(v)
for key, val in v.items():
if "http_secret_key" in val:
val["http_secret_key"] = "*" * 8
msg += f"\n\t{k}: {v}"
logging.info(msg)
def get_base_config(key, default=None):
if key is None:
return None
if default is None:
default = os.environ.get(key.upper())
return CONFIGS.get(key, default)
use_deserialize_safe_module = get_base_config(
'use_deserialize_safe_module', False)
class BaseType:
def to_dict(self):
return dict([(k.lstrip("_"), v) for k, v in self.__dict__.items()])
def to_dict_with_type(self):
def _dict(obj):
module = None
if issubclass(obj.__class__, BaseType):
data = {}
for attr, v in obj.__dict__.items():
k = attr.lstrip("_")
data[k] = _dict(v)
module = obj.__module__
elif isinstance(obj, (list, tuple)):
data = []
for i, vv in enumerate(obj):
data.append(_dict(vv))
elif isinstance(obj, dict):
data = {}
for _k, vv in obj.items():
data[_k] = _dict(vv)
else:
data = obj
return {"type": obj.__class__.__name__,
"data": data, "module": module}
return _dict(self)
class CustomJSONEncoder(json.JSONEncoder):
def __init__(self, **kwargs):
self._with_type = kwargs.pop("with_type", False)
super().__init__(**kwargs)
def default(self, obj):
if isinstance(obj, datetime.datetime):
return obj.strftime('%Y-%m-%d %H:%M:%S')
elif isinstance(obj, datetime.date):
return obj.strftime('%Y-%m-%d')
elif isinstance(obj, datetime.timedelta):
return str(obj)
elif issubclass(type(obj), Enum) or issubclass(type(obj), IntEnum):
return obj.value
elif isinstance(obj, set):
return list(obj)
elif issubclass(type(obj), BaseType):
if not self._with_type:
return obj.to_dict()
else:
return obj.to_dict_with_type()
elif isinstance(obj, type):
return obj.__name__
else:
return json.JSONEncoder.default(self, obj)
def rag_uuid():
return uuid.uuid1().hex
def string_to_bytes(string):
return string if isinstance(
string, bytes) else string.encode(encoding="utf-8")
def bytes_to_string(byte):
return byte.decode(encoding="utf-8")
def json_dumps(src, byte=False, indent=None, with_type=False):
dest = json.dumps(
src,
indent=indent,
cls=CustomJSONEncoder,
with_type=with_type)
if byte:
dest = string_to_bytes(dest)
return dest
def json_loads(src, object_hook=None, object_pairs_hook=None):
if isinstance(src, bytes):
src = bytes_to_string(src)
return json.loads(src, object_hook=object_hook,
object_pairs_hook=object_pairs_hook)
def current_timestamp(): def current_timestamp():
@ -213,45 +46,6 @@ def date_string_to_timestamp(time_str, format_string="%Y-%m-%d %H:%M:%S"):
return time_stamp return time_stamp
def serialize_b64(src, to_str=False):
dest = base64.b64encode(pickle.dumps(src))
if not to_str:
return dest
else:
return bytes_to_string(dest)
def deserialize_b64(src):
src = base64.b64decode(
string_to_bytes(src) if isinstance(
src, str) else src)
if use_deserialize_safe_module:
return restricted_loads(src)
return pickle.loads(src)
safe_module = {
'numpy',
'rag_flow'
}
class RestrictedUnpickler(pickle.Unpickler):
def find_class(self, module, name):
import importlib
if module.split('.')[0] in safe_module:
_module = importlib.import_module(module)
return getattr(_module, name)
# Forbid everything else.
raise pickle.UnpicklingError("global '%s.%s' is forbidden" %
(module, name))
def restricted_loads(src):
"""Helper function analogous to pickle.loads()."""
return RestrictedUnpickler(io.BytesIO(src)).load()
def get_lan_ip(): def get_lan_ip():
if os.name != "nt": if os.name != "nt":
import fcntl import fcntl
@ -296,47 +90,6 @@ def from_dict_hook(in_dict: dict):
return in_dict return in_dict
def decrypt_database_password(password):
encrypt_password = get_base_config("encrypt_password", False)
encrypt_module = get_base_config("encrypt_module", False)
private_key = get_base_config("private_key", None)
if not password or not encrypt_password:
return password
if not private_key:
raise ValueError("No private key")
module_fun = encrypt_module.split("#")
pwdecrypt_fun = getattr(
importlib.import_module(
module_fun[0]),
module_fun[1])
return pwdecrypt_fun(private_key, password)
def decrypt_database_config(
database=None, passwd_key="password", name="database"):
if not database:
database = get_base_config(name, {})
database[passwd_key] = decrypt_database_password(database[passwd_key])
return database
def update_config(key, value, conf_name=SERVICE_CONF):
conf_path = conf_realpath(conf_name=conf_name)
if not os.path.isabs(conf_path):
conf_path = os.path.join(
file_utils.get_project_base_directory(), conf_path)
with FileLock(os.path.join(os.path.dirname(conf_path), ".lock")):
config = file_utils.load_yaml_conf(conf_path=conf_path) or {}
config[key] = value
file_utils.rewrite_yaml_conf(conf_path=conf_path, config=config)
def get_uuid(): def get_uuid():
return uuid.uuid1().hex return uuid.uuid1().hex
@ -375,5 +128,5 @@ def delta_seconds(date_string: str):
return (datetime.datetime.now() - dt).total_seconds() return (datetime.datetime.now() - dt).total_seconds()
def hash_str2int(line:str, mod: int=10 ** 8) -> int: def hash_str2int(line: str, mod: int = 10 ** 8) -> int:
return int(hashlib.sha1(line.encode("utf-8")).hexdigest(), 16) % mod return int(hashlib.sha1(line.encode("utf-8")).hexdigest(), 16) % mod

View File

@ -54,7 +54,8 @@ from api.db.db_models import APIToken
from api.db.services import UserService from api.db.services import UserService
from api.db.services.llm_service import LLMService from api.db.services.llm_service import LLMService
from api.db.services.tenant_llm_service import TenantLLMService from api.db.services.tenant_llm_service import TenantLLMService
from api.utils import CustomJSONEncoder, get_uuid, json_dumps from api.utils.json import CustomJSONEncoder, json_dumps
from api.utils import get_uuid
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSONEncoder) requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSONEncoder)

23
api/utils/common.py Normal file
View File

@ -0,0 +1,23 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
def string_to_bytes(string):
return string if isinstance(
string, bytes) else string.encode(encoding="utf-8")
def bytes_to_string(byte):
return byte.decode(encoding="utf-8")

179
api/utils/configs.py Normal file
View File

@ -0,0 +1,179 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import io
import copy
import logging
import base64
import pickle
import importlib
from api.utils import file_utils
from filelock import FileLock
from api.utils.common import bytes_to_string, string_to_bytes
from api.constants import SERVICE_CONF
def conf_realpath(conf_name):
conf_path = f"conf/{conf_name}"
return os.path.join(file_utils.get_project_base_directory(), conf_path)
def read_config(conf_name=SERVICE_CONF):
local_config = {}
local_path = conf_realpath(f'local.{conf_name}')
# load local config file
if os.path.exists(local_path):
local_config = file_utils.load_yaml_conf(local_path)
if not isinstance(local_config, dict):
raise ValueError(f'Invalid config file: "{local_path}".')
global_config_path = conf_realpath(conf_name)
global_config = file_utils.load_yaml_conf(global_config_path)
if not isinstance(global_config, dict):
raise ValueError(f'Invalid config file: "{global_config_path}".')
global_config.update(local_config)
return global_config
CONFIGS = read_config()
def show_configs():
msg = f"Current configs, from {conf_realpath(SERVICE_CONF)}:"
for k, v in CONFIGS.items():
if isinstance(v, dict):
if "password" in v:
v = copy.deepcopy(v)
v["password"] = "*" * 8
if "access_key" in v:
v = copy.deepcopy(v)
v["access_key"] = "*" * 8
if "secret_key" in v:
v = copy.deepcopy(v)
v["secret_key"] = "*" * 8
if "secret" in v:
v = copy.deepcopy(v)
v["secret"] = "*" * 8
if "sas_token" in v:
v = copy.deepcopy(v)
v["sas_token"] = "*" * 8
if "oauth" in k:
v = copy.deepcopy(v)
for key, val in v.items():
if "client_secret" in val:
val["client_secret"] = "*" * 8
if "authentication" in k:
v = copy.deepcopy(v)
for key, val in v.items():
if "http_secret_key" in val:
val["http_secret_key"] = "*" * 8
msg += f"\n\t{k}: {v}"
logging.info(msg)
def get_base_config(key, default=None):
if key is None:
return None
if default is None:
default = os.environ.get(key.upper())
return CONFIGS.get(key, default)
def decrypt_database_password(password):
encrypt_password = get_base_config("encrypt_password", False)
encrypt_module = get_base_config("encrypt_module", False)
private_key = get_base_config("private_key", None)
if not password or not encrypt_password:
return password
if not private_key:
raise ValueError("No private key")
module_fun = encrypt_module.split("#")
pwdecrypt_fun = getattr(
importlib.import_module(
module_fun[0]),
module_fun[1])
return pwdecrypt_fun(private_key, password)
def decrypt_database_config(
database=None, passwd_key="password", name="database"):
if not database:
database = get_base_config(name, {})
database[passwd_key] = decrypt_database_password(database[passwd_key])
return database
def update_config(key, value, conf_name=SERVICE_CONF):
conf_path = conf_realpath(conf_name=conf_name)
if not os.path.isabs(conf_path):
conf_path = os.path.join(
file_utils.get_project_base_directory(), conf_path)
with FileLock(os.path.join(os.path.dirname(conf_path), ".lock")):
config = file_utils.load_yaml_conf(conf_path=conf_path) or {}
config[key] = value
file_utils.rewrite_yaml_conf(conf_path=conf_path, config=config)
safe_module = {
'numpy',
'rag_flow'
}
class RestrictedUnpickler(pickle.Unpickler):
def find_class(self, module, name):
import importlib
if module.split('.')[0] in safe_module:
_module = importlib.import_module(module)
return getattr(_module, name)
# Forbid everything else.
raise pickle.UnpicklingError("global '%s.%s' is forbidden" %
(module, name))
def restricted_loads(src):
"""Helper function analogous to pickle.loads()."""
return RestrictedUnpickler(io.BytesIO(src)).load()
def serialize_b64(src, to_str=False):
dest = base64.b64encode(pickle.dumps(src))
if not to_str:
return dest
else:
return bytes_to_string(dest)
def deserialize_b64(src):
src = base64.b64decode(
string_to_bytes(src) if isinstance(
src, str) else src)
use_deserialize_safe_module = get_base_config(
'use_deserialize_safe_module', False)
if use_deserialize_safe_module:
return restricted_loads(src)
return pickle.loads(src)

78
api/utils/json.py Normal file
View File

@ -0,0 +1,78 @@
import datetime
import json
from enum import Enum, IntEnum
from api.utils.common import string_to_bytes, bytes_to_string
class BaseType:
def to_dict(self):
return dict([(k.lstrip("_"), v) for k, v in self.__dict__.items()])
def to_dict_with_type(self):
def _dict(obj):
module = None
if issubclass(obj.__class__, BaseType):
data = {}
for attr, v in obj.__dict__.items():
k = attr.lstrip("_")
data[k] = _dict(v)
module = obj.__module__
elif isinstance(obj, (list, tuple)):
data = []
for i, vv in enumerate(obj):
data.append(_dict(vv))
elif isinstance(obj, dict):
data = {}
for _k, vv in obj.items():
data[_k] = _dict(vv)
else:
data = obj
return {"type": obj.__class__.__name__,
"data": data, "module": module}
return _dict(self)
class CustomJSONEncoder(json.JSONEncoder):
def __init__(self, **kwargs):
self._with_type = kwargs.pop("with_type", False)
super().__init__(**kwargs)
def default(self, obj):
if isinstance(obj, datetime.datetime):
return obj.strftime('%Y-%m-%d %H:%M:%S')
elif isinstance(obj, datetime.date):
return obj.strftime('%Y-%m-%d')
elif isinstance(obj, datetime.timedelta):
return str(obj)
elif issubclass(type(obj), Enum) or issubclass(type(obj), IntEnum):
return obj.value
elif isinstance(obj, set):
return list(obj)
elif issubclass(type(obj), BaseType):
if not self._with_type:
return obj.to_dict()
else:
return obj.to_dict_with_type()
elif isinstance(obj, type):
return obj.__name__
else:
return json.JSONEncoder.default(self, obj)
def json_dumps(src, byte=False, indent=None, with_type=False):
dest = json.dumps(
src,
indent=indent,
cls=CustomJSONEncoder,
with_type=with_type)
if byte:
dest = string_to_bytes(dest)
return dest
def json_loads(src, object_hook=None, object_pairs_hook=None):
if isinstance(src, bytes):
src = bytes_to_string(src)
return json.loads(src, object_hook=object_hook,
object_pairs_hook=object_pairs_hook)

View File

@ -350,7 +350,7 @@ class TextRecognizer:
def close(self): def close(self):
# close session and release manually # close session and release manually
logging.info('Close TextRecognizer.') logging.info('Close text recognizer.')
if hasattr(self, "predictor"): if hasattr(self, "predictor"):
del self.predictor del self.predictor
gc.collect() gc.collect()
@ -490,7 +490,7 @@ class TextDetector:
return dt_boxes return dt_boxes
def close(self): def close(self):
logging.info("Close TextDetector.") logging.info("Close text detector.")
if hasattr(self, "predictor"): if hasattr(self, "predictor"):
del self.predictor del self.predictor
gc.collect() gc.collect()

View File

@ -15,7 +15,7 @@
# #
import os import os
import logging import logging
from api.utils import get_base_config, decrypt_database_config from api.utils.configs import get_base_config, decrypt_database_config
from api.utils.file_utils import get_project_base_directory from api.utils.file_utils import get_project_base_directory
# Server # Server

View File

@ -3,7 +3,7 @@ import logging
import pymysql import pymysql
from urllib.parse import quote_plus from urllib.parse import quote_plus
from api.utils import get_base_config from api.utils.configs import get_base_config
from rag.utils import singleton from rag.utils import singleton