apply pep8 formalize (#155)

This commit is contained in:
KevinHuSh
2024-03-27 11:33:46 +08:00
committed by GitHub
parent a02e836790
commit fd7fcb5baf
55 changed files with 1568 additions and 753 deletions

View File

@ -34,10 +34,12 @@ from . import file_utils
SERVICE_CONF = "service_conf.yaml"
def conf_realpath(conf_name):
conf_path = f"conf/{conf_name}"
return os.path.join(file_utils.get_project_base_directory(), conf_path)
def get_base_config(key, default=None, conf_name=SERVICE_CONF) -> dict:
local_config = {}
local_path = conf_realpath(f'local.{conf_name}')
@ -62,7 +64,8 @@ def get_base_config(key, default=None, conf_name=SERVICE_CONF) -> dict:
return config.get(key, default) if key is not None else config
use_deserialize_safe_module = get_base_config('use_deserialize_safe_module', False)
use_deserialize_safe_module = get_base_config(
'use_deserialize_safe_module', False)
class CoordinationCommunicationProtocol(object):
@ -93,7 +96,8 @@ class BaseType:
data[_k] = _dict(vv)
else:
data = obj
return {"type": obj.__class__.__name__, "data": data, "module": module}
return {"type": obj.__class__.__name__,
"data": data, "module": module}
return _dict(self)
@ -129,7 +133,8 @@ def rag_uuid():
def string_to_bytes(string):
return string if isinstance(string, bytes) else string.encode(encoding="utf-8")
return string if isinstance(
string, bytes) else string.encode(encoding="utf-8")
def bytes_to_string(byte):
@ -137,7 +142,11 @@ def bytes_to_string(byte):
def json_dumps(src, byte=False, indent=None, with_type=False):
dest = json.dumps(src, indent=indent, cls=CustomJSONEncoder, with_type=with_type)
dest = json.dumps(
src,
indent=indent,
cls=CustomJSONEncoder,
with_type=with_type)
if byte:
dest = string_to_bytes(dest)
return dest
@ -146,7 +155,8 @@ def json_dumps(src, byte=False, indent=None, with_type=False):
def json_loads(src, object_hook=None, object_pairs_hook=None):
if isinstance(src, bytes):
src = bytes_to_string(src)
return json.loads(src, object_hook=object_hook, object_pairs_hook=object_pairs_hook)
return json.loads(src, object_hook=object_hook,
object_pairs_hook=object_pairs_hook)
def current_timestamp():
@ -177,7 +187,9 @@ def serialize_b64(src, to_str=False):
def deserialize_b64(src):
src = base64.b64decode(string_to_bytes(src) if isinstance(src, str) else src)
src = base64.b64decode(
string_to_bytes(src) if isinstance(
src, str) else src)
if use_deserialize_safe_module:
return restricted_loads(src)
return pickle.loads(src)
@ -237,12 +249,14 @@ def get_lan_ip():
pass
return ip or ''
def from_dict_hook(in_dict: dict):
if "type" in in_dict and "data" in in_dict:
if in_dict["module"] is None:
return in_dict["data"]
else:
return getattr(importlib.import_module(in_dict["module"]), in_dict["type"])(**in_dict["data"])
return getattr(importlib.import_module(
in_dict["module"]), in_dict["type"])(**in_dict["data"])
else:
return in_dict
@ -259,12 +273,16 @@ def decrypt_database_password(password):
raise ValueError("No private key")
module_fun = encrypt_module.split("#")
pwdecrypt_fun = getattr(importlib.import_module(module_fun[0]), module_fun[1])
pwdecrypt_fun = getattr(
importlib.import_module(
module_fun[0]),
module_fun[1])
return pwdecrypt_fun(private_key, password)
def decrypt_database_config(database=None, passwd_key="password", name="database"):
def decrypt_database_config(
database=None, passwd_key="password", name="database"):
if not database:
database = get_base_config(name, {})
@ -275,7 +293,8 @@ def decrypt_database_config(database=None, passwd_key="password", name="database
def update_config(key, value, conf_name=SERVICE_CONF):
conf_path = conf_realpath(conf_name=conf_name)
if not os.path.isabs(conf_path):
conf_path = os.path.join(file_utils.get_project_base_directory(), conf_path)
conf_path = os.path.join(
file_utils.get_project_base_directory(), conf_path)
with FileLock(os.path.join(os.path.dirname(conf_path), ".lock")):
config = file_utils.load_yaml_conf(conf_path=conf_path) or {}
@ -288,7 +307,8 @@ def get_uuid():
def datetime_format(date_time: datetime.datetime) -> datetime.datetime:
return datetime.datetime(date_time.year, date_time.month, date_time.day, date_time.hour, date_time.minute, date_time.second)
return datetime.datetime(date_time.year, date_time.month, date_time.day,
date_time.hour, date_time.minute, date_time.second)
def get_format_time() -> datetime.datetime:
@ -307,14 +327,19 @@ def elapsed2time(elapsed):
def decrypt(line):
file_path = os.path.join(file_utils.get_project_base_directory(), "conf", "private.pem")
file_path = os.path.join(
file_utils.get_project_base_directory(),
"conf",
"private.pem")
rsa_key = RSA.importKey(open(file_path).read(), "Welcome")
cipher = Cipher_pkcs1_v1_5.new(rsa_key)
return cipher.decrypt(base64.b64decode(line), "Fail to decrypt password!").decode('utf-8')
return cipher.decrypt(base64.b64decode(
line), "Fail to decrypt password!").decode('utf-8')
def download_img(url):
if not url: return ""
if not url:
return ""
response = requests.get(url)
return "data:" + \
response.headers.get('Content-Type', 'image/jpg') + ";" + \

View File

@ -19,7 +19,7 @@ import time
from functools import wraps
from io import BytesIO
from flask import (
Response, jsonify, send_file,make_response,
Response, jsonify, send_file, make_response,
request as flask_request,
)
from werkzeug.http import HTTP_STATUS_CODES
@ -29,7 +29,7 @@ from api.versions import get_rag_version
from api.settings import RetCode
from api.settings import (
REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC,
stat_logger,CLIENT_AUTHENTICATION, HTTP_APP_KEY, SECRET_KEY
stat_logger, CLIENT_AUTHENTICATION, HTTP_APP_KEY, SECRET_KEY
)
import requests
import functools
@ -40,14 +40,21 @@ from hmac import HMAC
from urllib.parse import quote, urlencode
requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSONEncoder)
requests.models.complexjson.dumps = functools.partial(
json.dumps, cls=CustomJSONEncoder)
def request(**kwargs):
sess = requests.Session()
stream = kwargs.pop('stream', sess.stream)
timeout = kwargs.pop('timeout', None)
kwargs['headers'] = {k.replace('_', '-').upper(): v for k, v in kwargs.get('headers', {}).items()}
kwargs['headers'] = {
k.replace(
'_',
'-').upper(): v for k,
v in kwargs.get(
'headers',
{}).items()}
prepped = requests.Request(**kwargs).prepare()
if CLIENT_AUTHENTICATION and HTTP_APP_KEY and SECRET_KEY:
@ -59,7 +66,11 @@ def request(**kwargs):
HTTP_APP_KEY.encode('ascii'),
prepped.path_url.encode('ascii'),
prepped.body if kwargs.get('json') else b'',
urlencode(sorted(kwargs['data'].items()), quote_via=quote, safe='-._~').encode('ascii')
urlencode(
sorted(
kwargs['data'].items()),
quote_via=quote,
safe='-._~').encode('ascii')
if kwargs.get('data') and isinstance(kwargs['data'], dict) else b'',
]), 'sha1').digest()).decode('ascii')
@ -88,11 +99,12 @@ def get_exponential_backoff_interval(retries, full_jitter=False):
return max(0, countdown)
def get_json_result(retcode=RetCode.SUCCESS, retmsg='success', data=None, job_id=None, meta=None):
def get_json_result(retcode=RetCode.SUCCESS, retmsg='success',
data=None, job_id=None, meta=None):
import re
result_dict = {
"retcode": retcode,
"retmsg":retmsg,
"retmsg": retmsg,
# "retmsg": re.sub(r"rag", "seceum", retmsg, flags=re.IGNORECASE),
"data": data,
"jobId": job_id,
@ -107,9 +119,17 @@ def get_json_result(retcode=RetCode.SUCCESS, retmsg='success', data=None, job_id
response[key] = value
return jsonify(response)
def get_data_error_result(retcode=RetCode.DATA_ERROR, retmsg='Sorry! Data missing!'):
def get_data_error_result(retcode=RetCode.DATA_ERROR,
retmsg='Sorry! Data missing!'):
import re
result_dict = {"retcode": retcode, "retmsg": re.sub(r"rag", "seceum", retmsg, flags=re.IGNORECASE)}
result_dict = {
"retcode": retcode,
"retmsg": re.sub(
r"rag",
"seceum",
retmsg,
flags=re.IGNORECASE)}
response = {}
for key, value in result_dict.items():
if value is None and key != "retcode":
@ -118,15 +138,17 @@ def get_data_error_result(retcode=RetCode.DATA_ERROR, retmsg='Sorry! Data missin
response[key] = value
return jsonify(response)
def server_error_response(e):
stat_logger.exception(e)
try:
if e.code==401:
if e.code == 401:
return get_json_result(retcode=401, retmsg=repr(e))
except:
except BaseException:
pass
if len(e.args) > 1:
return get_json_result(retcode=RetCode.EXCEPTION_ERROR, retmsg=repr(e.args[0]), data=e.args[1])
return get_json_result(
retcode=RetCode.EXCEPTION_ERROR, retmsg=repr(e.args[0]), data=e.args[1])
return get_json_result(retcode=RetCode.EXCEPTION_ERROR, retmsg=repr(e))
@ -162,10 +184,13 @@ def validate_request(*args, **kwargs):
if no_arguments or error_arguments:
error_string = ""
if no_arguments:
error_string += "required argument are missing: {}; ".format(",".join(no_arguments))
error_string += "required argument are missing: {}; ".format(
",".join(no_arguments))
if error_arguments:
error_string += "required argument values: {}".format(",".join(["{}={}".format(a[0], a[1]) for a in error_arguments]))
return get_json_result(retcode=RetCode.ARGUMENT_ERROR, retmsg=error_string)
error_string += "required argument values: {}".format(
",".join(["{}={}".format(a[0], a[1]) for a in error_arguments]))
return get_json_result(
retcode=RetCode.ARGUMENT_ERROR, retmsg=error_string)
return func(*_args, **_kwargs)
return decorated_function
return wrapper
@ -193,7 +218,8 @@ def get_json_result(retcode=RetCode.SUCCESS, retmsg='success', data=None):
return jsonify(response)
def cors_reponse(retcode=RetCode.SUCCESS, retmsg='success', data=None, auth=None):
def cors_reponse(retcode=RetCode.SUCCESS,
retmsg='success', data=None, auth=None):
result_dict = {"retcode": retcode, "retmsg": retmsg, "data": data}
response_dict = {}
for key, value in result_dict.items():
@ -209,4 +235,4 @@ def cors_reponse(retcode=RetCode.SUCCESS, retmsg='success', data=None, auth=None
response.headers["Access-Control-Allow-Headers"] = "*"
response.headers["Access-Control-Allow-Headers"] = "*"
response.headers["Access-Control-Expose-Headers"] = "Authorization"
return response
return response

View File

@ -29,6 +29,7 @@ from api.db import FileType
PROJECT_BASE = os.getenv("RAG_PROJECT_BASE") or os.getenv("RAG_DEPLOY_BASE")
RAG_BASE = os.getenv("RAG_BASE")
def get_project_base_directory(*args):
global PROJECT_BASE
if PROJECT_BASE is None:
@ -65,7 +66,6 @@ def get_rag_python_directory(*args):
return get_rag_directory("python", *args)
@cached(cache=LRUCache(maxsize=10))
def load_json_conf(conf_path):
if os.path.isabs(conf_path):
@ -146,10 +146,12 @@ def filename_type(filename):
if re.match(r".*\.pdf$", filename):
return FileType.PDF.value
if re.match(r".*\.(docx|doc|ppt|pptx|yml|xml|htm|json|csv|txt|ini|xls|xlsx|wps|rtf|hlp|pages|numbers|key|md)$", filename):
if re.match(
r".*\.(docx|doc|ppt|pptx|yml|xml|htm|json|csv|txt|ini|xls|xlsx|wps|rtf|hlp|pages|numbers|key|md)$", filename):
return FileType.DOC.value
if re.match(r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)$", filename):
if re.match(
r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)$", filename):
return FileType.AURAL.value
if re.match(r".*\.(jpg|jpeg|png|tif|gif|pcx|tga|exif|fpx|svg|psd|cdr|pcd|dxf|ufo|eps|ai|raw|WMF|webp|avif|apng|icon|ico|mpg|mpeg|avi|rm|rmvb|mov|wmv|asf|dat|asx|wvx|mpe|mpa|mp4)$", filename):
@ -164,14 +166,16 @@ def thumbnail(filename, blob):
buffered = BytesIO()
Image.frombytes("RGB", [pix.width, pix.height],
pix.samples).save(buffered, format="png")
return "data:image/png;base64," + base64.b64encode(buffered.getvalue()).decode("utf-8")
return "data:image/png;base64," + \
base64.b64encode(buffered.getvalue()).decode("utf-8")
if re.match(r".*\.(jpg|jpeg|png|tif|gif|icon|ico|webp)$", filename):
image = Image.open(BytesIO(blob))
image.thumbnail((30, 30))
buffered = BytesIO()
image.save(buffered, format="png")
return "data:image/png;base64," + base64.b64encode(buffered.getvalue()).decode("utf-8")
return "data:image/png;base64," + \
base64.b64encode(buffered.getvalue()).decode("utf-8")
if re.match(r".*\.(ppt|pptx)$", filename):
import aspose.slides as slides
@ -179,8 +183,10 @@ def thumbnail(filename, blob):
try:
with slides.Presentation(BytesIO(blob)) as presentation:
buffered = BytesIO()
presentation.slides[0].get_thumbnail(0.03, 0.03).save(buffered, drawing.imaging.ImageFormat.png)
return "data:image/png;base64," + base64.b64encode(buffered.getvalue()).decode("utf-8")
presentation.slides[0].get_thumbnail(0.03, 0.03).save(
buffered, drawing.imaging.ImageFormat.png)
return "data:image/png;base64," + \
base64.b64encode(buffered.getvalue()).decode("utf-8")
except Exception as e:
pass
@ -190,6 +196,3 @@ def traversal_files(base):
for f in fs:
fullname = os.path.join(root, f)
yield fullname

View File

@ -23,6 +23,7 @@ from threading import RLock
from api.utils import file_utils
class LoggerFactory(object):
TYPE = "FILE"
LOG_FORMAT = "[%(levelname)s] [%(asctime)s] [jobId] [%(process)s:%(thread)s] - [%(module)s.%(funcName)s] [line:%(lineno)d]: %(message)s"
@ -49,7 +50,8 @@ class LoggerFactory(object):
schedule_logger_dict = {}
@staticmethod
def set_directory(directory=None, parent_log_dir=None, append_to_parent_log=None, force=False):
def set_directory(directory=None, parent_log_dir=None,
append_to_parent_log=None, force=False):
if parent_log_dir:
LoggerFactory.PARENT_LOG_DIR = parent_log_dir
if append_to_parent_log:
@ -66,11 +68,13 @@ class LoggerFactory(object):
else:
os.makedirs(LoggerFactory.LOG_DIR, exist_ok=True)
for loggerName, ghandler in LoggerFactory.global_handler_dict.items():
for className, (logger, handler) in LoggerFactory.logger_dict.items():
for className, (logger,
handler) in LoggerFactory.logger_dict.items():
logger.removeHandler(ghandler)
ghandler.close()
LoggerFactory.global_handler_dict = {}
for className, (logger, handler) in LoggerFactory.logger_dict.items():
for className, (logger,
handler) in LoggerFactory.logger_dict.items():
logger.removeHandler(handler)
_handler = None
if handler:
@ -111,19 +115,23 @@ class LoggerFactory(object):
if logger_name_key not in LoggerFactory.global_handler_dict:
with LoggerFactory.lock:
if logger_name_key not in LoggerFactory.global_handler_dict:
handler = LoggerFactory.get_handler(logger_name, level, log_dir)
handler = LoggerFactory.get_handler(
logger_name, level, log_dir)
LoggerFactory.global_handler_dict[logger_name_key] = handler
return LoggerFactory.global_handler_dict[logger_name_key]
@staticmethod
def get_handler(class_name, level=None, log_dir=None, log_type=None, job_id=None):
def get_handler(class_name, level=None, log_dir=None,
log_type=None, job_id=None):
if not log_type:
if not LoggerFactory.LOG_DIR or not class_name:
return logging.StreamHandler()
# return Diy_StreamHandler()
if not log_dir:
log_file = os.path.join(LoggerFactory.LOG_DIR, "{}.log".format(class_name))
log_file = os.path.join(
LoggerFactory.LOG_DIR,
"{}.log".format(class_name))
else:
log_file = os.path.join(log_dir, "{}.log".format(class_name))
else:
@ -133,16 +141,16 @@ class LoggerFactory(object):
os.makedirs(os.path.dirname(log_file), exist_ok=True)
if LoggerFactory.log_share:
handler = ROpenHandler(log_file,
when='D',
interval=1,
backupCount=14,
delay=True)
when='D',
interval=1,
backupCount=14,
delay=True)
else:
handler = TimedRotatingFileHandler(log_file,
when='D',
interval=1,
backupCount=14,
delay=True)
when='D',
interval=1,
backupCount=14,
delay=True)
if level:
handler.level = level
@ -170,7 +178,9 @@ class LoggerFactory(object):
for level in LoggerFactory.levels:
if level >= LoggerFactory.LEVEL:
level_logger_name = logging._levelToName[level]
logger.addHandler(LoggerFactory.get_global_handler(level_logger_name, level))
logger.addHandler(
LoggerFactory.get_global_handler(
level_logger_name, level))
if LoggerFactory.append_to_parent_log and LoggerFactory.PARENT_LOG_DIR:
for level in LoggerFactory.levels:
if level >= LoggerFactory.LEVEL:
@ -224,22 +234,26 @@ def start_log(msg, job=None, task=None, role=None, party_id=None, detail=None):
return f"{prefix}start to {msg}{suffix}"
def successful_log(msg, job=None, task=None, role=None, party_id=None, detail=None):
def successful_log(msg, job=None, task=None, role=None,
party_id=None, detail=None):
prefix, suffix = base_msg(job, task, role, party_id, detail)
return f"{prefix}{msg} successfully{suffix}"
def warning_log(msg, job=None, task=None, role=None, party_id=None, detail=None):
def warning_log(msg, job=None, task=None, role=None,
party_id=None, detail=None):
prefix, suffix = base_msg(job, task, role, party_id, detail)
return f"{prefix}{msg} is not effective{suffix}"
def failed_log(msg, job=None, task=None, role=None, party_id=None, detail=None):
def failed_log(msg, job=None, task=None, role=None,
party_id=None, detail=None):
prefix, suffix = base_msg(job, task, role, party_id, detail)
return f"{prefix}failed to {msg}{suffix}"
def base_msg(job=None, task=None, role: str = None, party_id: typing.Union[str, int] = None, detail=None):
def base_msg(job=None, task=None, role: str = None,
party_id: typing.Union[str, int] = None, detail=None):
if detail:
detail_msg = f" detail: \n{detail}"
else:
@ -285,10 +299,14 @@ def get_job_logger(job_id, log_type):
for job_log_dir in log_dirs:
handler = LoggerFactory.get_handler(class_name=None, level=LoggerFactory.LEVEL,
log_dir=job_log_dir, log_type=log_type, job_id=job_id)
error_handler = LoggerFactory.get_handler(class_name=None, level=logging.ERROR, log_dir=job_log_dir, log_type=log_type, job_id=job_id)
error_handler = LoggerFactory.get_handler(
class_name=None,
level=logging.ERROR,
log_dir=job_log_dir,
log_type=log_type,
job_id=job_id)
logger.addHandler(handler)
logger.addHandler(error_handler)
with LoggerFactory.lock:
LoggerFactory.schedule_logger_dict[job_id + log_type] = logger
return logger

View File

@ -1,18 +1,23 @@
import base64, os, sys
import base64
import os
import sys
from Cryptodome.PublicKey import RSA
from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5
from api.utils import decrypt, file_utils
def crypt(line):
file_path = os.path.join(file_utils.get_project_base_directory(), "conf", "public.pem")
file_path = os.path.join(
file_utils.get_project_base_directory(),
"conf",
"public.pem")
rsa_key = RSA.importKey(open(file_path).read())
cipher = Cipher_pkcs1_v1_5.new(rsa_key)
return base64.b64encode(cipher.encrypt(line.encode('utf-8'))).decode("utf-8")
return base64.b64encode(cipher.encrypt(
line.encode('utf-8'))).decode("utf-8")
if __name__ == "__main__":
pswd = crypt(sys.argv[1])
print(pswd)
print(decrypt(pswd))