mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
rename web_server to api (#29)
* add front end code * change licence * rename web_server to API * change name to web_server
This commit is contained in:
0
api/__init__.py
Normal file
0
api/__init__.py
Normal file
147
api/apps/__init__.py
Normal file
147
api/apps/__init__.py
Normal file
@ -0,0 +1,147 @@
|
||||
#
|
||||
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import logging
|
||||
import sys
|
||||
from importlib.util import module_from_spec, spec_from_file_location
|
||||
from pathlib import Path
|
||||
from flask import Blueprint, Flask, request
|
||||
from werkzeug.wrappers.request import Request
|
||||
from flask_cors import CORS
|
||||
|
||||
from web_server.db import StatusEnum
|
||||
from web_server.db.services import UserService
|
||||
from web_server.utils import CustomJSONEncoder
|
||||
|
||||
from flask_session import Session
|
||||
from flask_login import LoginManager
|
||||
from web_server.settings import RetCode, SECRET_KEY, stat_logger
|
||||
from web_server.hook import HookManager
|
||||
from web_server.hook.common.parameters import AuthenticationParameters, ClientAuthenticationParameters
|
||||
from web_server.settings import API_VERSION, CLIENT_AUTHENTICATION, SITE_AUTHENTICATION, access_logger
|
||||
from web_server.utils.api_utils import get_json_result, server_error_response
|
||||
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
||||
|
||||
__all__ = ['app']
|
||||
|
||||
|
||||
logger = logging.getLogger('flask.app')
|
||||
for h in access_logger.handlers:
|
||||
logger.addHandler(h)
|
||||
|
||||
Request.json = property(lambda self: self.get_json(force=True, silent=True))
|
||||
|
||||
app = Flask(__name__)
|
||||
CORS(app, supports_credentials=True,max_age = 2592000)
|
||||
app.url_map.strict_slashes = False
|
||||
app.json_encoder = CustomJSONEncoder
|
||||
app.errorhandler(Exception)(server_error_response)
|
||||
|
||||
|
||||
## convince for dev and debug
|
||||
#app.config["LOGIN_DISABLED"] = True
|
||||
app.config["SESSION_PERMANENT"] = False
|
||||
app.config["SESSION_TYPE"] = "filesystem"
|
||||
app.config['MAX_CONTENT_LENGTH'] = 64 * 1024 * 1024
|
||||
|
||||
Session(app)
|
||||
login_manager = LoginManager()
|
||||
login_manager.init_app(app)
|
||||
|
||||
|
||||
|
||||
def search_pages_path(pages_dir):
|
||||
return [path for path in pages_dir.glob('*_app.py') if not path.name.startswith('.')]
|
||||
|
||||
|
||||
def register_page(page_path):
|
||||
page_name = page_path.stem.rstrip('_app')
|
||||
module_name = '.'.join(page_path.parts[page_path.parts.index('web_server'):-1] + (page_name, ))
|
||||
|
||||
spec = spec_from_file_location(module_name, page_path)
|
||||
page = module_from_spec(spec)
|
||||
page.app = app
|
||||
page.manager = Blueprint(page_name, module_name)
|
||||
sys.modules[module_name] = page
|
||||
spec.loader.exec_module(page)
|
||||
|
||||
page_name = getattr(page, 'page_name', page_name)
|
||||
url_prefix = f'/{API_VERSION}/{page_name}'
|
||||
|
||||
app.register_blueprint(page.manager, url_prefix=url_prefix)
|
||||
return url_prefix
|
||||
|
||||
|
||||
pages_dir = [
|
||||
Path(__file__).parent,
|
||||
Path(__file__).parent.parent / 'web_server' / 'apps',
|
||||
]
|
||||
|
||||
client_urls_prefix = [
|
||||
register_page(path)
|
||||
for dir in pages_dir
|
||||
for path in search_pages_path(dir)
|
||||
]
|
||||
|
||||
|
||||
def client_authentication_before_request():
|
||||
result = HookManager.client_authentication(ClientAuthenticationParameters(
|
||||
request.full_path, request.headers,
|
||||
request.form, request.data, request.json,
|
||||
))
|
||||
|
||||
if result.code != RetCode.SUCCESS:
|
||||
return get_json_result(result.code, result.message)
|
||||
|
||||
|
||||
def site_authentication_before_request():
|
||||
for url_prefix in client_urls_prefix:
|
||||
if request.path.startswith(url_prefix):
|
||||
return
|
||||
|
||||
result = HookManager.site_authentication(AuthenticationParameters(
|
||||
request.headers.get('site_signature'),
|
||||
request.json,
|
||||
))
|
||||
|
||||
if result.code != RetCode.SUCCESS:
|
||||
return get_json_result(result.code, result.message)
|
||||
|
||||
|
||||
@app.before_request
|
||||
def authentication_before_request():
|
||||
if CLIENT_AUTHENTICATION:
|
||||
return client_authentication_before_request()
|
||||
|
||||
if SITE_AUTHENTICATION:
|
||||
return site_authentication_before_request()
|
||||
|
||||
@login_manager.request_loader
|
||||
def load_user(web_request):
|
||||
jwt = Serializer(secret_key=SECRET_KEY)
|
||||
authorization = web_request.headers.get("Authorization")
|
||||
if authorization:
|
||||
try:
|
||||
access_token = str(jwt.loads(authorization))
|
||||
user = UserService.query(access_token=access_token, status=StatusEnum.VALID.value)
|
||||
if user:
|
||||
return user[0]
|
||||
else:
|
||||
return None
|
||||
except Exception as e:
|
||||
stat_logger.exception(e)
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
150
api/apps/chunk_app.py
Normal file
150
api/apps/chunk_app.py
Normal file
@ -0,0 +1,150 @@
|
||||
#
|
||||
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import base64
|
||||
import hashlib
|
||||
import pathlib
|
||||
import re
|
||||
|
||||
from elasticsearch_dsl import Q
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
|
||||
from rag.nlp import search, huqie
|
||||
from rag.utils import ELASTICSEARCH, rmSpace
|
||||
from web_server.db import LLMType
|
||||
from web_server.db.services import duplicate_name
|
||||
from web_server.db.services.kb_service import KnowledgebaseService
|
||||
from web_server.db.services.llm_service import TenantLLMService
|
||||
from web_server.db.services.user_service import UserTenantService
|
||||
from web_server.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
||||
from web_server.utils import get_uuid
|
||||
from web_server.db.services.document_service import DocumentService
|
||||
from web_server.settings import RetCode
|
||||
from web_server.utils.api_utils import get_json_result
|
||||
from rag.utils.minio_conn import MINIO
|
||||
from web_server.utils.file_utils import filename_type
|
||||
|
||||
retrival = search.Dealer(ELASTICSEARCH, None)
|
||||
|
||||
@manager.route('/list', methods=['POST'])
|
||||
@login_required
|
||||
@validate_request("doc_id")
|
||||
def list():
|
||||
req = request.json
|
||||
doc_id = req["doc_id"]
|
||||
page = req.get("page", 1)
|
||||
size = req.get("size", 30)
|
||||
question = req.get("keywords", "")
|
||||
try:
|
||||
tenants = UserTenantService.query(user_id=current_user.id)
|
||||
if not tenants:
|
||||
return get_data_error_result(retmsg="Tenant not found!")
|
||||
res = retrival.search({
|
||||
"doc_ids": [doc_id], "page": page, "size": size, "question": question
|
||||
}, search.index_name(tenants[0].tenant_id))
|
||||
return get_json_result(data=res)
|
||||
except Exception as e:
|
||||
if str(e).find("not_found") > 0:
|
||||
return get_json_result(data=False, retmsg=f'Index not found!',
|
||||
retcode=RetCode.DATA_ERROR)
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/get', methods=['GET'])
|
||||
@login_required
|
||||
def get():
|
||||
chunk_id = request.args["chunk_id"]
|
||||
try:
|
||||
tenants = UserTenantService.query(user_id=current_user.id)
|
||||
if not tenants:
|
||||
return get_data_error_result(retmsg="Tenant not found!")
|
||||
res = ELASTICSEARCH.get(chunk_id, search.index_name(tenants[0].tenant_id))
|
||||
if not res.get("found"):return server_error_response("Chunk not found")
|
||||
id = res["_id"]
|
||||
res = res["_source"]
|
||||
res["chunk_id"] = id
|
||||
k = []
|
||||
for n in res.keys():
|
||||
if re.search(r"(_vec$|_sm_)", n):
|
||||
k.append(n)
|
||||
if re.search(r"(_tks|_ltks)", n):
|
||||
res[n] = rmSpace(res[n])
|
||||
for n in k: del res[n]
|
||||
|
||||
return get_json_result(data=res)
|
||||
except Exception as e:
|
||||
if str(e).find("NotFoundError") >= 0:
|
||||
return get_json_result(data=False, retmsg=f'Chunk not found!',
|
||||
retcode=RetCode.DATA_ERROR)
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/set', methods=['POST'])
|
||||
@login_required
|
||||
@validate_request("doc_id", "chunk_id", "content_ltks", "important_kwd", "docnm_kwd")
|
||||
def set():
|
||||
req = request.json
|
||||
d = {"id": req["chunk_id"]}
|
||||
d["content_ltks"] = huqie.qie(req["content_ltks"])
|
||||
d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"])
|
||||
d["important_kwd"] = req["important_kwd"]
|
||||
d["important_tks"] = huqie.qie(" ".join(req["important_kwd"]))
|
||||
|
||||
try:
|
||||
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
||||
if not tenant_id: return get_data_error_result(retmsg="Tenant not found!")
|
||||
embd_mdl = TenantLLMService.model_instance(tenant_id, LLMType.EMBEDDING.value)
|
||||
v, c = embd_mdl.encode([req["docnm_kwd"], req["content_ltks"]])
|
||||
v = 0.1 * v[0] + 0.9 * v[1]
|
||||
d["q_%d_vec"%len(v)] = v.tolist()
|
||||
ELASTICSEARCH.upsert([d], search.index_name(tenant_id))
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/create', methods=['POST'])
|
||||
@login_required
|
||||
@validate_request("doc_id", "content_ltks", "important_kwd")
|
||||
def set():
|
||||
req = request.json
|
||||
md5 = hashlib.md5()
|
||||
md5.update((req["content_ltks"] + req["doc_id"]).encode("utf-8"))
|
||||
chunck_id = md5.hexdigest()
|
||||
d = {"id": chunck_id, "content_ltks": huqie.qie(req["content_ltks"])}
|
||||
d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"])
|
||||
d["important_kwd"] = req["important_kwd"]
|
||||
d["important_tks"] = huqie.qie(" ".join(req["important_kwd"]))
|
||||
|
||||
try:
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
if not e: return get_data_error_result(retmsg="Document not found!")
|
||||
d["kb_id"] = [doc.kb_id]
|
||||
d["docnm_kwd"] = doc.name
|
||||
d["doc_id"] = doc.id
|
||||
|
||||
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
||||
if not tenant_id: return get_data_error_result(retmsg="Tenant not found!")
|
||||
|
||||
embd_mdl = TenantLLMService.model_instance(tenant_id, LLMType.EMBEDDING.value)
|
||||
v, c = embd_mdl.encode([doc.name, req["content_ltks"]])
|
||||
DocumentService.increment_chunk_num(req["doc_id"], doc.kb_id, c, 1, 0)
|
||||
v = 0.1 * v[0] + 0.9 * v[1]
|
||||
d["q_%d_vec"%len(v)] = v.tolist()
|
||||
ELASTICSEARCH.upsert([d], search.index_name(tenant_id))
|
||||
return get_json_result(data={"chunk_id": chunck_id})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
279
api/apps/document_app.py
Normal file
279
api/apps/document_app.py
Normal file
@ -0,0 +1,279 @@
|
||||
#
|
||||
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import base64
|
||||
import pathlib
|
||||
|
||||
from elasticsearch_dsl import Q
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
|
||||
from rag.nlp import search
|
||||
from rag.utils import ELASTICSEARCH
|
||||
from web_server.db.services import duplicate_name
|
||||
from web_server.db.services.kb_service import KnowledgebaseService
|
||||
from web_server.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
||||
from web_server.utils import get_uuid
|
||||
from web_server.db import FileType
|
||||
from web_server.db.services.document_service import DocumentService
|
||||
from web_server.settings import RetCode
|
||||
from web_server.utils.api_utils import get_json_result
|
||||
from rag.utils.minio_conn import MINIO
|
||||
from web_server.utils.file_utils import filename_type
|
||||
|
||||
|
||||
@manager.route('/upload', methods=['POST'])
|
||||
@login_required
|
||||
@validate_request("kb_id")
|
||||
def upload():
|
||||
kb_id = request.form.get("kb_id")
|
||||
if not kb_id:
|
||||
return get_json_result(
|
||||
data=False, retmsg='Lack of "KB ID"', retcode=RetCode.ARGUMENT_ERROR)
|
||||
if 'file' not in request.files:
|
||||
return get_json_result(
|
||||
data=False, retmsg='No file part!', retcode=RetCode.ARGUMENT_ERROR)
|
||||
file = request.files['file']
|
||||
if file.filename == '':
|
||||
return get_json_result(
|
||||
data=False, retmsg='No file selected!', retcode=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
try:
|
||||
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
if not e:
|
||||
return get_data_error_result(
|
||||
retmsg="Can't find this knowledgebase!")
|
||||
|
||||
filename = duplicate_name(
|
||||
DocumentService.query,
|
||||
name=file.filename,
|
||||
kb_id=kb.id)
|
||||
location = filename
|
||||
while MINIO.obj_exist(kb_id, location):
|
||||
location += "_"
|
||||
blob = request.files['file'].read()
|
||||
MINIO.put(kb_id, filename, blob)
|
||||
doc = DocumentService.insert({
|
||||
"id": get_uuid(),
|
||||
"kb_id": kb.id,
|
||||
"parser_id": kb.parser_id,
|
||||
"created_by": current_user.id,
|
||||
"type": filename_type(filename),
|
||||
"name": filename,
|
||||
"location": location,
|
||||
"size": len(blob)
|
||||
})
|
||||
return get_json_result(data=doc.to_json())
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/create', methods=['POST'])
|
||||
@login_required
|
||||
@validate_request("name", "kb_id")
|
||||
def create():
|
||||
req = request.json
|
||||
kb_id = req["kb_id"]
|
||||
if not kb_id:
|
||||
return get_json_result(
|
||||
data=False, retmsg='Lack of "KB ID"', retcode=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
try:
|
||||
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
if not e:
|
||||
return get_data_error_result(
|
||||
retmsg="Can't find this knowledgebase!")
|
||||
|
||||
if DocumentService.query(name=req["name"], kb_id=kb_id):
|
||||
return get_data_error_result(
|
||||
retmsg="Duplicated document name in the same knowledgebase.")
|
||||
|
||||
doc = DocumentService.insert({
|
||||
"id": get_uuid(),
|
||||
"kb_id": kb.id,
|
||||
"parser_id": kb.parser_id,
|
||||
"created_by": current_user.id,
|
||||
"type": FileType.VIRTUAL,
|
||||
"name": req["name"],
|
||||
"location": "",
|
||||
"size": 0
|
||||
})
|
||||
return get_json_result(data=doc.to_json())
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/list', methods=['GET'])
|
||||
@login_required
|
||||
def list():
|
||||
kb_id = request.args.get("kb_id")
|
||||
if not kb_id:
|
||||
return get_json_result(
|
||||
data=False, retmsg='Lack of "KB ID"', retcode=RetCode.ARGUMENT_ERROR)
|
||||
keywords = request.args.get("keywords", "")
|
||||
|
||||
page_number = request.args.get("page", 1)
|
||||
items_per_page = request.args.get("page_size", 15)
|
||||
orderby = request.args.get("orderby", "create_time")
|
||||
desc = request.args.get("desc", True)
|
||||
try:
|
||||
docs = DocumentService.get_by_kb_id(
|
||||
kb_id, page_number, items_per_page, orderby, desc, keywords)
|
||||
return get_json_result(data=docs)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/change_status', methods=['POST'])
|
||||
@login_required
|
||||
@validate_request("doc_id", "status")
|
||||
def change_status():
|
||||
req = request.json
|
||||
if str(req["status"]) not in ["0", "1"]:
|
||||
get_json_result(
|
||||
data=False,
|
||||
retmsg='"Status" must be either 0 or 1!',
|
||||
retcode=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
try:
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
if not e:
|
||||
return get_data_error_result(retmsg="Document not found!")
|
||||
e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
|
||||
if not e:
|
||||
return get_data_error_result(
|
||||
retmsg="Can't find this knowledgebase!")
|
||||
|
||||
if not DocumentService.update_by_id(
|
||||
req["doc_id"], {"status": str(req["status"])}):
|
||||
return get_data_error_result(
|
||||
retmsg="Database error (Document update)!")
|
||||
|
||||
if str(req["status"]) == "0":
|
||||
ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=req["doc_id"]),
|
||||
scripts="""
|
||||
if(ctx._source.kb_id.contains('%s'))
|
||||
ctx._source.kb_id.remove(
|
||||
ctx._source.kb_id.indexOf('%s')
|
||||
);
|
||||
""" % (doc.kb_id, doc.kb_id),
|
||||
idxnm=search.index_name(
|
||||
kb.tenant_id)
|
||||
)
|
||||
else:
|
||||
ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=req["doc_id"]),
|
||||
scripts="""
|
||||
if(!ctx._source.kb_id.contains('%s'))
|
||||
ctx._source.kb_id.add('%s');
|
||||
""" % (doc.kb_id, doc.kb_id),
|
||||
idxnm=search.index_name(
|
||||
kb.tenant_id)
|
||||
)
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/rm', methods=['POST'])
|
||||
@login_required
|
||||
@validate_request("doc_id")
|
||||
def rm():
|
||||
req = request.json
|
||||
try:
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
if not e:
|
||||
return get_data_error_result(retmsg="Document not found!")
|
||||
if not ELASTICSEARCH.deleteByQuery(Q("match", doc_id=doc.id), idxnm=search.index_name(doc.kb_id)):
|
||||
return get_json_result(data=False, retmsg='Remove from ES failure"', retcode=RetCode.SERVER_ERROR)
|
||||
|
||||
DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num*-1, doc.chunk_num*-1, 0)
|
||||
if not DocumentService.delete_by_id(req["doc_id"]):
|
||||
return get_data_error_result(
|
||||
retmsg="Database error (Document removal)!")
|
||||
|
||||
MINIO.rm(doc.kb_id, doc.location)
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/rename', methods=['POST'])
|
||||
@login_required
|
||||
@validate_request("doc_id", "name", "old_name")
|
||||
def rename():
|
||||
req = request.json
|
||||
if pathlib.Path(req["name"].lower()).suffix != pathlib.Path(
|
||||
req["old_name"].lower()).suffix:
|
||||
get_json_result(
|
||||
data=False,
|
||||
retmsg="The extension of file can't be changed",
|
||||
retcode=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
try:
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
if not e:
|
||||
return get_data_error_result(retmsg="Document not found!")
|
||||
if DocumentService.query(name=req["name"], kb_id=doc.kb_id):
|
||||
return get_data_error_result(
|
||||
retmsg="Duplicated document name in the same knowledgebase.")
|
||||
|
||||
if not DocumentService.update_by_id(
|
||||
req["doc_id"], {"name": req["name"]}):
|
||||
return get_data_error_result(
|
||||
retmsg="Database error (Document rename)!")
|
||||
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/get', methods=['GET'])
|
||||
@login_required
|
||||
def get():
|
||||
doc_id = request.args["doc_id"]
|
||||
try:
|
||||
e, doc = DocumentService.get_by_id(doc_id)
|
||||
if not e:
|
||||
return get_data_error_result(retmsg="Document not found!")
|
||||
|
||||
blob = MINIO.get(doc.kb_id, doc.location)
|
||||
return get_json_result(data={"base64": base64.b64decode(blob)})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/change_parser', methods=['POST'])
|
||||
@login_required
|
||||
@validate_request("doc_id", "parser_id")
|
||||
def change_parser():
|
||||
req = request.json
|
||||
try:
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
if not e:
|
||||
return get_data_error_result(retmsg="Document not found!")
|
||||
if doc.parser_id.lower() == req["parser_id"].lower():
|
||||
return get_json_result(data=True)
|
||||
|
||||
e = DocumentService.update_by_id(doc.id, {"parser_id": req["parser_id"], "progress":0, "progress_msg": ""})
|
||||
if not e:
|
||||
return get_data_error_result(retmsg="Document not found!")
|
||||
e = DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num*-1, doc.chunk_num*-1, doc.process_duation*-1)
|
||||
if not e:
|
||||
return get_data_error_result(retmsg="Document not found!")
|
||||
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
114
api/apps/kb_app.py
Normal file
114
api/apps/kb_app.py
Normal file
@ -0,0 +1,114 @@
|
||||
#
|
||||
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
|
||||
from web_server.db.services import duplicate_name
|
||||
from web_server.db.services.user_service import TenantService, UserTenantService
|
||||
from web_server.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
||||
from web_server.utils import get_uuid, get_format_time
|
||||
from web_server.db import StatusEnum, UserTenantRole
|
||||
from web_server.db.services.kb_service import KnowledgebaseService
|
||||
from web_server.db.db_models import Knowledgebase
|
||||
from web_server.settings import stat_logger, RetCode
|
||||
from web_server.utils.api_utils import get_json_result
|
||||
|
||||
|
||||
@manager.route('/create', methods=['post'])
|
||||
@login_required
|
||||
@validate_request("name", "description", "permission", "parser_id")
|
||||
def create():
|
||||
req = request.json
|
||||
req["name"] = req["name"].strip()
|
||||
req["name"] = duplicate_name(KnowledgebaseService.query, name=req["name"], tenant_id=current_user.id, status=StatusEnum.VALID.value)
|
||||
try:
|
||||
req["id"] = get_uuid()
|
||||
req["tenant_id"] = current_user.id
|
||||
req["created_by"] = current_user.id
|
||||
if not KnowledgebaseService.save(**req): return get_data_error_result()
|
||||
return get_json_result(data={"kb_id": req["id"]})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/update', methods=['post'])
|
||||
@login_required
|
||||
@validate_request("kb_id", "name", "description", "permission", "parser_id")
|
||||
def update():
|
||||
req = request.json
|
||||
req["name"] = req["name"].strip()
|
||||
try:
|
||||
if not KnowledgebaseService.query(created_by=current_user.id, id=req["kb_id"]):
|
||||
return get_json_result(data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.', retcode=RetCode.OPERATING_ERROR)
|
||||
|
||||
e, kb = KnowledgebaseService.get_by_id(req["kb_id"])
|
||||
if not e: return get_data_error_result(retmsg="Can't find this knowledgebase!")
|
||||
|
||||
if req["name"].lower() != kb.name.lower() \
|
||||
and len(KnowledgebaseService.query(name=req["name"], tenant_id=current_user.id, status=StatusEnum.VALID.value))>1:
|
||||
return get_data_error_result(retmsg="Duplicated knowledgebase name.")
|
||||
|
||||
del req["kb_id"]
|
||||
if not KnowledgebaseService.update_by_id(kb.id, req): return get_data_error_result()
|
||||
|
||||
e, kb = KnowledgebaseService.get_by_id(kb.id)
|
||||
if not e: return get_data_error_result(retmsg="Database error (Knowledgebase rename)!")
|
||||
|
||||
return get_json_result(data=kb.to_json())
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/detail', methods=['GET'])
|
||||
@login_required
|
||||
def detail():
|
||||
kb_id = request.args["kb_id"]
|
||||
try:
|
||||
kb = KnowledgebaseService.get_detail(kb_id)
|
||||
if not kb: return get_data_error_result(retmsg="Can't find this knowledgebase!")
|
||||
return get_json_result(data=kb)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/list', methods=['GET'])
|
||||
@login_required
|
||||
def list():
|
||||
page_number = request.args.get("page", 1)
|
||||
items_per_page = request.args.get("page_size", 15)
|
||||
orderby = request.args.get("orderby", "create_time")
|
||||
desc = request.args.get("desc", True)
|
||||
try:
|
||||
tenants = TenantService.get_joined_tenants_by_user_id(current_user.id)
|
||||
kbs = KnowledgebaseService.get_by_tenant_ids([m["tenant_id"] for m in tenants], current_user.id, page_number, items_per_page, orderby, desc)
|
||||
return get_json_result(data=kbs)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/rm', methods=['post'])
|
||||
@login_required
|
||||
@validate_request("kb_id")
|
||||
def rm():
|
||||
req = request.json
|
||||
try:
|
||||
if not KnowledgebaseService.query(created_by=current_user.id, id=req["kb_id"]):
|
||||
return get_json_result(data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.', retcode=RetCode.OPERATING_ERROR)
|
||||
|
||||
if not KnowledgebaseService.update_by_id(req["kb_id"], {"status": StatusEnum.IN_VALID.value}): return get_data_error_result(retmsg="Database error (Knowledgebase removal)!")
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
93
api/apps/llm_app.py
Normal file
93
api/apps/llm_app.py
Normal file
@ -0,0 +1,93 @@
|
||||
#
|
||||
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
|
||||
from web_server.db.services import duplicate_name
|
||||
from web_server.db.services.llm_service import LLMFactoriesService, TenantLLMService, LLMService
|
||||
from web_server.db.services.user_service import TenantService, UserTenantService
|
||||
from web_server.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
||||
from web_server.utils import get_uuid, get_format_time
|
||||
from web_server.db import StatusEnum, UserTenantRole
|
||||
from web_server.db.services.kb_service import KnowledgebaseService
|
||||
from web_server.db.db_models import Knowledgebase, TenantLLM
|
||||
from web_server.settings import stat_logger, RetCode
|
||||
from web_server.utils.api_utils import get_json_result
|
||||
|
||||
|
||||
@manager.route('/factories', methods=['GET'])
|
||||
@login_required
|
||||
def factories():
|
||||
try:
|
||||
fac = LLMFactoriesService.get_all()
|
||||
return get_json_result(data=[f.to_dict() for f in fac])
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/set_api_key', methods=['POST'])
|
||||
@login_required
|
||||
@validate_request("llm_factory", "api_key")
|
||||
def set_api_key():
|
||||
req = request.json
|
||||
llm = {
|
||||
"tenant_id": current_user.id,
|
||||
"llm_factory": req["llm_factory"],
|
||||
"api_key": req["api_key"]
|
||||
}
|
||||
# TODO: Test api_key
|
||||
for n in ["model_type", "llm_name"]:
|
||||
if n in req: llm[n] = req[n]
|
||||
|
||||
TenantLLM.insert(**llm).on_conflict("replace").execute()
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
@manager.route('/my_llms', methods=['GET'])
|
||||
@login_required
|
||||
def my_llms():
|
||||
try:
|
||||
objs = TenantLLMService.get_my_llms(current_user.id)
|
||||
return get_json_result(data=objs)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/list', methods=['GET'])
|
||||
@login_required
|
||||
def list():
|
||||
try:
|
||||
objs = TenantLLMService.query(tenant_id=current_user.id)
|
||||
objs = [o.to_dict() for o in objs if o.api_key]
|
||||
fct = {}
|
||||
for o in objs:
|
||||
if o["llm_factory"] not in fct: fct[o["llm_factory"]] = []
|
||||
if o["llm_name"]: fct[o["llm_factory"]].append(o["llm_name"])
|
||||
|
||||
llms = LLMService.get_all()
|
||||
llms = [m.to_dict() for m in llms if m.status == StatusEnum.VALID.value]
|
||||
for m in llms:
|
||||
m["available"] = False
|
||||
if m["fid"] in fct and (not fct[m["fid"]] or m["llm_name"] in fct[m["fid"]]):
|
||||
m["available"] = True
|
||||
res = {}
|
||||
for m in llms:
|
||||
if m["fid"] not in res: res[m["fid"]] = []
|
||||
res[m["fid"]].append(m)
|
||||
|
||||
return get_json_result(data=res)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
268
api/apps/user_app.py
Normal file
268
api/apps/user_app.py
Normal file
@ -0,0 +1,268 @@
|
||||
#
|
||||
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from flask import request, session, redirect, url_for
|
||||
from werkzeug.security import generate_password_hash, check_password_hash
|
||||
from flask_login import login_required, current_user, login_user, logout_user
|
||||
|
||||
from web_server.db.db_models import TenantLLM
|
||||
from web_server.db.services.llm_service import TenantLLMService
|
||||
from web_server.utils.api_utils import server_error_response, validate_request
|
||||
from web_server.utils import get_uuid, get_format_time, decrypt, download_img
|
||||
from web_server.db import UserTenantRole, LLMType
|
||||
from web_server.settings import RetCode, GITHUB_OAUTH, CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS
|
||||
from web_server.db.services.user_service import UserService, TenantService, UserTenantService
|
||||
from web_server.settings import stat_logger
|
||||
from web_server.utils.api_utils import get_json_result, cors_reponse
|
||||
|
||||
|
||||
@manager.route('/login', methods=['POST', 'GET'])
|
||||
def login():
|
||||
userinfo = None
|
||||
login_channel = "password"
|
||||
if session.get("access_token"):
|
||||
login_channel = session["access_token_from"]
|
||||
if session["access_token_from"] == "github":
|
||||
userinfo = user_info_from_github(session["access_token"])
|
||||
elif not request.json:
|
||||
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR,
|
||||
retmsg='Unautherized!')
|
||||
|
||||
email = request.json.get('email') if not userinfo else userinfo["email"]
|
||||
users = UserService.query(email=email)
|
||||
if not users:
|
||||
if request.json is not None:
|
||||
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg=f'This Email is not registered!')
|
||||
avatar = ""
|
||||
try:
|
||||
avatar = download_img(userinfo["avatar_url"])
|
||||
except Exception as e:
|
||||
stat_logger.exception(e)
|
||||
user_id = get_uuid()
|
||||
try:
|
||||
users = user_register(user_id, {
|
||||
"access_token": session["access_token"],
|
||||
"email": userinfo["email"],
|
||||
"avatar": avatar,
|
||||
"nickname": userinfo["login"],
|
||||
"login_channel": login_channel,
|
||||
"last_login_time": get_format_time(),
|
||||
"is_superuser": False,
|
||||
})
|
||||
if not users: raise Exception('Register user failure.')
|
||||
if len(users) > 1: raise Exception('Same E-mail exist!')
|
||||
user = users[0]
|
||||
login_user(user)
|
||||
return cors_reponse(data=user.to_json(), auth=user.get_id(), retmsg="Welcome back!")
|
||||
except Exception as e:
|
||||
rollback_user_registration(user_id)
|
||||
stat_logger.exception(e)
|
||||
return server_error_response(e)
|
||||
elif not request.json:
|
||||
login_user(users[0])
|
||||
return cors_reponse(data=users[0].to_json(), auth=users[0].get_id(), retmsg="Welcome back!")
|
||||
|
||||
password = request.json.get('password')
|
||||
try:
|
||||
password = decrypt(password)
|
||||
except:
|
||||
return get_json_result(data=False, retcode=RetCode.SERVER_ERROR, retmsg='Fail to crypt password')
|
||||
|
||||
user = UserService.query_user(email, password)
|
||||
if user:
|
||||
response_data = user.to_json()
|
||||
user.access_token = get_uuid()
|
||||
login_user(user)
|
||||
user.save()
|
||||
msg = "Welcome back!"
|
||||
return cors_reponse(data=response_data, auth=user.get_id(), retmsg=msg)
|
||||
else:
|
||||
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg='Email and Password do not match!')
|
||||
|
||||
|
||||
@manager.route('/github_callback', methods=['GET'])
|
||||
def github_callback():
|
||||
try:
|
||||
import requests
|
||||
res = requests.post(GITHUB_OAUTH.get("url"), data={
|
||||
"client_id": GITHUB_OAUTH.get("client_id"),
|
||||
"client_secret": GITHUB_OAUTH.get("secret_key"),
|
||||
"code": request.args.get('code')
|
||||
},headers={"Accept": "application/json"})
|
||||
res = res.json()
|
||||
if "error" in res:
|
||||
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR,
|
||||
retmsg=res["error_description"])
|
||||
|
||||
if "user:email" not in res["scope"].split(","):
|
||||
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg='user:email not in scope')
|
||||
|
||||
session["access_token"] = res["access_token"]
|
||||
session["access_token_from"] = "github"
|
||||
return redirect(url_for("user.login"), code=307)
|
||||
|
||||
except Exception as e:
|
||||
stat_logger.exception(e)
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
def user_info_from_github(access_token):
|
||||
import requests
|
||||
headers = {"Accept": "application/json", 'Authorization': f"token {access_token}"}
|
||||
res = requests.get(f"https://api.github.com/user?access_token={access_token}", headers=headers)
|
||||
user_info = res.json()
|
||||
email_info = requests.get(f"https://api.github.com/user/emails?access_token={access_token}", headers=headers).json()
|
||||
user_info["email"] = next((email for email in email_info if email['primary'] == True), None)["email"]
|
||||
return user_info
|
||||
|
||||
|
||||
@manager.route("/logout", methods=['GET'])
|
||||
@login_required
|
||||
def log_out():
|
||||
current_user.access_token = ""
|
||||
current_user.save()
|
||||
logout_user()
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
@manager.route("/setting", methods=["POST"])
|
||||
@login_required
|
||||
def setting_user():
|
||||
update_dict = {}
|
||||
request_data = request.json
|
||||
if request_data.get("password"):
|
||||
new_password = request_data.get("new_password")
|
||||
if not check_password_hash(current_user.password, decrypt(request_data["password"])):
|
||||
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg='Password error!')
|
||||
|
||||
if new_password: update_dict["password"] = generate_password_hash(decrypt(new_password))
|
||||
|
||||
for k in request_data.keys():
|
||||
if k in ["password", "new_password"]:continue
|
||||
update_dict[k] = request_data[k]
|
||||
|
||||
try:
|
||||
UserService.update_by_id(current_user.id, update_dict)
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
stat_logger.exception(e)
|
||||
return get_json_result(data=False, retmsg='Update failure!', retcode=RetCode.EXCEPTION_ERROR)
|
||||
|
||||
|
||||
@manager.route("/info", methods=["GET"])
|
||||
@login_required
|
||||
def user_info():
|
||||
return get_json_result(data=current_user.to_dict())
|
||||
|
||||
|
||||
def rollback_user_registration(user_id):
|
||||
try:
|
||||
TenantService.delete_by_id(user_id)
|
||||
except Exception as e:
|
||||
pass
|
||||
try:
|
||||
u = UserTenantService.query(tenant_id=user_id)
|
||||
if u:
|
||||
UserTenantService.delete_by_id(u[0].id)
|
||||
except Exception as e:
|
||||
pass
|
||||
try:
|
||||
TenantLLM.delete().where(TenantLLM.tenant_id==user_id).excute()
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
|
||||
def user_register(user_id, user):
|
||||
|
||||
user_id = get_uuid()
|
||||
user["id"] = user_id
|
||||
tenant = {
|
||||
"id": user_id,
|
||||
"name": user["nickname"] + "‘s Kingdom",
|
||||
"llm_id": CHAT_MDL,
|
||||
"embd_id": EMBEDDING_MDL,
|
||||
"asr_id": ASR_MDL,
|
||||
"parser_ids": PARSERS,
|
||||
"img2txt_id": IMAGE2TEXT_MDL
|
||||
}
|
||||
usr_tenant = {
|
||||
"tenant_id": user_id,
|
||||
"user_id": user_id,
|
||||
"invited_by": user_id,
|
||||
"role": UserTenantRole.OWNER
|
||||
}
|
||||
tenant_llm = {"tenant_id": user_id, "llm_factory": "OpenAI", "api_key": "infiniflow API Key"}
|
||||
|
||||
if not UserService.save(**user):return
|
||||
TenantService.save(**tenant)
|
||||
UserTenantService.save(**usr_tenant)
|
||||
TenantLLMService.save(**tenant_llm)
|
||||
return UserService.query(email=user["email"])
|
||||
|
||||
|
||||
@manager.route("/register", methods=["POST"])
|
||||
@validate_request("nickname", "email", "password")
|
||||
def user_add():
|
||||
req = request.json
|
||||
if UserService.query(email=req["email"]):
|
||||
return get_json_result(data=False, retmsg=f'Email: {req["email"]} has already registered!', retcode=RetCode.OPERATING_ERROR)
|
||||
|
||||
user_dict = {
|
||||
"access_token": get_uuid(),
|
||||
"email": req["email"],
|
||||
"nickname": req["nickname"],
|
||||
"password": decrypt(req["password"]),
|
||||
"login_channel": "password",
|
||||
"last_login_time": get_format_time(),
|
||||
"is_superuser": False,
|
||||
}
|
||||
|
||||
user_id = get_uuid()
|
||||
try:
|
||||
users = user_register(user_id, user_dict)
|
||||
if not users: raise Exception('Register user failure.')
|
||||
if len(users) > 1: raise Exception('Same E-mail exist!')
|
||||
user = users[0]
|
||||
login_user(user)
|
||||
return cors_reponse(data=user.to_json(), auth=user.get_id(), retmsg="Welcome aboard!")
|
||||
except Exception as e:
|
||||
rollback_user_registration(user_id)
|
||||
stat_logger.exception(e)
|
||||
return get_json_result(data=False, retmsg='User registration failure!', retcode=RetCode.EXCEPTION_ERROR)
|
||||
|
||||
|
||||
|
||||
@manager.route("/tenant_info", methods=["GET"])
|
||||
@login_required
|
||||
def tenant_info():
|
||||
try:
|
||||
tenants = TenantService.get_by_user_id(current_user.id)[0]
|
||||
return get_json_result(data=tenants)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/set_tenant_info", methods=["POST"])
|
||||
@login_required
|
||||
@validate_request("tenant_id", "asr_id", "embd_id", "img2txt_id", "llm_id")
|
||||
def set_tenant_info():
|
||||
req = request.json
|
||||
try:
|
||||
tid = req["tenant_id"]
|
||||
del req["tenant_id"]
|
||||
TenantService.update_by_id(tid, req)
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
54
api/db/__init__.py
Normal file
54
api/db/__init__.py
Normal file
@ -0,0 +1,54 @@
|
||||
#
|
||||
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from enum import Enum
|
||||
from enum import IntEnum
|
||||
from strenum import StrEnum
|
||||
|
||||
|
||||
class StatusEnum(Enum):
|
||||
VALID = "1"
|
||||
IN_VALID = "0"
|
||||
|
||||
|
||||
class UserTenantRole(StrEnum):
|
||||
OWNER = 'owner'
|
||||
ADMIN = 'admin'
|
||||
NORMAL = 'normal'
|
||||
|
||||
|
||||
class TenantPermission(StrEnum):
|
||||
ME = 'me'
|
||||
TEAM = 'team'
|
||||
|
||||
|
||||
class SerializedType(IntEnum):
|
||||
PICKLE = 1
|
||||
JSON = 2
|
||||
|
||||
|
||||
class FileType(StrEnum):
|
||||
PDF = 'pdf'
|
||||
DOC = 'doc'
|
||||
VISUAL = 'visual'
|
||||
AURAL = 'aural'
|
||||
VIRTUAL = 'virtual'
|
||||
|
||||
|
||||
class LLMType(StrEnum):
|
||||
CHAT = 'chat'
|
||||
EMBEDDING = 'embedding'
|
||||
SPEECH2TEXT = 'speech2text'
|
||||
IMAGE2TEXT = 'image2text'
|
||||
619
api/db/db_models.py
Normal file
619
api/db/db_models.py
Normal file
@ -0,0 +1,619 @@
|
||||
#
|
||||
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import inspect
|
||||
import os
|
||||
import sys
|
||||
import typing
|
||||
import operator
|
||||
from functools import wraps
|
||||
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
||||
from flask_login import UserMixin
|
||||
|
||||
from peewee import (
|
||||
BigAutoField, BigIntegerField, BooleanField, CharField,
|
||||
CompositeKey, Insert, IntegerField, TextField, FloatField, DateTimeField,
|
||||
Field, Model, Metadata
|
||||
)
|
||||
from playhouse.pool import PooledMySQLDatabase
|
||||
|
||||
from web_server.db import SerializedType
|
||||
from web_server.settings import DATABASE, stat_logger, SECRET_KEY
|
||||
from web_server.utils.log_utils import getLogger
|
||||
from web_server import utils
|
||||
|
||||
LOGGER = getLogger()
|
||||
|
||||
|
||||
def singleton(cls, *args, **kw):
|
||||
instances = {}
|
||||
|
||||
def _singleton():
|
||||
key = str(cls) + str(os.getpid())
|
||||
if key not in instances:
|
||||
instances[key] = cls(*args, **kw)
|
||||
return instances[key]
|
||||
|
||||
return _singleton
|
||||
|
||||
|
||||
CONTINUOUS_FIELD_TYPE = {IntegerField, FloatField, DateTimeField}
|
||||
AUTO_DATE_TIMESTAMP_FIELD_PREFIX = {"create", "start", "end", "update", "read_access", "write_access"}
|
||||
|
||||
|
||||
class LongTextField(TextField):
|
||||
field_type = 'LONGTEXT'
|
||||
|
||||
|
||||
class JSONField(LongTextField):
|
||||
default_value = {}
|
||||
|
||||
def __init__(self, object_hook=None, object_pairs_hook=None, **kwargs):
|
||||
self._object_hook = object_hook
|
||||
self._object_pairs_hook = object_pairs_hook
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def db_value(self, value):
|
||||
if value is None:
|
||||
value = self.default_value
|
||||
return utils.json_dumps(value)
|
||||
|
||||
def python_value(self, value):
|
||||
if not value:
|
||||
return self.default_value
|
||||
return utils.json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
|
||||
|
||||
|
||||
class ListField(JSONField):
|
||||
default_value = []
|
||||
|
||||
|
||||
class SerializedField(LongTextField):
|
||||
def __init__(self, serialized_type=SerializedType.PICKLE, object_hook=None, object_pairs_hook=None, **kwargs):
|
||||
self._serialized_type = serialized_type
|
||||
self._object_hook = object_hook
|
||||
self._object_pairs_hook = object_pairs_hook
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def db_value(self, value):
|
||||
if self._serialized_type == SerializedType.PICKLE:
|
||||
return utils.serialize_b64(value, to_str=True)
|
||||
elif self._serialized_type == SerializedType.JSON:
|
||||
if value is None:
|
||||
return None
|
||||
return utils.json_dumps(value, with_type=True)
|
||||
else:
|
||||
raise ValueError(f"the serialized type {self._serialized_type} is not supported")
|
||||
|
||||
def python_value(self, value):
|
||||
if self._serialized_type == SerializedType.PICKLE:
|
||||
return utils.deserialize_b64(value)
|
||||
elif self._serialized_type == SerializedType.JSON:
|
||||
if value is None:
|
||||
return {}
|
||||
return utils.json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
|
||||
else:
|
||||
raise ValueError(f"the serialized type {self._serialized_type} is not supported")
|
||||
|
||||
|
||||
def is_continuous_field(cls: typing.Type) -> bool:
|
||||
if cls in CONTINUOUS_FIELD_TYPE:
|
||||
return True
|
||||
for p in cls.__bases__:
|
||||
if p in CONTINUOUS_FIELD_TYPE:
|
||||
return True
|
||||
elif p != Field and p != object:
|
||||
if is_continuous_field(p):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def auto_date_timestamp_field():
|
||||
return {f"{f}_time" for f in AUTO_DATE_TIMESTAMP_FIELD_PREFIX}
|
||||
|
||||
|
||||
def auto_date_timestamp_db_field():
|
||||
return {f"f_{f}_time" for f in AUTO_DATE_TIMESTAMP_FIELD_PREFIX}
|
||||
|
||||
|
||||
def remove_field_name_prefix(field_name):
|
||||
return field_name[2:] if field_name.startswith('f_') else field_name
|
||||
|
||||
|
||||
class BaseModel(Model):
|
||||
create_time = BigIntegerField(null=True)
|
||||
create_date = DateTimeField(null=True)
|
||||
update_time = BigIntegerField(null=True)
|
||||
update_date = DateTimeField(null=True)
|
||||
|
||||
def to_json(self):
|
||||
# This function is obsolete
|
||||
return self.to_dict()
|
||||
|
||||
def to_dict(self):
|
||||
return self.__dict__['__data__']
|
||||
|
||||
def to_human_model_dict(self, only_primary_with: list = None):
|
||||
model_dict = self.__dict__['__data__']
|
||||
|
||||
if not only_primary_with:
|
||||
return {remove_field_name_prefix(k): v for k, v in model_dict.items()}
|
||||
|
||||
human_model_dict = {}
|
||||
for k in self._meta.primary_key.field_names:
|
||||
human_model_dict[remove_field_name_prefix(k)] = model_dict[k]
|
||||
for k in only_primary_with:
|
||||
human_model_dict[k] = model_dict[f'f_{k}']
|
||||
return human_model_dict
|
||||
|
||||
@property
|
||||
def meta(self) -> Metadata:
|
||||
return self._meta
|
||||
|
||||
@classmethod
|
||||
def get_primary_keys_name(cls):
|
||||
return cls._meta.primary_key.field_names if isinstance(cls._meta.primary_key, CompositeKey) else [
|
||||
cls._meta.primary_key.name]
|
||||
|
||||
@classmethod
|
||||
def getter_by(cls, attr):
|
||||
return operator.attrgetter(attr)(cls)
|
||||
|
||||
@classmethod
|
||||
def query(cls, reverse=None, order_by=None, **kwargs):
|
||||
filters = []
|
||||
for f_n, f_v in kwargs.items():
|
||||
attr_name = '%s' % f_n
|
||||
if not hasattr(cls, attr_name) or f_v is None:
|
||||
continue
|
||||
if type(f_v) in {list, set}:
|
||||
f_v = list(f_v)
|
||||
if is_continuous_field(type(getattr(cls, attr_name))):
|
||||
if len(f_v) == 2:
|
||||
for i, v in enumerate(f_v):
|
||||
if isinstance(v, str) and f_n in auto_date_timestamp_field():
|
||||
# time type: %Y-%m-%d %H:%M:%S
|
||||
f_v[i] = utils.date_string_to_timestamp(v)
|
||||
lt_value = f_v[0]
|
||||
gt_value = f_v[1]
|
||||
if lt_value is not None and gt_value is not None:
|
||||
filters.append(cls.getter_by(attr_name).between(lt_value, gt_value))
|
||||
elif lt_value is not None:
|
||||
filters.append(operator.attrgetter(attr_name)(cls) >= lt_value)
|
||||
elif gt_value is not None:
|
||||
filters.append(operator.attrgetter(attr_name)(cls) <= gt_value)
|
||||
else:
|
||||
filters.append(operator.attrgetter(attr_name)(cls) << f_v)
|
||||
else:
|
||||
filters.append(operator.attrgetter(attr_name)(cls) == f_v)
|
||||
if filters:
|
||||
query_records = cls.select().where(*filters)
|
||||
if reverse is not None:
|
||||
if not order_by or not hasattr(cls, f"{order_by}"):
|
||||
order_by = "create_time"
|
||||
if reverse is True:
|
||||
query_records = query_records.order_by(cls.getter_by(f"{order_by}").desc())
|
||||
elif reverse is False:
|
||||
query_records = query_records.order_by(cls.getter_by(f"{order_by}").asc())
|
||||
return [query_record for query_record in query_records]
|
||||
else:
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def insert(cls, __data=None, **insert):
|
||||
if isinstance(__data, dict) and __data:
|
||||
__data[cls._meta.combined["create_time"]] = utils.current_timestamp()
|
||||
if insert:
|
||||
insert["create_time"] = utils.current_timestamp()
|
||||
|
||||
return super().insert(__data, **insert)
|
||||
|
||||
# update and insert will call this method
|
||||
@classmethod
|
||||
def _normalize_data(cls, data, kwargs):
|
||||
normalized = super()._normalize_data(data, kwargs)
|
||||
if not normalized:
|
||||
return {}
|
||||
|
||||
normalized[cls._meta.combined["update_time"]] = utils.current_timestamp()
|
||||
|
||||
for f_n in AUTO_DATE_TIMESTAMP_FIELD_PREFIX:
|
||||
if {f"{f_n}_time", f"{f_n}_date"}.issubset(cls._meta.combined.keys()) and \
|
||||
cls._meta.combined[f"{f_n}_time"] in normalized and \
|
||||
normalized[cls._meta.combined[f"{f_n}_time"]] is not None:
|
||||
normalized[cls._meta.combined[f"{f_n}_date"]] = utils.timestamp_to_date(
|
||||
normalized[cls._meta.combined[f"{f_n}_time"]])
|
||||
|
||||
return normalized
|
||||
|
||||
|
||||
class JsonSerializedField(SerializedField):
|
||||
def __init__(self, object_hook=utils.from_dict_hook, object_pairs_hook=None, **kwargs):
|
||||
super(JsonSerializedField, self).__init__(serialized_type=SerializedType.JSON, object_hook=object_hook,
|
||||
object_pairs_hook=object_pairs_hook, **kwargs)
|
||||
|
||||
|
||||
@singleton
|
||||
class BaseDataBase:
|
||||
def __init__(self):
|
||||
database_config = DATABASE.copy()
|
||||
db_name = database_config.pop("name")
|
||||
self.database_connection = PooledMySQLDatabase(db_name, **database_config)
|
||||
stat_logger.info('init mysql database on cluster mode successfully')
|
||||
|
||||
|
||||
class DatabaseLock:
|
||||
def __init__(self, lock_name, timeout=10, db=None):
|
||||
self.lock_name = lock_name
|
||||
self.timeout = int(timeout)
|
||||
self.db = db if db else DB
|
||||
|
||||
def lock(self):
|
||||
# SQL parameters only support %s format placeholders
|
||||
cursor = self.db.execute_sql("SELECT GET_LOCK(%s, %s)", (self.lock_name, self.timeout))
|
||||
ret = cursor.fetchone()
|
||||
if ret[0] == 0:
|
||||
raise Exception(f'acquire mysql lock {self.lock_name} timeout')
|
||||
elif ret[0] == 1:
|
||||
return True
|
||||
else:
|
||||
raise Exception(f'failed to acquire lock {self.lock_name}')
|
||||
|
||||
def unlock(self):
|
||||
cursor = self.db.execute_sql("SELECT RELEASE_LOCK(%s)", (self.lock_name,))
|
||||
ret = cursor.fetchone()
|
||||
if ret[0] == 0:
|
||||
raise Exception(f'mysql lock {self.lock_name} was not established by this thread')
|
||||
elif ret[0] == 1:
|
||||
return True
|
||||
else:
|
||||
raise Exception(f'mysql lock {self.lock_name} does not exist')
|
||||
|
||||
def __enter__(self):
|
||||
if isinstance(self.db, PooledMySQLDatabase):
|
||||
self.lock()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
if isinstance(self.db, PooledMySQLDatabase):
|
||||
self.unlock()
|
||||
|
||||
def __call__(self, func):
|
||||
@wraps(func)
|
||||
def magic(*args, **kwargs):
|
||||
with self:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return magic
|
||||
|
||||
|
||||
DB = BaseDataBase().database_connection
|
||||
DB.lock = DatabaseLock
|
||||
|
||||
|
||||
def close_connection():
|
||||
try:
|
||||
if DB:
|
||||
DB.close()
|
||||
except Exception as e:
|
||||
LOGGER.exception(e)
|
||||
|
||||
|
||||
class DataBaseModel(BaseModel):
|
||||
class Meta:
|
||||
database = DB
|
||||
|
||||
|
||||
@DB.connection_context()
|
||||
def init_database_tables():
|
||||
members = inspect.getmembers(sys.modules[__name__], inspect.isclass)
|
||||
table_objs = []
|
||||
create_failed_list = []
|
||||
for name, obj in members:
|
||||
if obj != DataBaseModel and issubclass(obj, DataBaseModel):
|
||||
table_objs.append(obj)
|
||||
LOGGER.info(f"start create table {obj.__name__}")
|
||||
try:
|
||||
obj.create_table()
|
||||
LOGGER.info(f"create table success: {obj.__name__}")
|
||||
except Exception as e:
|
||||
LOGGER.exception(e)
|
||||
create_failed_list.append(obj.__name__)
|
||||
if create_failed_list:
|
||||
LOGGER.info(f"create tables failed: {create_failed_list}")
|
||||
raise Exception(f"create tables failed: {create_failed_list}")
|
||||
|
||||
|
||||
def fill_db_model_object(model_object, human_model_dict):
|
||||
for k, v in human_model_dict.items():
|
||||
attr_name = '%s' % k
|
||||
if hasattr(model_object.__class__, attr_name):
|
||||
setattr(model_object, attr_name, v)
|
||||
return model_object
|
||||
|
||||
|
||||
class User(DataBaseModel, UserMixin):
|
||||
id = CharField(max_length=32, primary_key=True)
|
||||
access_token = CharField(max_length=255, null=True)
|
||||
nickname = CharField(max_length=100, null=False, help_text="nicky name")
|
||||
password = CharField(max_length=255, null=True, help_text="password")
|
||||
email = CharField(max_length=255, null=False, help_text="email", index=True)
|
||||
avatar = TextField(null=True, help_text="avatar base64 string")
|
||||
language = CharField(max_length=32, null=True, help_text="English|Chinese", default="Chinese")
|
||||
color_schema = CharField(max_length=32, null=True, help_text="Bright|Dark", default="Dark")
|
||||
last_login_time = DateTimeField(null=True)
|
||||
is_authenticated = CharField(max_length=1, null=False, default="1")
|
||||
is_active = CharField(max_length=1, null=False, default="1")
|
||||
is_anonymous = CharField(max_length=1, null=False, default="0")
|
||||
login_channel = CharField(null=True, help_text="from which user login")
|
||||
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
|
||||
is_superuser = BooleanField(null=True, help_text="is root", default=False)
|
||||
|
||||
def __str__(self):
|
||||
return self.email
|
||||
|
||||
def get_id(self):
|
||||
jwt = Serializer(secret_key=SECRET_KEY)
|
||||
return jwt.dumps(str(self.access_token))
|
||||
|
||||
class Meta:
|
||||
db_table = "user"
|
||||
|
||||
|
||||
class Tenant(DataBaseModel):
|
||||
id = CharField(max_length=32, primary_key=True)
|
||||
name = CharField(max_length=100, null=True, help_text="Tenant name")
|
||||
public_key = CharField(max_length=255, null=True)
|
||||
llm_id = CharField(max_length=128, null=False, help_text="default llm ID")
|
||||
embd_id = CharField(max_length=128, null=False, help_text="default embedding model ID")
|
||||
asr_id = CharField(max_length=128, null=False, help_text="default ASR model ID")
|
||||
img2txt_id = CharField(max_length=128, null=False, help_text="default image to text model ID")
|
||||
parser_ids = CharField(max_length=128, null=False, help_text="default image to text model ID")
|
||||
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
|
||||
|
||||
class Meta:
|
||||
db_table = "tenant"
|
||||
|
||||
|
||||
class UserTenant(DataBaseModel):
|
||||
id = CharField(max_length=32, primary_key=True)
|
||||
user_id = CharField(max_length=32, null=False)
|
||||
tenant_id = CharField(max_length=32, null=False)
|
||||
role = CharField(max_length=32, null=False, help_text="UserTenantRole")
|
||||
invited_by = CharField(max_length=32, null=False)
|
||||
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
|
||||
|
||||
class Meta:
|
||||
db_table = "user_tenant"
|
||||
|
||||
|
||||
class InvitationCode(DataBaseModel):
|
||||
id = CharField(max_length=32, primary_key=True)
|
||||
code = CharField(max_length=32, null=False)
|
||||
visit_time = DateTimeField(null=True)
|
||||
user_id = CharField(max_length=32, null=True)
|
||||
tenant_id = CharField(max_length=32, null=True)
|
||||
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
|
||||
|
||||
class Meta:
|
||||
db_table = "invitation_code"
|
||||
|
||||
|
||||
class LLMFactories(DataBaseModel):
|
||||
name = CharField(max_length=128, null=False, help_text="LLM factory name", primary_key=True)
|
||||
logo = TextField(null=True, help_text="llm logo base64")
|
||||
tags = CharField(max_length=255, null=False, help_text="LLM, Text Embedding, Image2Text, ASR")
|
||||
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
class Meta:
|
||||
db_table = "llm_factories"
|
||||
|
||||
|
||||
class LLM(DataBaseModel):
|
||||
# defautlt LLMs for every users
|
||||
llm_name = CharField(max_length=128, null=False, help_text="LLM name", primary_key=True)
|
||||
model_type = CharField(max_length=128, null=False, help_text="LLM, Text Embedding, Image2Text, ASR")
|
||||
fid = CharField(max_length=128, null=False, help_text="LLM factory id")
|
||||
tags = CharField(max_length=255, null=False, help_text="LLM, Text Embedding, Image2Text, Chat, 32k...")
|
||||
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
|
||||
|
||||
def __str__(self):
|
||||
return self.llm_name
|
||||
|
||||
class Meta:
|
||||
db_table = "llm"
|
||||
|
||||
|
||||
class TenantLLM(DataBaseModel):
|
||||
tenant_id = CharField(max_length=32, null=False)
|
||||
llm_factory = CharField(max_length=128, null=False, help_text="LLM factory name")
|
||||
model_type = CharField(max_length=128, null=True, help_text="LLM, Text Embedding, Image2Text, ASR")
|
||||
llm_name = CharField(max_length=128, null=True, help_text="LLM name", default="")
|
||||
api_key = CharField(max_length=255, null=True, help_text="API KEY")
|
||||
api_base = CharField(max_length=255, null=True, help_text="API Base")
|
||||
|
||||
def __str__(self):
|
||||
return self.llm_name
|
||||
|
||||
class Meta:
|
||||
db_table = "tenant_llm"
|
||||
primary_key = CompositeKey('tenant_id', 'llm_factory', 'llm_name')
|
||||
|
||||
|
||||
class Knowledgebase(DataBaseModel):
|
||||
id = CharField(max_length=32, primary_key=True)
|
||||
avatar = TextField(null=True, help_text="avatar base64 string")
|
||||
tenant_id = CharField(max_length=32, null=False)
|
||||
name = CharField(max_length=128, null=False, help_text="KB name", index=True)
|
||||
description = TextField(null=True, help_text="KB description")
|
||||
permission = CharField(max_length=16, null=False, help_text="me|team")
|
||||
created_by = CharField(max_length=32, null=False)
|
||||
doc_num = IntegerField(default=0)
|
||||
token_num = IntegerField(default=0)
|
||||
chunk_num = IntegerField(default=0)
|
||||
|
||||
parser_id = CharField(max_length=32, null=False, help_text="default parser ID")
|
||||
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
class Meta:
|
||||
db_table = "knowledgebase"
|
||||
|
||||
|
||||
class Document(DataBaseModel):
|
||||
id = CharField(max_length=32, primary_key=True)
|
||||
thumbnail = TextField(null=True, help_text="thumbnail base64 string")
|
||||
kb_id = CharField(max_length=256, null=False, index=True)
|
||||
parser_id = CharField(max_length=32, null=False, help_text="default parser ID")
|
||||
source_type = CharField(max_length=128, null=False, default="local", help_text="where dose this document from")
|
||||
type = CharField(max_length=32, null=False, help_text="file extension")
|
||||
created_by = CharField(max_length=32, null=False, help_text="who created it")
|
||||
name = CharField(max_length=255, null=True, help_text="file name", index=True)
|
||||
location = CharField(max_length=255, null=True, help_text="where dose it store")
|
||||
size = IntegerField(default=0)
|
||||
token_num = IntegerField(default=0)
|
||||
chunk_num = IntegerField(default=0)
|
||||
progress = FloatField(default=0)
|
||||
progress_msg = CharField(max_length=255, null=True, help_text="process message", default="")
|
||||
process_begin_at = DateTimeField(null=True)
|
||||
process_duation = FloatField(default=0)
|
||||
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
|
||||
|
||||
class Meta:
|
||||
db_table = "document"
|
||||
|
||||
|
||||
class Dialog(DataBaseModel):
|
||||
id = CharField(max_length=32, primary_key=True)
|
||||
tenant_id = CharField(max_length=32, null=False)
|
||||
name = CharField(max_length=255, null=True, help_text="dialog application name")
|
||||
description = TextField(null=True, help_text="Dialog description")
|
||||
icon = CharField(max_length=16, null=False, help_text="dialog icon")
|
||||
language = CharField(max_length=32, null=True, default="Chinese", help_text="English|Chinese")
|
||||
llm_id = CharField(max_length=32, null=False, help_text="default llm ID")
|
||||
llm_setting_type = CharField(max_length=8, null=False, help_text="Creative|Precise|Evenly|Custom",
|
||||
default="Creative")
|
||||
llm_setting = JSONField(null=False, default={"temperature": 0.1, "top_p": 0.3, "frequency_penalty": 0.7,
|
||||
"presence_penalty": 0.4, "max_tokens": 215})
|
||||
prompt_type = CharField(max_length=16, null=False, default="simple", help_text="simple|advanced")
|
||||
prompt_config = JSONField(null=False, default={"system": "", "prologue": "您好,我是您的助手小樱,长得可爱又善良,can I help you?",
|
||||
"parameters": [], "empty_response": "Sorry! 知识库中未找到相关内容!"})
|
||||
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
|
||||
|
||||
class Meta:
|
||||
db_table = "dialog"
|
||||
|
||||
|
||||
class DialogKb(DataBaseModel):
|
||||
dialog_id = CharField(max_length=32, null=False, index=True)
|
||||
kb_id = CharField(max_length=32, null=False)
|
||||
|
||||
class Meta:
|
||||
db_table = "dialog_kb"
|
||||
primary_key = CompositeKey('dialog_id', 'kb_id')
|
||||
|
||||
|
||||
class Conversation(DataBaseModel):
|
||||
id = CharField(max_length=32, primary_key=True)
|
||||
dialog_id = CharField(max_length=32, null=False, index=True)
|
||||
name = CharField(max_length=255, null=True, help_text="converastion name")
|
||||
message = JSONField(null=True)
|
||||
|
||||
class Meta:
|
||||
db_table = "conversation"
|
||||
|
||||
|
||||
"""
|
||||
class Job(DataBaseModel):
|
||||
# multi-party common configuration
|
||||
f_user_id = CharField(max_length=25, null=True)
|
||||
f_job_id = CharField(max_length=25, index=True)
|
||||
f_name = CharField(max_length=500, null=True, default='')
|
||||
f_description = TextField(null=True, default='')
|
||||
f_tag = CharField(max_length=50, null=True, default='')
|
||||
f_dsl = JSONField()
|
||||
f_runtime_conf = JSONField()
|
||||
f_runtime_conf_on_party = JSONField()
|
||||
f_train_runtime_conf = JSONField(null=True)
|
||||
f_roles = JSONField()
|
||||
f_initiator_role = CharField(max_length=50)
|
||||
f_initiator_party_id = CharField(max_length=50)
|
||||
f_status = CharField(max_length=50)
|
||||
f_status_code = IntegerField(null=True)
|
||||
f_user = JSONField()
|
||||
# this party configuration
|
||||
f_role = CharField(max_length=50, index=True)
|
||||
f_party_id = CharField(max_length=10, index=True)
|
||||
f_is_initiator = BooleanField(null=True, default=False)
|
||||
f_progress = IntegerField(null=True, default=0)
|
||||
f_ready_signal = BooleanField(default=False)
|
||||
f_ready_time = BigIntegerField(null=True)
|
||||
f_cancel_signal = BooleanField(default=False)
|
||||
f_cancel_time = BigIntegerField(null=True)
|
||||
f_rerun_signal = BooleanField(default=False)
|
||||
f_end_scheduling_updates = IntegerField(null=True, default=0)
|
||||
|
||||
f_engine_name = CharField(max_length=50, null=True)
|
||||
f_engine_type = CharField(max_length=10, null=True)
|
||||
f_cores = IntegerField(default=0)
|
||||
f_memory = IntegerField(default=0) # MB
|
||||
f_remaining_cores = IntegerField(default=0)
|
||||
f_remaining_memory = IntegerField(default=0) # MB
|
||||
f_resource_in_use = BooleanField(default=False)
|
||||
f_apply_resource_time = BigIntegerField(null=True)
|
||||
f_return_resource_time = BigIntegerField(null=True)
|
||||
|
||||
f_inheritance_info = JSONField(null=True)
|
||||
f_inheritance_status = CharField(max_length=50, null=True)
|
||||
|
||||
f_start_time = BigIntegerField(null=True)
|
||||
f_start_date = DateTimeField(null=True)
|
||||
f_end_time = BigIntegerField(null=True)
|
||||
f_end_date = DateTimeField(null=True)
|
||||
f_elapsed = BigIntegerField(null=True)
|
||||
|
||||
class Meta:
|
||||
db_table = "t_job"
|
||||
primary_key = CompositeKey('f_job_id', 'f_role', 'f_party_id')
|
||||
|
||||
|
||||
|
||||
class PipelineComponentMeta(DataBaseModel):
|
||||
f_model_id = CharField(max_length=100, index=True)
|
||||
f_model_version = CharField(max_length=100, index=True)
|
||||
f_role = CharField(max_length=50, index=True)
|
||||
f_party_id = CharField(max_length=10, index=True)
|
||||
f_component_name = CharField(max_length=100, index=True)
|
||||
f_component_module_name = CharField(max_length=100)
|
||||
f_model_alias = CharField(max_length=100, index=True)
|
||||
f_model_proto_index = JSONField(null=True)
|
||||
f_run_parameters = JSONField(null=True)
|
||||
f_archive_sha256 = CharField(max_length=100, null=True)
|
||||
f_archive_from_ip = CharField(max_length=100, null=True)
|
||||
|
||||
class Meta:
|
||||
db_table = 't_pipeline_component_meta'
|
||||
indexes = (
|
||||
(('f_model_id', 'f_model_version', 'f_role', 'f_party_id', 'f_component_name'), True),
|
||||
)
|
||||
|
||||
|
||||
"""
|
||||
157
api/db/db_services.py
Normal file
157
api/db/db_services.py
Normal file
@ -0,0 +1,157 @@
|
||||
#
|
||||
# Copyright 2021 The RAG Flow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import abc
|
||||
import json
|
||||
import time
|
||||
from functools import wraps
|
||||
from shortuuid import ShortUUID
|
||||
|
||||
from web_server.versions import get_rag_version
|
||||
|
||||
from web_server.errors.error_services import *
|
||||
from web_server.settings import (
|
||||
GRPC_PORT, HOST, HTTP_PORT,
|
||||
RANDOM_INSTANCE_ID, stat_logger,
|
||||
)
|
||||
|
||||
|
||||
instance_id = ShortUUID().random(length=8) if RANDOM_INSTANCE_ID else f'flow-{HOST}-{HTTP_PORT}'
|
||||
server_instance = (
|
||||
f'{HOST}:{GRPC_PORT}',
|
||||
json.dumps({
|
||||
'instance_id': instance_id,
|
||||
'timestamp': round(time.time() * 1000),
|
||||
'version': get_rag_version() or '',
|
||||
'host': HOST,
|
||||
'grpc_port': GRPC_PORT,
|
||||
'http_port': HTTP_PORT,
|
||||
}),
|
||||
)
|
||||
|
||||
|
||||
def check_service_supported(method):
|
||||
"""Decorator to check if `service_name` is supported.
|
||||
The attribute `supported_services` MUST be defined in class.
|
||||
The first and second arguments of `method` MUST be `self` and `service_name`.
|
||||
|
||||
:param Callable method: The class method.
|
||||
:return: The inner wrapper function.
|
||||
:rtype: Callable
|
||||
"""
|
||||
@wraps(method)
|
||||
def magic(self, service_name, *args, **kwargs):
|
||||
if service_name not in self.supported_services:
|
||||
raise ServiceNotSupported(service_name=service_name)
|
||||
return method(self, service_name, *args, **kwargs)
|
||||
return magic
|
||||
|
||||
|
||||
class ServicesDB(abc.ABC):
|
||||
"""Database for storage service urls.
|
||||
Abstract base class for the real backends.
|
||||
|
||||
"""
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def supported_services(self):
|
||||
"""The names of supported services.
|
||||
The returned list SHOULD contain `ragflow` (model download) and `servings` (RAG-Serving).
|
||||
|
||||
:return: The service names.
|
||||
:rtype: list
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_serving(self):
|
||||
pass
|
||||
|
||||
def get_serving(self):
|
||||
|
||||
try:
|
||||
return self._get_serving()
|
||||
except ServicesError as e:
|
||||
stat_logger.exception(e)
|
||||
return []
|
||||
|
||||
@abc.abstractmethod
|
||||
def _insert(self, service_name, service_url, value=''):
|
||||
pass
|
||||
|
||||
@check_service_supported
|
||||
def insert(self, service_name, service_url, value=''):
|
||||
"""Insert a service url to database.
|
||||
|
||||
:param str service_name: The service name.
|
||||
:param str service_url: The service url.
|
||||
:return: None
|
||||
"""
|
||||
try:
|
||||
self._insert(service_name, service_url, value)
|
||||
except ServicesError as e:
|
||||
stat_logger.exception(e)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _delete(self, service_name, service_url):
|
||||
pass
|
||||
|
||||
@check_service_supported
|
||||
def delete(self, service_name, service_url):
|
||||
"""Delete a service url from database.
|
||||
|
||||
:param str service_name: The service name.
|
||||
:param str service_url: The service url.
|
||||
:return: None
|
||||
"""
|
||||
try:
|
||||
self._delete(service_name, service_url)
|
||||
except ServicesError as e:
|
||||
stat_logger.exception(e)
|
||||
|
||||
def register_flow(self):
|
||||
"""Call `self.insert` for insert the flow server address to databae.
|
||||
|
||||
:return: None
|
||||
"""
|
||||
self.insert('flow-server', *server_instance)
|
||||
|
||||
def unregister_flow(self):
|
||||
"""Call `self.delete` for delete the flow server address from databae.
|
||||
|
||||
:return: None
|
||||
"""
|
||||
self.delete('flow-server', server_instance[0])
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_urls(self, service_name, with_values=False):
|
||||
pass
|
||||
|
||||
@check_service_supported
|
||||
def get_urls(self, service_name, with_values=False):
|
||||
"""Query service urls from database. The urls may belong to other nodes.
|
||||
Currently, only `ragflow` (model download) urls and `servings` (RAG-Serving) urls are supported.
|
||||
`ragflow` is a url containing scheme, host, port and path,
|
||||
while `servings` only contains host and port.
|
||||
|
||||
:param str service_name: The service name.
|
||||
:return: The service urls.
|
||||
:rtype: list
|
||||
"""
|
||||
try:
|
||||
return self._get_urls(service_name, with_values)
|
||||
except ServicesError as e:
|
||||
stat_logger.exception(e)
|
||||
return []
|
||||
131
api/db/db_utils.py
Normal file
131
api/db/db_utils.py
Normal file
@ -0,0 +1,131 @@
|
||||
#
|
||||
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import operator
|
||||
from functools import reduce
|
||||
from typing import Dict, Type, Union
|
||||
|
||||
from web_server.utils import current_timestamp, timestamp_to_date
|
||||
|
||||
from web_server.db.db_models import DB, DataBaseModel
|
||||
from web_server.db.runtime_config import RuntimeConfig
|
||||
from web_server.utils.log_utils import getLogger
|
||||
from enum import Enum
|
||||
|
||||
|
||||
LOGGER = getLogger()
|
||||
|
||||
|
||||
@DB.connection_context()
|
||||
def bulk_insert_into_db(model, data_source, replace_on_conflict=False):
|
||||
DB.create_tables([model])
|
||||
|
||||
current_time = current_timestamp()
|
||||
current_date = timestamp_to_date(current_time)
|
||||
|
||||
for data in data_source:
|
||||
if 'f_create_time' not in data:
|
||||
data['f_create_time'] = current_time
|
||||
data['f_create_date'] = timestamp_to_date(data['f_create_time'])
|
||||
data['f_update_time'] = current_time
|
||||
data['f_update_date'] = current_date
|
||||
|
||||
preserve = tuple(data_source[0].keys() - {'f_create_time', 'f_create_date'})
|
||||
|
||||
batch_size = 50 if RuntimeConfig.USE_LOCAL_DATABASE else 1000
|
||||
|
||||
for i in range(0, len(data_source), batch_size):
|
||||
with DB.atomic():
|
||||
query = model.insert_many(data_source[i:i + batch_size])
|
||||
if replace_on_conflict:
|
||||
query = query.on_conflict(preserve=preserve)
|
||||
query.execute()
|
||||
|
||||
|
||||
def get_dynamic_db_model(base, job_id):
|
||||
return type(base.model(table_index=get_dynamic_tracking_table_index(job_id=job_id)))
|
||||
|
||||
|
||||
def get_dynamic_tracking_table_index(job_id):
|
||||
return job_id[:8]
|
||||
|
||||
|
||||
def fill_db_model_object(model_object, human_model_dict):
|
||||
for k, v in human_model_dict.items():
|
||||
attr_name = 'f_%s' % k
|
||||
if hasattr(model_object.__class__, attr_name):
|
||||
setattr(model_object, attr_name, v)
|
||||
return model_object
|
||||
|
||||
|
||||
# https://docs.peewee-orm.com/en/latest/peewee/query_operators.html
|
||||
supported_operators = {
|
||||
'==': operator.eq,
|
||||
'<': operator.lt,
|
||||
'<=': operator.le,
|
||||
'>': operator.gt,
|
||||
'>=': operator.ge,
|
||||
'!=': operator.ne,
|
||||
'<<': operator.lshift,
|
||||
'>>': operator.rshift,
|
||||
'%': operator.mod,
|
||||
'**': operator.pow,
|
||||
'^': operator.xor,
|
||||
'~': operator.inv,
|
||||
}
|
||||
|
||||
def query_dict2expression(model: Type[DataBaseModel], query: Dict[str, Union[bool, int, str, list, tuple]]):
|
||||
expression = []
|
||||
|
||||
for field, value in query.items():
|
||||
if not isinstance(value, (list, tuple)):
|
||||
value = ('==', value)
|
||||
op, *val = value
|
||||
|
||||
field = getattr(model, f'f_{field}')
|
||||
value = supported_operators[op](field, val[0]) if op in supported_operators else getattr(field, op)(*val)
|
||||
expression.append(value)
|
||||
|
||||
return reduce(operator.iand, expression)
|
||||
|
||||
|
||||
def query_db(model: Type[DataBaseModel], limit: int = 0, offset: int = 0,
|
||||
query: dict = None, order_by: Union[str, list, tuple] = None):
|
||||
data = model.select()
|
||||
if query:
|
||||
data = data.where(query_dict2expression(model, query))
|
||||
count = data.count()
|
||||
|
||||
if not order_by:
|
||||
order_by = 'create_time'
|
||||
if not isinstance(order_by, (list, tuple)):
|
||||
order_by = (order_by, 'asc')
|
||||
order_by, order = order_by
|
||||
order_by = getattr(model, f'f_{order_by}')
|
||||
order_by = getattr(order_by, order)()
|
||||
data = data.order_by(order_by)
|
||||
|
||||
if limit > 0:
|
||||
data = data.limit(limit)
|
||||
if offset > 0:
|
||||
data = data.offset(offset)
|
||||
|
||||
return list(data), count
|
||||
|
||||
|
||||
class StatusEnum(Enum):
|
||||
# 样本可用状态
|
||||
VALID = "1"
|
||||
IN_VALID = "0"
|
||||
141
api/db/init_data.py
Normal file
141
api/db/init_data.py
Normal file
@ -0,0 +1,141 @@
|
||||
#
|
||||
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import time
|
||||
import uuid
|
||||
|
||||
from web_server.db import LLMType
|
||||
from web_server.db.db_models import init_database_tables as init_web_db
|
||||
from web_server.db.services import UserService
|
||||
from web_server.db.services.llm_service import LLMFactoriesService, LLMService
|
||||
|
||||
|
||||
def init_superuser():
|
||||
user_info = {
|
||||
"id": uuid.uuid1().hex,
|
||||
"password": "admin",
|
||||
"nickname": "admin",
|
||||
"is_superuser": True,
|
||||
"email": "kai.hu@infiniflow.org",
|
||||
"creator": "system",
|
||||
"status": "1",
|
||||
}
|
||||
UserService.save(**user_info)
|
||||
|
||||
|
||||
def init_llm_factory():
|
||||
factory_infos = [{
|
||||
"name": "OpenAI",
|
||||
"logo": "",
|
||||
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
|
||||
"status": "1",
|
||||
},{
|
||||
"name": "通义千问",
|
||||
"logo": "",
|
||||
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
|
||||
"status": "1",
|
||||
},{
|
||||
"name": "智普AI",
|
||||
"logo": "",
|
||||
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
|
||||
"status": "1",
|
||||
},{
|
||||
"name": "文心一言",
|
||||
"logo": "",
|
||||
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
|
||||
"status": "1",
|
||||
},
|
||||
]
|
||||
llm_infos = [{
|
||||
"fid": factory_infos[0]["name"],
|
||||
"llm_name": "gpt-3.5-turbo",
|
||||
"tags": "LLM,CHAT,4K",
|
||||
"model_type": LLMType.CHAT.value
|
||||
},{
|
||||
"fid": factory_infos[0]["name"],
|
||||
"llm_name": "gpt-3.5-turbo-16k-0613",
|
||||
"tags": "LLM,CHAT,16k",
|
||||
"model_type": LLMType.CHAT.value
|
||||
},{
|
||||
"fid": factory_infos[0]["name"],
|
||||
"llm_name": "text-embedding-ada-002",
|
||||
"tags": "TEXT EMBEDDING,8K",
|
||||
"model_type": LLMType.EMBEDDING.value
|
||||
},{
|
||||
"fid": factory_infos[0]["name"],
|
||||
"llm_name": "whisper-1",
|
||||
"tags": "SPEECH2TEXT",
|
||||
"model_type": LLMType.SPEECH2TEXT.value
|
||||
},{
|
||||
"fid": factory_infos[0]["name"],
|
||||
"llm_name": "gpt-4",
|
||||
"tags": "LLM,CHAT,8K",
|
||||
"model_type": LLMType.CHAT.value
|
||||
},{
|
||||
"fid": factory_infos[0]["name"],
|
||||
"llm_name": "gpt-4-32k",
|
||||
"tags": "LLM,CHAT,32K",
|
||||
"model_type": LLMType.CHAT.value
|
||||
},{
|
||||
"fid": factory_infos[0]["name"],
|
||||
"llm_name": "gpt-4-vision-preview",
|
||||
"tags": "LLM,CHAT,IMAGE2TEXT",
|
||||
"model_type": LLMType.IMAGE2TEXT.value
|
||||
},{
|
||||
"fid": factory_infos[1]["name"],
|
||||
"llm_name": "qwen-turbo",
|
||||
"tags": "LLM,CHAT,8K",
|
||||
"model_type": LLMType.CHAT.value
|
||||
},{
|
||||
"fid": factory_infos[1]["name"],
|
||||
"llm_name": "qwen-plus",
|
||||
"tags": "LLM,CHAT,32K",
|
||||
"model_type": LLMType.CHAT.value
|
||||
},{
|
||||
"fid": factory_infos[1]["name"],
|
||||
"llm_name": "text-embedding-v2",
|
||||
"tags": "TEXT EMBEDDING,2K",
|
||||
"model_type": LLMType.EMBEDDING.value
|
||||
},{
|
||||
"fid": factory_infos[1]["name"],
|
||||
"llm_name": "paraformer-realtime-8k-v1",
|
||||
"tags": "SPEECH2TEXT",
|
||||
"model_type": LLMType.SPEECH2TEXT.value
|
||||
},{
|
||||
"fid": factory_infos[1]["name"],
|
||||
"llm_name": "qwen_vl_chat_v1",
|
||||
"tags": "LLM,CHAT,IMAGE2TEXT",
|
||||
"model_type": LLMType.IMAGE2TEXT.value
|
||||
},
|
||||
]
|
||||
for info in factory_infos:
|
||||
LLMFactoriesService.save(**info)
|
||||
for info in llm_infos:
|
||||
LLMService.save(**info)
|
||||
|
||||
|
||||
def init_web_data():
|
||||
start_time = time.time()
|
||||
if not UserService.get_all().count():
|
||||
init_superuser()
|
||||
|
||||
if not LLMService.get_all().count():init_llm_factory()
|
||||
|
||||
print("init web data success:{}".format(time.time() - start_time))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
init_web_db()
|
||||
init_web_data()
|
||||
21
api/db/operatioins.py
Normal file
21
api/db/operatioins.py
Normal file
@ -0,0 +1,21 @@
|
||||
#
|
||||
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import operator
|
||||
import time
|
||||
import typing
|
||||
from web_server.utils.log_utils import sql_logger
|
||||
import peewee
|
||||
27
api/db/reload_config_base.py
Normal file
27
api/db/reload_config_base.py
Normal file
@ -0,0 +1,27 @@
|
||||
#
|
||||
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
class ReloadConfigBase:
|
||||
@classmethod
|
||||
def get_all(cls):
|
||||
configs = {}
|
||||
for k, v in cls.__dict__.items():
|
||||
if not callable(getattr(cls, k)) and not k.startswith("__") and not k.startswith("_"):
|
||||
configs[k] = v
|
||||
return configs
|
||||
|
||||
@classmethod
|
||||
def get(cls, config_name):
|
||||
return getattr(cls, config_name) if hasattr(cls, config_name) else None
|
||||
54
api/db/runtime_config.py
Normal file
54
api/db/runtime_config.py
Normal file
@ -0,0 +1,54 @@
|
||||
#
|
||||
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from web_server.versions import get_versions
|
||||
from .reload_config_base import ReloadConfigBase
|
||||
|
||||
|
||||
class RuntimeConfig(ReloadConfigBase):
|
||||
DEBUG = None
|
||||
WORK_MODE = None
|
||||
HTTP_PORT = None
|
||||
JOB_SERVER_HOST = None
|
||||
JOB_SERVER_VIP = None
|
||||
ENV = dict()
|
||||
SERVICE_DB = None
|
||||
LOAD_CONFIG_MANAGER = False
|
||||
|
||||
@classmethod
|
||||
def init_config(cls, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
if hasattr(cls, k):
|
||||
setattr(cls, k, v)
|
||||
|
||||
@classmethod
|
||||
def init_env(cls):
|
||||
cls.ENV.update(get_versions())
|
||||
|
||||
@classmethod
|
||||
def load_config_manager(cls):
|
||||
cls.LOAD_CONFIG_MANAGER = True
|
||||
|
||||
@classmethod
|
||||
def get_env(cls, key):
|
||||
return cls.ENV.get(key, None)
|
||||
|
||||
@classmethod
|
||||
def get_all_env(cls):
|
||||
return cls.ENV
|
||||
|
||||
@classmethod
|
||||
def set_service_db(cls, service_db):
|
||||
cls.SERVICE_DB = service_db
|
||||
38
api/db/services/__init__.py
Normal file
38
api/db/services/__init__.py
Normal file
@ -0,0 +1,38 @@
|
||||
#
|
||||
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import pathlib
|
||||
import re
|
||||
from .user_service import UserService
|
||||
|
||||
|
||||
def duplicate_name(query_func, **kwargs):
|
||||
fnm = kwargs["name"]
|
||||
objs = query_func(**kwargs)
|
||||
if not objs: return fnm
|
||||
ext = pathlib.Path(fnm).suffix #.jpg
|
||||
nm = re.sub(r"%s$"%ext, "", fnm)
|
||||
r = re.search(r"\([0-9]+\)$", nm)
|
||||
c = 0
|
||||
if r:
|
||||
c = int(r.group(1))
|
||||
nm = re.sub(r"\([0-9]+\)$", "", nm)
|
||||
c += 1
|
||||
nm = f"{nm}({c})"
|
||||
if ext: nm += f"{ext}"
|
||||
|
||||
kwargs["name"] = nm
|
||||
return duplicate_name(query_func, **kwargs)
|
||||
|
||||
153
api/db/services/common_service.py
Normal file
153
api/db/services/common_service.py
Normal file
@ -0,0 +1,153 @@
|
||||
#
|
||||
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from datetime import datetime
|
||||
|
||||
import peewee
|
||||
|
||||
from web_server.db.db_models import DB
|
||||
from web_server.utils import datetime_format
|
||||
|
||||
|
||||
class CommonService:
|
||||
model = None
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def query(cls, cols=None, reverse=None, order_by=None, **kwargs):
|
||||
return cls.model.query(cols=cols, reverse=reverse, order_by=order_by, **kwargs)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_all(cls, cols=None, reverse=None, order_by=None):
|
||||
if cols:
|
||||
query_records = cls.model.select(*cols)
|
||||
else:
|
||||
query_records = cls.model.select()
|
||||
if reverse is not None:
|
||||
if not order_by or not hasattr(cls, order_by):
|
||||
order_by = "create_time"
|
||||
if reverse is True:
|
||||
query_records = query_records.order_by(cls.model.getter_by(order_by).desc())
|
||||
elif reverse is False:
|
||||
query_records = query_records.order_by(cls.model.getter_by(order_by).asc())
|
||||
return query_records
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get(cls, **kwargs):
|
||||
return cls.model.get(**kwargs)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_or_none(cls, **kwargs):
|
||||
try:
|
||||
return cls.model.get(**kwargs)
|
||||
except peewee.DoesNotExist:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def save(cls, **kwargs):
|
||||
#if "id" not in kwargs:
|
||||
# kwargs["id"] = get_uuid()
|
||||
sample_obj = cls.model(**kwargs).save(force_insert=True)
|
||||
return sample_obj
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def insert_many(cls, data_list, batch_size=100):
|
||||
with DB.atomic():
|
||||
for i in range(0, len(data_list), batch_size):
|
||||
cls.model.insert_many(data_list[i:i + batch_size]).execute()
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def update_many_by_id(cls, data_list):
|
||||
cur = datetime_format(datetime.now())
|
||||
with DB.atomic():
|
||||
for data in data_list:
|
||||
data["update_time"] = cur
|
||||
cls.model.update(data).where(cls.model.id == data["id"]).execute()
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def update_by_id(cls, pid, data):
|
||||
data["update_time"] = datetime_format(datetime.now())
|
||||
num = cls.model.update(data).where(cls.model.id == pid).execute()
|
||||
return num
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_by_id(cls, pid):
|
||||
try:
|
||||
obj = cls.model.query(id=pid)[0]
|
||||
return True, obj
|
||||
except Exception as e:
|
||||
return False, None
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_by_ids(cls, pids, cols=None):
|
||||
if cols:
|
||||
objs = cls.model.select(*cols)
|
||||
else:
|
||||
objs = cls.model.select()
|
||||
return objs.where(cls.model.id.in_(pids))
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def delete_by_id(cls, pid):
|
||||
return cls.model.delete().where(cls.model.id == pid).execute()
|
||||
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def filter_delete(cls, filters):
|
||||
with DB.atomic():
|
||||
num = cls.model.delete().where(*filters).execute()
|
||||
return num
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def filter_update(cls, filters, update_data):
|
||||
with DB.atomic():
|
||||
cls.model.update(update_data).where(*filters).execute()
|
||||
|
||||
@staticmethod
|
||||
def cut_list(tar_list, n):
|
||||
length = len(tar_list)
|
||||
arr = range(length)
|
||||
result = [tuple(tar_list[x:(x + n)]) for x in arr[::n]]
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def filter_scope_list(cls, in_key, in_filters_list, filters=None, cols=None):
|
||||
in_filters_tuple_list = cls.cut_list(in_filters_list, 20)
|
||||
if not filters:
|
||||
filters = []
|
||||
res_list = []
|
||||
if cols:
|
||||
for i in in_filters_tuple_list:
|
||||
query_records = cls.model.select(*cols).where(getattr(cls.model, in_key).in_(i), *filters)
|
||||
if query_records:
|
||||
res_list.extend([query_record for query_record in query_records])
|
||||
else:
|
||||
for i in in_filters_tuple_list:
|
||||
query_records = cls.model.select().where(getattr(cls.model, in_key).in_(i), *filters)
|
||||
if query_records:
|
||||
res_list.extend([query_record for query_record in query_records])
|
||||
return res_list
|
||||
35
api/db/services/dialog_service.py
Normal file
35
api/db/services/dialog_service.py
Normal file
@ -0,0 +1,35 @@
|
||||
#
|
||||
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import peewee
|
||||
from werkzeug.security import generate_password_hash, check_password_hash
|
||||
|
||||
from web_server.db.db_models import DB, UserTenant
|
||||
from web_server.db.db_models import Dialog, Conversation, DialogKb
|
||||
from web_server.db.services.common_service import CommonService
|
||||
from web_server.utils import get_uuid, get_format_time
|
||||
from web_server.db.db_utils import StatusEnum
|
||||
|
||||
|
||||
class DialogService(CommonService):
|
||||
model = Dialog
|
||||
|
||||
|
||||
class ConversationService(CommonService):
|
||||
model = Conversation
|
||||
|
||||
|
||||
class DialogKbService(CommonService):
|
||||
model = DialogKb
|
||||
96
api/db/services/document_service.py
Normal file
96
api/db/services/document_service.py
Normal file
@ -0,0 +1,96 @@
|
||||
#
|
||||
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from peewee import Expression
|
||||
|
||||
from web_server.db import TenantPermission, FileType
|
||||
from web_server.db.db_models import DB, Knowledgebase, Tenant
|
||||
from web_server.db.db_models import Document
|
||||
from web_server.db.services.common_service import CommonService
|
||||
from web_server.db.services.kb_service import KnowledgebaseService
|
||||
from web_server.db.db_utils import StatusEnum
|
||||
|
||||
|
||||
class DocumentService(CommonService):
|
||||
model = Document
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_by_kb_id(cls, kb_id, page_number, items_per_page,
|
||||
orderby, desc, keywords):
|
||||
if keywords:
|
||||
docs = cls.model.select().where(
|
||||
cls.model.kb_id == kb_id,
|
||||
cls.model.name.like(f"%%{keywords}%%"))
|
||||
else:
|
||||
docs = cls.model.select().where(cls.model.kb_id == kb_id)
|
||||
if desc:
|
||||
docs = docs.order_by(cls.model.getter_by(orderby).desc())
|
||||
else:
|
||||
docs = docs.order_by(cls.model.getter_by(orderby).asc())
|
||||
|
||||
docs = docs.paginate(page_number, items_per_page)
|
||||
|
||||
return list(docs.dicts())
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def insert(cls, doc):
|
||||
if not cls.save(**doc):
|
||||
raise RuntimeError("Database error (Document)!")
|
||||
e, doc = cls.get_by_id(doc["id"])
|
||||
if not e:
|
||||
raise RuntimeError("Database error (Document retrieval)!")
|
||||
e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
|
||||
if not KnowledgebaseService.update_by_id(
|
||||
kb.id, {"doc_num": kb.doc_num + 1}):
|
||||
raise RuntimeError("Database error (Knowledgebase)!")
|
||||
return doc
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_newly_uploaded(cls, tm, mod, comm, items_per_page=64):
|
||||
fields = [cls.model.id, cls.model.kb_id, cls.model.parser_id, cls.model.name, cls.model.location, cls.model.size, Knowledgebase.tenant_id, Tenant.embd_id, Tenant.img2txt_id, cls.model.update_time]
|
||||
docs = cls.model.select(*fields) \
|
||||
.join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) \
|
||||
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))\
|
||||
.where(
|
||||
cls.model.status == StatusEnum.VALID.value,
|
||||
~(cls.model.type == FileType.VIRTUAL.value),
|
||||
cls.model.progress == 0,
|
||||
cls.model.update_time >= tm,
|
||||
(Expression(cls.model.create_time, "%%", comm) == mod))\
|
||||
.order_by(cls.model.update_time.asc())\
|
||||
.paginate(1, items_per_page)
|
||||
return list(docs.dicts())
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def increment_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation):
|
||||
num = cls.model.update(token_num=cls.model.token_num + token_num,
|
||||
chunk_num=cls.model.chunk_num + chunk_num,
|
||||
process_duation=cls.model.process_duation+duation).where(
|
||||
cls.model.id == doc_id).execute()
|
||||
if num == 0:raise LookupError("Document not found which is supposed to be there")
|
||||
num = Knowledgebase.update(token_num=Knowledgebase.token_num+token_num, chunk_num=Knowledgebase.chunk_num+chunk_num).where(Knowledgebase.id==kb_id).execute()
|
||||
return num
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_tenant_id(cls, doc_id):
|
||||
docs = cls.model.select(Knowledgebase.tenant_id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(cls.model.id == doc_id, Knowledgebase.status==StatusEnum.VALID.value)
|
||||
docs = docs.dicts()
|
||||
if not docs:return
|
||||
return docs[0]["tenant_id"]
|
||||
70
api/db/services/kb_service.py
Normal file
70
api/db/services/kb_service.py
Normal file
@ -0,0 +1,70 @@
|
||||
#
|
||||
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import peewee
|
||||
from werkzeug.security import generate_password_hash, check_password_hash
|
||||
|
||||
from web_server.db import TenantPermission
|
||||
from web_server.db.db_models import DB, UserTenant, Tenant
|
||||
from web_server.db.db_models import Knowledgebase
|
||||
from web_server.db.services.common_service import CommonService
|
||||
from web_server.utils import get_uuid, get_format_time
|
||||
from web_server.db.db_utils import StatusEnum
|
||||
|
||||
|
||||
class KnowledgebaseService(CommonService):
|
||||
model = Knowledgebase
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_by_tenant_ids(cls, joined_tenant_ids, user_id,
|
||||
page_number, items_per_page, orderby, desc):
|
||||
kbs = cls.model.select().where(
|
||||
((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission ==
|
||||
TenantPermission.TEAM.value)) | (cls.model.tenant_id == user_id))
|
||||
& (cls.model.status == StatusEnum.VALID.value)
|
||||
)
|
||||
if desc:
|
||||
kbs = kbs.order_by(cls.model.getter_by(orderby).desc())
|
||||
else:
|
||||
kbs = kbs.order_by(cls.model.getter_by(orderby).asc())
|
||||
|
||||
kbs = kbs.paginate(page_number, items_per_page)
|
||||
|
||||
return list(kbs.dicts())
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_detail(cls, kb_id):
|
||||
fields = [
|
||||
cls.model.id,
|
||||
Tenant.embd_id,
|
||||
cls.model.avatar,
|
||||
cls.model.name,
|
||||
cls.model.description,
|
||||
cls.model.permission,
|
||||
cls.model.doc_num,
|
||||
cls.model.token_num,
|
||||
cls.model.chunk_num,
|
||||
cls.model.parser_id]
|
||||
kbs = cls.model.select(*fields).join(Tenant, on=((Tenant.id == cls.model.tenant_id)&(Tenant.status== StatusEnum.VALID.value))).where(
|
||||
(cls.model.id == kb_id),
|
||||
(cls.model.status == StatusEnum.VALID.value)
|
||||
)
|
||||
if not kbs:
|
||||
return
|
||||
d = kbs[0].to_dict()
|
||||
d["embd_id"] = kbs[0].tenant.embd_id
|
||||
return d
|
||||
31
api/db/services/knowledgebase_service.py
Normal file
31
api/db/services/knowledgebase_service.py
Normal file
@ -0,0 +1,31 @@
|
||||
#
|
||||
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import peewee
|
||||
from werkzeug.security import generate_password_hash, check_password_hash
|
||||
|
||||
from web_server.db.db_models import DB, UserTenant
|
||||
from web_server.db.db_models import Knowledgebase, Document
|
||||
from web_server.db.services.common_service import CommonService
|
||||
from web_server.utils import get_uuid, get_format_time
|
||||
from web_server.db.db_utils import StatusEnum
|
||||
|
||||
|
||||
class KnowledgebaseService(CommonService):
|
||||
model = Knowledgebase
|
||||
|
||||
|
||||
class DocumentService(CommonService):
|
||||
model = Document
|
||||
76
api/db/services/llm_service.py
Normal file
76
api/db/services/llm_service.py
Normal file
@ -0,0 +1,76 @@
|
||||
#
|
||||
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import peewee
|
||||
from werkzeug.security import generate_password_hash, check_password_hash
|
||||
|
||||
from rag.llm import EmbeddingModel, CvModel
|
||||
from web_server.db import LLMType
|
||||
from web_server.db.db_models import DB, UserTenant
|
||||
from web_server.db.db_models import LLMFactories, LLM, TenantLLM
|
||||
from web_server.db.services.common_service import CommonService
|
||||
from web_server.db.db_utils import StatusEnum
|
||||
|
||||
|
||||
class LLMFactoriesService(CommonService):
|
||||
model = LLMFactories
|
||||
|
||||
|
||||
class LLMService(CommonService):
|
||||
model = LLM
|
||||
|
||||
|
||||
class TenantLLMService(CommonService):
|
||||
model = TenantLLM
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_api_key(cls, tenant_id, model_type):
|
||||
objs = cls.query(tenant_id=tenant_id, model_type=model_type)
|
||||
if objs and len(objs)>0 and objs[0].llm_name:
|
||||
return objs[0]
|
||||
|
||||
fields = [LLM.llm_name, cls.model.llm_factory, cls.model.api_key]
|
||||
objs = cls.model.select(*fields).join(LLM, on=(LLM.fid == cls.model.llm_factory)).where(
|
||||
(cls.model.tenant_id == tenant_id),
|
||||
(cls.model.model_type == model_type),
|
||||
(LLM.status == StatusEnum.VALID)
|
||||
)
|
||||
|
||||
if not objs:return
|
||||
return objs[0]
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_my_llms(cls, tenant_id):
|
||||
fields = [cls.model.llm_factory, LLMFactories.logo, LLMFactories.tags, cls.model.model_type, cls.model.llm_name]
|
||||
objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory==LLMFactories.name)).where(cls.model.tenant_id==tenant_id).dicts()
|
||||
|
||||
return list(objs)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def model_instance(cls, tenant_id, llm_type):
|
||||
model_config = cls.get_api_key(tenant_id, model_type=LLMType.EMBEDDING)
|
||||
if not model_config:
|
||||
model_config = {"llm_factory": "local", "api_key": "", "llm_name": ""}
|
||||
else:
|
||||
model_config = model_config[0].to_dict()
|
||||
if llm_type == LLMType.EMBEDDING:
|
||||
if model_config["llm_factory"] not in EmbeddingModel: return
|
||||
return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"])
|
||||
if llm_type == LLMType.IMAGE2TEXT:
|
||||
if model_config["llm_factory"] not in CvModel: return
|
||||
return CvModel[model_config.llm_factory](model_config["api_key"], model_config["llm_name"])
|
||||
105
api/db/services/user_service.py
Normal file
105
api/db/services/user_service.py
Normal file
@ -0,0 +1,105 @@
|
||||
#
|
||||
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import peewee
|
||||
from werkzeug.security import generate_password_hash, check_password_hash
|
||||
|
||||
from web_server.db import UserTenantRole
|
||||
from web_server.db.db_models import DB, UserTenant
|
||||
from web_server.db.db_models import User, Tenant
|
||||
from web_server.db.services.common_service import CommonService
|
||||
from web_server.utils import get_uuid, get_format_time
|
||||
from web_server.db.db_utils import StatusEnum
|
||||
|
||||
|
||||
class UserService(CommonService):
|
||||
model = User
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def filter_by_id(cls, user_id):
|
||||
try:
|
||||
user = cls.model.select().where(cls.model.id == user_id).get()
|
||||
return user
|
||||
except peewee.DoesNotExist:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def query_user(cls, email, password):
|
||||
user = cls.model.select().where((cls.model.email == email),
|
||||
(cls.model.status == StatusEnum.VALID.value)).first()
|
||||
if user and check_password_hash(str(user.password), password):
|
||||
return user
|
||||
else:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def save(cls, **kwargs):
|
||||
if "id" not in kwargs:
|
||||
kwargs["id"] = get_uuid()
|
||||
if "password" in kwargs:
|
||||
kwargs["password"] = generate_password_hash(str(kwargs["password"]))
|
||||
obj = cls.model(**kwargs).save(force_insert=True)
|
||||
return obj
|
||||
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def delete_user(cls, user_ids, update_user_dict):
|
||||
with DB.atomic():
|
||||
cls.model.update({"status": 0}).where(cls.model.id.in_(user_ids)).execute()
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def update_user(cls, user_id, user_dict):
|
||||
date_time = get_format_time()
|
||||
with DB.atomic():
|
||||
if user_dict:
|
||||
user_dict["update_time"] = date_time
|
||||
cls.model.update(user_dict).where(cls.model.id == user_id).execute()
|
||||
|
||||
|
||||
class TenantService(CommonService):
|
||||
model = Tenant
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_by_user_id(cls, user_id):
|
||||
fields = [cls.model.id.alias("tenant_id"), cls.model.name, cls.model.llm_id, cls.model.embd_id, cls.model.asr_id, cls.model.img2txt_id, cls.model.parser_ids, UserTenant.role]
|
||||
return list(cls.model.select(*fields)\
|
||||
.join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id==user_id) & (UserTenant.status == StatusEnum.VALID.value)))\
|
||||
.where(cls.model.status == StatusEnum.VALID.value).dicts())
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_joined_tenants_by_user_id(cls, user_id):
|
||||
fields = [cls.model.id.alias("tenant_id"), cls.model.name, cls.model.llm_id, cls.model.embd_id, cls.model.asr_id, cls.model.img2txt_id, UserTenant.role]
|
||||
return list(cls.model.select(*fields)\
|
||||
.join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id==user_id) & (UserTenant.status == StatusEnum.VALID.value) & (UserTenant.role==UserTenantRole.NORMAL.value)))\
|
||||
.where(cls.model.status == StatusEnum.VALID.value).dicts())
|
||||
|
||||
|
||||
class UserTenantService(CommonService):
|
||||
model = UserTenant
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def save(cls, **kwargs):
|
||||
if "id" not in kwargs:
|
||||
kwargs["id"] = get_uuid()
|
||||
obj = cls.model(**kwargs).save(force_insert=True)
|
||||
return obj
|
||||
10
api/errors/__init__.py
Normal file
10
api/errors/__init__.py
Normal file
@ -0,0 +1,10 @@
|
||||
from .general_error import *
|
||||
|
||||
|
||||
class RagFlowError(Exception):
|
||||
message = 'Unknown Rag Flow Error'
|
||||
|
||||
def __init__(self, message=None, *args, **kwargs):
|
||||
message = str(message) if message is not None else self.message
|
||||
message = message.format(*args, **kwargs)
|
||||
super().__init__(message)
|
||||
13
api/errors/error_services.py
Normal file
13
api/errors/error_services.py
Normal file
@ -0,0 +1,13 @@
|
||||
from web_server.errors import RagFlowError
|
||||
|
||||
__all__ = ['ServicesError', 'ServiceNotSupported', 'ZooKeeperNotConfigured',
|
||||
'MissingZooKeeperUsernameOrPassword', 'ZooKeeperBackendError']
|
||||
|
||||
|
||||
class ServicesError(RagFlowError):
|
||||
message = 'Unknown services error'
|
||||
|
||||
|
||||
class ServiceNotSupported(ServicesError):
|
||||
message = 'The service {service_name} is not supported'
|
||||
|
||||
21
api/errors/general_error.py
Normal file
21
api/errors/general_error.py
Normal file
@ -0,0 +1,21 @@
|
||||
#
|
||||
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
class ParameterError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class PassError(Exception):
|
||||
pass
|
||||
BIN
api/flask_session/2029240f6d1128be89ddc32729463129
Normal file
BIN
api/flask_session/2029240f6d1128be89ddc32729463129
Normal file
Binary file not shown.
57
api/hook/__init__.py
Normal file
57
api/hook/__init__.py
Normal file
@ -0,0 +1,57 @@
|
||||
import importlib
|
||||
|
||||
from web_server.hook.common.parameters import SignatureParameters, AuthenticationParameters, \
|
||||
SignatureReturn, AuthenticationReturn, PermissionReturn, ClientAuthenticationReturn, ClientAuthenticationParameters
|
||||
from web_server.settings import HOOK_MODULE, stat_logger,RetCode
|
||||
|
||||
|
||||
class HookManager:
|
||||
SITE_SIGNATURE = []
|
||||
SITE_AUTHENTICATION = []
|
||||
CLIENT_AUTHENTICATION = []
|
||||
PERMISSION_CHECK = []
|
||||
|
||||
@staticmethod
|
||||
def init():
|
||||
if HOOK_MODULE is not None:
|
||||
for modules in HOOK_MODULE.values():
|
||||
for module in modules.split(";"):
|
||||
try:
|
||||
importlib.import_module(module)
|
||||
except Exception as e:
|
||||
stat_logger.exception(e)
|
||||
|
||||
@staticmethod
|
||||
def register_site_signature_hook(func):
|
||||
HookManager.SITE_SIGNATURE.append(func)
|
||||
|
||||
@staticmethod
|
||||
def register_site_authentication_hook(func):
|
||||
HookManager.SITE_AUTHENTICATION.append(func)
|
||||
|
||||
@staticmethod
|
||||
def register_client_authentication_hook(func):
|
||||
HookManager.CLIENT_AUTHENTICATION.append(func)
|
||||
|
||||
@staticmethod
|
||||
def register_permission_check_hook(func):
|
||||
HookManager.PERMISSION_CHECK.append(func)
|
||||
|
||||
@staticmethod
|
||||
def client_authentication(parm: ClientAuthenticationParameters) -> ClientAuthenticationReturn:
|
||||
if HookManager.CLIENT_AUTHENTICATION:
|
||||
return HookManager.CLIENT_AUTHENTICATION[0](parm)
|
||||
return ClientAuthenticationReturn()
|
||||
|
||||
@staticmethod
|
||||
def site_signature(parm: SignatureParameters) -> SignatureReturn:
|
||||
if HookManager.SITE_SIGNATURE:
|
||||
return HookManager.SITE_SIGNATURE[0](parm)
|
||||
return SignatureReturn()
|
||||
|
||||
@staticmethod
|
||||
def site_authentication(parm: AuthenticationParameters) -> AuthenticationReturn:
|
||||
if HookManager.SITE_AUTHENTICATION:
|
||||
return HookManager.SITE_AUTHENTICATION[0](parm)
|
||||
return AuthenticationReturn()
|
||||
|
||||
29
api/hook/api/client_authentication.py
Normal file
29
api/hook/api/client_authentication.py
Normal file
@ -0,0 +1,29 @@
|
||||
import requests
|
||||
|
||||
from web_server.db.service_registry import ServiceRegistry
|
||||
from web_server.settings import RegistryServiceName
|
||||
from web_server.hook import HookManager
|
||||
from web_server.hook.common.parameters import ClientAuthenticationParameters, ClientAuthenticationReturn
|
||||
from web_server.settings import HOOK_SERVER_NAME
|
||||
|
||||
|
||||
@HookManager.register_client_authentication_hook
|
||||
def authentication(parm: ClientAuthenticationParameters) -> ClientAuthenticationReturn:
|
||||
service_list = ServiceRegistry.load_service(
|
||||
server_name=HOOK_SERVER_NAME,
|
||||
service_name=RegistryServiceName.CLIENT_AUTHENTICATION.value
|
||||
)
|
||||
if not service_list:
|
||||
raise Exception(f"client authentication error: no found server"
|
||||
f" {HOOK_SERVER_NAME} service client_authentication")
|
||||
service = service_list[0]
|
||||
response = getattr(requests, service.f_method.lower(), None)(
|
||||
url=service.f_url,
|
||||
json=parm.to_dict()
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise Exception(
|
||||
f"client authentication error: request authentication url failed, status code {response.status_code}")
|
||||
elif response.json().get("code") != 0:
|
||||
return ClientAuthenticationReturn(code=response.json().get("code"), message=response.json().get("msg"))
|
||||
return ClientAuthenticationReturn()
|
||||
25
api/hook/api/permission.py
Normal file
25
api/hook/api/permission.py
Normal file
@ -0,0 +1,25 @@
|
||||
import requests
|
||||
|
||||
from web_server.db.service_registry import ServiceRegistry
|
||||
from web_server.settings import RegistryServiceName
|
||||
from web_server.hook import HookManager
|
||||
from web_server.hook.common.parameters import PermissionCheckParameters, PermissionReturn
|
||||
from web_server.settings import HOOK_SERVER_NAME
|
||||
|
||||
|
||||
@HookManager.register_permission_check_hook
|
||||
def permission(parm: PermissionCheckParameters) -> PermissionReturn:
|
||||
service_list = ServiceRegistry.load_service(server_name=HOOK_SERVER_NAME, service_name=RegistryServiceName.PERMISSION_CHECK.value)
|
||||
if not service_list:
|
||||
raise Exception(f"permission check error: no found server {HOOK_SERVER_NAME} service permission")
|
||||
service = service_list[0]
|
||||
response = getattr(requests, service.f_method.lower(), None)(
|
||||
url=service.f_url,
|
||||
json=parm.to_dict()
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise Exception(
|
||||
f"permission check error: request permission url failed, status code {response.status_code}")
|
||||
elif response.json().get("code") != 0:
|
||||
return PermissionReturn(code=response.json().get("code"), message=response.json().get("msg"))
|
||||
return PermissionReturn()
|
||||
49
api/hook/api/site_authentication.py
Normal file
49
api/hook/api/site_authentication.py
Normal file
@ -0,0 +1,49 @@
|
||||
import requests
|
||||
|
||||
from web_server.db.service_registry import ServiceRegistry
|
||||
from web_server.settings import RegistryServiceName
|
||||
from web_server.hook import HookManager
|
||||
from web_server.hook.common.parameters import SignatureParameters, AuthenticationParameters, AuthenticationReturn,\
|
||||
SignatureReturn
|
||||
from web_server.settings import HOOK_SERVER_NAME, PARTY_ID
|
||||
|
||||
|
||||
@HookManager.register_site_signature_hook
|
||||
def signature(parm: SignatureParameters) -> SignatureReturn:
|
||||
service_list = ServiceRegistry.load_service(server_name=HOOK_SERVER_NAME, service_name=RegistryServiceName.SIGNATURE.value)
|
||||
if not service_list:
|
||||
raise Exception(f"signature error: no found server {HOOK_SERVER_NAME} service signature")
|
||||
service = service_list[0]
|
||||
response = getattr(requests, service.f_method.lower(), None)(
|
||||
url=service.f_url,
|
||||
json=parm.to_dict()
|
||||
)
|
||||
if response.status_code == 200:
|
||||
if response.json().get("code") == 0:
|
||||
return SignatureReturn(site_signature=response.json().get("data"))
|
||||
else:
|
||||
raise Exception(f"signature error: request signature url failed, result: {response.json()}")
|
||||
else:
|
||||
raise Exception(f"signature error: request signature url failed, status code {response.status_code}")
|
||||
|
||||
|
||||
@HookManager.register_site_authentication_hook
|
||||
def authentication(parm: AuthenticationParameters) -> AuthenticationReturn:
|
||||
if not parm.src_party_id or str(parm.src_party_id) == "0":
|
||||
parm.src_party_id = PARTY_ID
|
||||
service_list = ServiceRegistry.load_service(server_name=HOOK_SERVER_NAME,
|
||||
service_name=RegistryServiceName.SITE_AUTHENTICATION.value)
|
||||
if not service_list:
|
||||
raise Exception(
|
||||
f"site authentication error: no found server {HOOK_SERVER_NAME} service site_authentication")
|
||||
service = service_list[0]
|
||||
response = getattr(requests, service.f_method.lower(), None)(
|
||||
url=service.f_url,
|
||||
json=parm.to_dict()
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise Exception(
|
||||
f"site authentication error: request site_authentication url failed, status code {response.status_code}")
|
||||
elif response.json().get("code") != 0:
|
||||
return AuthenticationReturn(code=response.json().get("code"), message=response.json().get("msg"))
|
||||
return AuthenticationReturn()
|
||||
56
api/hook/common/parameters.py
Normal file
56
api/hook/common/parameters.py
Normal file
@ -0,0 +1,56 @@
|
||||
from web_server.settings import RetCode
|
||||
|
||||
|
||||
class ParametersBase:
|
||||
def to_dict(self):
|
||||
d = {}
|
||||
for k, v in self.__dict__.items():
|
||||
d[k] = v
|
||||
return d
|
||||
|
||||
|
||||
class ClientAuthenticationParameters(ParametersBase):
|
||||
def __init__(self, full_path, headers, form, data, json):
|
||||
self.full_path = full_path
|
||||
self.headers = headers
|
||||
self.form = form
|
||||
self.data = data
|
||||
self.json = json
|
||||
|
||||
|
||||
class ClientAuthenticationReturn(ParametersBase):
|
||||
def __init__(self, code=RetCode.SUCCESS, message="success"):
|
||||
self.code = code
|
||||
self.message = message
|
||||
|
||||
|
||||
class SignatureParameters(ParametersBase):
|
||||
def __init__(self, party_id, body):
|
||||
self.party_id = party_id
|
||||
self.body = body
|
||||
|
||||
|
||||
class SignatureReturn(ParametersBase):
|
||||
def __init__(self, code=RetCode.SUCCESS, site_signature=None):
|
||||
self.code = code
|
||||
self.site_signature = site_signature
|
||||
|
||||
|
||||
class AuthenticationParameters(ParametersBase):
|
||||
def __init__(self, site_signature, body):
|
||||
self.site_signature = site_signature
|
||||
self.body = body
|
||||
|
||||
|
||||
class AuthenticationReturn(ParametersBase):
|
||||
def __init__(self, code=RetCode.SUCCESS, message="success"):
|
||||
self.code = code
|
||||
self.message = message
|
||||
|
||||
|
||||
class PermissionReturn(ParametersBase):
|
||||
def __init__(self, code=RetCode.SUCCESS, message="success"):
|
||||
self.code = code
|
||||
self.message = message
|
||||
|
||||
|
||||
80
api/ragflow_server.py
Normal file
80
api/ragflow_server.py
Normal file
@ -0,0 +1,80 @@
|
||||
#
|
||||
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# init env. must be the first import
|
||||
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
from werkzeug.serving import run_simple
|
||||
|
||||
from web_server.apps import app
|
||||
from web_server.db.runtime_config import RuntimeConfig
|
||||
from web_server.hook import HookManager
|
||||
from web_server.settings import (
|
||||
HOST, HTTP_PORT, access_logger, database_logger, stat_logger,
|
||||
)
|
||||
from web_server import utils
|
||||
|
||||
from web_server.db.db_models import init_database_tables as init_web_db
|
||||
from web_server.db.init_data import init_web_data
|
||||
from web_server.versions import get_versions
|
||||
|
||||
if __name__ == '__main__':
|
||||
stat_logger.info(
|
||||
f'project base: {utils.file_utils.get_project_base_directory()}'
|
||||
)
|
||||
|
||||
# init db
|
||||
init_web_db()
|
||||
init_web_data()
|
||||
# init runtime config
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--version', default=False, help="rag flow version", action='store_true')
|
||||
parser.add_argument('--debug', default=False, help="debug mode", action='store_true')
|
||||
args = parser.parse_args()
|
||||
if args.version:
|
||||
print(get_versions())
|
||||
sys.exit(0)
|
||||
|
||||
RuntimeConfig.DEBUG = args.debug
|
||||
if RuntimeConfig.DEBUG:
|
||||
stat_logger.info("run on debug mode")
|
||||
|
||||
RuntimeConfig.init_env()
|
||||
RuntimeConfig.init_config(JOB_SERVER_HOST=HOST, HTTP_PORT=HTTP_PORT)
|
||||
|
||||
HookManager.init()
|
||||
|
||||
peewee_logger = logging.getLogger('peewee')
|
||||
peewee_logger.propagate = False
|
||||
# rag_arch.common.log.ROpenHandler
|
||||
peewee_logger.addHandler(database_logger.handlers[0])
|
||||
peewee_logger.setLevel(database_logger.level)
|
||||
|
||||
# start http server
|
||||
try:
|
||||
stat_logger.info("RAG Flow http server start...")
|
||||
werkzeug_logger = logging.getLogger("werkzeug")
|
||||
for h in access_logger.handlers:
|
||||
werkzeug_logger.addHandler(h)
|
||||
run_simple(hostname=HOST, port=HTTP_PORT, application=app, threaded=True, use_reloader=RuntimeConfig.DEBUG, use_debugger=RuntimeConfig.DEBUG)
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
os.kill(os.getpid(), signal.SIGKILL)
|
||||
156
api/settings.py
Normal file
156
api/settings.py
Normal file
@ -0,0 +1,156 @@
|
||||
#
|
||||
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import os
|
||||
|
||||
from enum import IntEnum, Enum
|
||||
|
||||
from web_server.utils import get_base_config,decrypt_database_config
|
||||
from web_server.utils.file_utils import get_project_base_directory
|
||||
from web_server.utils.log_utils import LoggerFactory, getLogger
|
||||
|
||||
|
||||
# Server
|
||||
API_VERSION = "v1"
|
||||
RAG_FLOW_SERVICE_NAME = "ragflow"
|
||||
SERVER_MODULE = "rag_flow_server.py"
|
||||
TEMP_DIRECTORY = os.path.join(get_project_base_directory(), "temp")
|
||||
RAG_FLOW_CONF_PATH = os.path.join(get_project_base_directory(), "conf")
|
||||
|
||||
SUBPROCESS_STD_LOG_NAME = "std.log"
|
||||
|
||||
ERROR_REPORT = True
|
||||
ERROR_REPORT_WITH_PATH = False
|
||||
|
||||
MAX_TIMESTAMP_INTERVAL = 60
|
||||
SESSION_VALID_PERIOD = 7 * 24 * 60 * 60 * 1000
|
||||
|
||||
REQUEST_TRY_TIMES = 3
|
||||
REQUEST_WAIT_SEC = 2
|
||||
REQUEST_MAX_WAIT_SEC = 300
|
||||
|
||||
USE_REGISTRY = get_base_config("use_registry")
|
||||
|
||||
LLM = get_base_config("llm", {})
|
||||
CHAT_MDL = LLM.get("chat_model", "gpt-3.5-turbo")
|
||||
EMBEDDING_MDL = LLM.get("embedding_model", "text-embedding-ada-002")
|
||||
ASR_MDL = LLM.get("asr_model", "whisper-1")
|
||||
PARSERS = LLM.get("parsers", "General,Resume,Laws,Product Instructions,Books,Paper,Q&A,Programming Code,Power Point,Research Report")
|
||||
IMAGE2TEXT_MDL = LLM.get("image2text_model", "gpt-4-vision-preview")
|
||||
|
||||
# distribution
|
||||
DEPENDENT_DISTRIBUTION = get_base_config("dependent_distribution", False)
|
||||
RAG_FLOW_UPDATE_CHECK = False
|
||||
|
||||
HOST = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("host", "127.0.0.1")
|
||||
HTTP_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("http_port")
|
||||
|
||||
SECRET_KEY = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("secret_key", "infiniflow")
|
||||
TOKEN_EXPIRE_IN = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("token_expires_in", 3600)
|
||||
|
||||
NGINX_HOST = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("nginx", {}).get("host") or HOST
|
||||
NGINX_HTTP_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("nginx", {}).get("http_port") or HTTP_PORT
|
||||
|
||||
RANDOM_INSTANCE_ID = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("random_instance_id", False)
|
||||
|
||||
PROXY = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("proxy")
|
||||
PROXY_PROTOCOL = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("protocol")
|
||||
|
||||
DATABASE = decrypt_database_config()
|
||||
|
||||
# Logger
|
||||
LoggerFactory.set_directory(os.path.join(get_project_base_directory(), "logs", "web_server"))
|
||||
# {CRITICAL: 50, FATAL:50, ERROR:40, WARNING:30, WARN:30, INFO:20, DEBUG:10, NOTSET:0}
|
||||
LoggerFactory.LEVEL = 10
|
||||
|
||||
stat_logger = getLogger("stat")
|
||||
access_logger = getLogger("access")
|
||||
database_logger = getLogger("database")
|
||||
|
||||
# Switch
|
||||
# upload
|
||||
UPLOAD_DATA_FROM_CLIENT = True
|
||||
|
||||
# authentication
|
||||
AUTHENTICATION_CONF = get_base_config("authentication", {})
|
||||
|
||||
# client
|
||||
CLIENT_AUTHENTICATION = AUTHENTICATION_CONF.get("client", {}).get("switch", False)
|
||||
HTTP_APP_KEY = AUTHENTICATION_CONF.get("client", {}).get("http_app_key")
|
||||
GITHUB_OAUTH = get_base_config("oauth", {}).get("github")
|
||||
WECHAT_OAUTH = get_base_config("oauth", {}).get("wechat")
|
||||
|
||||
# site
|
||||
SITE_AUTHENTICATION = AUTHENTICATION_CONF.get("site", {}).get("switch", False)
|
||||
|
||||
# permission
|
||||
PERMISSION_CONF = get_base_config("permission", {})
|
||||
PERMISSION_SWITCH = PERMISSION_CONF.get("switch")
|
||||
COMPONENT_PERMISSION = PERMISSION_CONF.get("component")
|
||||
DATASET_PERMISSION = PERMISSION_CONF.get("dataset")
|
||||
|
||||
HOOK_MODULE = get_base_config("hook_module")
|
||||
HOOK_SERVER_NAME = get_base_config("hook_server_name")
|
||||
|
||||
ENABLE_MODEL_STORE = get_base_config('enable_model_store', False)
|
||||
# authentication
|
||||
USE_AUTHENTICATION = False
|
||||
USE_DATA_AUTHENTICATION = False
|
||||
AUTOMATIC_AUTHORIZATION_OUTPUT_DATA = True
|
||||
USE_DEFAULT_TIMEOUT = False
|
||||
AUTHENTICATION_DEFAULT_TIMEOUT = 30 * 24 * 60 * 60 # s
|
||||
PRIVILEGE_COMMAND_WHITELIST = []
|
||||
CHECK_NODES_IDENTITY = False
|
||||
|
||||
class CustomEnum(Enum):
|
||||
@classmethod
|
||||
def valid(cls, value):
|
||||
try:
|
||||
cls(value)
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def values(cls):
|
||||
return [member.value for member in cls.__members__.values()]
|
||||
|
||||
@classmethod
|
||||
def names(cls):
|
||||
return [member.name for member in cls.__members__.values()]
|
||||
|
||||
|
||||
class PythonDependenceName(CustomEnum):
|
||||
Rag_Source_Code = "python"
|
||||
Python_Env = "miniconda"
|
||||
|
||||
|
||||
class ModelStorage(CustomEnum):
|
||||
REDIS = "redis"
|
||||
MYSQL = "mysql"
|
||||
|
||||
|
||||
class RetCode(IntEnum, CustomEnum):
|
||||
SUCCESS = 0
|
||||
NOT_EFFECTIVE = 10
|
||||
EXCEPTION_ERROR = 100
|
||||
ARGUMENT_ERROR = 101
|
||||
DATA_ERROR = 102
|
||||
OPERATING_ERROR = 103
|
||||
CONNECTION_ERROR = 105
|
||||
RUNNING = 106
|
||||
PERMISSION_ERROR = 108
|
||||
AUTHENTICATION_ERROR = 109
|
||||
SERVER_ERROR = 500
|
||||
321
api/utils/__init__.py
Normal file
321
api/utils/__init__.py
Normal file
@ -0,0 +1,321 @@
|
||||
#
|
||||
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import base64
|
||||
import datetime
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
import socket
|
||||
import time
|
||||
import uuid
|
||||
import requests
|
||||
from enum import Enum, IntEnum
|
||||
import importlib
|
||||
from Cryptodome.PublicKey import RSA
|
||||
from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5
|
||||
|
||||
from filelock import FileLock
|
||||
|
||||
from . import file_utils
|
||||
|
||||
SERVICE_CONF = "service_conf.yaml"
|
||||
|
||||
def conf_realpath(conf_name):
|
||||
conf_path = f"conf/{conf_name}"
|
||||
return os.path.join(file_utils.get_project_base_directory(), conf_path)
|
||||
|
||||
def get_base_config(key, default=None, conf_name=SERVICE_CONF) -> dict:
|
||||
local_config = {}
|
||||
local_path = conf_realpath(f'local.{conf_name}')
|
||||
if default is None:
|
||||
default = os.environ.get(key.upper())
|
||||
|
||||
if os.path.exists(local_path):
|
||||
local_config = file_utils.load_yaml_conf(local_path)
|
||||
if not isinstance(local_config, dict):
|
||||
raise ValueError(f'Invalid config file: "{local_path}".')
|
||||
|
||||
if key is not None and key in local_config:
|
||||
return local_config[key]
|
||||
|
||||
config_path = conf_realpath(conf_name)
|
||||
config = file_utils.load_yaml_conf(config_path)
|
||||
|
||||
if not isinstance(config, dict):
|
||||
raise ValueError(f'Invalid config file: "{config_path}".')
|
||||
|
||||
config.update(local_config)
|
||||
return config.get(key, default) if key is not None else config
|
||||
|
||||
|
||||
use_deserialize_safe_module = get_base_config('use_deserialize_safe_module', False)
|
||||
|
||||
|
||||
class CoordinationCommunicationProtocol(object):
|
||||
HTTP = "http"
|
||||
GRPC = "grpc"
|
||||
|
||||
|
||||
class BaseType:
|
||||
def to_dict(self):
|
||||
return dict([(k.lstrip("_"), v) for k, v in self.__dict__.items()])
|
||||
|
||||
def to_dict_with_type(self):
|
||||
def _dict(obj):
|
||||
module = None
|
||||
if issubclass(obj.__class__, BaseType):
|
||||
data = {}
|
||||
for attr, v in obj.__dict__.items():
|
||||
k = attr.lstrip("_")
|
||||
data[k] = _dict(v)
|
||||
module = obj.__module__
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
data = []
|
||||
for i, vv in enumerate(obj):
|
||||
data.append(_dict(vv))
|
||||
elif isinstance(obj, dict):
|
||||
data = {}
|
||||
for _k, vv in obj.items():
|
||||
data[_k] = _dict(vv)
|
||||
else:
|
||||
data = obj
|
||||
return {"type": obj.__class__.__name__, "data": data, "module": module}
|
||||
return _dict(self)
|
||||
|
||||
|
||||
class CustomJSONEncoder(json.JSONEncoder):
|
||||
def __init__(self, **kwargs):
|
||||
self._with_type = kwargs.pop("with_type", False)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def default(self, obj):
|
||||
if isinstance(obj, datetime.datetime):
|
||||
return obj.strftime('%Y-%m-%d %H:%M:%S')
|
||||
elif isinstance(obj, datetime.date):
|
||||
return obj.strftime('%Y-%m-%d')
|
||||
elif isinstance(obj, datetime.timedelta):
|
||||
return str(obj)
|
||||
elif issubclass(type(obj), Enum) or issubclass(type(obj), IntEnum):
|
||||
return obj.value
|
||||
elif isinstance(obj, set):
|
||||
return list(obj)
|
||||
elif issubclass(type(obj), BaseType):
|
||||
if not self._with_type:
|
||||
return obj.to_dict()
|
||||
else:
|
||||
return obj.to_dict_with_type()
|
||||
elif isinstance(obj, type):
|
||||
return obj.__name__
|
||||
else:
|
||||
return json.JSONEncoder.default(self, obj)
|
||||
|
||||
|
||||
def rag_uuid():
|
||||
return uuid.uuid1().hex
|
||||
|
||||
|
||||
def string_to_bytes(string):
|
||||
return string if isinstance(string, bytes) else string.encode(encoding="utf-8")
|
||||
|
||||
|
||||
def bytes_to_string(byte):
|
||||
return byte.decode(encoding="utf-8")
|
||||
|
||||
|
||||
def json_dumps(src, byte=False, indent=None, with_type=False):
|
||||
dest = json.dumps(src, indent=indent, cls=CustomJSONEncoder, with_type=with_type)
|
||||
if byte:
|
||||
dest = string_to_bytes(dest)
|
||||
return dest
|
||||
|
||||
|
||||
def json_loads(src, object_hook=None, object_pairs_hook=None):
|
||||
if isinstance(src, bytes):
|
||||
src = bytes_to_string(src)
|
||||
return json.loads(src, object_hook=object_hook, object_pairs_hook=object_pairs_hook)
|
||||
|
||||
|
||||
def current_timestamp():
|
||||
return int(time.time() * 1000)
|
||||
|
||||
|
||||
def timestamp_to_date(timestamp, format_string="%Y-%m-%d %H:%M:%S"):
|
||||
if not timestamp:
|
||||
timestamp = time.time()
|
||||
timestamp = int(timestamp) / 1000
|
||||
time_array = time.localtime(timestamp)
|
||||
str_date = time.strftime(format_string, time_array)
|
||||
return str_date
|
||||
|
||||
|
||||
def date_string_to_timestamp(time_str, format_string="%Y-%m-%d %H:%M:%S"):
|
||||
time_array = time.strptime(time_str, format_string)
|
||||
time_stamp = int(time.mktime(time_array) * 1000)
|
||||
return time_stamp
|
||||
|
||||
|
||||
def serialize_b64(src, to_str=False):
|
||||
dest = base64.b64encode(pickle.dumps(src))
|
||||
if not to_str:
|
||||
return dest
|
||||
else:
|
||||
return bytes_to_string(dest)
|
||||
|
||||
|
||||
def deserialize_b64(src):
|
||||
src = base64.b64decode(string_to_bytes(src) if isinstance(src, str) else src)
|
||||
if use_deserialize_safe_module:
|
||||
return restricted_loads(src)
|
||||
return pickle.loads(src)
|
||||
|
||||
|
||||
safe_module = {
|
||||
'numpy',
|
||||
'rag_flow'
|
||||
}
|
||||
|
||||
|
||||
class RestrictedUnpickler(pickle.Unpickler):
|
||||
def find_class(self, module, name):
|
||||
import importlib
|
||||
if module.split('.')[0] in safe_module:
|
||||
_module = importlib.import_module(module)
|
||||
return getattr(_module, name)
|
||||
# Forbid everything else.
|
||||
raise pickle.UnpicklingError("global '%s.%s' is forbidden" %
|
||||
(module, name))
|
||||
|
||||
|
||||
def restricted_loads(src):
|
||||
"""Helper function analogous to pickle.loads()."""
|
||||
return RestrictedUnpickler(io.BytesIO(src)).load()
|
||||
|
||||
|
||||
def get_lan_ip():
|
||||
if os.name != "nt":
|
||||
import fcntl
|
||||
import struct
|
||||
|
||||
def get_interface_ip(ifname):
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
return socket.inet_ntoa(
|
||||
fcntl.ioctl(s.fileno(), 0x8915, struct.pack('256s', string_to_bytes(ifname[:15])))[20:24])
|
||||
|
||||
ip = socket.gethostbyname(socket.getfqdn())
|
||||
if ip.startswith("127.") and os.name != "nt":
|
||||
interfaces = [
|
||||
"bond1",
|
||||
"eth0",
|
||||
"eth1",
|
||||
"eth2",
|
||||
"wlan0",
|
||||
"wlan1",
|
||||
"wifi0",
|
||||
"ath0",
|
||||
"ath1",
|
||||
"ppp0",
|
||||
]
|
||||
for ifname in interfaces:
|
||||
try:
|
||||
ip = get_interface_ip(ifname)
|
||||
break
|
||||
except IOError as e:
|
||||
pass
|
||||
return ip or ''
|
||||
|
||||
def from_dict_hook(in_dict: dict):
|
||||
if "type" in in_dict and "data" in in_dict:
|
||||
if in_dict["module"] is None:
|
||||
return in_dict["data"]
|
||||
else:
|
||||
return getattr(importlib.import_module(in_dict["module"]), in_dict["type"])(**in_dict["data"])
|
||||
else:
|
||||
return in_dict
|
||||
|
||||
|
||||
def decrypt_database_password(password):
|
||||
encrypt_password = get_base_config("encrypt_password", False)
|
||||
encrypt_module = get_base_config("encrypt_module", False)
|
||||
private_key = get_base_config("private_key", None)
|
||||
|
||||
if not password or not encrypt_password:
|
||||
return password
|
||||
|
||||
if not private_key:
|
||||
raise ValueError("No private key")
|
||||
|
||||
module_fun = encrypt_module.split("#")
|
||||
pwdecrypt_fun = getattr(importlib.import_module(module_fun[0]), module_fun[1])
|
||||
|
||||
return pwdecrypt_fun(private_key, password)
|
||||
|
||||
|
||||
def decrypt_database_config(database=None, passwd_key="passwd", name="database"):
|
||||
if not database:
|
||||
database = get_base_config(name, {})
|
||||
|
||||
database[passwd_key] = decrypt_database_password(database[passwd_key])
|
||||
return database
|
||||
|
||||
|
||||
def update_config(key, value, conf_name=SERVICE_CONF):
|
||||
conf_path = conf_realpath(conf_name=conf_name)
|
||||
if not os.path.isabs(conf_path):
|
||||
conf_path = os.path.join(file_utils.get_project_base_directory(), conf_path)
|
||||
|
||||
with FileLock(os.path.join(os.path.dirname(conf_path), ".lock")):
|
||||
config = file_utils.load_yaml_conf(conf_path=conf_path) or {}
|
||||
config[key] = value
|
||||
file_utils.rewrite_yaml_conf(conf_path=conf_path, config=config)
|
||||
|
||||
|
||||
def get_uuid():
|
||||
return uuid.uuid1().hex
|
||||
|
||||
|
||||
def datetime_format(date_time: datetime.datetime) -> datetime.datetime:
|
||||
return datetime.datetime(date_time.year, date_time.month, date_time.day, date_time.hour, date_time.minute, date_time.second)
|
||||
|
||||
|
||||
def get_format_time() -> datetime.datetime:
|
||||
return datetime_format(datetime.datetime.now())
|
||||
|
||||
|
||||
def str2date(date_time: str):
|
||||
return datetime.datetime.strptime(date_time, '%Y-%m-%d')
|
||||
|
||||
|
||||
def elapsed2time(elapsed):
|
||||
seconds = elapsed / 1000
|
||||
minuter, second = divmod(seconds, 60)
|
||||
hour, minuter = divmod(minuter, 60)
|
||||
return '%02d:%02d:%02d' % (hour, minuter, second)
|
||||
|
||||
|
||||
def decrypt(line):
|
||||
file_path = os.path.join(file_utils.get_project_base_directory(), "conf", "private.pem")
|
||||
rsa_key = RSA.importKey(open(file_path).read(), "Welcome")
|
||||
cipher = Cipher_pkcs1_v1_5.new(rsa_key)
|
||||
return cipher.decrypt(base64.b64decode(line), "Fail to decrypt password!").decode('utf-8')
|
||||
|
||||
|
||||
def download_img(url):
|
||||
if not url: return ""
|
||||
response = requests.get(url)
|
||||
return "data:" + \
|
||||
response.headers.get('Content-Type', 'image/jpg') + ";" + \
|
||||
"base64," + base64.b64encode(response.content).decode("utf-8")
|
||||
212
api/utils/api_utils.py
Normal file
212
api/utils/api_utils.py
Normal file
@ -0,0 +1,212 @@
|
||||
#
|
||||
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import json
|
||||
import random
|
||||
import time
|
||||
from functools import wraps
|
||||
from io import BytesIO
|
||||
from flask import (
|
||||
Response, jsonify, send_file,make_response,
|
||||
request as flask_request,
|
||||
)
|
||||
from werkzeug.http import HTTP_STATUS_CODES
|
||||
|
||||
from web_server.utils import json_dumps
|
||||
from web_server.versions import get_rag_version
|
||||
from web_server.settings import RetCode
|
||||
from web_server.settings import (
|
||||
REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC,
|
||||
stat_logger,CLIENT_AUTHENTICATION, HTTP_APP_KEY, SECRET_KEY
|
||||
)
|
||||
import requests
|
||||
import functools
|
||||
from web_server.utils import CustomJSONEncoder
|
||||
from uuid import uuid1
|
||||
from base64 import b64encode
|
||||
from hmac import HMAC
|
||||
from urllib.parse import quote, urlencode
|
||||
|
||||
|
||||
requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSONEncoder)
|
||||
|
||||
|
||||
def request(**kwargs):
|
||||
sess = requests.Session()
|
||||
stream = kwargs.pop('stream', sess.stream)
|
||||
timeout = kwargs.pop('timeout', None)
|
||||
kwargs['headers'] = {k.replace('_', '-').upper(): v for k, v in kwargs.get('headers', {}).items()}
|
||||
prepped = requests.Request(**kwargs).prepare()
|
||||
|
||||
if CLIENT_AUTHENTICATION and HTTP_APP_KEY and SECRET_KEY:
|
||||
timestamp = str(round(time() * 1000))
|
||||
nonce = str(uuid1())
|
||||
signature = b64encode(HMAC(SECRET_KEY.encode('ascii'), b'\n'.join([
|
||||
timestamp.encode('ascii'),
|
||||
nonce.encode('ascii'),
|
||||
HTTP_APP_KEY.encode('ascii'),
|
||||
prepped.path_url.encode('ascii'),
|
||||
prepped.body if kwargs.get('json') else b'',
|
||||
urlencode(sorted(kwargs['data'].items()), quote_via=quote, safe='-._~').encode('ascii')
|
||||
if kwargs.get('data') and isinstance(kwargs['data'], dict) else b'',
|
||||
]), 'sha1').digest()).decode('ascii')
|
||||
|
||||
prepped.headers.update({
|
||||
'TIMESTAMP': timestamp,
|
||||
'NONCE': nonce,
|
||||
'APP-KEY': HTTP_APP_KEY,
|
||||
'SIGNATURE': signature,
|
||||
})
|
||||
|
||||
return sess.send(prepped, stream=stream, timeout=timeout)
|
||||
|
||||
|
||||
rag_version = get_rag_version() or ''
|
||||
|
||||
|
||||
def get_exponential_backoff_interval(retries, full_jitter=False):
|
||||
"""Calculate the exponential backoff wait time."""
|
||||
# Will be zero if factor equals 0
|
||||
countdown = min(REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC * (2 ** retries))
|
||||
# Full jitter according to
|
||||
# https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/
|
||||
if full_jitter:
|
||||
countdown = random.randrange(countdown + 1)
|
||||
# Adjust according to maximum wait time and account for negative values.
|
||||
return max(0, countdown)
|
||||
|
||||
|
||||
def get_json_result(retcode=RetCode.SUCCESS, retmsg='success', data=None, job_id=None, meta=None):
|
||||
import re
|
||||
result_dict = {
|
||||
"retcode": retcode,
|
||||
"retmsg":retmsg,
|
||||
# "retmsg": re.sub(r"rag", "seceum", retmsg, flags=re.IGNORECASE),
|
||||
"data": data,
|
||||
"jobId": job_id,
|
||||
"meta": meta,
|
||||
}
|
||||
|
||||
response = {}
|
||||
for key, value in result_dict.items():
|
||||
if value is None and key != "retcode":
|
||||
continue
|
||||
else:
|
||||
response[key] = value
|
||||
return jsonify(response)
|
||||
|
||||
def get_data_error_result(retcode=RetCode.DATA_ERROR, retmsg='Sorry! Data missing!'):
|
||||
import re
|
||||
result_dict = {"retcode": retcode, "retmsg": re.sub(r"rag", "seceum", retmsg, flags=re.IGNORECASE)}
|
||||
response = {}
|
||||
for key, value in result_dict.items():
|
||||
if value is None and key != "retcode":
|
||||
continue
|
||||
else:
|
||||
response[key] = value
|
||||
return jsonify(response)
|
||||
|
||||
def server_error_response(e):
|
||||
stat_logger.exception(e)
|
||||
try:
|
||||
if e.code==401:
|
||||
return get_json_result(retcode=401, retmsg=repr(e))
|
||||
except:
|
||||
pass
|
||||
if len(e.args) > 1:
|
||||
return get_json_result(retcode=RetCode.EXCEPTION_ERROR, retmsg=repr(e.args[0]), data=e.args[1])
|
||||
return get_json_result(retcode=RetCode.EXCEPTION_ERROR, retmsg=repr(e))
|
||||
|
||||
|
||||
def error_response(response_code, retmsg=None):
|
||||
if retmsg is None:
|
||||
retmsg = HTTP_STATUS_CODES.get(response_code, 'Unknown Error')
|
||||
|
||||
return Response(json.dumps({
|
||||
'retmsg': retmsg,
|
||||
'retcode': response_code,
|
||||
}), status=response_code, mimetype='application/json')
|
||||
|
||||
|
||||
def validate_request(*args, **kwargs):
|
||||
def wrapper(func):
|
||||
@wraps(func)
|
||||
def decorated_function(*_args, **_kwargs):
|
||||
input_arguments = flask_request.json or flask_request.form.to_dict()
|
||||
no_arguments = []
|
||||
error_arguments = []
|
||||
for arg in args:
|
||||
if arg not in input_arguments:
|
||||
no_arguments.append(arg)
|
||||
for k, v in kwargs.items():
|
||||
config_value = input_arguments.get(k, None)
|
||||
if config_value is None:
|
||||
no_arguments.append(k)
|
||||
elif isinstance(v, (tuple, list)):
|
||||
if config_value not in v:
|
||||
error_arguments.append((k, set(v)))
|
||||
elif config_value != v:
|
||||
error_arguments.append((k, v))
|
||||
if no_arguments or error_arguments:
|
||||
error_string = ""
|
||||
if no_arguments:
|
||||
error_string += "required argument are missing: {}; ".format(",".join(no_arguments))
|
||||
if error_arguments:
|
||||
error_string += "required argument values: {}".format(",".join(["{}={}".format(a[0], a[1]) for a in error_arguments]))
|
||||
return get_json_result(retcode=RetCode.ARGUMENT_ERROR, retmsg=error_string)
|
||||
return func(*_args, **_kwargs)
|
||||
return decorated_function
|
||||
return wrapper
|
||||
|
||||
|
||||
def is_localhost(ip):
|
||||
return ip in {'127.0.0.1', '::1', '[::1]', 'localhost'}
|
||||
|
||||
|
||||
def send_file_in_mem(data, filename):
|
||||
if not isinstance(data, (str, bytes)):
|
||||
data = json_dumps(data)
|
||||
if isinstance(data, str):
|
||||
data = data.encode('utf-8')
|
||||
|
||||
f = BytesIO()
|
||||
f.write(data)
|
||||
f.seek(0)
|
||||
|
||||
return send_file(f, as_attachment=True, attachment_filename=filename)
|
||||
|
||||
|
||||
def get_json_result(retcode=RetCode.SUCCESS, retmsg='success', data=None):
|
||||
response = {"retcode": retcode, "retmsg": retmsg, "data": data}
|
||||
return jsonify(response)
|
||||
|
||||
|
||||
def cors_reponse(retcode=RetCode.SUCCESS, retmsg='success', data=None, auth=None):
|
||||
result_dict = {"retcode": retcode, "retmsg": retmsg, "data": data}
|
||||
response_dict = {}
|
||||
for key, value in result_dict.items():
|
||||
if value is None and key != "retcode":
|
||||
continue
|
||||
else:
|
||||
response_dict[key] = value
|
||||
response = make_response(jsonify(response_dict))
|
||||
if auth:
|
||||
response.headers["Authorization"] = auth
|
||||
response.headers["Access-Control-Allow-Origin"] = "*"
|
||||
response.headers["Access-Control-Allow-Method"] = "*"
|
||||
response.headers["Access-Control-Allow-Headers"] = "*"
|
||||
response.headers["Access-Control-Allow-Headers"] = "*"
|
||||
response.headers["Access-Control-Expose-Headers"] = "Authorization"
|
||||
return response
|
||||
153
api/utils/file_utils.py
Normal file
153
api/utils/file_utils.py
Normal file
@ -0,0 +1,153 @@
|
||||
#
|
||||
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
|
||||
from cachetools import LRUCache, cached
|
||||
from ruamel.yaml import YAML
|
||||
|
||||
from web_server.db import FileType
|
||||
|
||||
PROJECT_BASE = os.getenv("RAG_PROJECT_BASE") or os.getenv("RAG_DEPLOY_BASE")
|
||||
RAG_BASE = os.getenv("RAG_BASE")
|
||||
|
||||
def get_project_base_directory(*args):
|
||||
global PROJECT_BASE
|
||||
if PROJECT_BASE is None:
|
||||
PROJECT_BASE = os.path.abspath(
|
||||
os.path.join(
|
||||
os.path.dirname(os.path.realpath(__file__)),
|
||||
os.pardir,
|
||||
os.pardir,
|
||||
)
|
||||
)
|
||||
|
||||
if args:
|
||||
return os.path.join(PROJECT_BASE, *args)
|
||||
return PROJECT_BASE
|
||||
|
||||
|
||||
def get_rag_directory(*args):
|
||||
global RAG_BASE
|
||||
if RAG_BASE is None:
|
||||
RAG_BASE = os.path.abspath(
|
||||
os.path.join(
|
||||
os.path.dirname(os.path.realpath(__file__)),
|
||||
os.pardir,
|
||||
os.pardir,
|
||||
os.pardir,
|
||||
)
|
||||
)
|
||||
if args:
|
||||
return os.path.join(RAG_BASE, *args)
|
||||
return RAG_BASE
|
||||
|
||||
|
||||
def get_rag_python_directory(*args):
|
||||
return get_rag_directory("python", *args)
|
||||
|
||||
|
||||
|
||||
@cached(cache=LRUCache(maxsize=10))
|
||||
def load_json_conf(conf_path):
|
||||
if os.path.isabs(conf_path):
|
||||
json_conf_path = conf_path
|
||||
else:
|
||||
json_conf_path = os.path.join(get_project_base_directory(), conf_path)
|
||||
try:
|
||||
with open(json_conf_path) as f:
|
||||
return json.load(f)
|
||||
except BaseException:
|
||||
raise EnvironmentError(
|
||||
"loading json file config from '{}' failed!".format(json_conf_path)
|
||||
)
|
||||
|
||||
|
||||
def dump_json_conf(config_data, conf_path):
|
||||
if os.path.isabs(conf_path):
|
||||
json_conf_path = conf_path
|
||||
else:
|
||||
json_conf_path = os.path.join(get_project_base_directory(), conf_path)
|
||||
try:
|
||||
with open(json_conf_path, "w") as f:
|
||||
json.dump(config_data, f, indent=4)
|
||||
except BaseException:
|
||||
raise EnvironmentError(
|
||||
"loading json file config from '{}' failed!".format(json_conf_path)
|
||||
)
|
||||
|
||||
|
||||
def load_json_conf_real_time(conf_path):
|
||||
if os.path.isabs(conf_path):
|
||||
json_conf_path = conf_path
|
||||
else:
|
||||
json_conf_path = os.path.join(get_project_base_directory(), conf_path)
|
||||
try:
|
||||
with open(json_conf_path) as f:
|
||||
return json.load(f)
|
||||
except BaseException:
|
||||
raise EnvironmentError(
|
||||
"loading json file config from '{}' failed!".format(json_conf_path)
|
||||
)
|
||||
|
||||
|
||||
def load_yaml_conf(conf_path):
|
||||
if not os.path.isabs(conf_path):
|
||||
conf_path = os.path.join(get_project_base_directory(), conf_path)
|
||||
try:
|
||||
with open(conf_path) as f:
|
||||
yaml = YAML(typ='safe', pure=True)
|
||||
return yaml.load(f)
|
||||
except Exception as e:
|
||||
raise EnvironmentError(
|
||||
"loading yaml file config from {} failed:".format(conf_path), e
|
||||
)
|
||||
|
||||
|
||||
def rewrite_yaml_conf(conf_path, config):
|
||||
if not os.path.isabs(conf_path):
|
||||
conf_path = os.path.join(get_project_base_directory(), conf_path)
|
||||
try:
|
||||
with open(conf_path, "w") as f:
|
||||
yaml = YAML(typ="safe")
|
||||
yaml.dump(config, f)
|
||||
except Exception as e:
|
||||
raise EnvironmentError(
|
||||
"rewrite yaml file config {} failed:".format(conf_path), e
|
||||
)
|
||||
|
||||
|
||||
def rewrite_json_file(filepath, json_data):
|
||||
with open(filepath, "w") as f:
|
||||
json.dump(json_data, f, indent=4, separators=(",", ": "))
|
||||
f.close()
|
||||
|
||||
|
||||
def filename_type(filename):
|
||||
filename = filename.lower()
|
||||
if re.match(r".*\.pdf$", filename):
|
||||
return FileType.PDF.value
|
||||
|
||||
if re.match(r".*\.(doc|ppt|yml|xml|htm|json|csv|txt|ini|xsl|wps|rtf|hlp|pages|numbers|key|md)$", filename):
|
||||
return FileType.DOC.value
|
||||
|
||||
if re.match(r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)$", filename):
|
||||
return FileType.AURAL.value
|
||||
|
||||
if re.match(r".*\.(jpg|jpeg|png|tif|gif|pcx|tga|exif|fpx|svg|psd|cdr|pcd|dxf|ufo|eps|ai|raw|WMF|webp|avif|apng|icon|ico|mpg|mpeg|avi|rm|rmvb|mov|wmv|asf|dat|asx|wvx|mpe|mpa|mp4)$", filename):
|
||||
return FileType.VISUAL
|
||||
294
api/utils/log_utils.py
Normal file
294
api/utils/log_utils.py
Normal file
@ -0,0 +1,294 @@
|
||||
#
|
||||
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import os
|
||||
import typing
|
||||
import traceback
|
||||
import logging
|
||||
import inspect
|
||||
from logging.handlers import TimedRotatingFileHandler
|
||||
from threading import RLock
|
||||
|
||||
from web_server.utils import file_utils
|
||||
|
||||
class LoggerFactory(object):
|
||||
TYPE = "FILE"
|
||||
LOG_FORMAT = "[%(levelname)s] [%(asctime)s] [jobId] [%(process)s:%(thread)s] - [%(module)s.%(funcName)s] [line:%(lineno)d]: %(message)s"
|
||||
LEVEL = logging.DEBUG
|
||||
logger_dict = {}
|
||||
global_handler_dict = {}
|
||||
|
||||
LOG_DIR = None
|
||||
PARENT_LOG_DIR = None
|
||||
log_share = True
|
||||
|
||||
append_to_parent_log = None
|
||||
|
||||
lock = RLock()
|
||||
# CRITICAL = 50
|
||||
# FATAL = CRITICAL
|
||||
# ERROR = 40
|
||||
# WARNING = 30
|
||||
# WARN = WARNING
|
||||
# INFO = 20
|
||||
# DEBUG = 10
|
||||
# NOTSET = 0
|
||||
levels = (10, 20, 30, 40)
|
||||
schedule_logger_dict = {}
|
||||
|
||||
@staticmethod
|
||||
def set_directory(directory=None, parent_log_dir=None, append_to_parent_log=None, force=False):
|
||||
if parent_log_dir:
|
||||
LoggerFactory.PARENT_LOG_DIR = parent_log_dir
|
||||
if append_to_parent_log:
|
||||
LoggerFactory.append_to_parent_log = append_to_parent_log
|
||||
with LoggerFactory.lock:
|
||||
if not directory:
|
||||
directory = file_utils.get_project_base_directory("logs")
|
||||
if not LoggerFactory.LOG_DIR or force:
|
||||
LoggerFactory.LOG_DIR = directory
|
||||
if LoggerFactory.log_share:
|
||||
oldmask = os.umask(000)
|
||||
os.makedirs(LoggerFactory.LOG_DIR, exist_ok=True)
|
||||
os.umask(oldmask)
|
||||
else:
|
||||
os.makedirs(LoggerFactory.LOG_DIR, exist_ok=True)
|
||||
for loggerName, ghandler in LoggerFactory.global_handler_dict.items():
|
||||
for className, (logger, handler) in LoggerFactory.logger_dict.items():
|
||||
logger.removeHandler(ghandler)
|
||||
ghandler.close()
|
||||
LoggerFactory.global_handler_dict = {}
|
||||
for className, (logger, handler) in LoggerFactory.logger_dict.items():
|
||||
logger.removeHandler(handler)
|
||||
_handler = None
|
||||
if handler:
|
||||
handler.close()
|
||||
if className != "default":
|
||||
_handler = LoggerFactory.get_handler(className)
|
||||
logger.addHandler(_handler)
|
||||
LoggerFactory.assemble_global_handler(logger)
|
||||
LoggerFactory.logger_dict[className] = logger, _handler
|
||||
|
||||
@staticmethod
|
||||
def new_logger(name):
|
||||
logger = logging.getLogger(name)
|
||||
logger.propagate = False
|
||||
logger.setLevel(LoggerFactory.LEVEL)
|
||||
return logger
|
||||
|
||||
@staticmethod
|
||||
def get_logger(class_name=None):
|
||||
with LoggerFactory.lock:
|
||||
if class_name in LoggerFactory.logger_dict.keys():
|
||||
logger, handler = LoggerFactory.logger_dict[class_name]
|
||||
if not logger:
|
||||
logger, handler = LoggerFactory.init_logger(class_name)
|
||||
else:
|
||||
logger, handler = LoggerFactory.init_logger(class_name)
|
||||
return logger
|
||||
|
||||
@staticmethod
|
||||
def get_global_handler(logger_name, level=None, log_dir=None):
|
||||
if not LoggerFactory.LOG_DIR:
|
||||
return logging.StreamHandler()
|
||||
if log_dir:
|
||||
logger_name_key = logger_name + "_" + log_dir
|
||||
else:
|
||||
logger_name_key = logger_name + "_" + LoggerFactory.LOG_DIR
|
||||
# if loggerName not in LoggerFactory.globalHandlerDict:
|
||||
if logger_name_key not in LoggerFactory.global_handler_dict:
|
||||
with LoggerFactory.lock:
|
||||
if logger_name_key not in LoggerFactory.global_handler_dict:
|
||||
handler = LoggerFactory.get_handler(logger_name, level, log_dir)
|
||||
LoggerFactory.global_handler_dict[logger_name_key] = handler
|
||||
return LoggerFactory.global_handler_dict[logger_name_key]
|
||||
|
||||
@staticmethod
|
||||
def get_handler(class_name, level=None, log_dir=None, log_type=None, job_id=None):
|
||||
if not log_type:
|
||||
if not LoggerFactory.LOG_DIR or not class_name:
|
||||
return logging.StreamHandler()
|
||||
# return Diy_StreamHandler()
|
||||
|
||||
if not log_dir:
|
||||
log_file = os.path.join(LoggerFactory.LOG_DIR, "{}.log".format(class_name))
|
||||
else:
|
||||
log_file = os.path.join(log_dir, "{}.log".format(class_name))
|
||||
else:
|
||||
log_file = os.path.join(log_dir, "rag_flow_{}.log".format(
|
||||
log_type) if level == LoggerFactory.LEVEL else 'rag_flow_{}_error.log'.format(log_type))
|
||||
|
||||
os.makedirs(os.path.dirname(log_file), exist_ok=True)
|
||||
if LoggerFactory.log_share:
|
||||
handler = ROpenHandler(log_file,
|
||||
when='D',
|
||||
interval=1,
|
||||
backupCount=14,
|
||||
delay=True)
|
||||
else:
|
||||
handler = TimedRotatingFileHandler(log_file,
|
||||
when='D',
|
||||
interval=1,
|
||||
backupCount=14,
|
||||
delay=True)
|
||||
if level:
|
||||
handler.level = level
|
||||
|
||||
return handler
|
||||
|
||||
@staticmethod
|
||||
def init_logger(class_name):
|
||||
with LoggerFactory.lock:
|
||||
logger = LoggerFactory.new_logger(class_name)
|
||||
handler = None
|
||||
if class_name:
|
||||
handler = LoggerFactory.get_handler(class_name)
|
||||
logger.addHandler(handler)
|
||||
LoggerFactory.logger_dict[class_name] = logger, handler
|
||||
|
||||
else:
|
||||
LoggerFactory.logger_dict["default"] = logger, handler
|
||||
|
||||
LoggerFactory.assemble_global_handler(logger)
|
||||
return logger, handler
|
||||
|
||||
@staticmethod
|
||||
def assemble_global_handler(logger):
|
||||
if LoggerFactory.LOG_DIR:
|
||||
for level in LoggerFactory.levels:
|
||||
if level >= LoggerFactory.LEVEL:
|
||||
level_logger_name = logging._levelToName[level]
|
||||
logger.addHandler(LoggerFactory.get_global_handler(level_logger_name, level))
|
||||
if LoggerFactory.append_to_parent_log and LoggerFactory.PARENT_LOG_DIR:
|
||||
for level in LoggerFactory.levels:
|
||||
if level >= LoggerFactory.LEVEL:
|
||||
level_logger_name = logging._levelToName[level]
|
||||
logger.addHandler(
|
||||
LoggerFactory.get_global_handler(level_logger_name, level, LoggerFactory.PARENT_LOG_DIR))
|
||||
|
||||
|
||||
def setDirectory(directory=None):
|
||||
LoggerFactory.set_directory(directory)
|
||||
|
||||
|
||||
def setLevel(level):
|
||||
LoggerFactory.LEVEL = level
|
||||
|
||||
|
||||
def getLogger(className=None, useLevelFile=False):
|
||||
if className is None:
|
||||
frame = inspect.stack()[1]
|
||||
module = inspect.getmodule(frame[0])
|
||||
className = 'stat'
|
||||
return LoggerFactory.get_logger(className)
|
||||
|
||||
|
||||
def exception_to_trace_string(ex):
|
||||
return "".join(traceback.TracebackException.from_exception(ex).format())
|
||||
|
||||
|
||||
class ROpenHandler(TimedRotatingFileHandler):
|
||||
def _open(self):
|
||||
prevumask = os.umask(000)
|
||||
rtv = TimedRotatingFileHandler._open(self)
|
||||
os.umask(prevumask)
|
||||
return rtv
|
||||
|
||||
|
||||
def sql_logger(job_id='', log_type='sql'):
|
||||
key = job_id + log_type
|
||||
if key in LoggerFactory.schedule_logger_dict.keys():
|
||||
return LoggerFactory.schedule_logger_dict[key]
|
||||
return get_job_logger(job_id=job_id, log_type=log_type)
|
||||
|
||||
|
||||
def ready_log(msg, job=None, task=None, role=None, party_id=None, detail=None):
|
||||
prefix, suffix = base_msg(job, task, role, party_id, detail)
|
||||
return f"{prefix}{msg} ready{suffix}"
|
||||
|
||||
|
||||
def start_log(msg, job=None, task=None, role=None, party_id=None, detail=None):
|
||||
prefix, suffix = base_msg(job, task, role, party_id, detail)
|
||||
return f"{prefix}start to {msg}{suffix}"
|
||||
|
||||
|
||||
def successful_log(msg, job=None, task=None, role=None, party_id=None, detail=None):
|
||||
prefix, suffix = base_msg(job, task, role, party_id, detail)
|
||||
return f"{prefix}{msg} successfully{suffix}"
|
||||
|
||||
|
||||
def warning_log(msg, job=None, task=None, role=None, party_id=None, detail=None):
|
||||
prefix, suffix = base_msg(job, task, role, party_id, detail)
|
||||
return f"{prefix}{msg} is not effective{suffix}"
|
||||
|
||||
|
||||
def failed_log(msg, job=None, task=None, role=None, party_id=None, detail=None):
|
||||
prefix, suffix = base_msg(job, task, role, party_id, detail)
|
||||
return f"{prefix}failed to {msg}{suffix}"
|
||||
|
||||
|
||||
def base_msg(job=None, task=None, role: str = None, party_id: typing.Union[str, int] = None, detail=None):
|
||||
if detail:
|
||||
detail_msg = f" detail: \n{detail}"
|
||||
else:
|
||||
detail_msg = ""
|
||||
if task is not None:
|
||||
return f"task {task.f_task_id} {task.f_task_version} ", f" on {task.f_role} {task.f_party_id}{detail_msg}"
|
||||
elif job is not None:
|
||||
return "", f" on {job.f_role} {job.f_party_id}{detail_msg}"
|
||||
elif role and party_id:
|
||||
return "", f" on {role} {party_id}{detail_msg}"
|
||||
else:
|
||||
return "", f"{detail_msg}"
|
||||
|
||||
|
||||
def exception_to_trace_string(ex):
|
||||
return "".join(traceback.TracebackException.from_exception(ex).format())
|
||||
|
||||
|
||||
def get_logger_base_dir():
|
||||
job_log_dir = file_utils.get_rag_flow_directory('logs')
|
||||
return job_log_dir
|
||||
|
||||
|
||||
def get_job_logger(job_id, log_type):
|
||||
rag_flow_log_dir = file_utils.get_rag_flow_directory('logs', 'rag_flow')
|
||||
job_log_dir = file_utils.get_rag_flow_directory('logs', job_id)
|
||||
if not job_id:
|
||||
log_dirs = [rag_flow_log_dir]
|
||||
else:
|
||||
if log_type == 'audit':
|
||||
log_dirs = [job_log_dir, rag_flow_log_dir]
|
||||
else:
|
||||
log_dirs = [job_log_dir]
|
||||
if LoggerFactory.log_share:
|
||||
oldmask = os.umask(000)
|
||||
os.makedirs(job_log_dir, exist_ok=True)
|
||||
os.makedirs(rag_flow_log_dir, exist_ok=True)
|
||||
os.umask(oldmask)
|
||||
else:
|
||||
os.makedirs(job_log_dir, exist_ok=True)
|
||||
os.makedirs(rag_flow_log_dir, exist_ok=True)
|
||||
logger = LoggerFactory.new_logger(f"{job_id}_{log_type}")
|
||||
for job_log_dir in log_dirs:
|
||||
handler = LoggerFactory.get_handler(class_name=None, level=LoggerFactory.LEVEL,
|
||||
log_dir=job_log_dir, log_type=log_type, job_id=job_id)
|
||||
error_handler = LoggerFactory.get_handler(class_name=None, level=logging.ERROR, log_dir=job_log_dir, log_type=log_type, job_id=job_id)
|
||||
logger.addHandler(handler)
|
||||
logger.addHandler(error_handler)
|
||||
with LoggerFactory.lock:
|
||||
LoggerFactory.schedule_logger_dict[job_id + log_type] = logger
|
||||
return logger
|
||||
|
||||
18
api/utils/t_crypt.py
Normal file
18
api/utils/t_crypt.py
Normal file
@ -0,0 +1,18 @@
|
||||
import base64, os, sys
|
||||
from Cryptodome.PublicKey import RSA
|
||||
from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5
|
||||
from web_server.utils import decrypt, file_utils
|
||||
|
||||
def crypt(line):
|
||||
file_path = os.path.join(file_utils.get_project_base_directory(), "conf", "public.pem")
|
||||
rsa_key = RSA.importKey(open(file_path).read())
|
||||
cipher = Cipher_pkcs1_v1_5.new(rsa_key)
|
||||
return base64.b64encode(cipher.encrypt(line.encode('utf-8'))).decode("utf-8")
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pswd = crypt(sys.argv[1])
|
||||
print(pswd)
|
||||
print(decrypt(pswd))
|
||||
|
||||
30
api/versions.py
Normal file
30
api/versions.py
Normal file
@ -0,0 +1,30 @@
|
||||
#
|
||||
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import os
|
||||
|
||||
import dotenv
|
||||
import typing
|
||||
|
||||
from web_server.utils.file_utils import get_project_base_directory
|
||||
|
||||
|
||||
def get_versions() -> typing.Mapping[str, typing.Any]:
|
||||
return dotenv.dotenv_values(
|
||||
dotenv_path=os.path.join(get_project_base_directory(), "rag.env")
|
||||
)
|
||||
|
||||
def get_rag_version() -> typing.Optional[str]:
|
||||
return get_versions().get("RAG")
|
||||
Reference in New Issue
Block a user