mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
rename web_server to api (#29)
* add front end code * change licence * rename web_server to API * change name to web_server
This commit is contained in:
321
api/utils/__init__.py
Normal file
321
api/utils/__init__.py
Normal file
@ -0,0 +1,321 @@
|
||||
#
|
||||
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import base64
|
||||
import datetime
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
import socket
|
||||
import time
|
||||
import uuid
|
||||
import requests
|
||||
from enum import Enum, IntEnum
|
||||
import importlib
|
||||
from Cryptodome.PublicKey import RSA
|
||||
from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5
|
||||
|
||||
from filelock import FileLock
|
||||
|
||||
from . import file_utils
|
||||
|
||||
SERVICE_CONF = "service_conf.yaml"
|
||||
|
||||
def conf_realpath(conf_name):
|
||||
conf_path = f"conf/{conf_name}"
|
||||
return os.path.join(file_utils.get_project_base_directory(), conf_path)
|
||||
|
||||
def get_base_config(key, default=None, conf_name=SERVICE_CONF) -> dict:
|
||||
local_config = {}
|
||||
local_path = conf_realpath(f'local.{conf_name}')
|
||||
if default is None:
|
||||
default = os.environ.get(key.upper())
|
||||
|
||||
if os.path.exists(local_path):
|
||||
local_config = file_utils.load_yaml_conf(local_path)
|
||||
if not isinstance(local_config, dict):
|
||||
raise ValueError(f'Invalid config file: "{local_path}".')
|
||||
|
||||
if key is not None and key in local_config:
|
||||
return local_config[key]
|
||||
|
||||
config_path = conf_realpath(conf_name)
|
||||
config = file_utils.load_yaml_conf(config_path)
|
||||
|
||||
if not isinstance(config, dict):
|
||||
raise ValueError(f'Invalid config file: "{config_path}".')
|
||||
|
||||
config.update(local_config)
|
||||
return config.get(key, default) if key is not None else config
|
||||
|
||||
|
||||
use_deserialize_safe_module = get_base_config('use_deserialize_safe_module', False)
|
||||
|
||||
|
||||
class CoordinationCommunicationProtocol(object):
|
||||
HTTP = "http"
|
||||
GRPC = "grpc"
|
||||
|
||||
|
||||
class BaseType:
|
||||
def to_dict(self):
|
||||
return dict([(k.lstrip("_"), v) for k, v in self.__dict__.items()])
|
||||
|
||||
def to_dict_with_type(self):
|
||||
def _dict(obj):
|
||||
module = None
|
||||
if issubclass(obj.__class__, BaseType):
|
||||
data = {}
|
||||
for attr, v in obj.__dict__.items():
|
||||
k = attr.lstrip("_")
|
||||
data[k] = _dict(v)
|
||||
module = obj.__module__
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
data = []
|
||||
for i, vv in enumerate(obj):
|
||||
data.append(_dict(vv))
|
||||
elif isinstance(obj, dict):
|
||||
data = {}
|
||||
for _k, vv in obj.items():
|
||||
data[_k] = _dict(vv)
|
||||
else:
|
||||
data = obj
|
||||
return {"type": obj.__class__.__name__, "data": data, "module": module}
|
||||
return _dict(self)
|
||||
|
||||
|
||||
class CustomJSONEncoder(json.JSONEncoder):
|
||||
def __init__(self, **kwargs):
|
||||
self._with_type = kwargs.pop("with_type", False)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def default(self, obj):
|
||||
if isinstance(obj, datetime.datetime):
|
||||
return obj.strftime('%Y-%m-%d %H:%M:%S')
|
||||
elif isinstance(obj, datetime.date):
|
||||
return obj.strftime('%Y-%m-%d')
|
||||
elif isinstance(obj, datetime.timedelta):
|
||||
return str(obj)
|
||||
elif issubclass(type(obj), Enum) or issubclass(type(obj), IntEnum):
|
||||
return obj.value
|
||||
elif isinstance(obj, set):
|
||||
return list(obj)
|
||||
elif issubclass(type(obj), BaseType):
|
||||
if not self._with_type:
|
||||
return obj.to_dict()
|
||||
else:
|
||||
return obj.to_dict_with_type()
|
||||
elif isinstance(obj, type):
|
||||
return obj.__name__
|
||||
else:
|
||||
return json.JSONEncoder.default(self, obj)
|
||||
|
||||
|
||||
def rag_uuid():
|
||||
return uuid.uuid1().hex
|
||||
|
||||
|
||||
def string_to_bytes(string):
|
||||
return string if isinstance(string, bytes) else string.encode(encoding="utf-8")
|
||||
|
||||
|
||||
def bytes_to_string(byte):
|
||||
return byte.decode(encoding="utf-8")
|
||||
|
||||
|
||||
def json_dumps(src, byte=False, indent=None, with_type=False):
|
||||
dest = json.dumps(src, indent=indent, cls=CustomJSONEncoder, with_type=with_type)
|
||||
if byte:
|
||||
dest = string_to_bytes(dest)
|
||||
return dest
|
||||
|
||||
|
||||
def json_loads(src, object_hook=None, object_pairs_hook=None):
|
||||
if isinstance(src, bytes):
|
||||
src = bytes_to_string(src)
|
||||
return json.loads(src, object_hook=object_hook, object_pairs_hook=object_pairs_hook)
|
||||
|
||||
|
||||
def current_timestamp():
|
||||
return int(time.time() * 1000)
|
||||
|
||||
|
||||
def timestamp_to_date(timestamp, format_string="%Y-%m-%d %H:%M:%S"):
|
||||
if not timestamp:
|
||||
timestamp = time.time()
|
||||
timestamp = int(timestamp) / 1000
|
||||
time_array = time.localtime(timestamp)
|
||||
str_date = time.strftime(format_string, time_array)
|
||||
return str_date
|
||||
|
||||
|
||||
def date_string_to_timestamp(time_str, format_string="%Y-%m-%d %H:%M:%S"):
|
||||
time_array = time.strptime(time_str, format_string)
|
||||
time_stamp = int(time.mktime(time_array) * 1000)
|
||||
return time_stamp
|
||||
|
||||
|
||||
def serialize_b64(src, to_str=False):
|
||||
dest = base64.b64encode(pickle.dumps(src))
|
||||
if not to_str:
|
||||
return dest
|
||||
else:
|
||||
return bytes_to_string(dest)
|
||||
|
||||
|
||||
def deserialize_b64(src):
|
||||
src = base64.b64decode(string_to_bytes(src) if isinstance(src, str) else src)
|
||||
if use_deserialize_safe_module:
|
||||
return restricted_loads(src)
|
||||
return pickle.loads(src)
|
||||
|
||||
|
||||
safe_module = {
|
||||
'numpy',
|
||||
'rag_flow'
|
||||
}
|
||||
|
||||
|
||||
class RestrictedUnpickler(pickle.Unpickler):
|
||||
def find_class(self, module, name):
|
||||
import importlib
|
||||
if module.split('.')[0] in safe_module:
|
||||
_module = importlib.import_module(module)
|
||||
return getattr(_module, name)
|
||||
# Forbid everything else.
|
||||
raise pickle.UnpicklingError("global '%s.%s' is forbidden" %
|
||||
(module, name))
|
||||
|
||||
|
||||
def restricted_loads(src):
|
||||
"""Helper function analogous to pickle.loads()."""
|
||||
return RestrictedUnpickler(io.BytesIO(src)).load()
|
||||
|
||||
|
||||
def get_lan_ip():
|
||||
if os.name != "nt":
|
||||
import fcntl
|
||||
import struct
|
||||
|
||||
def get_interface_ip(ifname):
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
return socket.inet_ntoa(
|
||||
fcntl.ioctl(s.fileno(), 0x8915, struct.pack('256s', string_to_bytes(ifname[:15])))[20:24])
|
||||
|
||||
ip = socket.gethostbyname(socket.getfqdn())
|
||||
if ip.startswith("127.") and os.name != "nt":
|
||||
interfaces = [
|
||||
"bond1",
|
||||
"eth0",
|
||||
"eth1",
|
||||
"eth2",
|
||||
"wlan0",
|
||||
"wlan1",
|
||||
"wifi0",
|
||||
"ath0",
|
||||
"ath1",
|
||||
"ppp0",
|
||||
]
|
||||
for ifname in interfaces:
|
||||
try:
|
||||
ip = get_interface_ip(ifname)
|
||||
break
|
||||
except IOError as e:
|
||||
pass
|
||||
return ip or ''
|
||||
|
||||
def from_dict_hook(in_dict: dict):
|
||||
if "type" in in_dict and "data" in in_dict:
|
||||
if in_dict["module"] is None:
|
||||
return in_dict["data"]
|
||||
else:
|
||||
return getattr(importlib.import_module(in_dict["module"]), in_dict["type"])(**in_dict["data"])
|
||||
else:
|
||||
return in_dict
|
||||
|
||||
|
||||
def decrypt_database_password(password):
|
||||
encrypt_password = get_base_config("encrypt_password", False)
|
||||
encrypt_module = get_base_config("encrypt_module", False)
|
||||
private_key = get_base_config("private_key", None)
|
||||
|
||||
if not password or not encrypt_password:
|
||||
return password
|
||||
|
||||
if not private_key:
|
||||
raise ValueError("No private key")
|
||||
|
||||
module_fun = encrypt_module.split("#")
|
||||
pwdecrypt_fun = getattr(importlib.import_module(module_fun[0]), module_fun[1])
|
||||
|
||||
return pwdecrypt_fun(private_key, password)
|
||||
|
||||
|
||||
def decrypt_database_config(database=None, passwd_key="passwd", name="database"):
|
||||
if not database:
|
||||
database = get_base_config(name, {})
|
||||
|
||||
database[passwd_key] = decrypt_database_password(database[passwd_key])
|
||||
return database
|
||||
|
||||
|
||||
def update_config(key, value, conf_name=SERVICE_CONF):
|
||||
conf_path = conf_realpath(conf_name=conf_name)
|
||||
if not os.path.isabs(conf_path):
|
||||
conf_path = os.path.join(file_utils.get_project_base_directory(), conf_path)
|
||||
|
||||
with FileLock(os.path.join(os.path.dirname(conf_path), ".lock")):
|
||||
config = file_utils.load_yaml_conf(conf_path=conf_path) or {}
|
||||
config[key] = value
|
||||
file_utils.rewrite_yaml_conf(conf_path=conf_path, config=config)
|
||||
|
||||
|
||||
def get_uuid():
|
||||
return uuid.uuid1().hex
|
||||
|
||||
|
||||
def datetime_format(date_time: datetime.datetime) -> datetime.datetime:
|
||||
return datetime.datetime(date_time.year, date_time.month, date_time.day, date_time.hour, date_time.minute, date_time.second)
|
||||
|
||||
|
||||
def get_format_time() -> datetime.datetime:
|
||||
return datetime_format(datetime.datetime.now())
|
||||
|
||||
|
||||
def str2date(date_time: str):
|
||||
return datetime.datetime.strptime(date_time, '%Y-%m-%d')
|
||||
|
||||
|
||||
def elapsed2time(elapsed):
|
||||
seconds = elapsed / 1000
|
||||
minuter, second = divmod(seconds, 60)
|
||||
hour, minuter = divmod(minuter, 60)
|
||||
return '%02d:%02d:%02d' % (hour, minuter, second)
|
||||
|
||||
|
||||
def decrypt(line):
|
||||
file_path = os.path.join(file_utils.get_project_base_directory(), "conf", "private.pem")
|
||||
rsa_key = RSA.importKey(open(file_path).read(), "Welcome")
|
||||
cipher = Cipher_pkcs1_v1_5.new(rsa_key)
|
||||
return cipher.decrypt(base64.b64decode(line), "Fail to decrypt password!").decode('utf-8')
|
||||
|
||||
|
||||
def download_img(url):
|
||||
if not url: return ""
|
||||
response = requests.get(url)
|
||||
return "data:" + \
|
||||
response.headers.get('Content-Type', 'image/jpg') + ";" + \
|
||||
"base64," + base64.b64encode(response.content).decode("utf-8")
|
||||
212
api/utils/api_utils.py
Normal file
212
api/utils/api_utils.py
Normal file
@ -0,0 +1,212 @@
|
||||
#
|
||||
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import json
|
||||
import random
|
||||
import time
|
||||
from functools import wraps
|
||||
from io import BytesIO
|
||||
from flask import (
|
||||
Response, jsonify, send_file,make_response,
|
||||
request as flask_request,
|
||||
)
|
||||
from werkzeug.http import HTTP_STATUS_CODES
|
||||
|
||||
from web_server.utils import json_dumps
|
||||
from web_server.versions import get_rag_version
|
||||
from web_server.settings import RetCode
|
||||
from web_server.settings import (
|
||||
REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC,
|
||||
stat_logger,CLIENT_AUTHENTICATION, HTTP_APP_KEY, SECRET_KEY
|
||||
)
|
||||
import requests
|
||||
import functools
|
||||
from web_server.utils import CustomJSONEncoder
|
||||
from uuid import uuid1
|
||||
from base64 import b64encode
|
||||
from hmac import HMAC
|
||||
from urllib.parse import quote, urlencode
|
||||
|
||||
|
||||
requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSONEncoder)
|
||||
|
||||
|
||||
def request(**kwargs):
|
||||
sess = requests.Session()
|
||||
stream = kwargs.pop('stream', sess.stream)
|
||||
timeout = kwargs.pop('timeout', None)
|
||||
kwargs['headers'] = {k.replace('_', '-').upper(): v for k, v in kwargs.get('headers', {}).items()}
|
||||
prepped = requests.Request(**kwargs).prepare()
|
||||
|
||||
if CLIENT_AUTHENTICATION and HTTP_APP_KEY and SECRET_KEY:
|
||||
timestamp = str(round(time() * 1000))
|
||||
nonce = str(uuid1())
|
||||
signature = b64encode(HMAC(SECRET_KEY.encode('ascii'), b'\n'.join([
|
||||
timestamp.encode('ascii'),
|
||||
nonce.encode('ascii'),
|
||||
HTTP_APP_KEY.encode('ascii'),
|
||||
prepped.path_url.encode('ascii'),
|
||||
prepped.body if kwargs.get('json') else b'',
|
||||
urlencode(sorted(kwargs['data'].items()), quote_via=quote, safe='-._~').encode('ascii')
|
||||
if kwargs.get('data') and isinstance(kwargs['data'], dict) else b'',
|
||||
]), 'sha1').digest()).decode('ascii')
|
||||
|
||||
prepped.headers.update({
|
||||
'TIMESTAMP': timestamp,
|
||||
'NONCE': nonce,
|
||||
'APP-KEY': HTTP_APP_KEY,
|
||||
'SIGNATURE': signature,
|
||||
})
|
||||
|
||||
return sess.send(prepped, stream=stream, timeout=timeout)
|
||||
|
||||
|
||||
rag_version = get_rag_version() or ''
|
||||
|
||||
|
||||
def get_exponential_backoff_interval(retries, full_jitter=False):
|
||||
"""Calculate the exponential backoff wait time."""
|
||||
# Will be zero if factor equals 0
|
||||
countdown = min(REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC * (2 ** retries))
|
||||
# Full jitter according to
|
||||
# https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/
|
||||
if full_jitter:
|
||||
countdown = random.randrange(countdown + 1)
|
||||
# Adjust according to maximum wait time and account for negative values.
|
||||
return max(0, countdown)
|
||||
|
||||
|
||||
def get_json_result(retcode=RetCode.SUCCESS, retmsg='success', data=None, job_id=None, meta=None):
|
||||
import re
|
||||
result_dict = {
|
||||
"retcode": retcode,
|
||||
"retmsg":retmsg,
|
||||
# "retmsg": re.sub(r"rag", "seceum", retmsg, flags=re.IGNORECASE),
|
||||
"data": data,
|
||||
"jobId": job_id,
|
||||
"meta": meta,
|
||||
}
|
||||
|
||||
response = {}
|
||||
for key, value in result_dict.items():
|
||||
if value is None and key != "retcode":
|
||||
continue
|
||||
else:
|
||||
response[key] = value
|
||||
return jsonify(response)
|
||||
|
||||
def get_data_error_result(retcode=RetCode.DATA_ERROR, retmsg='Sorry! Data missing!'):
|
||||
import re
|
||||
result_dict = {"retcode": retcode, "retmsg": re.sub(r"rag", "seceum", retmsg, flags=re.IGNORECASE)}
|
||||
response = {}
|
||||
for key, value in result_dict.items():
|
||||
if value is None and key != "retcode":
|
||||
continue
|
||||
else:
|
||||
response[key] = value
|
||||
return jsonify(response)
|
||||
|
||||
def server_error_response(e):
|
||||
stat_logger.exception(e)
|
||||
try:
|
||||
if e.code==401:
|
||||
return get_json_result(retcode=401, retmsg=repr(e))
|
||||
except:
|
||||
pass
|
||||
if len(e.args) > 1:
|
||||
return get_json_result(retcode=RetCode.EXCEPTION_ERROR, retmsg=repr(e.args[0]), data=e.args[1])
|
||||
return get_json_result(retcode=RetCode.EXCEPTION_ERROR, retmsg=repr(e))
|
||||
|
||||
|
||||
def error_response(response_code, retmsg=None):
|
||||
if retmsg is None:
|
||||
retmsg = HTTP_STATUS_CODES.get(response_code, 'Unknown Error')
|
||||
|
||||
return Response(json.dumps({
|
||||
'retmsg': retmsg,
|
||||
'retcode': response_code,
|
||||
}), status=response_code, mimetype='application/json')
|
||||
|
||||
|
||||
def validate_request(*args, **kwargs):
|
||||
def wrapper(func):
|
||||
@wraps(func)
|
||||
def decorated_function(*_args, **_kwargs):
|
||||
input_arguments = flask_request.json or flask_request.form.to_dict()
|
||||
no_arguments = []
|
||||
error_arguments = []
|
||||
for arg in args:
|
||||
if arg not in input_arguments:
|
||||
no_arguments.append(arg)
|
||||
for k, v in kwargs.items():
|
||||
config_value = input_arguments.get(k, None)
|
||||
if config_value is None:
|
||||
no_arguments.append(k)
|
||||
elif isinstance(v, (tuple, list)):
|
||||
if config_value not in v:
|
||||
error_arguments.append((k, set(v)))
|
||||
elif config_value != v:
|
||||
error_arguments.append((k, v))
|
||||
if no_arguments or error_arguments:
|
||||
error_string = ""
|
||||
if no_arguments:
|
||||
error_string += "required argument are missing: {}; ".format(",".join(no_arguments))
|
||||
if error_arguments:
|
||||
error_string += "required argument values: {}".format(",".join(["{}={}".format(a[0], a[1]) for a in error_arguments]))
|
||||
return get_json_result(retcode=RetCode.ARGUMENT_ERROR, retmsg=error_string)
|
||||
return func(*_args, **_kwargs)
|
||||
return decorated_function
|
||||
return wrapper
|
||||
|
||||
|
||||
def is_localhost(ip):
|
||||
return ip in {'127.0.0.1', '::1', '[::1]', 'localhost'}
|
||||
|
||||
|
||||
def send_file_in_mem(data, filename):
|
||||
if not isinstance(data, (str, bytes)):
|
||||
data = json_dumps(data)
|
||||
if isinstance(data, str):
|
||||
data = data.encode('utf-8')
|
||||
|
||||
f = BytesIO()
|
||||
f.write(data)
|
||||
f.seek(0)
|
||||
|
||||
return send_file(f, as_attachment=True, attachment_filename=filename)
|
||||
|
||||
|
||||
def get_json_result(retcode=RetCode.SUCCESS, retmsg='success', data=None):
|
||||
response = {"retcode": retcode, "retmsg": retmsg, "data": data}
|
||||
return jsonify(response)
|
||||
|
||||
|
||||
def cors_reponse(retcode=RetCode.SUCCESS, retmsg='success', data=None, auth=None):
|
||||
result_dict = {"retcode": retcode, "retmsg": retmsg, "data": data}
|
||||
response_dict = {}
|
||||
for key, value in result_dict.items():
|
||||
if value is None and key != "retcode":
|
||||
continue
|
||||
else:
|
||||
response_dict[key] = value
|
||||
response = make_response(jsonify(response_dict))
|
||||
if auth:
|
||||
response.headers["Authorization"] = auth
|
||||
response.headers["Access-Control-Allow-Origin"] = "*"
|
||||
response.headers["Access-Control-Allow-Method"] = "*"
|
||||
response.headers["Access-Control-Allow-Headers"] = "*"
|
||||
response.headers["Access-Control-Allow-Headers"] = "*"
|
||||
response.headers["Access-Control-Expose-Headers"] = "Authorization"
|
||||
return response
|
||||
153
api/utils/file_utils.py
Normal file
153
api/utils/file_utils.py
Normal file
@ -0,0 +1,153 @@
|
||||
#
|
||||
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
|
||||
from cachetools import LRUCache, cached
|
||||
from ruamel.yaml import YAML
|
||||
|
||||
from web_server.db import FileType
|
||||
|
||||
PROJECT_BASE = os.getenv("RAG_PROJECT_BASE") or os.getenv("RAG_DEPLOY_BASE")
|
||||
RAG_BASE = os.getenv("RAG_BASE")
|
||||
|
||||
def get_project_base_directory(*args):
|
||||
global PROJECT_BASE
|
||||
if PROJECT_BASE is None:
|
||||
PROJECT_BASE = os.path.abspath(
|
||||
os.path.join(
|
||||
os.path.dirname(os.path.realpath(__file__)),
|
||||
os.pardir,
|
||||
os.pardir,
|
||||
)
|
||||
)
|
||||
|
||||
if args:
|
||||
return os.path.join(PROJECT_BASE, *args)
|
||||
return PROJECT_BASE
|
||||
|
||||
|
||||
def get_rag_directory(*args):
|
||||
global RAG_BASE
|
||||
if RAG_BASE is None:
|
||||
RAG_BASE = os.path.abspath(
|
||||
os.path.join(
|
||||
os.path.dirname(os.path.realpath(__file__)),
|
||||
os.pardir,
|
||||
os.pardir,
|
||||
os.pardir,
|
||||
)
|
||||
)
|
||||
if args:
|
||||
return os.path.join(RAG_BASE, *args)
|
||||
return RAG_BASE
|
||||
|
||||
|
||||
def get_rag_python_directory(*args):
|
||||
return get_rag_directory("python", *args)
|
||||
|
||||
|
||||
|
||||
@cached(cache=LRUCache(maxsize=10))
|
||||
def load_json_conf(conf_path):
|
||||
if os.path.isabs(conf_path):
|
||||
json_conf_path = conf_path
|
||||
else:
|
||||
json_conf_path = os.path.join(get_project_base_directory(), conf_path)
|
||||
try:
|
||||
with open(json_conf_path) as f:
|
||||
return json.load(f)
|
||||
except BaseException:
|
||||
raise EnvironmentError(
|
||||
"loading json file config from '{}' failed!".format(json_conf_path)
|
||||
)
|
||||
|
||||
|
||||
def dump_json_conf(config_data, conf_path):
|
||||
if os.path.isabs(conf_path):
|
||||
json_conf_path = conf_path
|
||||
else:
|
||||
json_conf_path = os.path.join(get_project_base_directory(), conf_path)
|
||||
try:
|
||||
with open(json_conf_path, "w") as f:
|
||||
json.dump(config_data, f, indent=4)
|
||||
except BaseException:
|
||||
raise EnvironmentError(
|
||||
"loading json file config from '{}' failed!".format(json_conf_path)
|
||||
)
|
||||
|
||||
|
||||
def load_json_conf_real_time(conf_path):
|
||||
if os.path.isabs(conf_path):
|
||||
json_conf_path = conf_path
|
||||
else:
|
||||
json_conf_path = os.path.join(get_project_base_directory(), conf_path)
|
||||
try:
|
||||
with open(json_conf_path) as f:
|
||||
return json.load(f)
|
||||
except BaseException:
|
||||
raise EnvironmentError(
|
||||
"loading json file config from '{}' failed!".format(json_conf_path)
|
||||
)
|
||||
|
||||
|
||||
def load_yaml_conf(conf_path):
|
||||
if not os.path.isabs(conf_path):
|
||||
conf_path = os.path.join(get_project_base_directory(), conf_path)
|
||||
try:
|
||||
with open(conf_path) as f:
|
||||
yaml = YAML(typ='safe', pure=True)
|
||||
return yaml.load(f)
|
||||
except Exception as e:
|
||||
raise EnvironmentError(
|
||||
"loading yaml file config from {} failed:".format(conf_path), e
|
||||
)
|
||||
|
||||
|
||||
def rewrite_yaml_conf(conf_path, config):
|
||||
if not os.path.isabs(conf_path):
|
||||
conf_path = os.path.join(get_project_base_directory(), conf_path)
|
||||
try:
|
||||
with open(conf_path, "w") as f:
|
||||
yaml = YAML(typ="safe")
|
||||
yaml.dump(config, f)
|
||||
except Exception as e:
|
||||
raise EnvironmentError(
|
||||
"rewrite yaml file config {} failed:".format(conf_path), e
|
||||
)
|
||||
|
||||
|
||||
def rewrite_json_file(filepath, json_data):
|
||||
with open(filepath, "w") as f:
|
||||
json.dump(json_data, f, indent=4, separators=(",", ": "))
|
||||
f.close()
|
||||
|
||||
|
||||
def filename_type(filename):
|
||||
filename = filename.lower()
|
||||
if re.match(r".*\.pdf$", filename):
|
||||
return FileType.PDF.value
|
||||
|
||||
if re.match(r".*\.(doc|ppt|yml|xml|htm|json|csv|txt|ini|xsl|wps|rtf|hlp|pages|numbers|key|md)$", filename):
|
||||
return FileType.DOC.value
|
||||
|
||||
if re.match(r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)$", filename):
|
||||
return FileType.AURAL.value
|
||||
|
||||
if re.match(r".*\.(jpg|jpeg|png|tif|gif|pcx|tga|exif|fpx|svg|psd|cdr|pcd|dxf|ufo|eps|ai|raw|WMF|webp|avif|apng|icon|ico|mpg|mpeg|avi|rm|rmvb|mov|wmv|asf|dat|asx|wvx|mpe|mpa|mp4)$", filename):
|
||||
return FileType.VISUAL
|
||||
294
api/utils/log_utils.py
Normal file
294
api/utils/log_utils.py
Normal file
@ -0,0 +1,294 @@
|
||||
#
|
||||
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import os
|
||||
import typing
|
||||
import traceback
|
||||
import logging
|
||||
import inspect
|
||||
from logging.handlers import TimedRotatingFileHandler
|
||||
from threading import RLock
|
||||
|
||||
from web_server.utils import file_utils
|
||||
|
||||
class LoggerFactory(object):
|
||||
TYPE = "FILE"
|
||||
LOG_FORMAT = "[%(levelname)s] [%(asctime)s] [jobId] [%(process)s:%(thread)s] - [%(module)s.%(funcName)s] [line:%(lineno)d]: %(message)s"
|
||||
LEVEL = logging.DEBUG
|
||||
logger_dict = {}
|
||||
global_handler_dict = {}
|
||||
|
||||
LOG_DIR = None
|
||||
PARENT_LOG_DIR = None
|
||||
log_share = True
|
||||
|
||||
append_to_parent_log = None
|
||||
|
||||
lock = RLock()
|
||||
# CRITICAL = 50
|
||||
# FATAL = CRITICAL
|
||||
# ERROR = 40
|
||||
# WARNING = 30
|
||||
# WARN = WARNING
|
||||
# INFO = 20
|
||||
# DEBUG = 10
|
||||
# NOTSET = 0
|
||||
levels = (10, 20, 30, 40)
|
||||
schedule_logger_dict = {}
|
||||
|
||||
@staticmethod
|
||||
def set_directory(directory=None, parent_log_dir=None, append_to_parent_log=None, force=False):
|
||||
if parent_log_dir:
|
||||
LoggerFactory.PARENT_LOG_DIR = parent_log_dir
|
||||
if append_to_parent_log:
|
||||
LoggerFactory.append_to_parent_log = append_to_parent_log
|
||||
with LoggerFactory.lock:
|
||||
if not directory:
|
||||
directory = file_utils.get_project_base_directory("logs")
|
||||
if not LoggerFactory.LOG_DIR or force:
|
||||
LoggerFactory.LOG_DIR = directory
|
||||
if LoggerFactory.log_share:
|
||||
oldmask = os.umask(000)
|
||||
os.makedirs(LoggerFactory.LOG_DIR, exist_ok=True)
|
||||
os.umask(oldmask)
|
||||
else:
|
||||
os.makedirs(LoggerFactory.LOG_DIR, exist_ok=True)
|
||||
for loggerName, ghandler in LoggerFactory.global_handler_dict.items():
|
||||
for className, (logger, handler) in LoggerFactory.logger_dict.items():
|
||||
logger.removeHandler(ghandler)
|
||||
ghandler.close()
|
||||
LoggerFactory.global_handler_dict = {}
|
||||
for className, (logger, handler) in LoggerFactory.logger_dict.items():
|
||||
logger.removeHandler(handler)
|
||||
_handler = None
|
||||
if handler:
|
||||
handler.close()
|
||||
if className != "default":
|
||||
_handler = LoggerFactory.get_handler(className)
|
||||
logger.addHandler(_handler)
|
||||
LoggerFactory.assemble_global_handler(logger)
|
||||
LoggerFactory.logger_dict[className] = logger, _handler
|
||||
|
||||
@staticmethod
|
||||
def new_logger(name):
|
||||
logger = logging.getLogger(name)
|
||||
logger.propagate = False
|
||||
logger.setLevel(LoggerFactory.LEVEL)
|
||||
return logger
|
||||
|
||||
@staticmethod
|
||||
def get_logger(class_name=None):
|
||||
with LoggerFactory.lock:
|
||||
if class_name in LoggerFactory.logger_dict.keys():
|
||||
logger, handler = LoggerFactory.logger_dict[class_name]
|
||||
if not logger:
|
||||
logger, handler = LoggerFactory.init_logger(class_name)
|
||||
else:
|
||||
logger, handler = LoggerFactory.init_logger(class_name)
|
||||
return logger
|
||||
|
||||
@staticmethod
|
||||
def get_global_handler(logger_name, level=None, log_dir=None):
|
||||
if not LoggerFactory.LOG_DIR:
|
||||
return logging.StreamHandler()
|
||||
if log_dir:
|
||||
logger_name_key = logger_name + "_" + log_dir
|
||||
else:
|
||||
logger_name_key = logger_name + "_" + LoggerFactory.LOG_DIR
|
||||
# if loggerName not in LoggerFactory.globalHandlerDict:
|
||||
if logger_name_key not in LoggerFactory.global_handler_dict:
|
||||
with LoggerFactory.lock:
|
||||
if logger_name_key not in LoggerFactory.global_handler_dict:
|
||||
handler = LoggerFactory.get_handler(logger_name, level, log_dir)
|
||||
LoggerFactory.global_handler_dict[logger_name_key] = handler
|
||||
return LoggerFactory.global_handler_dict[logger_name_key]
|
||||
|
||||
@staticmethod
|
||||
def get_handler(class_name, level=None, log_dir=None, log_type=None, job_id=None):
|
||||
if not log_type:
|
||||
if not LoggerFactory.LOG_DIR or not class_name:
|
||||
return logging.StreamHandler()
|
||||
# return Diy_StreamHandler()
|
||||
|
||||
if not log_dir:
|
||||
log_file = os.path.join(LoggerFactory.LOG_DIR, "{}.log".format(class_name))
|
||||
else:
|
||||
log_file = os.path.join(log_dir, "{}.log".format(class_name))
|
||||
else:
|
||||
log_file = os.path.join(log_dir, "rag_flow_{}.log".format(
|
||||
log_type) if level == LoggerFactory.LEVEL else 'rag_flow_{}_error.log'.format(log_type))
|
||||
|
||||
os.makedirs(os.path.dirname(log_file), exist_ok=True)
|
||||
if LoggerFactory.log_share:
|
||||
handler = ROpenHandler(log_file,
|
||||
when='D',
|
||||
interval=1,
|
||||
backupCount=14,
|
||||
delay=True)
|
||||
else:
|
||||
handler = TimedRotatingFileHandler(log_file,
|
||||
when='D',
|
||||
interval=1,
|
||||
backupCount=14,
|
||||
delay=True)
|
||||
if level:
|
||||
handler.level = level
|
||||
|
||||
return handler
|
||||
|
||||
@staticmethod
|
||||
def init_logger(class_name):
|
||||
with LoggerFactory.lock:
|
||||
logger = LoggerFactory.new_logger(class_name)
|
||||
handler = None
|
||||
if class_name:
|
||||
handler = LoggerFactory.get_handler(class_name)
|
||||
logger.addHandler(handler)
|
||||
LoggerFactory.logger_dict[class_name] = logger, handler
|
||||
|
||||
else:
|
||||
LoggerFactory.logger_dict["default"] = logger, handler
|
||||
|
||||
LoggerFactory.assemble_global_handler(logger)
|
||||
return logger, handler
|
||||
|
||||
@staticmethod
|
||||
def assemble_global_handler(logger):
|
||||
if LoggerFactory.LOG_DIR:
|
||||
for level in LoggerFactory.levels:
|
||||
if level >= LoggerFactory.LEVEL:
|
||||
level_logger_name = logging._levelToName[level]
|
||||
logger.addHandler(LoggerFactory.get_global_handler(level_logger_name, level))
|
||||
if LoggerFactory.append_to_parent_log and LoggerFactory.PARENT_LOG_DIR:
|
||||
for level in LoggerFactory.levels:
|
||||
if level >= LoggerFactory.LEVEL:
|
||||
level_logger_name = logging._levelToName[level]
|
||||
logger.addHandler(
|
||||
LoggerFactory.get_global_handler(level_logger_name, level, LoggerFactory.PARENT_LOG_DIR))
|
||||
|
||||
|
||||
def setDirectory(directory=None):
|
||||
LoggerFactory.set_directory(directory)
|
||||
|
||||
|
||||
def setLevel(level):
|
||||
LoggerFactory.LEVEL = level
|
||||
|
||||
|
||||
def getLogger(className=None, useLevelFile=False):
|
||||
if className is None:
|
||||
frame = inspect.stack()[1]
|
||||
module = inspect.getmodule(frame[0])
|
||||
className = 'stat'
|
||||
return LoggerFactory.get_logger(className)
|
||||
|
||||
|
||||
def exception_to_trace_string(ex):
|
||||
return "".join(traceback.TracebackException.from_exception(ex).format())
|
||||
|
||||
|
||||
class ROpenHandler(TimedRotatingFileHandler):
|
||||
def _open(self):
|
||||
prevumask = os.umask(000)
|
||||
rtv = TimedRotatingFileHandler._open(self)
|
||||
os.umask(prevumask)
|
||||
return rtv
|
||||
|
||||
|
||||
def sql_logger(job_id='', log_type='sql'):
|
||||
key = job_id + log_type
|
||||
if key in LoggerFactory.schedule_logger_dict.keys():
|
||||
return LoggerFactory.schedule_logger_dict[key]
|
||||
return get_job_logger(job_id=job_id, log_type=log_type)
|
||||
|
||||
|
||||
def ready_log(msg, job=None, task=None, role=None, party_id=None, detail=None):
|
||||
prefix, suffix = base_msg(job, task, role, party_id, detail)
|
||||
return f"{prefix}{msg} ready{suffix}"
|
||||
|
||||
|
||||
def start_log(msg, job=None, task=None, role=None, party_id=None, detail=None):
|
||||
prefix, suffix = base_msg(job, task, role, party_id, detail)
|
||||
return f"{prefix}start to {msg}{suffix}"
|
||||
|
||||
|
||||
def successful_log(msg, job=None, task=None, role=None, party_id=None, detail=None):
|
||||
prefix, suffix = base_msg(job, task, role, party_id, detail)
|
||||
return f"{prefix}{msg} successfully{suffix}"
|
||||
|
||||
|
||||
def warning_log(msg, job=None, task=None, role=None, party_id=None, detail=None):
|
||||
prefix, suffix = base_msg(job, task, role, party_id, detail)
|
||||
return f"{prefix}{msg} is not effective{suffix}"
|
||||
|
||||
|
||||
def failed_log(msg, job=None, task=None, role=None, party_id=None, detail=None):
|
||||
prefix, suffix = base_msg(job, task, role, party_id, detail)
|
||||
return f"{prefix}failed to {msg}{suffix}"
|
||||
|
||||
|
||||
def base_msg(job=None, task=None, role: str = None, party_id: typing.Union[str, int] = None, detail=None):
|
||||
if detail:
|
||||
detail_msg = f" detail: \n{detail}"
|
||||
else:
|
||||
detail_msg = ""
|
||||
if task is not None:
|
||||
return f"task {task.f_task_id} {task.f_task_version} ", f" on {task.f_role} {task.f_party_id}{detail_msg}"
|
||||
elif job is not None:
|
||||
return "", f" on {job.f_role} {job.f_party_id}{detail_msg}"
|
||||
elif role and party_id:
|
||||
return "", f" on {role} {party_id}{detail_msg}"
|
||||
else:
|
||||
return "", f"{detail_msg}"
|
||||
|
||||
|
||||
def exception_to_trace_string(ex):
|
||||
return "".join(traceback.TracebackException.from_exception(ex).format())
|
||||
|
||||
|
||||
def get_logger_base_dir():
|
||||
job_log_dir = file_utils.get_rag_flow_directory('logs')
|
||||
return job_log_dir
|
||||
|
||||
|
||||
def get_job_logger(job_id, log_type):
|
||||
rag_flow_log_dir = file_utils.get_rag_flow_directory('logs', 'rag_flow')
|
||||
job_log_dir = file_utils.get_rag_flow_directory('logs', job_id)
|
||||
if not job_id:
|
||||
log_dirs = [rag_flow_log_dir]
|
||||
else:
|
||||
if log_type == 'audit':
|
||||
log_dirs = [job_log_dir, rag_flow_log_dir]
|
||||
else:
|
||||
log_dirs = [job_log_dir]
|
||||
if LoggerFactory.log_share:
|
||||
oldmask = os.umask(000)
|
||||
os.makedirs(job_log_dir, exist_ok=True)
|
||||
os.makedirs(rag_flow_log_dir, exist_ok=True)
|
||||
os.umask(oldmask)
|
||||
else:
|
||||
os.makedirs(job_log_dir, exist_ok=True)
|
||||
os.makedirs(rag_flow_log_dir, exist_ok=True)
|
||||
logger = LoggerFactory.new_logger(f"{job_id}_{log_type}")
|
||||
for job_log_dir in log_dirs:
|
||||
handler = LoggerFactory.get_handler(class_name=None, level=LoggerFactory.LEVEL,
|
||||
log_dir=job_log_dir, log_type=log_type, job_id=job_id)
|
||||
error_handler = LoggerFactory.get_handler(class_name=None, level=logging.ERROR, log_dir=job_log_dir, log_type=log_type, job_id=job_id)
|
||||
logger.addHandler(handler)
|
||||
logger.addHandler(error_handler)
|
||||
with LoggerFactory.lock:
|
||||
LoggerFactory.schedule_logger_dict[job_id + log_type] = logger
|
||||
return logger
|
||||
|
||||
18
api/utils/t_crypt.py
Normal file
18
api/utils/t_crypt.py
Normal file
@ -0,0 +1,18 @@
|
||||
import base64, os, sys
|
||||
from Cryptodome.PublicKey import RSA
|
||||
from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5
|
||||
from web_server.utils import decrypt, file_utils
|
||||
|
||||
def crypt(line):
|
||||
file_path = os.path.join(file_utils.get_project_base_directory(), "conf", "public.pem")
|
||||
rsa_key = RSA.importKey(open(file_path).read())
|
||||
cipher = Cipher_pkcs1_v1_5.new(rsa_key)
|
||||
return base64.b64encode(cipher.encrypt(line.encode('utf-8'))).decode("utf-8")
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pswd = crypt(sys.argv[1])
|
||||
print(pswd)
|
||||
print(decrypt(pswd))
|
||||
|
||||
Reference in New Issue
Block a user