mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 12:32:30 +08:00
Refactor: move some functions out of api/utils/__init__.py (#10216)
### What problem does this PR solve? Refactor import modules. ### Type of change - [x] Refactoring --------- Signed-off-by: jinhai <haijin.chn@gmail.com> Signed-off-by: Jin Hai <haijin.chn@gmail.com>
This commit is contained in:
@ -27,7 +27,8 @@ from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
|||||||
from api.db import StatusEnum
|
from api.db import StatusEnum
|
||||||
from api.db.db_models import close_connection
|
from api.db.db_models import close_connection
|
||||||
from api.db.services import UserService
|
from api.db.services import UserService
|
||||||
from api.utils import CustomJSONEncoder, commands
|
from api.utils.json import CustomJSONEncoder
|
||||||
|
from api.utils import commands
|
||||||
|
|
||||||
from flask_mail import Mail
|
from flask_mail import Mail
|
||||||
from flask_session import Session
|
from flask_session import Session
|
||||||
|
|||||||
@ -32,6 +32,8 @@ from playhouse.pool import PooledMySQLDatabase, PooledPostgresqlDatabase
|
|||||||
|
|
||||||
from api import settings, utils
|
from api import settings, utils
|
||||||
from api.db import ParserType, SerializedType
|
from api.db import ParserType, SerializedType
|
||||||
|
from api.utils.json import json_dumps, json_loads
|
||||||
|
from api.utils.configs import deserialize_b64, serialize_b64
|
||||||
|
|
||||||
|
|
||||||
def singleton(cls, *args, **kw):
|
def singleton(cls, *args, **kw):
|
||||||
@ -70,12 +72,12 @@ class JSONField(LongTextField):
|
|||||||
def db_value(self, value):
|
def db_value(self, value):
|
||||||
if value is None:
|
if value is None:
|
||||||
value = self.default_value
|
value = self.default_value
|
||||||
return utils.json_dumps(value)
|
return json_dumps(value)
|
||||||
|
|
||||||
def python_value(self, value):
|
def python_value(self, value):
|
||||||
if not value:
|
if not value:
|
||||||
return self.default_value
|
return self.default_value
|
||||||
return utils.json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
|
return json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
|
||||||
|
|
||||||
|
|
||||||
class ListField(JSONField):
|
class ListField(JSONField):
|
||||||
@ -91,21 +93,21 @@ class SerializedField(LongTextField):
|
|||||||
|
|
||||||
def db_value(self, value):
|
def db_value(self, value):
|
||||||
if self._serialized_type == SerializedType.PICKLE:
|
if self._serialized_type == SerializedType.PICKLE:
|
||||||
return utils.serialize_b64(value, to_str=True)
|
return serialize_b64(value, to_str=True)
|
||||||
elif self._serialized_type == SerializedType.JSON:
|
elif self._serialized_type == SerializedType.JSON:
|
||||||
if value is None:
|
if value is None:
|
||||||
return None
|
return None
|
||||||
return utils.json_dumps(value, with_type=True)
|
return json_dumps(value, with_type=True)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"the serialized type {self._serialized_type} is not supported")
|
raise ValueError(f"the serialized type {self._serialized_type} is not supported")
|
||||||
|
|
||||||
def python_value(self, value):
|
def python_value(self, value):
|
||||||
if self._serialized_type == SerializedType.PICKLE:
|
if self._serialized_type == SerializedType.PICKLE:
|
||||||
return utils.deserialize_b64(value)
|
return deserialize_b64(value)
|
||||||
elif self._serialized_type == SerializedType.JSON:
|
elif self._serialized_type == SerializedType.JSON:
|
||||||
if value is None:
|
if value is None:
|
||||||
return {}
|
return {}
|
||||||
return utils.json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
|
return json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"the serialized type {self._serialized_type} is not supported")
|
raise ValueError(f"the serialized type {self._serialized_type} is not supported")
|
||||||
|
|
||||||
|
|||||||
@ -41,7 +41,7 @@ from api import utils
|
|||||||
from api.db.db_models import init_database_tables as init_web_db
|
from api.db.db_models import init_database_tables as init_web_db
|
||||||
from api.db.init_data import init_web_data
|
from api.db.init_data import init_web_data
|
||||||
from api.versions import get_ragflow_version
|
from api.versions import get_ragflow_version
|
||||||
from api.utils import show_configs
|
from api.utils.configs import show_configs
|
||||||
from rag.settings import print_rag_settings
|
from rag.settings import print_rag_settings
|
||||||
from rag.utils.mcp_tool_call_conn import shutdown_all_mcp_sessions
|
from rag.utils.mcp_tool_call_conn import shutdown_all_mcp_sessions
|
||||||
from rag.utils.redis_conn import RedisDistributedLock
|
from rag.utils.redis_conn import RedisDistributedLock
|
||||||
|
|||||||
@ -24,7 +24,7 @@ import rag.utils.es_conn
|
|||||||
import rag.utils.infinity_conn
|
import rag.utils.infinity_conn
|
||||||
import rag.utils.opensearch_conn
|
import rag.utils.opensearch_conn
|
||||||
from api.constants import RAG_FLOW_SERVICE_NAME
|
from api.constants import RAG_FLOW_SERVICE_NAME
|
||||||
from api.utils import decrypt_database_config, get_base_config
|
from api.utils.configs import decrypt_database_config, get_base_config
|
||||||
from api.utils.file_utils import get_project_base_directory
|
from api.utils.file_utils import get_project_base_directory
|
||||||
from rag.nlp import search
|
from rag.nlp import search
|
||||||
|
|
||||||
|
|||||||
@ -16,182 +16,15 @@
|
|||||||
import base64
|
import base64
|
||||||
import datetime
|
import datetime
|
||||||
import hashlib
|
import hashlib
|
||||||
import io
|
|
||||||
import json
|
|
||||||
import os
|
import os
|
||||||
import pickle
|
|
||||||
import socket
|
import socket
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
import requests
|
import requests
|
||||||
import logging
|
|
||||||
import copy
|
|
||||||
from enum import Enum, IntEnum
|
|
||||||
import importlib
|
import importlib
|
||||||
from filelock import FileLock
|
|
||||||
from api.constants import SERVICE_CONF
|
|
||||||
|
|
||||||
from . import file_utils
|
from .common import string_to_bytes
|
||||||
|
|
||||||
|
|
||||||
def conf_realpath(conf_name):
|
|
||||||
conf_path = f"conf/{conf_name}"
|
|
||||||
return os.path.join(file_utils.get_project_base_directory(), conf_path)
|
|
||||||
|
|
||||||
|
|
||||||
def read_config(conf_name=SERVICE_CONF):
|
|
||||||
local_config = {}
|
|
||||||
local_path = conf_realpath(f'local.{conf_name}')
|
|
||||||
|
|
||||||
# load local config file
|
|
||||||
if os.path.exists(local_path):
|
|
||||||
local_config = file_utils.load_yaml_conf(local_path)
|
|
||||||
if not isinstance(local_config, dict):
|
|
||||||
raise ValueError(f'Invalid config file: "{local_path}".')
|
|
||||||
|
|
||||||
global_config_path = conf_realpath(conf_name)
|
|
||||||
global_config = file_utils.load_yaml_conf(global_config_path)
|
|
||||||
|
|
||||||
if not isinstance(global_config, dict):
|
|
||||||
raise ValueError(f'Invalid config file: "{global_config_path}".')
|
|
||||||
|
|
||||||
global_config.update(local_config)
|
|
||||||
return global_config
|
|
||||||
|
|
||||||
|
|
||||||
CONFIGS = read_config()
|
|
||||||
|
|
||||||
|
|
||||||
def show_configs():
|
|
||||||
msg = f"Current configs, from {conf_realpath(SERVICE_CONF)}:"
|
|
||||||
for k, v in CONFIGS.items():
|
|
||||||
if isinstance(v, dict):
|
|
||||||
if "password" in v:
|
|
||||||
v = copy.deepcopy(v)
|
|
||||||
v["password"] = "*" * 8
|
|
||||||
if "access_key" in v:
|
|
||||||
v = copy.deepcopy(v)
|
|
||||||
v["access_key"] = "*" * 8
|
|
||||||
if "secret_key" in v:
|
|
||||||
v = copy.deepcopy(v)
|
|
||||||
v["secret_key"] = "*" * 8
|
|
||||||
if "secret" in v:
|
|
||||||
v = copy.deepcopy(v)
|
|
||||||
v["secret"] = "*" * 8
|
|
||||||
if "sas_token" in v:
|
|
||||||
v = copy.deepcopy(v)
|
|
||||||
v["sas_token"] = "*" * 8
|
|
||||||
if "oauth" in k:
|
|
||||||
v = copy.deepcopy(v)
|
|
||||||
for key, val in v.items():
|
|
||||||
if "client_secret" in val:
|
|
||||||
val["client_secret"] = "*" * 8
|
|
||||||
if "authentication" in k:
|
|
||||||
v = copy.deepcopy(v)
|
|
||||||
for key, val in v.items():
|
|
||||||
if "http_secret_key" in val:
|
|
||||||
val["http_secret_key"] = "*" * 8
|
|
||||||
msg += f"\n\t{k}: {v}"
|
|
||||||
logging.info(msg)
|
|
||||||
|
|
||||||
|
|
||||||
def get_base_config(key, default=None):
|
|
||||||
if key is None:
|
|
||||||
return None
|
|
||||||
if default is None:
|
|
||||||
default = os.environ.get(key.upper())
|
|
||||||
return CONFIGS.get(key, default)
|
|
||||||
|
|
||||||
|
|
||||||
use_deserialize_safe_module = get_base_config(
|
|
||||||
'use_deserialize_safe_module', False)
|
|
||||||
|
|
||||||
|
|
||||||
class BaseType:
|
|
||||||
def to_dict(self):
|
|
||||||
return dict([(k.lstrip("_"), v) for k, v in self.__dict__.items()])
|
|
||||||
|
|
||||||
def to_dict_with_type(self):
|
|
||||||
def _dict(obj):
|
|
||||||
module = None
|
|
||||||
if issubclass(obj.__class__, BaseType):
|
|
||||||
data = {}
|
|
||||||
for attr, v in obj.__dict__.items():
|
|
||||||
k = attr.lstrip("_")
|
|
||||||
data[k] = _dict(v)
|
|
||||||
module = obj.__module__
|
|
||||||
elif isinstance(obj, (list, tuple)):
|
|
||||||
data = []
|
|
||||||
for i, vv in enumerate(obj):
|
|
||||||
data.append(_dict(vv))
|
|
||||||
elif isinstance(obj, dict):
|
|
||||||
data = {}
|
|
||||||
for _k, vv in obj.items():
|
|
||||||
data[_k] = _dict(vv)
|
|
||||||
else:
|
|
||||||
data = obj
|
|
||||||
return {"type": obj.__class__.__name__,
|
|
||||||
"data": data, "module": module}
|
|
||||||
|
|
||||||
return _dict(self)
|
|
||||||
|
|
||||||
|
|
||||||
class CustomJSONEncoder(json.JSONEncoder):
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
self._with_type = kwargs.pop("with_type", False)
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
|
|
||||||
def default(self, obj):
|
|
||||||
if isinstance(obj, datetime.datetime):
|
|
||||||
return obj.strftime('%Y-%m-%d %H:%M:%S')
|
|
||||||
elif isinstance(obj, datetime.date):
|
|
||||||
return obj.strftime('%Y-%m-%d')
|
|
||||||
elif isinstance(obj, datetime.timedelta):
|
|
||||||
return str(obj)
|
|
||||||
elif issubclass(type(obj), Enum) or issubclass(type(obj), IntEnum):
|
|
||||||
return obj.value
|
|
||||||
elif isinstance(obj, set):
|
|
||||||
return list(obj)
|
|
||||||
elif issubclass(type(obj), BaseType):
|
|
||||||
if not self._with_type:
|
|
||||||
return obj.to_dict()
|
|
||||||
else:
|
|
||||||
return obj.to_dict_with_type()
|
|
||||||
elif isinstance(obj, type):
|
|
||||||
return obj.__name__
|
|
||||||
else:
|
|
||||||
return json.JSONEncoder.default(self, obj)
|
|
||||||
|
|
||||||
|
|
||||||
def rag_uuid():
|
|
||||||
return uuid.uuid1().hex
|
|
||||||
|
|
||||||
|
|
||||||
def string_to_bytes(string):
|
|
||||||
return string if isinstance(
|
|
||||||
string, bytes) else string.encode(encoding="utf-8")
|
|
||||||
|
|
||||||
|
|
||||||
def bytes_to_string(byte):
|
|
||||||
return byte.decode(encoding="utf-8")
|
|
||||||
|
|
||||||
|
|
||||||
def json_dumps(src, byte=False, indent=None, with_type=False):
|
|
||||||
dest = json.dumps(
|
|
||||||
src,
|
|
||||||
indent=indent,
|
|
||||||
cls=CustomJSONEncoder,
|
|
||||||
with_type=with_type)
|
|
||||||
if byte:
|
|
||||||
dest = string_to_bytes(dest)
|
|
||||||
return dest
|
|
||||||
|
|
||||||
|
|
||||||
def json_loads(src, object_hook=None, object_pairs_hook=None):
|
|
||||||
if isinstance(src, bytes):
|
|
||||||
src = bytes_to_string(src)
|
|
||||||
return json.loads(src, object_hook=object_hook,
|
|
||||||
object_pairs_hook=object_pairs_hook)
|
|
||||||
|
|
||||||
|
|
||||||
def current_timestamp():
|
def current_timestamp():
|
||||||
@ -213,45 +46,6 @@ def date_string_to_timestamp(time_str, format_string="%Y-%m-%d %H:%M:%S"):
|
|||||||
return time_stamp
|
return time_stamp
|
||||||
|
|
||||||
|
|
||||||
def serialize_b64(src, to_str=False):
|
|
||||||
dest = base64.b64encode(pickle.dumps(src))
|
|
||||||
if not to_str:
|
|
||||||
return dest
|
|
||||||
else:
|
|
||||||
return bytes_to_string(dest)
|
|
||||||
|
|
||||||
|
|
||||||
def deserialize_b64(src):
|
|
||||||
src = base64.b64decode(
|
|
||||||
string_to_bytes(src) if isinstance(
|
|
||||||
src, str) else src)
|
|
||||||
if use_deserialize_safe_module:
|
|
||||||
return restricted_loads(src)
|
|
||||||
return pickle.loads(src)
|
|
||||||
|
|
||||||
|
|
||||||
safe_module = {
|
|
||||||
'numpy',
|
|
||||||
'rag_flow'
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class RestrictedUnpickler(pickle.Unpickler):
|
|
||||||
def find_class(self, module, name):
|
|
||||||
import importlib
|
|
||||||
if module.split('.')[0] in safe_module:
|
|
||||||
_module = importlib.import_module(module)
|
|
||||||
return getattr(_module, name)
|
|
||||||
# Forbid everything else.
|
|
||||||
raise pickle.UnpicklingError("global '%s.%s' is forbidden" %
|
|
||||||
(module, name))
|
|
||||||
|
|
||||||
|
|
||||||
def restricted_loads(src):
|
|
||||||
"""Helper function analogous to pickle.loads()."""
|
|
||||||
return RestrictedUnpickler(io.BytesIO(src)).load()
|
|
||||||
|
|
||||||
|
|
||||||
def get_lan_ip():
|
def get_lan_ip():
|
||||||
if os.name != "nt":
|
if os.name != "nt":
|
||||||
import fcntl
|
import fcntl
|
||||||
@ -296,47 +90,6 @@ def from_dict_hook(in_dict: dict):
|
|||||||
return in_dict
|
return in_dict
|
||||||
|
|
||||||
|
|
||||||
def decrypt_database_password(password):
|
|
||||||
encrypt_password = get_base_config("encrypt_password", False)
|
|
||||||
encrypt_module = get_base_config("encrypt_module", False)
|
|
||||||
private_key = get_base_config("private_key", None)
|
|
||||||
|
|
||||||
if not password or not encrypt_password:
|
|
||||||
return password
|
|
||||||
|
|
||||||
if not private_key:
|
|
||||||
raise ValueError("No private key")
|
|
||||||
|
|
||||||
module_fun = encrypt_module.split("#")
|
|
||||||
pwdecrypt_fun = getattr(
|
|
||||||
importlib.import_module(
|
|
||||||
module_fun[0]),
|
|
||||||
module_fun[1])
|
|
||||||
|
|
||||||
return pwdecrypt_fun(private_key, password)
|
|
||||||
|
|
||||||
|
|
||||||
def decrypt_database_config(
|
|
||||||
database=None, passwd_key="password", name="database"):
|
|
||||||
if not database:
|
|
||||||
database = get_base_config(name, {})
|
|
||||||
|
|
||||||
database[passwd_key] = decrypt_database_password(database[passwd_key])
|
|
||||||
return database
|
|
||||||
|
|
||||||
|
|
||||||
def update_config(key, value, conf_name=SERVICE_CONF):
|
|
||||||
conf_path = conf_realpath(conf_name=conf_name)
|
|
||||||
if not os.path.isabs(conf_path):
|
|
||||||
conf_path = os.path.join(
|
|
||||||
file_utils.get_project_base_directory(), conf_path)
|
|
||||||
|
|
||||||
with FileLock(os.path.join(os.path.dirname(conf_path), ".lock")):
|
|
||||||
config = file_utils.load_yaml_conf(conf_path=conf_path) or {}
|
|
||||||
config[key] = value
|
|
||||||
file_utils.rewrite_yaml_conf(conf_path=conf_path, config=config)
|
|
||||||
|
|
||||||
|
|
||||||
def get_uuid():
|
def get_uuid():
|
||||||
return uuid.uuid1().hex
|
return uuid.uuid1().hex
|
||||||
|
|
||||||
@ -375,5 +128,5 @@ def delta_seconds(date_string: str):
|
|||||||
return (datetime.datetime.now() - dt).total_seconds()
|
return (datetime.datetime.now() - dt).total_seconds()
|
||||||
|
|
||||||
|
|
||||||
def hash_str2int(line:str, mod: int=10 ** 8) -> int:
|
def hash_str2int(line: str, mod: int = 10 ** 8) -> int:
|
||||||
return int(hashlib.sha1(line.encode("utf-8")).hexdigest(), 16) % mod
|
return int(hashlib.sha1(line.encode("utf-8")).hexdigest(), 16) % mod
|
||||||
|
|||||||
@ -54,7 +54,8 @@ from api.db.db_models import APIToken
|
|||||||
from api.db.services import UserService
|
from api.db.services import UserService
|
||||||
from api.db.services.llm_service import LLMService
|
from api.db.services.llm_service import LLMService
|
||||||
from api.db.services.tenant_llm_service import TenantLLMService
|
from api.db.services.tenant_llm_service import TenantLLMService
|
||||||
from api.utils import CustomJSONEncoder, get_uuid, json_dumps
|
from api.utils.json import CustomJSONEncoder, json_dumps
|
||||||
|
from api.utils import get_uuid
|
||||||
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
|
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
|
||||||
|
|
||||||
requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSONEncoder)
|
requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSONEncoder)
|
||||||
|
|||||||
23
api/utils/common.py
Normal file
23
api/utils/common.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
def string_to_bytes(string):
|
||||||
|
return string if isinstance(
|
||||||
|
string, bytes) else string.encode(encoding="utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
def bytes_to_string(byte):
|
||||||
|
return byte.decode(encoding="utf-8")
|
||||||
179
api/utils/configs.py
Normal file
179
api/utils/configs.py
Normal file
@ -0,0 +1,179 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
import os
|
||||||
|
import io
|
||||||
|
import copy
|
||||||
|
import logging
|
||||||
|
import base64
|
||||||
|
import pickle
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
from api.utils import file_utils
|
||||||
|
from filelock import FileLock
|
||||||
|
from api.utils.common import bytes_to_string, string_to_bytes
|
||||||
|
from api.constants import SERVICE_CONF
|
||||||
|
|
||||||
|
|
||||||
|
def conf_realpath(conf_name):
|
||||||
|
conf_path = f"conf/{conf_name}"
|
||||||
|
return os.path.join(file_utils.get_project_base_directory(), conf_path)
|
||||||
|
|
||||||
|
|
||||||
|
def read_config(conf_name=SERVICE_CONF):
|
||||||
|
local_config = {}
|
||||||
|
local_path = conf_realpath(f'local.{conf_name}')
|
||||||
|
|
||||||
|
# load local config file
|
||||||
|
if os.path.exists(local_path):
|
||||||
|
local_config = file_utils.load_yaml_conf(local_path)
|
||||||
|
if not isinstance(local_config, dict):
|
||||||
|
raise ValueError(f'Invalid config file: "{local_path}".')
|
||||||
|
|
||||||
|
global_config_path = conf_realpath(conf_name)
|
||||||
|
global_config = file_utils.load_yaml_conf(global_config_path)
|
||||||
|
|
||||||
|
if not isinstance(global_config, dict):
|
||||||
|
raise ValueError(f'Invalid config file: "{global_config_path}".')
|
||||||
|
|
||||||
|
global_config.update(local_config)
|
||||||
|
return global_config
|
||||||
|
|
||||||
|
|
||||||
|
CONFIGS = read_config()
|
||||||
|
|
||||||
|
|
||||||
|
def show_configs():
|
||||||
|
msg = f"Current configs, from {conf_realpath(SERVICE_CONF)}:"
|
||||||
|
for k, v in CONFIGS.items():
|
||||||
|
if isinstance(v, dict):
|
||||||
|
if "password" in v:
|
||||||
|
v = copy.deepcopy(v)
|
||||||
|
v["password"] = "*" * 8
|
||||||
|
if "access_key" in v:
|
||||||
|
v = copy.deepcopy(v)
|
||||||
|
v["access_key"] = "*" * 8
|
||||||
|
if "secret_key" in v:
|
||||||
|
v = copy.deepcopy(v)
|
||||||
|
v["secret_key"] = "*" * 8
|
||||||
|
if "secret" in v:
|
||||||
|
v = copy.deepcopy(v)
|
||||||
|
v["secret"] = "*" * 8
|
||||||
|
if "sas_token" in v:
|
||||||
|
v = copy.deepcopy(v)
|
||||||
|
v["sas_token"] = "*" * 8
|
||||||
|
if "oauth" in k:
|
||||||
|
v = copy.deepcopy(v)
|
||||||
|
for key, val in v.items():
|
||||||
|
if "client_secret" in val:
|
||||||
|
val["client_secret"] = "*" * 8
|
||||||
|
if "authentication" in k:
|
||||||
|
v = copy.deepcopy(v)
|
||||||
|
for key, val in v.items():
|
||||||
|
if "http_secret_key" in val:
|
||||||
|
val["http_secret_key"] = "*" * 8
|
||||||
|
msg += f"\n\t{k}: {v}"
|
||||||
|
logging.info(msg)
|
||||||
|
|
||||||
|
|
||||||
|
def get_base_config(key, default=None):
|
||||||
|
if key is None:
|
||||||
|
return None
|
||||||
|
if default is None:
|
||||||
|
default = os.environ.get(key.upper())
|
||||||
|
return CONFIGS.get(key, default)
|
||||||
|
|
||||||
|
|
||||||
|
def decrypt_database_password(password):
|
||||||
|
encrypt_password = get_base_config("encrypt_password", False)
|
||||||
|
encrypt_module = get_base_config("encrypt_module", False)
|
||||||
|
private_key = get_base_config("private_key", None)
|
||||||
|
|
||||||
|
if not password or not encrypt_password:
|
||||||
|
return password
|
||||||
|
|
||||||
|
if not private_key:
|
||||||
|
raise ValueError("No private key")
|
||||||
|
|
||||||
|
module_fun = encrypt_module.split("#")
|
||||||
|
pwdecrypt_fun = getattr(
|
||||||
|
importlib.import_module(
|
||||||
|
module_fun[0]),
|
||||||
|
module_fun[1])
|
||||||
|
|
||||||
|
return pwdecrypt_fun(private_key, password)
|
||||||
|
|
||||||
|
|
||||||
|
def decrypt_database_config(
|
||||||
|
database=None, passwd_key="password", name="database"):
|
||||||
|
if not database:
|
||||||
|
database = get_base_config(name, {})
|
||||||
|
|
||||||
|
database[passwd_key] = decrypt_database_password(database[passwd_key])
|
||||||
|
return database
|
||||||
|
|
||||||
|
|
||||||
|
def update_config(key, value, conf_name=SERVICE_CONF):
|
||||||
|
conf_path = conf_realpath(conf_name=conf_name)
|
||||||
|
if not os.path.isabs(conf_path):
|
||||||
|
conf_path = os.path.join(
|
||||||
|
file_utils.get_project_base_directory(), conf_path)
|
||||||
|
|
||||||
|
with FileLock(os.path.join(os.path.dirname(conf_path), ".lock")):
|
||||||
|
config = file_utils.load_yaml_conf(conf_path=conf_path) or {}
|
||||||
|
config[key] = value
|
||||||
|
file_utils.rewrite_yaml_conf(conf_path=conf_path, config=config)
|
||||||
|
|
||||||
|
|
||||||
|
safe_module = {
|
||||||
|
'numpy',
|
||||||
|
'rag_flow'
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class RestrictedUnpickler(pickle.Unpickler):
|
||||||
|
def find_class(self, module, name):
|
||||||
|
import importlib
|
||||||
|
if module.split('.')[0] in safe_module:
|
||||||
|
_module = importlib.import_module(module)
|
||||||
|
return getattr(_module, name)
|
||||||
|
# Forbid everything else.
|
||||||
|
raise pickle.UnpicklingError("global '%s.%s' is forbidden" %
|
||||||
|
(module, name))
|
||||||
|
|
||||||
|
|
||||||
|
def restricted_loads(src):
|
||||||
|
"""Helper function analogous to pickle.loads()."""
|
||||||
|
return RestrictedUnpickler(io.BytesIO(src)).load()
|
||||||
|
|
||||||
|
|
||||||
|
def serialize_b64(src, to_str=False):
|
||||||
|
dest = base64.b64encode(pickle.dumps(src))
|
||||||
|
if not to_str:
|
||||||
|
return dest
|
||||||
|
else:
|
||||||
|
return bytes_to_string(dest)
|
||||||
|
|
||||||
|
|
||||||
|
def deserialize_b64(src):
|
||||||
|
src = base64.b64decode(
|
||||||
|
string_to_bytes(src) if isinstance(
|
||||||
|
src, str) else src)
|
||||||
|
use_deserialize_safe_module = get_base_config(
|
||||||
|
'use_deserialize_safe_module', False)
|
||||||
|
if use_deserialize_safe_module:
|
||||||
|
return restricted_loads(src)
|
||||||
|
return pickle.loads(src)
|
||||||
78
api/utils/json.py
Normal file
78
api/utils/json.py
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
import datetime
|
||||||
|
import json
|
||||||
|
from enum import Enum, IntEnum
|
||||||
|
from api.utils.common import string_to_bytes, bytes_to_string
|
||||||
|
|
||||||
|
|
||||||
|
class BaseType:
|
||||||
|
def to_dict(self):
|
||||||
|
return dict([(k.lstrip("_"), v) for k, v in self.__dict__.items()])
|
||||||
|
|
||||||
|
def to_dict_with_type(self):
|
||||||
|
def _dict(obj):
|
||||||
|
module = None
|
||||||
|
if issubclass(obj.__class__, BaseType):
|
||||||
|
data = {}
|
||||||
|
for attr, v in obj.__dict__.items():
|
||||||
|
k = attr.lstrip("_")
|
||||||
|
data[k] = _dict(v)
|
||||||
|
module = obj.__module__
|
||||||
|
elif isinstance(obj, (list, tuple)):
|
||||||
|
data = []
|
||||||
|
for i, vv in enumerate(obj):
|
||||||
|
data.append(_dict(vv))
|
||||||
|
elif isinstance(obj, dict):
|
||||||
|
data = {}
|
||||||
|
for _k, vv in obj.items():
|
||||||
|
data[_k] = _dict(vv)
|
||||||
|
else:
|
||||||
|
data = obj
|
||||||
|
return {"type": obj.__class__.__name__,
|
||||||
|
"data": data, "module": module}
|
||||||
|
|
||||||
|
return _dict(self)
|
||||||
|
|
||||||
|
|
||||||
|
class CustomJSONEncoder(json.JSONEncoder):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
self._with_type = kwargs.pop("with_type", False)
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
def default(self, obj):
|
||||||
|
if isinstance(obj, datetime.datetime):
|
||||||
|
return obj.strftime('%Y-%m-%d %H:%M:%S')
|
||||||
|
elif isinstance(obj, datetime.date):
|
||||||
|
return obj.strftime('%Y-%m-%d')
|
||||||
|
elif isinstance(obj, datetime.timedelta):
|
||||||
|
return str(obj)
|
||||||
|
elif issubclass(type(obj), Enum) or issubclass(type(obj), IntEnum):
|
||||||
|
return obj.value
|
||||||
|
elif isinstance(obj, set):
|
||||||
|
return list(obj)
|
||||||
|
elif issubclass(type(obj), BaseType):
|
||||||
|
if not self._with_type:
|
||||||
|
return obj.to_dict()
|
||||||
|
else:
|
||||||
|
return obj.to_dict_with_type()
|
||||||
|
elif isinstance(obj, type):
|
||||||
|
return obj.__name__
|
||||||
|
else:
|
||||||
|
return json.JSONEncoder.default(self, obj)
|
||||||
|
|
||||||
|
|
||||||
|
def json_dumps(src, byte=False, indent=None, with_type=False):
|
||||||
|
dest = json.dumps(
|
||||||
|
src,
|
||||||
|
indent=indent,
|
||||||
|
cls=CustomJSONEncoder,
|
||||||
|
with_type=with_type)
|
||||||
|
if byte:
|
||||||
|
dest = string_to_bytes(dest)
|
||||||
|
return dest
|
||||||
|
|
||||||
|
|
||||||
|
def json_loads(src, object_hook=None, object_pairs_hook=None):
|
||||||
|
if isinstance(src, bytes):
|
||||||
|
src = bytes_to_string(src)
|
||||||
|
return json.loads(src, object_hook=object_hook,
|
||||||
|
object_pairs_hook=object_pairs_hook)
|
||||||
@ -350,7 +350,7 @@ class TextRecognizer:
|
|||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
# close session and release manually
|
# close session and release manually
|
||||||
logging.info('Close TextRecognizer.')
|
logging.info('Close text recognizer.')
|
||||||
if hasattr(self, "predictor"):
|
if hasattr(self, "predictor"):
|
||||||
del self.predictor
|
del self.predictor
|
||||||
gc.collect()
|
gc.collect()
|
||||||
@ -490,7 +490,7 @@ class TextDetector:
|
|||||||
return dt_boxes
|
return dt_boxes
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
logging.info("Close TextDetector.")
|
logging.info("Close text detector.")
|
||||||
if hasattr(self, "predictor"):
|
if hasattr(self, "predictor"):
|
||||||
del self.predictor
|
del self.predictor
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|||||||
@ -15,7 +15,7 @@
|
|||||||
#
|
#
|
||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
from api.utils import get_base_config, decrypt_database_config
|
from api.utils.configs import get_base_config, decrypt_database_config
|
||||||
from api.utils.file_utils import get_project_base_directory
|
from api.utils.file_utils import get_project_base_directory
|
||||||
|
|
||||||
# Server
|
# Server
|
||||||
|
|||||||
@ -3,7 +3,7 @@ import logging
|
|||||||
import pymysql
|
import pymysql
|
||||||
from urllib.parse import quote_plus
|
from urllib.parse import quote_plus
|
||||||
|
|
||||||
from api.utils import get_base_config
|
from api.utils.configs import get_base_config
|
||||||
from rag.utils import singleton
|
from rag.utils import singleton
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user