rename web_server to api (#29)

* add front end code

* change licence

* rename web_server to API

* change name to web_server
This commit is contained in:
KevinHuSh
2024-01-17 09:43:27 +08:00
committed by GitHub
parent c372afe40a
commit 6be3dd56fa
41 changed files with 284 additions and 262 deletions

0
api/__init__.py Normal file
View File

147
api/apps/__init__.py Normal file
View File

@ -0,0 +1,147 @@
#
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import sys
from importlib.util import module_from_spec, spec_from_file_location
from pathlib import Path
from flask import Blueprint, Flask, request
from werkzeug.wrappers.request import Request
from flask_cors import CORS
from web_server.db import StatusEnum
from web_server.db.services import UserService
from web_server.utils import CustomJSONEncoder
from flask_session import Session
from flask_login import LoginManager
from web_server.settings import RetCode, SECRET_KEY, stat_logger
from web_server.hook import HookManager
from web_server.hook.common.parameters import AuthenticationParameters, ClientAuthenticationParameters
from web_server.settings import API_VERSION, CLIENT_AUTHENTICATION, SITE_AUTHENTICATION, access_logger
from web_server.utils.api_utils import get_json_result, server_error_response
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
__all__ = ['app']
logger = logging.getLogger('flask.app')
for h in access_logger.handlers:
logger.addHandler(h)
Request.json = property(lambda self: self.get_json(force=True, silent=True))
app = Flask(__name__)
CORS(app, supports_credentials=True,max_age = 2592000)
app.url_map.strict_slashes = False
app.json_encoder = CustomJSONEncoder
app.errorhandler(Exception)(server_error_response)
## convince for dev and debug
#app.config["LOGIN_DISABLED"] = True
app.config["SESSION_PERMANENT"] = False
app.config["SESSION_TYPE"] = "filesystem"
app.config['MAX_CONTENT_LENGTH'] = 64 * 1024 * 1024
Session(app)
login_manager = LoginManager()
login_manager.init_app(app)
def search_pages_path(pages_dir):
return [path for path in pages_dir.glob('*_app.py') if not path.name.startswith('.')]
def register_page(page_path):
page_name = page_path.stem.rstrip('_app')
module_name = '.'.join(page_path.parts[page_path.parts.index('web_server'):-1] + (page_name, ))
spec = spec_from_file_location(module_name, page_path)
page = module_from_spec(spec)
page.app = app
page.manager = Blueprint(page_name, module_name)
sys.modules[module_name] = page
spec.loader.exec_module(page)
page_name = getattr(page, 'page_name', page_name)
url_prefix = f'/{API_VERSION}/{page_name}'
app.register_blueprint(page.manager, url_prefix=url_prefix)
return url_prefix
pages_dir = [
Path(__file__).parent,
Path(__file__).parent.parent / 'web_server' / 'apps',
]
client_urls_prefix = [
register_page(path)
for dir in pages_dir
for path in search_pages_path(dir)
]
def client_authentication_before_request():
result = HookManager.client_authentication(ClientAuthenticationParameters(
request.full_path, request.headers,
request.form, request.data, request.json,
))
if result.code != RetCode.SUCCESS:
return get_json_result(result.code, result.message)
def site_authentication_before_request():
for url_prefix in client_urls_prefix:
if request.path.startswith(url_prefix):
return
result = HookManager.site_authentication(AuthenticationParameters(
request.headers.get('site_signature'),
request.json,
))
if result.code != RetCode.SUCCESS:
return get_json_result(result.code, result.message)
@app.before_request
def authentication_before_request():
if CLIENT_AUTHENTICATION:
return client_authentication_before_request()
if SITE_AUTHENTICATION:
return site_authentication_before_request()
@login_manager.request_loader
def load_user(web_request):
jwt = Serializer(secret_key=SECRET_KEY)
authorization = web_request.headers.get("Authorization")
if authorization:
try:
access_token = str(jwt.loads(authorization))
user = UserService.query(access_token=access_token, status=StatusEnum.VALID.value)
if user:
return user[0]
else:
return None
except Exception as e:
stat_logger.exception(e)
return None
else:
return None

150
api/apps/chunk_app.py Normal file
View File

@ -0,0 +1,150 @@
#
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import base64
import hashlib
import pathlib
import re
from elasticsearch_dsl import Q
from flask import request
from flask_login import login_required, current_user
from rag.nlp import search, huqie
from rag.utils import ELASTICSEARCH, rmSpace
from web_server.db import LLMType
from web_server.db.services import duplicate_name
from web_server.db.services.kb_service import KnowledgebaseService
from web_server.db.services.llm_service import TenantLLMService
from web_server.db.services.user_service import UserTenantService
from web_server.utils.api_utils import server_error_response, get_data_error_result, validate_request
from web_server.utils import get_uuid
from web_server.db.services.document_service import DocumentService
from web_server.settings import RetCode
from web_server.utils.api_utils import get_json_result
from rag.utils.minio_conn import MINIO
from web_server.utils.file_utils import filename_type
retrival = search.Dealer(ELASTICSEARCH, None)
@manager.route('/list', methods=['POST'])
@login_required
@validate_request("doc_id")
def list():
req = request.json
doc_id = req["doc_id"]
page = req.get("page", 1)
size = req.get("size", 30)
question = req.get("keywords", "")
try:
tenants = UserTenantService.query(user_id=current_user.id)
if not tenants:
return get_data_error_result(retmsg="Tenant not found!")
res = retrival.search({
"doc_ids": [doc_id], "page": page, "size": size, "question": question
}, search.index_name(tenants[0].tenant_id))
return get_json_result(data=res)
except Exception as e:
if str(e).find("not_found") > 0:
return get_json_result(data=False, retmsg=f'Index not found!',
retcode=RetCode.DATA_ERROR)
return server_error_response(e)
@manager.route('/get', methods=['GET'])
@login_required
def get():
chunk_id = request.args["chunk_id"]
try:
tenants = UserTenantService.query(user_id=current_user.id)
if not tenants:
return get_data_error_result(retmsg="Tenant not found!")
res = ELASTICSEARCH.get(chunk_id, search.index_name(tenants[0].tenant_id))
if not res.get("found"):return server_error_response("Chunk not found")
id = res["_id"]
res = res["_source"]
res["chunk_id"] = id
k = []
for n in res.keys():
if re.search(r"(_vec$|_sm_)", n):
k.append(n)
if re.search(r"(_tks|_ltks)", n):
res[n] = rmSpace(res[n])
for n in k: del res[n]
return get_json_result(data=res)
except Exception as e:
if str(e).find("NotFoundError") >= 0:
return get_json_result(data=False, retmsg=f'Chunk not found!',
retcode=RetCode.DATA_ERROR)
return server_error_response(e)
@manager.route('/set', methods=['POST'])
@login_required
@validate_request("doc_id", "chunk_id", "content_ltks", "important_kwd", "docnm_kwd")
def set():
req = request.json
d = {"id": req["chunk_id"]}
d["content_ltks"] = huqie.qie(req["content_ltks"])
d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"])
d["important_kwd"] = req["important_kwd"]
d["important_tks"] = huqie.qie(" ".join(req["important_kwd"]))
try:
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
if not tenant_id: return get_data_error_result(retmsg="Tenant not found!")
embd_mdl = TenantLLMService.model_instance(tenant_id, LLMType.EMBEDDING.value)
v, c = embd_mdl.encode([req["docnm_kwd"], req["content_ltks"]])
v = 0.1 * v[0] + 0.9 * v[1]
d["q_%d_vec"%len(v)] = v.tolist()
ELASTICSEARCH.upsert([d], search.index_name(tenant_id))
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)
@manager.route('/create', methods=['POST'])
@login_required
@validate_request("doc_id", "content_ltks", "important_kwd")
def set():
req = request.json
md5 = hashlib.md5()
md5.update((req["content_ltks"] + req["doc_id"]).encode("utf-8"))
chunck_id = md5.hexdigest()
d = {"id": chunck_id, "content_ltks": huqie.qie(req["content_ltks"])}
d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"])
d["important_kwd"] = req["important_kwd"]
d["important_tks"] = huqie.qie(" ".join(req["important_kwd"]))
try:
e, doc = DocumentService.get_by_id(req["doc_id"])
if not e: return get_data_error_result(retmsg="Document not found!")
d["kb_id"] = [doc.kb_id]
d["docnm_kwd"] = doc.name
d["doc_id"] = doc.id
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
if not tenant_id: return get_data_error_result(retmsg="Tenant not found!")
embd_mdl = TenantLLMService.model_instance(tenant_id, LLMType.EMBEDDING.value)
v, c = embd_mdl.encode([doc.name, req["content_ltks"]])
DocumentService.increment_chunk_num(req["doc_id"], doc.kb_id, c, 1, 0)
v = 0.1 * v[0] + 0.9 * v[1]
d["q_%d_vec"%len(v)] = v.tolist()
ELASTICSEARCH.upsert([d], search.index_name(tenant_id))
return get_json_result(data={"chunk_id": chunck_id})
except Exception as e:
return server_error_response(e)

279
api/apps/document_app.py Normal file
View File

@ -0,0 +1,279 @@
#
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import base64
import pathlib
from elasticsearch_dsl import Q
from flask import request
from flask_login import login_required, current_user
from rag.nlp import search
from rag.utils import ELASTICSEARCH
from web_server.db.services import duplicate_name
from web_server.db.services.kb_service import KnowledgebaseService
from web_server.utils.api_utils import server_error_response, get_data_error_result, validate_request
from web_server.utils import get_uuid
from web_server.db import FileType
from web_server.db.services.document_service import DocumentService
from web_server.settings import RetCode
from web_server.utils.api_utils import get_json_result
from rag.utils.minio_conn import MINIO
from web_server.utils.file_utils import filename_type
@manager.route('/upload', methods=['POST'])
@login_required
@validate_request("kb_id")
def upload():
kb_id = request.form.get("kb_id")
if not kb_id:
return get_json_result(
data=False, retmsg='Lack of "KB ID"', retcode=RetCode.ARGUMENT_ERROR)
if 'file' not in request.files:
return get_json_result(
data=False, retmsg='No file part!', retcode=RetCode.ARGUMENT_ERROR)
file = request.files['file']
if file.filename == '':
return get_json_result(
data=False, retmsg='No file selected!', retcode=RetCode.ARGUMENT_ERROR)
try:
e, kb = KnowledgebaseService.get_by_id(kb_id)
if not e:
return get_data_error_result(
retmsg="Can't find this knowledgebase!")
filename = duplicate_name(
DocumentService.query,
name=file.filename,
kb_id=kb.id)
location = filename
while MINIO.obj_exist(kb_id, location):
location += "_"
blob = request.files['file'].read()
MINIO.put(kb_id, filename, blob)
doc = DocumentService.insert({
"id": get_uuid(),
"kb_id": kb.id,
"parser_id": kb.parser_id,
"created_by": current_user.id,
"type": filename_type(filename),
"name": filename,
"location": location,
"size": len(blob)
})
return get_json_result(data=doc.to_json())
except Exception as e:
return server_error_response(e)
@manager.route('/create', methods=['POST'])
@login_required
@validate_request("name", "kb_id")
def create():
req = request.json
kb_id = req["kb_id"]
if not kb_id:
return get_json_result(
data=False, retmsg='Lack of "KB ID"', retcode=RetCode.ARGUMENT_ERROR)
try:
e, kb = KnowledgebaseService.get_by_id(kb_id)
if not e:
return get_data_error_result(
retmsg="Can't find this knowledgebase!")
if DocumentService.query(name=req["name"], kb_id=kb_id):
return get_data_error_result(
retmsg="Duplicated document name in the same knowledgebase.")
doc = DocumentService.insert({
"id": get_uuid(),
"kb_id": kb.id,
"parser_id": kb.parser_id,
"created_by": current_user.id,
"type": FileType.VIRTUAL,
"name": req["name"],
"location": "",
"size": 0
})
return get_json_result(data=doc.to_json())
except Exception as e:
return server_error_response(e)
@manager.route('/list', methods=['GET'])
@login_required
def list():
kb_id = request.args.get("kb_id")
if not kb_id:
return get_json_result(
data=False, retmsg='Lack of "KB ID"', retcode=RetCode.ARGUMENT_ERROR)
keywords = request.args.get("keywords", "")
page_number = request.args.get("page", 1)
items_per_page = request.args.get("page_size", 15)
orderby = request.args.get("orderby", "create_time")
desc = request.args.get("desc", True)
try:
docs = DocumentService.get_by_kb_id(
kb_id, page_number, items_per_page, orderby, desc, keywords)
return get_json_result(data=docs)
except Exception as e:
return server_error_response(e)
@manager.route('/change_status', methods=['POST'])
@login_required
@validate_request("doc_id", "status")
def change_status():
req = request.json
if str(req["status"]) not in ["0", "1"]:
get_json_result(
data=False,
retmsg='"Status" must be either 0 or 1!',
retcode=RetCode.ARGUMENT_ERROR)
try:
e, doc = DocumentService.get_by_id(req["doc_id"])
if not e:
return get_data_error_result(retmsg="Document not found!")
e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
if not e:
return get_data_error_result(
retmsg="Can't find this knowledgebase!")
if not DocumentService.update_by_id(
req["doc_id"], {"status": str(req["status"])}):
return get_data_error_result(
retmsg="Database error (Document update)!")
if str(req["status"]) == "0":
ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=req["doc_id"]),
scripts="""
if(ctx._source.kb_id.contains('%s'))
ctx._source.kb_id.remove(
ctx._source.kb_id.indexOf('%s')
);
""" % (doc.kb_id, doc.kb_id),
idxnm=search.index_name(
kb.tenant_id)
)
else:
ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=req["doc_id"]),
scripts="""
if(!ctx._source.kb_id.contains('%s'))
ctx._source.kb_id.add('%s');
""" % (doc.kb_id, doc.kb_id),
idxnm=search.index_name(
kb.tenant_id)
)
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)
@manager.route('/rm', methods=['POST'])
@login_required
@validate_request("doc_id")
def rm():
req = request.json
try:
e, doc = DocumentService.get_by_id(req["doc_id"])
if not e:
return get_data_error_result(retmsg="Document not found!")
if not ELASTICSEARCH.deleteByQuery(Q("match", doc_id=doc.id), idxnm=search.index_name(doc.kb_id)):
return get_json_result(data=False, retmsg='Remove from ES failure"', retcode=RetCode.SERVER_ERROR)
DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num*-1, doc.chunk_num*-1, 0)
if not DocumentService.delete_by_id(req["doc_id"]):
return get_data_error_result(
retmsg="Database error (Document removal)!")
MINIO.rm(doc.kb_id, doc.location)
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)
@manager.route('/rename', methods=['POST'])
@login_required
@validate_request("doc_id", "name", "old_name")
def rename():
req = request.json
if pathlib.Path(req["name"].lower()).suffix != pathlib.Path(
req["old_name"].lower()).suffix:
get_json_result(
data=False,
retmsg="The extension of file can't be changed",
retcode=RetCode.ARGUMENT_ERROR)
try:
e, doc = DocumentService.get_by_id(req["doc_id"])
if not e:
return get_data_error_result(retmsg="Document not found!")
if DocumentService.query(name=req["name"], kb_id=doc.kb_id):
return get_data_error_result(
retmsg="Duplicated document name in the same knowledgebase.")
if not DocumentService.update_by_id(
req["doc_id"], {"name": req["name"]}):
return get_data_error_result(
retmsg="Database error (Document rename)!")
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)
@manager.route('/get', methods=['GET'])
@login_required
def get():
doc_id = request.args["doc_id"]
try:
e, doc = DocumentService.get_by_id(doc_id)
if not e:
return get_data_error_result(retmsg="Document not found!")
blob = MINIO.get(doc.kb_id, doc.location)
return get_json_result(data={"base64": base64.b64decode(blob)})
except Exception as e:
return server_error_response(e)
@manager.route('/change_parser', methods=['POST'])
@login_required
@validate_request("doc_id", "parser_id")
def change_parser():
req = request.json
try:
e, doc = DocumentService.get_by_id(req["doc_id"])
if not e:
return get_data_error_result(retmsg="Document not found!")
if doc.parser_id.lower() == req["parser_id"].lower():
return get_json_result(data=True)
e = DocumentService.update_by_id(doc.id, {"parser_id": req["parser_id"], "progress":0, "progress_msg": ""})
if not e:
return get_data_error_result(retmsg="Document not found!")
e = DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num*-1, doc.chunk_num*-1, doc.process_duation*-1)
if not e:
return get_data_error_result(retmsg="Document not found!")
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)

114
api/apps/kb_app.py Normal file
View File

@ -0,0 +1,114 @@
#
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from flask import request
from flask_login import login_required, current_user
from web_server.db.services import duplicate_name
from web_server.db.services.user_service import TenantService, UserTenantService
from web_server.utils.api_utils import server_error_response, get_data_error_result, validate_request
from web_server.utils import get_uuid, get_format_time
from web_server.db import StatusEnum, UserTenantRole
from web_server.db.services.kb_service import KnowledgebaseService
from web_server.db.db_models import Knowledgebase
from web_server.settings import stat_logger, RetCode
from web_server.utils.api_utils import get_json_result
@manager.route('/create', methods=['post'])
@login_required
@validate_request("name", "description", "permission", "parser_id")
def create():
req = request.json
req["name"] = req["name"].strip()
req["name"] = duplicate_name(KnowledgebaseService.query, name=req["name"], tenant_id=current_user.id, status=StatusEnum.VALID.value)
try:
req["id"] = get_uuid()
req["tenant_id"] = current_user.id
req["created_by"] = current_user.id
if not KnowledgebaseService.save(**req): return get_data_error_result()
return get_json_result(data={"kb_id": req["id"]})
except Exception as e:
return server_error_response(e)
@manager.route('/update', methods=['post'])
@login_required
@validate_request("kb_id", "name", "description", "permission", "parser_id")
def update():
req = request.json
req["name"] = req["name"].strip()
try:
if not KnowledgebaseService.query(created_by=current_user.id, id=req["kb_id"]):
return get_json_result(data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.', retcode=RetCode.OPERATING_ERROR)
e, kb = KnowledgebaseService.get_by_id(req["kb_id"])
if not e: return get_data_error_result(retmsg="Can't find this knowledgebase!")
if req["name"].lower() != kb.name.lower() \
and len(KnowledgebaseService.query(name=req["name"], tenant_id=current_user.id, status=StatusEnum.VALID.value))>1:
return get_data_error_result(retmsg="Duplicated knowledgebase name.")
del req["kb_id"]
if not KnowledgebaseService.update_by_id(kb.id, req): return get_data_error_result()
e, kb = KnowledgebaseService.get_by_id(kb.id)
if not e: return get_data_error_result(retmsg="Database error (Knowledgebase rename)!")
return get_json_result(data=kb.to_json())
except Exception as e:
return server_error_response(e)
@manager.route('/detail', methods=['GET'])
@login_required
def detail():
kb_id = request.args["kb_id"]
try:
kb = KnowledgebaseService.get_detail(kb_id)
if not kb: return get_data_error_result(retmsg="Can't find this knowledgebase!")
return get_json_result(data=kb)
except Exception as e:
return server_error_response(e)
@manager.route('/list', methods=['GET'])
@login_required
def list():
page_number = request.args.get("page", 1)
items_per_page = request.args.get("page_size", 15)
orderby = request.args.get("orderby", "create_time")
desc = request.args.get("desc", True)
try:
tenants = TenantService.get_joined_tenants_by_user_id(current_user.id)
kbs = KnowledgebaseService.get_by_tenant_ids([m["tenant_id"] for m in tenants], current_user.id, page_number, items_per_page, orderby, desc)
return get_json_result(data=kbs)
except Exception as e:
return server_error_response(e)
@manager.route('/rm', methods=['post'])
@login_required
@validate_request("kb_id")
def rm():
req = request.json
try:
if not KnowledgebaseService.query(created_by=current_user.id, id=req["kb_id"]):
return get_json_result(data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.', retcode=RetCode.OPERATING_ERROR)
if not KnowledgebaseService.update_by_id(req["kb_id"], {"status": StatusEnum.IN_VALID.value}): return get_data_error_result(retmsg="Database error (Knowledgebase removal)!")
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)

93
api/apps/llm_app.py Normal file
View File

@ -0,0 +1,93 @@
#
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from flask import request
from flask_login import login_required, current_user
from web_server.db.services import duplicate_name
from web_server.db.services.llm_service import LLMFactoriesService, TenantLLMService, LLMService
from web_server.db.services.user_service import TenantService, UserTenantService
from web_server.utils.api_utils import server_error_response, get_data_error_result, validate_request
from web_server.utils import get_uuid, get_format_time
from web_server.db import StatusEnum, UserTenantRole
from web_server.db.services.kb_service import KnowledgebaseService
from web_server.db.db_models import Knowledgebase, TenantLLM
from web_server.settings import stat_logger, RetCode
from web_server.utils.api_utils import get_json_result
@manager.route('/factories', methods=['GET'])
@login_required
def factories():
try:
fac = LLMFactoriesService.get_all()
return get_json_result(data=[f.to_dict() for f in fac])
except Exception as e:
return server_error_response(e)
@manager.route('/set_api_key', methods=['POST'])
@login_required
@validate_request("llm_factory", "api_key")
def set_api_key():
req = request.json
llm = {
"tenant_id": current_user.id,
"llm_factory": req["llm_factory"],
"api_key": req["api_key"]
}
# TODO: Test api_key
for n in ["model_type", "llm_name"]:
if n in req: llm[n] = req[n]
TenantLLM.insert(**llm).on_conflict("replace").execute()
return get_json_result(data=True)
@manager.route('/my_llms', methods=['GET'])
@login_required
def my_llms():
try:
objs = TenantLLMService.get_my_llms(current_user.id)
return get_json_result(data=objs)
except Exception as e:
return server_error_response(e)
@manager.route('/list', methods=['GET'])
@login_required
def list():
try:
objs = TenantLLMService.query(tenant_id=current_user.id)
objs = [o.to_dict() for o in objs if o.api_key]
fct = {}
for o in objs:
if o["llm_factory"] not in fct: fct[o["llm_factory"]] = []
if o["llm_name"]: fct[o["llm_factory"]].append(o["llm_name"])
llms = LLMService.get_all()
llms = [m.to_dict() for m in llms if m.status == StatusEnum.VALID.value]
for m in llms:
m["available"] = False
if m["fid"] in fct and (not fct[m["fid"]] or m["llm_name"] in fct[m["fid"]]):
m["available"] = True
res = {}
for m in llms:
if m["fid"] not in res: res[m["fid"]] = []
res[m["fid"]].append(m)
return get_json_result(data=res)
except Exception as e:
return server_error_response(e)

268
api/apps/user_app.py Normal file
View File

@ -0,0 +1,268 @@
#
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from flask import request, session, redirect, url_for
from werkzeug.security import generate_password_hash, check_password_hash
from flask_login import login_required, current_user, login_user, logout_user
from web_server.db.db_models import TenantLLM
from web_server.db.services.llm_service import TenantLLMService
from web_server.utils.api_utils import server_error_response, validate_request
from web_server.utils import get_uuid, get_format_time, decrypt, download_img
from web_server.db import UserTenantRole, LLMType
from web_server.settings import RetCode, GITHUB_OAUTH, CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS
from web_server.db.services.user_service import UserService, TenantService, UserTenantService
from web_server.settings import stat_logger
from web_server.utils.api_utils import get_json_result, cors_reponse
@manager.route('/login', methods=['POST', 'GET'])
def login():
userinfo = None
login_channel = "password"
if session.get("access_token"):
login_channel = session["access_token_from"]
if session["access_token_from"] == "github":
userinfo = user_info_from_github(session["access_token"])
elif not request.json:
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR,
retmsg='Unautherized!')
email = request.json.get('email') if not userinfo else userinfo["email"]
users = UserService.query(email=email)
if not users:
if request.json is not None:
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg=f'This Email is not registered!')
avatar = ""
try:
avatar = download_img(userinfo["avatar_url"])
except Exception as e:
stat_logger.exception(e)
user_id = get_uuid()
try:
users = user_register(user_id, {
"access_token": session["access_token"],
"email": userinfo["email"],
"avatar": avatar,
"nickname": userinfo["login"],
"login_channel": login_channel,
"last_login_time": get_format_time(),
"is_superuser": False,
})
if not users: raise Exception('Register user failure.')
if len(users) > 1: raise Exception('Same E-mail exist!')
user = users[0]
login_user(user)
return cors_reponse(data=user.to_json(), auth=user.get_id(), retmsg="Welcome back!")
except Exception as e:
rollback_user_registration(user_id)
stat_logger.exception(e)
return server_error_response(e)
elif not request.json:
login_user(users[0])
return cors_reponse(data=users[0].to_json(), auth=users[0].get_id(), retmsg="Welcome back!")
password = request.json.get('password')
try:
password = decrypt(password)
except:
return get_json_result(data=False, retcode=RetCode.SERVER_ERROR, retmsg='Fail to crypt password')
user = UserService.query_user(email, password)
if user:
response_data = user.to_json()
user.access_token = get_uuid()
login_user(user)
user.save()
msg = "Welcome back!"
return cors_reponse(data=response_data, auth=user.get_id(), retmsg=msg)
else:
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg='Email and Password do not match!')
@manager.route('/github_callback', methods=['GET'])
def github_callback():
try:
import requests
res = requests.post(GITHUB_OAUTH.get("url"), data={
"client_id": GITHUB_OAUTH.get("client_id"),
"client_secret": GITHUB_OAUTH.get("secret_key"),
"code": request.args.get('code')
},headers={"Accept": "application/json"})
res = res.json()
if "error" in res:
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR,
retmsg=res["error_description"])
if "user:email" not in res["scope"].split(","):
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg='user:email not in scope')
session["access_token"] = res["access_token"]
session["access_token_from"] = "github"
return redirect(url_for("user.login"), code=307)
except Exception as e:
stat_logger.exception(e)
return server_error_response(e)
def user_info_from_github(access_token):
import requests
headers = {"Accept": "application/json", 'Authorization': f"token {access_token}"}
res = requests.get(f"https://api.github.com/user?access_token={access_token}", headers=headers)
user_info = res.json()
email_info = requests.get(f"https://api.github.com/user/emails?access_token={access_token}", headers=headers).json()
user_info["email"] = next((email for email in email_info if email['primary'] == True), None)["email"]
return user_info
@manager.route("/logout", methods=['GET'])
@login_required
def log_out():
current_user.access_token = ""
current_user.save()
logout_user()
return get_json_result(data=True)
@manager.route("/setting", methods=["POST"])
@login_required
def setting_user():
update_dict = {}
request_data = request.json
if request_data.get("password"):
new_password = request_data.get("new_password")
if not check_password_hash(current_user.password, decrypt(request_data["password"])):
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg='Password error!')
if new_password: update_dict["password"] = generate_password_hash(decrypt(new_password))
for k in request_data.keys():
if k in ["password", "new_password"]:continue
update_dict[k] = request_data[k]
try:
UserService.update_by_id(current_user.id, update_dict)
return get_json_result(data=True)
except Exception as e:
stat_logger.exception(e)
return get_json_result(data=False, retmsg='Update failure!', retcode=RetCode.EXCEPTION_ERROR)
@manager.route("/info", methods=["GET"])
@login_required
def user_info():
return get_json_result(data=current_user.to_dict())
def rollback_user_registration(user_id):
try:
TenantService.delete_by_id(user_id)
except Exception as e:
pass
try:
u = UserTenantService.query(tenant_id=user_id)
if u:
UserTenantService.delete_by_id(u[0].id)
except Exception as e:
pass
try:
TenantLLM.delete().where(TenantLLM.tenant_id==user_id).excute()
except Exception as e:
pass
def user_register(user_id, user):
user_id = get_uuid()
user["id"] = user_id
tenant = {
"id": user_id,
"name": user["nickname"] + "s Kingdom",
"llm_id": CHAT_MDL,
"embd_id": EMBEDDING_MDL,
"asr_id": ASR_MDL,
"parser_ids": PARSERS,
"img2txt_id": IMAGE2TEXT_MDL
}
usr_tenant = {
"tenant_id": user_id,
"user_id": user_id,
"invited_by": user_id,
"role": UserTenantRole.OWNER
}
tenant_llm = {"tenant_id": user_id, "llm_factory": "OpenAI", "api_key": "infiniflow API Key"}
if not UserService.save(**user):return
TenantService.save(**tenant)
UserTenantService.save(**usr_tenant)
TenantLLMService.save(**tenant_llm)
return UserService.query(email=user["email"])
@manager.route("/register", methods=["POST"])
@validate_request("nickname", "email", "password")
def user_add():
req = request.json
if UserService.query(email=req["email"]):
return get_json_result(data=False, retmsg=f'Email: {req["email"]} has already registered!', retcode=RetCode.OPERATING_ERROR)
user_dict = {
"access_token": get_uuid(),
"email": req["email"],
"nickname": req["nickname"],
"password": decrypt(req["password"]),
"login_channel": "password",
"last_login_time": get_format_time(),
"is_superuser": False,
}
user_id = get_uuid()
try:
users = user_register(user_id, user_dict)
if not users: raise Exception('Register user failure.')
if len(users) > 1: raise Exception('Same E-mail exist!')
user = users[0]
login_user(user)
return cors_reponse(data=user.to_json(), auth=user.get_id(), retmsg="Welcome aboard!")
except Exception as e:
rollback_user_registration(user_id)
stat_logger.exception(e)
return get_json_result(data=False, retmsg='User registration failure!', retcode=RetCode.EXCEPTION_ERROR)
@manager.route("/tenant_info", methods=["GET"])
@login_required
def tenant_info():
try:
tenants = TenantService.get_by_user_id(current_user.id)[0]
return get_json_result(data=tenants)
except Exception as e:
return server_error_response(e)
@manager.route("/set_tenant_info", methods=["POST"])
@login_required
@validate_request("tenant_id", "asr_id", "embd_id", "img2txt_id", "llm_id")
def set_tenant_info():
req = request.json
try:
tid = req["tenant_id"]
del req["tenant_id"]
TenantService.update_by_id(tid, req)
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)

54
api/db/__init__.py Normal file
View File

@ -0,0 +1,54 @@
#
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from enum import Enum
from enum import IntEnum
from strenum import StrEnum
class StatusEnum(Enum):
VALID = "1"
IN_VALID = "0"
class UserTenantRole(StrEnum):
OWNER = 'owner'
ADMIN = 'admin'
NORMAL = 'normal'
class TenantPermission(StrEnum):
ME = 'me'
TEAM = 'team'
class SerializedType(IntEnum):
PICKLE = 1
JSON = 2
class FileType(StrEnum):
PDF = 'pdf'
DOC = 'doc'
VISUAL = 'visual'
AURAL = 'aural'
VIRTUAL = 'virtual'
class LLMType(StrEnum):
CHAT = 'chat'
EMBEDDING = 'embedding'
SPEECH2TEXT = 'speech2text'
IMAGE2TEXT = 'image2text'

619
api/db/db_models.py Normal file
View File

@ -0,0 +1,619 @@
#
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import inspect
import os
import sys
import typing
import operator
from functools import wraps
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
from flask_login import UserMixin
from peewee import (
BigAutoField, BigIntegerField, BooleanField, CharField,
CompositeKey, Insert, IntegerField, TextField, FloatField, DateTimeField,
Field, Model, Metadata
)
from playhouse.pool import PooledMySQLDatabase
from web_server.db import SerializedType
from web_server.settings import DATABASE, stat_logger, SECRET_KEY
from web_server.utils.log_utils import getLogger
from web_server import utils
LOGGER = getLogger()
def singleton(cls, *args, **kw):
instances = {}
def _singleton():
key = str(cls) + str(os.getpid())
if key not in instances:
instances[key] = cls(*args, **kw)
return instances[key]
return _singleton
CONTINUOUS_FIELD_TYPE = {IntegerField, FloatField, DateTimeField}
AUTO_DATE_TIMESTAMP_FIELD_PREFIX = {"create", "start", "end", "update", "read_access", "write_access"}
class LongTextField(TextField):
field_type = 'LONGTEXT'
class JSONField(LongTextField):
default_value = {}
def __init__(self, object_hook=None, object_pairs_hook=None, **kwargs):
self._object_hook = object_hook
self._object_pairs_hook = object_pairs_hook
super().__init__(**kwargs)
def db_value(self, value):
if value is None:
value = self.default_value
return utils.json_dumps(value)
def python_value(self, value):
if not value:
return self.default_value
return utils.json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
class ListField(JSONField):
default_value = []
class SerializedField(LongTextField):
def __init__(self, serialized_type=SerializedType.PICKLE, object_hook=None, object_pairs_hook=None, **kwargs):
self._serialized_type = serialized_type
self._object_hook = object_hook
self._object_pairs_hook = object_pairs_hook
super().__init__(**kwargs)
def db_value(self, value):
if self._serialized_type == SerializedType.PICKLE:
return utils.serialize_b64(value, to_str=True)
elif self._serialized_type == SerializedType.JSON:
if value is None:
return None
return utils.json_dumps(value, with_type=True)
else:
raise ValueError(f"the serialized type {self._serialized_type} is not supported")
def python_value(self, value):
if self._serialized_type == SerializedType.PICKLE:
return utils.deserialize_b64(value)
elif self._serialized_type == SerializedType.JSON:
if value is None:
return {}
return utils.json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
else:
raise ValueError(f"the serialized type {self._serialized_type} is not supported")
def is_continuous_field(cls: typing.Type) -> bool:
if cls in CONTINUOUS_FIELD_TYPE:
return True
for p in cls.__bases__:
if p in CONTINUOUS_FIELD_TYPE:
return True
elif p != Field and p != object:
if is_continuous_field(p):
return True
else:
return False
def auto_date_timestamp_field():
return {f"{f}_time" for f in AUTO_DATE_TIMESTAMP_FIELD_PREFIX}
def auto_date_timestamp_db_field():
return {f"f_{f}_time" for f in AUTO_DATE_TIMESTAMP_FIELD_PREFIX}
def remove_field_name_prefix(field_name):
return field_name[2:] if field_name.startswith('f_') else field_name
class BaseModel(Model):
create_time = BigIntegerField(null=True)
create_date = DateTimeField(null=True)
update_time = BigIntegerField(null=True)
update_date = DateTimeField(null=True)
def to_json(self):
# This function is obsolete
return self.to_dict()
def to_dict(self):
return self.__dict__['__data__']
def to_human_model_dict(self, only_primary_with: list = None):
model_dict = self.__dict__['__data__']
if not only_primary_with:
return {remove_field_name_prefix(k): v for k, v in model_dict.items()}
human_model_dict = {}
for k in self._meta.primary_key.field_names:
human_model_dict[remove_field_name_prefix(k)] = model_dict[k]
for k in only_primary_with:
human_model_dict[k] = model_dict[f'f_{k}']
return human_model_dict
@property
def meta(self) -> Metadata:
return self._meta
@classmethod
def get_primary_keys_name(cls):
return cls._meta.primary_key.field_names if isinstance(cls._meta.primary_key, CompositeKey) else [
cls._meta.primary_key.name]
@classmethod
def getter_by(cls, attr):
return operator.attrgetter(attr)(cls)
@classmethod
def query(cls, reverse=None, order_by=None, **kwargs):
filters = []
for f_n, f_v in kwargs.items():
attr_name = '%s' % f_n
if not hasattr(cls, attr_name) or f_v is None:
continue
if type(f_v) in {list, set}:
f_v = list(f_v)
if is_continuous_field(type(getattr(cls, attr_name))):
if len(f_v) == 2:
for i, v in enumerate(f_v):
if isinstance(v, str) and f_n in auto_date_timestamp_field():
# time type: %Y-%m-%d %H:%M:%S
f_v[i] = utils.date_string_to_timestamp(v)
lt_value = f_v[0]
gt_value = f_v[1]
if lt_value is not None and gt_value is not None:
filters.append(cls.getter_by(attr_name).between(lt_value, gt_value))
elif lt_value is not None:
filters.append(operator.attrgetter(attr_name)(cls) >= lt_value)
elif gt_value is not None:
filters.append(operator.attrgetter(attr_name)(cls) <= gt_value)
else:
filters.append(operator.attrgetter(attr_name)(cls) << f_v)
else:
filters.append(operator.attrgetter(attr_name)(cls) == f_v)
if filters:
query_records = cls.select().where(*filters)
if reverse is not None:
if not order_by or not hasattr(cls, f"{order_by}"):
order_by = "create_time"
if reverse is True:
query_records = query_records.order_by(cls.getter_by(f"{order_by}").desc())
elif reverse is False:
query_records = query_records.order_by(cls.getter_by(f"{order_by}").asc())
return [query_record for query_record in query_records]
else:
return []
@classmethod
def insert(cls, __data=None, **insert):
if isinstance(__data, dict) and __data:
__data[cls._meta.combined["create_time"]] = utils.current_timestamp()
if insert:
insert["create_time"] = utils.current_timestamp()
return super().insert(__data, **insert)
# update and insert will call this method
@classmethod
def _normalize_data(cls, data, kwargs):
normalized = super()._normalize_data(data, kwargs)
if not normalized:
return {}
normalized[cls._meta.combined["update_time"]] = utils.current_timestamp()
for f_n in AUTO_DATE_TIMESTAMP_FIELD_PREFIX:
if {f"{f_n}_time", f"{f_n}_date"}.issubset(cls._meta.combined.keys()) and \
cls._meta.combined[f"{f_n}_time"] in normalized and \
normalized[cls._meta.combined[f"{f_n}_time"]] is not None:
normalized[cls._meta.combined[f"{f_n}_date"]] = utils.timestamp_to_date(
normalized[cls._meta.combined[f"{f_n}_time"]])
return normalized
class JsonSerializedField(SerializedField):
def __init__(self, object_hook=utils.from_dict_hook, object_pairs_hook=None, **kwargs):
super(JsonSerializedField, self).__init__(serialized_type=SerializedType.JSON, object_hook=object_hook,
object_pairs_hook=object_pairs_hook, **kwargs)
@singleton
class BaseDataBase:
def __init__(self):
database_config = DATABASE.copy()
db_name = database_config.pop("name")
self.database_connection = PooledMySQLDatabase(db_name, **database_config)
stat_logger.info('init mysql database on cluster mode successfully')
class DatabaseLock:
def __init__(self, lock_name, timeout=10, db=None):
self.lock_name = lock_name
self.timeout = int(timeout)
self.db = db if db else DB
def lock(self):
# SQL parameters only support %s format placeholders
cursor = self.db.execute_sql("SELECT GET_LOCK(%s, %s)", (self.lock_name, self.timeout))
ret = cursor.fetchone()
if ret[0] == 0:
raise Exception(f'acquire mysql lock {self.lock_name} timeout')
elif ret[0] == 1:
return True
else:
raise Exception(f'failed to acquire lock {self.lock_name}')
def unlock(self):
cursor = self.db.execute_sql("SELECT RELEASE_LOCK(%s)", (self.lock_name,))
ret = cursor.fetchone()
if ret[0] == 0:
raise Exception(f'mysql lock {self.lock_name} was not established by this thread')
elif ret[0] == 1:
return True
else:
raise Exception(f'mysql lock {self.lock_name} does not exist')
def __enter__(self):
if isinstance(self.db, PooledMySQLDatabase):
self.lock()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if isinstance(self.db, PooledMySQLDatabase):
self.unlock()
def __call__(self, func):
@wraps(func)
def magic(*args, **kwargs):
with self:
return func(*args, **kwargs)
return magic
DB = BaseDataBase().database_connection
DB.lock = DatabaseLock
def close_connection():
try:
if DB:
DB.close()
except Exception as e:
LOGGER.exception(e)
class DataBaseModel(BaseModel):
class Meta:
database = DB
@DB.connection_context()
def init_database_tables():
members = inspect.getmembers(sys.modules[__name__], inspect.isclass)
table_objs = []
create_failed_list = []
for name, obj in members:
if obj != DataBaseModel and issubclass(obj, DataBaseModel):
table_objs.append(obj)
LOGGER.info(f"start create table {obj.__name__}")
try:
obj.create_table()
LOGGER.info(f"create table success: {obj.__name__}")
except Exception as e:
LOGGER.exception(e)
create_failed_list.append(obj.__name__)
if create_failed_list:
LOGGER.info(f"create tables failed: {create_failed_list}")
raise Exception(f"create tables failed: {create_failed_list}")
def fill_db_model_object(model_object, human_model_dict):
for k, v in human_model_dict.items():
attr_name = '%s' % k
if hasattr(model_object.__class__, attr_name):
setattr(model_object, attr_name, v)
return model_object
class User(DataBaseModel, UserMixin):
id = CharField(max_length=32, primary_key=True)
access_token = CharField(max_length=255, null=True)
nickname = CharField(max_length=100, null=False, help_text="nicky name")
password = CharField(max_length=255, null=True, help_text="password")
email = CharField(max_length=255, null=False, help_text="email", index=True)
avatar = TextField(null=True, help_text="avatar base64 string")
language = CharField(max_length=32, null=True, help_text="English|Chinese", default="Chinese")
color_schema = CharField(max_length=32, null=True, help_text="Bright|Dark", default="Dark")
last_login_time = DateTimeField(null=True)
is_authenticated = CharField(max_length=1, null=False, default="1")
is_active = CharField(max_length=1, null=False, default="1")
is_anonymous = CharField(max_length=1, null=False, default="0")
login_channel = CharField(null=True, help_text="from which user login")
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1")
is_superuser = BooleanField(null=True, help_text="is root", default=False)
def __str__(self):
return self.email
def get_id(self):
jwt = Serializer(secret_key=SECRET_KEY)
return jwt.dumps(str(self.access_token))
class Meta:
db_table = "user"
class Tenant(DataBaseModel):
id = CharField(max_length=32, primary_key=True)
name = CharField(max_length=100, null=True, help_text="Tenant name")
public_key = CharField(max_length=255, null=True)
llm_id = CharField(max_length=128, null=False, help_text="default llm ID")
embd_id = CharField(max_length=128, null=False, help_text="default embedding model ID")
asr_id = CharField(max_length=128, null=False, help_text="default ASR model ID")
img2txt_id = CharField(max_length=128, null=False, help_text="default image to text model ID")
parser_ids = CharField(max_length=128, null=False, help_text="default image to text model ID")
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1")
class Meta:
db_table = "tenant"
class UserTenant(DataBaseModel):
id = CharField(max_length=32, primary_key=True)
user_id = CharField(max_length=32, null=False)
tenant_id = CharField(max_length=32, null=False)
role = CharField(max_length=32, null=False, help_text="UserTenantRole")
invited_by = CharField(max_length=32, null=False)
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1")
class Meta:
db_table = "user_tenant"
class InvitationCode(DataBaseModel):
id = CharField(max_length=32, primary_key=True)
code = CharField(max_length=32, null=False)
visit_time = DateTimeField(null=True)
user_id = CharField(max_length=32, null=True)
tenant_id = CharField(max_length=32, null=True)
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1")
class Meta:
db_table = "invitation_code"
class LLMFactories(DataBaseModel):
name = CharField(max_length=128, null=False, help_text="LLM factory name", primary_key=True)
logo = TextField(null=True, help_text="llm logo base64")
tags = CharField(max_length=255, null=False, help_text="LLM, Text Embedding, Image2Text, ASR")
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1")
def __str__(self):
return self.name
class Meta:
db_table = "llm_factories"
class LLM(DataBaseModel):
# defautlt LLMs for every users
llm_name = CharField(max_length=128, null=False, help_text="LLM name", primary_key=True)
model_type = CharField(max_length=128, null=False, help_text="LLM, Text Embedding, Image2Text, ASR")
fid = CharField(max_length=128, null=False, help_text="LLM factory id")
tags = CharField(max_length=255, null=False, help_text="LLM, Text Embedding, Image2Text, Chat, 32k...")
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1")
def __str__(self):
return self.llm_name
class Meta:
db_table = "llm"
class TenantLLM(DataBaseModel):
tenant_id = CharField(max_length=32, null=False)
llm_factory = CharField(max_length=128, null=False, help_text="LLM factory name")
model_type = CharField(max_length=128, null=True, help_text="LLM, Text Embedding, Image2Text, ASR")
llm_name = CharField(max_length=128, null=True, help_text="LLM name", default="")
api_key = CharField(max_length=255, null=True, help_text="API KEY")
api_base = CharField(max_length=255, null=True, help_text="API Base")
def __str__(self):
return self.llm_name
class Meta:
db_table = "tenant_llm"
primary_key = CompositeKey('tenant_id', 'llm_factory', 'llm_name')
class Knowledgebase(DataBaseModel):
id = CharField(max_length=32, primary_key=True)
avatar = TextField(null=True, help_text="avatar base64 string")
tenant_id = CharField(max_length=32, null=False)
name = CharField(max_length=128, null=False, help_text="KB name", index=True)
description = TextField(null=True, help_text="KB description")
permission = CharField(max_length=16, null=False, help_text="me|team")
created_by = CharField(max_length=32, null=False)
doc_num = IntegerField(default=0)
token_num = IntegerField(default=0)
chunk_num = IntegerField(default=0)
parser_id = CharField(max_length=32, null=False, help_text="default parser ID")
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1")
def __str__(self):
return self.name
class Meta:
db_table = "knowledgebase"
class Document(DataBaseModel):
id = CharField(max_length=32, primary_key=True)
thumbnail = TextField(null=True, help_text="thumbnail base64 string")
kb_id = CharField(max_length=256, null=False, index=True)
parser_id = CharField(max_length=32, null=False, help_text="default parser ID")
source_type = CharField(max_length=128, null=False, default="local", help_text="where dose this document from")
type = CharField(max_length=32, null=False, help_text="file extension")
created_by = CharField(max_length=32, null=False, help_text="who created it")
name = CharField(max_length=255, null=True, help_text="file name", index=True)
location = CharField(max_length=255, null=True, help_text="where dose it store")
size = IntegerField(default=0)
token_num = IntegerField(default=0)
chunk_num = IntegerField(default=0)
progress = FloatField(default=0)
progress_msg = CharField(max_length=255, null=True, help_text="process message", default="")
process_begin_at = DateTimeField(null=True)
process_duation = FloatField(default=0)
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1")
class Meta:
db_table = "document"
class Dialog(DataBaseModel):
id = CharField(max_length=32, primary_key=True)
tenant_id = CharField(max_length=32, null=False)
name = CharField(max_length=255, null=True, help_text="dialog application name")
description = TextField(null=True, help_text="Dialog description")
icon = CharField(max_length=16, null=False, help_text="dialog icon")
language = CharField(max_length=32, null=True, default="Chinese", help_text="English|Chinese")
llm_id = CharField(max_length=32, null=False, help_text="default llm ID")
llm_setting_type = CharField(max_length=8, null=False, help_text="Creative|Precise|Evenly|Custom",
default="Creative")
llm_setting = JSONField(null=False, default={"temperature": 0.1, "top_p": 0.3, "frequency_penalty": 0.7,
"presence_penalty": 0.4, "max_tokens": 215})
prompt_type = CharField(max_length=16, null=False, default="simple", help_text="simple|advanced")
prompt_config = JSONField(null=False, default={"system": "", "prologue": "您好我是您的助手小樱长得可爱又善良can I help you?",
"parameters": [], "empty_response": "Sorry! 知识库中未找到相关内容!"})
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1")
class Meta:
db_table = "dialog"
class DialogKb(DataBaseModel):
dialog_id = CharField(max_length=32, null=False, index=True)
kb_id = CharField(max_length=32, null=False)
class Meta:
db_table = "dialog_kb"
primary_key = CompositeKey('dialog_id', 'kb_id')
class Conversation(DataBaseModel):
id = CharField(max_length=32, primary_key=True)
dialog_id = CharField(max_length=32, null=False, index=True)
name = CharField(max_length=255, null=True, help_text="converastion name")
message = JSONField(null=True)
class Meta:
db_table = "conversation"
"""
class Job(DataBaseModel):
# multi-party common configuration
f_user_id = CharField(max_length=25, null=True)
f_job_id = CharField(max_length=25, index=True)
f_name = CharField(max_length=500, null=True, default='')
f_description = TextField(null=True, default='')
f_tag = CharField(max_length=50, null=True, default='')
f_dsl = JSONField()
f_runtime_conf = JSONField()
f_runtime_conf_on_party = JSONField()
f_train_runtime_conf = JSONField(null=True)
f_roles = JSONField()
f_initiator_role = CharField(max_length=50)
f_initiator_party_id = CharField(max_length=50)
f_status = CharField(max_length=50)
f_status_code = IntegerField(null=True)
f_user = JSONField()
# this party configuration
f_role = CharField(max_length=50, index=True)
f_party_id = CharField(max_length=10, index=True)
f_is_initiator = BooleanField(null=True, default=False)
f_progress = IntegerField(null=True, default=0)
f_ready_signal = BooleanField(default=False)
f_ready_time = BigIntegerField(null=True)
f_cancel_signal = BooleanField(default=False)
f_cancel_time = BigIntegerField(null=True)
f_rerun_signal = BooleanField(default=False)
f_end_scheduling_updates = IntegerField(null=True, default=0)
f_engine_name = CharField(max_length=50, null=True)
f_engine_type = CharField(max_length=10, null=True)
f_cores = IntegerField(default=0)
f_memory = IntegerField(default=0) # MB
f_remaining_cores = IntegerField(default=0)
f_remaining_memory = IntegerField(default=0) # MB
f_resource_in_use = BooleanField(default=False)
f_apply_resource_time = BigIntegerField(null=True)
f_return_resource_time = BigIntegerField(null=True)
f_inheritance_info = JSONField(null=True)
f_inheritance_status = CharField(max_length=50, null=True)
f_start_time = BigIntegerField(null=True)
f_start_date = DateTimeField(null=True)
f_end_time = BigIntegerField(null=True)
f_end_date = DateTimeField(null=True)
f_elapsed = BigIntegerField(null=True)
class Meta:
db_table = "t_job"
primary_key = CompositeKey('f_job_id', 'f_role', 'f_party_id')
class PipelineComponentMeta(DataBaseModel):
f_model_id = CharField(max_length=100, index=True)
f_model_version = CharField(max_length=100, index=True)
f_role = CharField(max_length=50, index=True)
f_party_id = CharField(max_length=10, index=True)
f_component_name = CharField(max_length=100, index=True)
f_component_module_name = CharField(max_length=100)
f_model_alias = CharField(max_length=100, index=True)
f_model_proto_index = JSONField(null=True)
f_run_parameters = JSONField(null=True)
f_archive_sha256 = CharField(max_length=100, null=True)
f_archive_from_ip = CharField(max_length=100, null=True)
class Meta:
db_table = 't_pipeline_component_meta'
indexes = (
(('f_model_id', 'f_model_version', 'f_role', 'f_party_id', 'f_component_name'), True),
)
"""

157
api/db/db_services.py Normal file
View File

@ -0,0 +1,157 @@
#
# Copyright 2021 The RAG Flow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import abc
import json
import time
from functools import wraps
from shortuuid import ShortUUID
from web_server.versions import get_rag_version
from web_server.errors.error_services import *
from web_server.settings import (
GRPC_PORT, HOST, HTTP_PORT,
RANDOM_INSTANCE_ID, stat_logger,
)
instance_id = ShortUUID().random(length=8) if RANDOM_INSTANCE_ID else f'flow-{HOST}-{HTTP_PORT}'
server_instance = (
f'{HOST}:{GRPC_PORT}',
json.dumps({
'instance_id': instance_id,
'timestamp': round(time.time() * 1000),
'version': get_rag_version() or '',
'host': HOST,
'grpc_port': GRPC_PORT,
'http_port': HTTP_PORT,
}),
)
def check_service_supported(method):
"""Decorator to check if `service_name` is supported.
The attribute `supported_services` MUST be defined in class.
The first and second arguments of `method` MUST be `self` and `service_name`.
:param Callable method: The class method.
:return: The inner wrapper function.
:rtype: Callable
"""
@wraps(method)
def magic(self, service_name, *args, **kwargs):
if service_name not in self.supported_services:
raise ServiceNotSupported(service_name=service_name)
return method(self, service_name, *args, **kwargs)
return magic
class ServicesDB(abc.ABC):
"""Database for storage service urls.
Abstract base class for the real backends.
"""
@property
@abc.abstractmethod
def supported_services(self):
"""The names of supported services.
The returned list SHOULD contain `ragflow` (model download) and `servings` (RAG-Serving).
:return: The service names.
:rtype: list
"""
pass
@abc.abstractmethod
def _get_serving(self):
pass
def get_serving(self):
try:
return self._get_serving()
except ServicesError as e:
stat_logger.exception(e)
return []
@abc.abstractmethod
def _insert(self, service_name, service_url, value=''):
pass
@check_service_supported
def insert(self, service_name, service_url, value=''):
"""Insert a service url to database.
:param str service_name: The service name.
:param str service_url: The service url.
:return: None
"""
try:
self._insert(service_name, service_url, value)
except ServicesError as e:
stat_logger.exception(e)
@abc.abstractmethod
def _delete(self, service_name, service_url):
pass
@check_service_supported
def delete(self, service_name, service_url):
"""Delete a service url from database.
:param str service_name: The service name.
:param str service_url: The service url.
:return: None
"""
try:
self._delete(service_name, service_url)
except ServicesError as e:
stat_logger.exception(e)
def register_flow(self):
"""Call `self.insert` for insert the flow server address to databae.
:return: None
"""
self.insert('flow-server', *server_instance)
def unregister_flow(self):
"""Call `self.delete` for delete the flow server address from databae.
:return: None
"""
self.delete('flow-server', server_instance[0])
@abc.abstractmethod
def _get_urls(self, service_name, with_values=False):
pass
@check_service_supported
def get_urls(self, service_name, with_values=False):
"""Query service urls from database. The urls may belong to other nodes.
Currently, only `ragflow` (model download) urls and `servings` (RAG-Serving) urls are supported.
`ragflow` is a url containing scheme, host, port and path,
while `servings` only contains host and port.
:param str service_name: The service name.
:return: The service urls.
:rtype: list
"""
try:
return self._get_urls(service_name, with_values)
except ServicesError as e:
stat_logger.exception(e)
return []

131
api/db/db_utils.py Normal file
View File

@ -0,0 +1,131 @@
#
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import operator
from functools import reduce
from typing import Dict, Type, Union
from web_server.utils import current_timestamp, timestamp_to_date
from web_server.db.db_models import DB, DataBaseModel
from web_server.db.runtime_config import RuntimeConfig
from web_server.utils.log_utils import getLogger
from enum import Enum
LOGGER = getLogger()
@DB.connection_context()
def bulk_insert_into_db(model, data_source, replace_on_conflict=False):
DB.create_tables([model])
current_time = current_timestamp()
current_date = timestamp_to_date(current_time)
for data in data_source:
if 'f_create_time' not in data:
data['f_create_time'] = current_time
data['f_create_date'] = timestamp_to_date(data['f_create_time'])
data['f_update_time'] = current_time
data['f_update_date'] = current_date
preserve = tuple(data_source[0].keys() - {'f_create_time', 'f_create_date'})
batch_size = 50 if RuntimeConfig.USE_LOCAL_DATABASE else 1000
for i in range(0, len(data_source), batch_size):
with DB.atomic():
query = model.insert_many(data_source[i:i + batch_size])
if replace_on_conflict:
query = query.on_conflict(preserve=preserve)
query.execute()
def get_dynamic_db_model(base, job_id):
return type(base.model(table_index=get_dynamic_tracking_table_index(job_id=job_id)))
def get_dynamic_tracking_table_index(job_id):
return job_id[:8]
def fill_db_model_object(model_object, human_model_dict):
for k, v in human_model_dict.items():
attr_name = 'f_%s' % k
if hasattr(model_object.__class__, attr_name):
setattr(model_object, attr_name, v)
return model_object
# https://docs.peewee-orm.com/en/latest/peewee/query_operators.html
supported_operators = {
'==': operator.eq,
'<': operator.lt,
'<=': operator.le,
'>': operator.gt,
'>=': operator.ge,
'!=': operator.ne,
'<<': operator.lshift,
'>>': operator.rshift,
'%': operator.mod,
'**': operator.pow,
'^': operator.xor,
'~': operator.inv,
}
def query_dict2expression(model: Type[DataBaseModel], query: Dict[str, Union[bool, int, str, list, tuple]]):
expression = []
for field, value in query.items():
if not isinstance(value, (list, tuple)):
value = ('==', value)
op, *val = value
field = getattr(model, f'f_{field}')
value = supported_operators[op](field, val[0]) if op in supported_operators else getattr(field, op)(*val)
expression.append(value)
return reduce(operator.iand, expression)
def query_db(model: Type[DataBaseModel], limit: int = 0, offset: int = 0,
query: dict = None, order_by: Union[str, list, tuple] = None):
data = model.select()
if query:
data = data.where(query_dict2expression(model, query))
count = data.count()
if not order_by:
order_by = 'create_time'
if not isinstance(order_by, (list, tuple)):
order_by = (order_by, 'asc')
order_by, order = order_by
order_by = getattr(model, f'f_{order_by}')
order_by = getattr(order_by, order)()
data = data.order_by(order_by)
if limit > 0:
data = data.limit(limit)
if offset > 0:
data = data.offset(offset)
return list(data), count
class StatusEnum(Enum):
# 样本可用状态
VALID = "1"
IN_VALID = "0"

141
api/db/init_data.py Normal file
View File

@ -0,0 +1,141 @@
#
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import time
import uuid
from web_server.db import LLMType
from web_server.db.db_models import init_database_tables as init_web_db
from web_server.db.services import UserService
from web_server.db.services.llm_service import LLMFactoriesService, LLMService
def init_superuser():
user_info = {
"id": uuid.uuid1().hex,
"password": "admin",
"nickname": "admin",
"is_superuser": True,
"email": "kai.hu@infiniflow.org",
"creator": "system",
"status": "1",
}
UserService.save(**user_info)
def init_llm_factory():
factory_infos = [{
"name": "OpenAI",
"logo": "",
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
"status": "1",
},{
"name": "通义千问",
"logo": "",
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
"status": "1",
},{
"name": "智普AI",
"logo": "",
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
"status": "1",
},{
"name": "文心一言",
"logo": "",
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
"status": "1",
},
]
llm_infos = [{
"fid": factory_infos[0]["name"],
"llm_name": "gpt-3.5-turbo",
"tags": "LLM,CHAT,4K",
"model_type": LLMType.CHAT.value
},{
"fid": factory_infos[0]["name"],
"llm_name": "gpt-3.5-turbo-16k-0613",
"tags": "LLM,CHAT,16k",
"model_type": LLMType.CHAT.value
},{
"fid": factory_infos[0]["name"],
"llm_name": "text-embedding-ada-002",
"tags": "TEXT EMBEDDING,8K",
"model_type": LLMType.EMBEDDING.value
},{
"fid": factory_infos[0]["name"],
"llm_name": "whisper-1",
"tags": "SPEECH2TEXT",
"model_type": LLMType.SPEECH2TEXT.value
},{
"fid": factory_infos[0]["name"],
"llm_name": "gpt-4",
"tags": "LLM,CHAT,8K",
"model_type": LLMType.CHAT.value
},{
"fid": factory_infos[0]["name"],
"llm_name": "gpt-4-32k",
"tags": "LLM,CHAT,32K",
"model_type": LLMType.CHAT.value
},{
"fid": factory_infos[0]["name"],
"llm_name": "gpt-4-vision-preview",
"tags": "LLM,CHAT,IMAGE2TEXT",
"model_type": LLMType.IMAGE2TEXT.value
},{
"fid": factory_infos[1]["name"],
"llm_name": "qwen-turbo",
"tags": "LLM,CHAT,8K",
"model_type": LLMType.CHAT.value
},{
"fid": factory_infos[1]["name"],
"llm_name": "qwen-plus",
"tags": "LLM,CHAT,32K",
"model_type": LLMType.CHAT.value
},{
"fid": factory_infos[1]["name"],
"llm_name": "text-embedding-v2",
"tags": "TEXT EMBEDDING,2K",
"model_type": LLMType.EMBEDDING.value
},{
"fid": factory_infos[1]["name"],
"llm_name": "paraformer-realtime-8k-v1",
"tags": "SPEECH2TEXT",
"model_type": LLMType.SPEECH2TEXT.value
},{
"fid": factory_infos[1]["name"],
"llm_name": "qwen_vl_chat_v1",
"tags": "LLM,CHAT,IMAGE2TEXT",
"model_type": LLMType.IMAGE2TEXT.value
},
]
for info in factory_infos:
LLMFactoriesService.save(**info)
for info in llm_infos:
LLMService.save(**info)
def init_web_data():
start_time = time.time()
if not UserService.get_all().count():
init_superuser()
if not LLMService.get_all().count():init_llm_factory()
print("init web data success:{}".format(time.time() - start_time))
if __name__ == '__main__':
init_web_db()
init_web_data()

21
api/db/operatioins.py Normal file
View File

@ -0,0 +1,21 @@
#
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import operator
import time
import typing
from web_server.utils.log_utils import sql_logger
import peewee

View File

@ -0,0 +1,27 @@
#
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
class ReloadConfigBase:
@classmethod
def get_all(cls):
configs = {}
for k, v in cls.__dict__.items():
if not callable(getattr(cls, k)) and not k.startswith("__") and not k.startswith("_"):
configs[k] = v
return configs
@classmethod
def get(cls, config_name):
return getattr(cls, config_name) if hasattr(cls, config_name) else None

54
api/db/runtime_config.py Normal file
View File

@ -0,0 +1,54 @@
#
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from web_server.versions import get_versions
from .reload_config_base import ReloadConfigBase
class RuntimeConfig(ReloadConfigBase):
DEBUG = None
WORK_MODE = None
HTTP_PORT = None
JOB_SERVER_HOST = None
JOB_SERVER_VIP = None
ENV = dict()
SERVICE_DB = None
LOAD_CONFIG_MANAGER = False
@classmethod
def init_config(cls, **kwargs):
for k, v in kwargs.items():
if hasattr(cls, k):
setattr(cls, k, v)
@classmethod
def init_env(cls):
cls.ENV.update(get_versions())
@classmethod
def load_config_manager(cls):
cls.LOAD_CONFIG_MANAGER = True
@classmethod
def get_env(cls, key):
return cls.ENV.get(key, None)
@classmethod
def get_all_env(cls):
return cls.ENV
@classmethod
def set_service_db(cls, service_db):
cls.SERVICE_DB = service_db

View File

@ -0,0 +1,38 @@
#
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import pathlib
import re
from .user_service import UserService
def duplicate_name(query_func, **kwargs):
fnm = kwargs["name"]
objs = query_func(**kwargs)
if not objs: return fnm
ext = pathlib.Path(fnm).suffix #.jpg
nm = re.sub(r"%s$"%ext, "", fnm)
r = re.search(r"\([0-9]+\)$", nm)
c = 0
if r:
c = int(r.group(1))
nm = re.sub(r"\([0-9]+\)$", "", nm)
c += 1
nm = f"{nm}({c})"
if ext: nm += f"{ext}"
kwargs["name"] = nm
return duplicate_name(query_func, **kwargs)

View File

@ -0,0 +1,153 @@
#
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from datetime import datetime
import peewee
from web_server.db.db_models import DB
from web_server.utils import datetime_format
class CommonService:
model = None
@classmethod
@DB.connection_context()
def query(cls, cols=None, reverse=None, order_by=None, **kwargs):
return cls.model.query(cols=cols, reverse=reverse, order_by=order_by, **kwargs)
@classmethod
@DB.connection_context()
def get_all(cls, cols=None, reverse=None, order_by=None):
if cols:
query_records = cls.model.select(*cols)
else:
query_records = cls.model.select()
if reverse is not None:
if not order_by or not hasattr(cls, order_by):
order_by = "create_time"
if reverse is True:
query_records = query_records.order_by(cls.model.getter_by(order_by).desc())
elif reverse is False:
query_records = query_records.order_by(cls.model.getter_by(order_by).asc())
return query_records
@classmethod
@DB.connection_context()
def get(cls, **kwargs):
return cls.model.get(**kwargs)
@classmethod
@DB.connection_context()
def get_or_none(cls, **kwargs):
try:
return cls.model.get(**kwargs)
except peewee.DoesNotExist:
return None
@classmethod
@DB.connection_context()
def save(cls, **kwargs):
#if "id" not in kwargs:
# kwargs["id"] = get_uuid()
sample_obj = cls.model(**kwargs).save(force_insert=True)
return sample_obj
@classmethod
@DB.connection_context()
def insert_many(cls, data_list, batch_size=100):
with DB.atomic():
for i in range(0, len(data_list), batch_size):
cls.model.insert_many(data_list[i:i + batch_size]).execute()
@classmethod
@DB.connection_context()
def update_many_by_id(cls, data_list):
cur = datetime_format(datetime.now())
with DB.atomic():
for data in data_list:
data["update_time"] = cur
cls.model.update(data).where(cls.model.id == data["id"]).execute()
@classmethod
@DB.connection_context()
def update_by_id(cls, pid, data):
data["update_time"] = datetime_format(datetime.now())
num = cls.model.update(data).where(cls.model.id == pid).execute()
return num
@classmethod
@DB.connection_context()
def get_by_id(cls, pid):
try:
obj = cls.model.query(id=pid)[0]
return True, obj
except Exception as e:
return False, None
@classmethod
@DB.connection_context()
def get_by_ids(cls, pids, cols=None):
if cols:
objs = cls.model.select(*cols)
else:
objs = cls.model.select()
return objs.where(cls.model.id.in_(pids))
@classmethod
@DB.connection_context()
def delete_by_id(cls, pid):
return cls.model.delete().where(cls.model.id == pid).execute()
@classmethod
@DB.connection_context()
def filter_delete(cls, filters):
with DB.atomic():
num = cls.model.delete().where(*filters).execute()
return num
@classmethod
@DB.connection_context()
def filter_update(cls, filters, update_data):
with DB.atomic():
cls.model.update(update_data).where(*filters).execute()
@staticmethod
def cut_list(tar_list, n):
length = len(tar_list)
arr = range(length)
result = [tuple(tar_list[x:(x + n)]) for x in arr[::n]]
return result
@classmethod
@DB.connection_context()
def filter_scope_list(cls, in_key, in_filters_list, filters=None, cols=None):
in_filters_tuple_list = cls.cut_list(in_filters_list, 20)
if not filters:
filters = []
res_list = []
if cols:
for i in in_filters_tuple_list:
query_records = cls.model.select(*cols).where(getattr(cls.model, in_key).in_(i), *filters)
if query_records:
res_list.extend([query_record for query_record in query_records])
else:
for i in in_filters_tuple_list:
query_records = cls.model.select().where(getattr(cls.model, in_key).in_(i), *filters)
if query_records:
res_list.extend([query_record for query_record in query_records])
return res_list

View File

@ -0,0 +1,35 @@
#
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import peewee
from werkzeug.security import generate_password_hash, check_password_hash
from web_server.db.db_models import DB, UserTenant
from web_server.db.db_models import Dialog, Conversation, DialogKb
from web_server.db.services.common_service import CommonService
from web_server.utils import get_uuid, get_format_time
from web_server.db.db_utils import StatusEnum
class DialogService(CommonService):
model = Dialog
class ConversationService(CommonService):
model = Conversation
class DialogKbService(CommonService):
model = DialogKb

View File

@ -0,0 +1,96 @@
#
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from peewee import Expression
from web_server.db import TenantPermission, FileType
from web_server.db.db_models import DB, Knowledgebase, Tenant
from web_server.db.db_models import Document
from web_server.db.services.common_service import CommonService
from web_server.db.services.kb_service import KnowledgebaseService
from web_server.db.db_utils import StatusEnum
class DocumentService(CommonService):
model = Document
@classmethod
@DB.connection_context()
def get_by_kb_id(cls, kb_id, page_number, items_per_page,
orderby, desc, keywords):
if keywords:
docs = cls.model.select().where(
cls.model.kb_id == kb_id,
cls.model.name.like(f"%%{keywords}%%"))
else:
docs = cls.model.select().where(cls.model.kb_id == kb_id)
if desc:
docs = docs.order_by(cls.model.getter_by(orderby).desc())
else:
docs = docs.order_by(cls.model.getter_by(orderby).asc())
docs = docs.paginate(page_number, items_per_page)
return list(docs.dicts())
@classmethod
@DB.connection_context()
def insert(cls, doc):
if not cls.save(**doc):
raise RuntimeError("Database error (Document)!")
e, doc = cls.get_by_id(doc["id"])
if not e:
raise RuntimeError("Database error (Document retrieval)!")
e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
if not KnowledgebaseService.update_by_id(
kb.id, {"doc_num": kb.doc_num + 1}):
raise RuntimeError("Database error (Knowledgebase)!")
return doc
@classmethod
@DB.connection_context()
def get_newly_uploaded(cls, tm, mod, comm, items_per_page=64):
fields = [cls.model.id, cls.model.kb_id, cls.model.parser_id, cls.model.name, cls.model.location, cls.model.size, Knowledgebase.tenant_id, Tenant.embd_id, Tenant.img2txt_id, cls.model.update_time]
docs = cls.model.select(*fields) \
.join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) \
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))\
.where(
cls.model.status == StatusEnum.VALID.value,
~(cls.model.type == FileType.VIRTUAL.value),
cls.model.progress == 0,
cls.model.update_time >= tm,
(Expression(cls.model.create_time, "%%", comm) == mod))\
.order_by(cls.model.update_time.asc())\
.paginate(1, items_per_page)
return list(docs.dicts())
@classmethod
@DB.connection_context()
def increment_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation):
num = cls.model.update(token_num=cls.model.token_num + token_num,
chunk_num=cls.model.chunk_num + chunk_num,
process_duation=cls.model.process_duation+duation).where(
cls.model.id == doc_id).execute()
if num == 0:raise LookupError("Document not found which is supposed to be there")
num = Knowledgebase.update(token_num=Knowledgebase.token_num+token_num, chunk_num=Knowledgebase.chunk_num+chunk_num).where(Knowledgebase.id==kb_id).execute()
return num
@classmethod
@DB.connection_context()
def get_tenant_id(cls, doc_id):
docs = cls.model.select(Knowledgebase.tenant_id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(cls.model.id == doc_id, Knowledgebase.status==StatusEnum.VALID.value)
docs = docs.dicts()
if not docs:return
return docs[0]["tenant_id"]

View File

@ -0,0 +1,70 @@
#
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import peewee
from werkzeug.security import generate_password_hash, check_password_hash
from web_server.db import TenantPermission
from web_server.db.db_models import DB, UserTenant, Tenant
from web_server.db.db_models import Knowledgebase
from web_server.db.services.common_service import CommonService
from web_server.utils import get_uuid, get_format_time
from web_server.db.db_utils import StatusEnum
class KnowledgebaseService(CommonService):
model = Knowledgebase
@classmethod
@DB.connection_context()
def get_by_tenant_ids(cls, joined_tenant_ids, user_id,
page_number, items_per_page, orderby, desc):
kbs = cls.model.select().where(
((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission ==
TenantPermission.TEAM.value)) | (cls.model.tenant_id == user_id))
& (cls.model.status == StatusEnum.VALID.value)
)
if desc:
kbs = kbs.order_by(cls.model.getter_by(orderby).desc())
else:
kbs = kbs.order_by(cls.model.getter_by(orderby).asc())
kbs = kbs.paginate(page_number, items_per_page)
return list(kbs.dicts())
@classmethod
@DB.connection_context()
def get_detail(cls, kb_id):
fields = [
cls.model.id,
Tenant.embd_id,
cls.model.avatar,
cls.model.name,
cls.model.description,
cls.model.permission,
cls.model.doc_num,
cls.model.token_num,
cls.model.chunk_num,
cls.model.parser_id]
kbs = cls.model.select(*fields).join(Tenant, on=((Tenant.id == cls.model.tenant_id)&(Tenant.status== StatusEnum.VALID.value))).where(
(cls.model.id == kb_id),
(cls.model.status == StatusEnum.VALID.value)
)
if not kbs:
return
d = kbs[0].to_dict()
d["embd_id"] = kbs[0].tenant.embd_id
return d

View File

@ -0,0 +1,31 @@
#
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import peewee
from werkzeug.security import generate_password_hash, check_password_hash
from web_server.db.db_models import DB, UserTenant
from web_server.db.db_models import Knowledgebase, Document
from web_server.db.services.common_service import CommonService
from web_server.utils import get_uuid, get_format_time
from web_server.db.db_utils import StatusEnum
class KnowledgebaseService(CommonService):
model = Knowledgebase
class DocumentService(CommonService):
model = Document

View File

@ -0,0 +1,76 @@
#
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import peewee
from werkzeug.security import generate_password_hash, check_password_hash
from rag.llm import EmbeddingModel, CvModel
from web_server.db import LLMType
from web_server.db.db_models import DB, UserTenant
from web_server.db.db_models import LLMFactories, LLM, TenantLLM
from web_server.db.services.common_service import CommonService
from web_server.db.db_utils import StatusEnum
class LLMFactoriesService(CommonService):
model = LLMFactories
class LLMService(CommonService):
model = LLM
class TenantLLMService(CommonService):
model = TenantLLM
@classmethod
@DB.connection_context()
def get_api_key(cls, tenant_id, model_type):
objs = cls.query(tenant_id=tenant_id, model_type=model_type)
if objs and len(objs)>0 and objs[0].llm_name:
return objs[0]
fields = [LLM.llm_name, cls.model.llm_factory, cls.model.api_key]
objs = cls.model.select(*fields).join(LLM, on=(LLM.fid == cls.model.llm_factory)).where(
(cls.model.tenant_id == tenant_id),
(cls.model.model_type == model_type),
(LLM.status == StatusEnum.VALID)
)
if not objs:return
return objs[0]
@classmethod
@DB.connection_context()
def get_my_llms(cls, tenant_id):
fields = [cls.model.llm_factory, LLMFactories.logo, LLMFactories.tags, cls.model.model_type, cls.model.llm_name]
objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory==LLMFactories.name)).where(cls.model.tenant_id==tenant_id).dicts()
return list(objs)
@classmethod
@DB.connection_context()
def model_instance(cls, tenant_id, llm_type):
model_config = cls.get_api_key(tenant_id, model_type=LLMType.EMBEDDING)
if not model_config:
model_config = {"llm_factory": "local", "api_key": "", "llm_name": ""}
else:
model_config = model_config[0].to_dict()
if llm_type == LLMType.EMBEDDING:
if model_config["llm_factory"] not in EmbeddingModel: return
return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"])
if llm_type == LLMType.IMAGE2TEXT:
if model_config["llm_factory"] not in CvModel: return
return CvModel[model_config.llm_factory](model_config["api_key"], model_config["llm_name"])

View File

@ -0,0 +1,105 @@
#
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import peewee
from werkzeug.security import generate_password_hash, check_password_hash
from web_server.db import UserTenantRole
from web_server.db.db_models import DB, UserTenant
from web_server.db.db_models import User, Tenant
from web_server.db.services.common_service import CommonService
from web_server.utils import get_uuid, get_format_time
from web_server.db.db_utils import StatusEnum
class UserService(CommonService):
model = User
@classmethod
@DB.connection_context()
def filter_by_id(cls, user_id):
try:
user = cls.model.select().where(cls.model.id == user_id).get()
return user
except peewee.DoesNotExist:
return None
@classmethod
@DB.connection_context()
def query_user(cls, email, password):
user = cls.model.select().where((cls.model.email == email),
(cls.model.status == StatusEnum.VALID.value)).first()
if user and check_password_hash(str(user.password), password):
return user
else:
return None
@classmethod
@DB.connection_context()
def save(cls, **kwargs):
if "id" not in kwargs:
kwargs["id"] = get_uuid()
if "password" in kwargs:
kwargs["password"] = generate_password_hash(str(kwargs["password"]))
obj = cls.model(**kwargs).save(force_insert=True)
return obj
@classmethod
@DB.connection_context()
def delete_user(cls, user_ids, update_user_dict):
with DB.atomic():
cls.model.update({"status": 0}).where(cls.model.id.in_(user_ids)).execute()
@classmethod
@DB.connection_context()
def update_user(cls, user_id, user_dict):
date_time = get_format_time()
with DB.atomic():
if user_dict:
user_dict["update_time"] = date_time
cls.model.update(user_dict).where(cls.model.id == user_id).execute()
class TenantService(CommonService):
model = Tenant
@classmethod
@DB.connection_context()
def get_by_user_id(cls, user_id):
fields = [cls.model.id.alias("tenant_id"), cls.model.name, cls.model.llm_id, cls.model.embd_id, cls.model.asr_id, cls.model.img2txt_id, cls.model.parser_ids, UserTenant.role]
return list(cls.model.select(*fields)\
.join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id==user_id) & (UserTenant.status == StatusEnum.VALID.value)))\
.where(cls.model.status == StatusEnum.VALID.value).dicts())
@classmethod
@DB.connection_context()
def get_joined_tenants_by_user_id(cls, user_id):
fields = [cls.model.id.alias("tenant_id"), cls.model.name, cls.model.llm_id, cls.model.embd_id, cls.model.asr_id, cls.model.img2txt_id, UserTenant.role]
return list(cls.model.select(*fields)\
.join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id==user_id) & (UserTenant.status == StatusEnum.VALID.value) & (UserTenant.role==UserTenantRole.NORMAL.value)))\
.where(cls.model.status == StatusEnum.VALID.value).dicts())
class UserTenantService(CommonService):
model = UserTenant
@classmethod
@DB.connection_context()
def save(cls, **kwargs):
if "id" not in kwargs:
kwargs["id"] = get_uuid()
obj = cls.model(**kwargs).save(force_insert=True)
return obj

10
api/errors/__init__.py Normal file
View File

@ -0,0 +1,10 @@
from .general_error import *
class RagFlowError(Exception):
message = 'Unknown Rag Flow Error'
def __init__(self, message=None, *args, **kwargs):
message = str(message) if message is not None else self.message
message = message.format(*args, **kwargs)
super().__init__(message)

View File

@ -0,0 +1,13 @@
from web_server.errors import RagFlowError
__all__ = ['ServicesError', 'ServiceNotSupported', 'ZooKeeperNotConfigured',
'MissingZooKeeperUsernameOrPassword', 'ZooKeeperBackendError']
class ServicesError(RagFlowError):
message = 'Unknown services error'
class ServiceNotSupported(ServicesError):
message = 'The service {service_name} is not supported'

View File

@ -0,0 +1,21 @@
#
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
class ParameterError(Exception):
pass
class PassError(Exception):
pass

Binary file not shown.

57
api/hook/__init__.py Normal file
View File

@ -0,0 +1,57 @@
import importlib
from web_server.hook.common.parameters import SignatureParameters, AuthenticationParameters, \
SignatureReturn, AuthenticationReturn, PermissionReturn, ClientAuthenticationReturn, ClientAuthenticationParameters
from web_server.settings import HOOK_MODULE, stat_logger,RetCode
class HookManager:
SITE_SIGNATURE = []
SITE_AUTHENTICATION = []
CLIENT_AUTHENTICATION = []
PERMISSION_CHECK = []
@staticmethod
def init():
if HOOK_MODULE is not None:
for modules in HOOK_MODULE.values():
for module in modules.split(";"):
try:
importlib.import_module(module)
except Exception as e:
stat_logger.exception(e)
@staticmethod
def register_site_signature_hook(func):
HookManager.SITE_SIGNATURE.append(func)
@staticmethod
def register_site_authentication_hook(func):
HookManager.SITE_AUTHENTICATION.append(func)
@staticmethod
def register_client_authentication_hook(func):
HookManager.CLIENT_AUTHENTICATION.append(func)
@staticmethod
def register_permission_check_hook(func):
HookManager.PERMISSION_CHECK.append(func)
@staticmethod
def client_authentication(parm: ClientAuthenticationParameters) -> ClientAuthenticationReturn:
if HookManager.CLIENT_AUTHENTICATION:
return HookManager.CLIENT_AUTHENTICATION[0](parm)
return ClientAuthenticationReturn()
@staticmethod
def site_signature(parm: SignatureParameters) -> SignatureReturn:
if HookManager.SITE_SIGNATURE:
return HookManager.SITE_SIGNATURE[0](parm)
return SignatureReturn()
@staticmethod
def site_authentication(parm: AuthenticationParameters) -> AuthenticationReturn:
if HookManager.SITE_AUTHENTICATION:
return HookManager.SITE_AUTHENTICATION[0](parm)
return AuthenticationReturn()

View File

@ -0,0 +1,29 @@
import requests
from web_server.db.service_registry import ServiceRegistry
from web_server.settings import RegistryServiceName
from web_server.hook import HookManager
from web_server.hook.common.parameters import ClientAuthenticationParameters, ClientAuthenticationReturn
from web_server.settings import HOOK_SERVER_NAME
@HookManager.register_client_authentication_hook
def authentication(parm: ClientAuthenticationParameters) -> ClientAuthenticationReturn:
service_list = ServiceRegistry.load_service(
server_name=HOOK_SERVER_NAME,
service_name=RegistryServiceName.CLIENT_AUTHENTICATION.value
)
if not service_list:
raise Exception(f"client authentication error: no found server"
f" {HOOK_SERVER_NAME} service client_authentication")
service = service_list[0]
response = getattr(requests, service.f_method.lower(), None)(
url=service.f_url,
json=parm.to_dict()
)
if response.status_code != 200:
raise Exception(
f"client authentication error: request authentication url failed, status code {response.status_code}")
elif response.json().get("code") != 0:
return ClientAuthenticationReturn(code=response.json().get("code"), message=response.json().get("msg"))
return ClientAuthenticationReturn()

View File

@ -0,0 +1,25 @@
import requests
from web_server.db.service_registry import ServiceRegistry
from web_server.settings import RegistryServiceName
from web_server.hook import HookManager
from web_server.hook.common.parameters import PermissionCheckParameters, PermissionReturn
from web_server.settings import HOOK_SERVER_NAME
@HookManager.register_permission_check_hook
def permission(parm: PermissionCheckParameters) -> PermissionReturn:
service_list = ServiceRegistry.load_service(server_name=HOOK_SERVER_NAME, service_name=RegistryServiceName.PERMISSION_CHECK.value)
if not service_list:
raise Exception(f"permission check error: no found server {HOOK_SERVER_NAME} service permission")
service = service_list[0]
response = getattr(requests, service.f_method.lower(), None)(
url=service.f_url,
json=parm.to_dict()
)
if response.status_code != 200:
raise Exception(
f"permission check error: request permission url failed, status code {response.status_code}")
elif response.json().get("code") != 0:
return PermissionReturn(code=response.json().get("code"), message=response.json().get("msg"))
return PermissionReturn()

View File

@ -0,0 +1,49 @@
import requests
from web_server.db.service_registry import ServiceRegistry
from web_server.settings import RegistryServiceName
from web_server.hook import HookManager
from web_server.hook.common.parameters import SignatureParameters, AuthenticationParameters, AuthenticationReturn,\
SignatureReturn
from web_server.settings import HOOK_SERVER_NAME, PARTY_ID
@HookManager.register_site_signature_hook
def signature(parm: SignatureParameters) -> SignatureReturn:
service_list = ServiceRegistry.load_service(server_name=HOOK_SERVER_NAME, service_name=RegistryServiceName.SIGNATURE.value)
if not service_list:
raise Exception(f"signature error: no found server {HOOK_SERVER_NAME} service signature")
service = service_list[0]
response = getattr(requests, service.f_method.lower(), None)(
url=service.f_url,
json=parm.to_dict()
)
if response.status_code == 200:
if response.json().get("code") == 0:
return SignatureReturn(site_signature=response.json().get("data"))
else:
raise Exception(f"signature error: request signature url failed, result: {response.json()}")
else:
raise Exception(f"signature error: request signature url failed, status code {response.status_code}")
@HookManager.register_site_authentication_hook
def authentication(parm: AuthenticationParameters) -> AuthenticationReturn:
if not parm.src_party_id or str(parm.src_party_id) == "0":
parm.src_party_id = PARTY_ID
service_list = ServiceRegistry.load_service(server_name=HOOK_SERVER_NAME,
service_name=RegistryServiceName.SITE_AUTHENTICATION.value)
if not service_list:
raise Exception(
f"site authentication error: no found server {HOOK_SERVER_NAME} service site_authentication")
service = service_list[0]
response = getattr(requests, service.f_method.lower(), None)(
url=service.f_url,
json=parm.to_dict()
)
if response.status_code != 200:
raise Exception(
f"site authentication error: request site_authentication url failed, status code {response.status_code}")
elif response.json().get("code") != 0:
return AuthenticationReturn(code=response.json().get("code"), message=response.json().get("msg"))
return AuthenticationReturn()

View File

@ -0,0 +1,56 @@
from web_server.settings import RetCode
class ParametersBase:
def to_dict(self):
d = {}
for k, v in self.__dict__.items():
d[k] = v
return d
class ClientAuthenticationParameters(ParametersBase):
def __init__(self, full_path, headers, form, data, json):
self.full_path = full_path
self.headers = headers
self.form = form
self.data = data
self.json = json
class ClientAuthenticationReturn(ParametersBase):
def __init__(self, code=RetCode.SUCCESS, message="success"):
self.code = code
self.message = message
class SignatureParameters(ParametersBase):
def __init__(self, party_id, body):
self.party_id = party_id
self.body = body
class SignatureReturn(ParametersBase):
def __init__(self, code=RetCode.SUCCESS, site_signature=None):
self.code = code
self.site_signature = site_signature
class AuthenticationParameters(ParametersBase):
def __init__(self, site_signature, body):
self.site_signature = site_signature
self.body = body
class AuthenticationReturn(ParametersBase):
def __init__(self, code=RetCode.SUCCESS, message="success"):
self.code = code
self.message = message
class PermissionReturn(ParametersBase):
def __init__(self, code=RetCode.SUCCESS, message="success"):
self.code = code
self.message = message

80
api/ragflow_server.py Normal file
View File

@ -0,0 +1,80 @@
#
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# init env. must be the first import
import logging
import os
import signal
import sys
import traceback
from werkzeug.serving import run_simple
from web_server.apps import app
from web_server.db.runtime_config import RuntimeConfig
from web_server.hook import HookManager
from web_server.settings import (
HOST, HTTP_PORT, access_logger, database_logger, stat_logger,
)
from web_server import utils
from web_server.db.db_models import init_database_tables as init_web_db
from web_server.db.init_data import init_web_data
from web_server.versions import get_versions
if __name__ == '__main__':
stat_logger.info(
f'project base: {utils.file_utils.get_project_base_directory()}'
)
# init db
init_web_db()
init_web_data()
# init runtime config
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--version', default=False, help="rag flow version", action='store_true')
parser.add_argument('--debug', default=False, help="debug mode", action='store_true')
args = parser.parse_args()
if args.version:
print(get_versions())
sys.exit(0)
RuntimeConfig.DEBUG = args.debug
if RuntimeConfig.DEBUG:
stat_logger.info("run on debug mode")
RuntimeConfig.init_env()
RuntimeConfig.init_config(JOB_SERVER_HOST=HOST, HTTP_PORT=HTTP_PORT)
HookManager.init()
peewee_logger = logging.getLogger('peewee')
peewee_logger.propagate = False
# rag_arch.common.log.ROpenHandler
peewee_logger.addHandler(database_logger.handlers[0])
peewee_logger.setLevel(database_logger.level)
# start http server
try:
stat_logger.info("RAG Flow http server start...")
werkzeug_logger = logging.getLogger("werkzeug")
for h in access_logger.handlers:
werkzeug_logger.addHandler(h)
run_simple(hostname=HOST, port=HTTP_PORT, application=app, threaded=True, use_reloader=RuntimeConfig.DEBUG, use_debugger=RuntimeConfig.DEBUG)
except Exception:
traceback.print_exc()
os.kill(os.getpid(), signal.SIGKILL)

156
api/settings.py Normal file
View File

@ -0,0 +1,156 @@
#
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
from enum import IntEnum, Enum
from web_server.utils import get_base_config,decrypt_database_config
from web_server.utils.file_utils import get_project_base_directory
from web_server.utils.log_utils import LoggerFactory, getLogger
# Server
API_VERSION = "v1"
RAG_FLOW_SERVICE_NAME = "ragflow"
SERVER_MODULE = "rag_flow_server.py"
TEMP_DIRECTORY = os.path.join(get_project_base_directory(), "temp")
RAG_FLOW_CONF_PATH = os.path.join(get_project_base_directory(), "conf")
SUBPROCESS_STD_LOG_NAME = "std.log"
ERROR_REPORT = True
ERROR_REPORT_WITH_PATH = False
MAX_TIMESTAMP_INTERVAL = 60
SESSION_VALID_PERIOD = 7 * 24 * 60 * 60 * 1000
REQUEST_TRY_TIMES = 3
REQUEST_WAIT_SEC = 2
REQUEST_MAX_WAIT_SEC = 300
USE_REGISTRY = get_base_config("use_registry")
LLM = get_base_config("llm", {})
CHAT_MDL = LLM.get("chat_model", "gpt-3.5-turbo")
EMBEDDING_MDL = LLM.get("embedding_model", "text-embedding-ada-002")
ASR_MDL = LLM.get("asr_model", "whisper-1")
PARSERS = LLM.get("parsers", "General,Resume,Laws,Product Instructions,Books,Paper,Q&A,Programming Code,Power Point,Research Report")
IMAGE2TEXT_MDL = LLM.get("image2text_model", "gpt-4-vision-preview")
# distribution
DEPENDENT_DISTRIBUTION = get_base_config("dependent_distribution", False)
RAG_FLOW_UPDATE_CHECK = False
HOST = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("host", "127.0.0.1")
HTTP_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("http_port")
SECRET_KEY = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("secret_key", "infiniflow")
TOKEN_EXPIRE_IN = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("token_expires_in", 3600)
NGINX_HOST = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("nginx", {}).get("host") or HOST
NGINX_HTTP_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("nginx", {}).get("http_port") or HTTP_PORT
RANDOM_INSTANCE_ID = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("random_instance_id", False)
PROXY = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("proxy")
PROXY_PROTOCOL = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("protocol")
DATABASE = decrypt_database_config()
# Logger
LoggerFactory.set_directory(os.path.join(get_project_base_directory(), "logs", "web_server"))
# {CRITICAL: 50, FATAL:50, ERROR:40, WARNING:30, WARN:30, INFO:20, DEBUG:10, NOTSET:0}
LoggerFactory.LEVEL = 10
stat_logger = getLogger("stat")
access_logger = getLogger("access")
database_logger = getLogger("database")
# Switch
# upload
UPLOAD_DATA_FROM_CLIENT = True
# authentication
AUTHENTICATION_CONF = get_base_config("authentication", {})
# client
CLIENT_AUTHENTICATION = AUTHENTICATION_CONF.get("client", {}).get("switch", False)
HTTP_APP_KEY = AUTHENTICATION_CONF.get("client", {}).get("http_app_key")
GITHUB_OAUTH = get_base_config("oauth", {}).get("github")
WECHAT_OAUTH = get_base_config("oauth", {}).get("wechat")
# site
SITE_AUTHENTICATION = AUTHENTICATION_CONF.get("site", {}).get("switch", False)
# permission
PERMISSION_CONF = get_base_config("permission", {})
PERMISSION_SWITCH = PERMISSION_CONF.get("switch")
COMPONENT_PERMISSION = PERMISSION_CONF.get("component")
DATASET_PERMISSION = PERMISSION_CONF.get("dataset")
HOOK_MODULE = get_base_config("hook_module")
HOOK_SERVER_NAME = get_base_config("hook_server_name")
ENABLE_MODEL_STORE = get_base_config('enable_model_store', False)
# authentication
USE_AUTHENTICATION = False
USE_DATA_AUTHENTICATION = False
AUTOMATIC_AUTHORIZATION_OUTPUT_DATA = True
USE_DEFAULT_TIMEOUT = False
AUTHENTICATION_DEFAULT_TIMEOUT = 30 * 24 * 60 * 60 # s
PRIVILEGE_COMMAND_WHITELIST = []
CHECK_NODES_IDENTITY = False
class CustomEnum(Enum):
@classmethod
def valid(cls, value):
try:
cls(value)
return True
except:
return False
@classmethod
def values(cls):
return [member.value for member in cls.__members__.values()]
@classmethod
def names(cls):
return [member.name for member in cls.__members__.values()]
class PythonDependenceName(CustomEnum):
Rag_Source_Code = "python"
Python_Env = "miniconda"
class ModelStorage(CustomEnum):
REDIS = "redis"
MYSQL = "mysql"
class RetCode(IntEnum, CustomEnum):
SUCCESS = 0
NOT_EFFECTIVE = 10
EXCEPTION_ERROR = 100
ARGUMENT_ERROR = 101
DATA_ERROR = 102
OPERATING_ERROR = 103
CONNECTION_ERROR = 105
RUNNING = 106
PERMISSION_ERROR = 108
AUTHENTICATION_ERROR = 109
SERVER_ERROR = 500

321
api/utils/__init__.py Normal file
View File

@ -0,0 +1,321 @@
#
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import base64
import datetime
import io
import json
import os
import pickle
import socket
import time
import uuid
import requests
from enum import Enum, IntEnum
import importlib
from Cryptodome.PublicKey import RSA
from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5
from filelock import FileLock
from . import file_utils
SERVICE_CONF = "service_conf.yaml"
def conf_realpath(conf_name):
conf_path = f"conf/{conf_name}"
return os.path.join(file_utils.get_project_base_directory(), conf_path)
def get_base_config(key, default=None, conf_name=SERVICE_CONF) -> dict:
local_config = {}
local_path = conf_realpath(f'local.{conf_name}')
if default is None:
default = os.environ.get(key.upper())
if os.path.exists(local_path):
local_config = file_utils.load_yaml_conf(local_path)
if not isinstance(local_config, dict):
raise ValueError(f'Invalid config file: "{local_path}".')
if key is not None and key in local_config:
return local_config[key]
config_path = conf_realpath(conf_name)
config = file_utils.load_yaml_conf(config_path)
if not isinstance(config, dict):
raise ValueError(f'Invalid config file: "{config_path}".')
config.update(local_config)
return config.get(key, default) if key is not None else config
use_deserialize_safe_module = get_base_config('use_deserialize_safe_module', False)
class CoordinationCommunicationProtocol(object):
HTTP = "http"
GRPC = "grpc"
class BaseType:
def to_dict(self):
return dict([(k.lstrip("_"), v) for k, v in self.__dict__.items()])
def to_dict_with_type(self):
def _dict(obj):
module = None
if issubclass(obj.__class__, BaseType):
data = {}
for attr, v in obj.__dict__.items():
k = attr.lstrip("_")
data[k] = _dict(v)
module = obj.__module__
elif isinstance(obj, (list, tuple)):
data = []
for i, vv in enumerate(obj):
data.append(_dict(vv))
elif isinstance(obj, dict):
data = {}
for _k, vv in obj.items():
data[_k] = _dict(vv)
else:
data = obj
return {"type": obj.__class__.__name__, "data": data, "module": module}
return _dict(self)
class CustomJSONEncoder(json.JSONEncoder):
def __init__(self, **kwargs):
self._with_type = kwargs.pop("with_type", False)
super().__init__(**kwargs)
def default(self, obj):
if isinstance(obj, datetime.datetime):
return obj.strftime('%Y-%m-%d %H:%M:%S')
elif isinstance(obj, datetime.date):
return obj.strftime('%Y-%m-%d')
elif isinstance(obj, datetime.timedelta):
return str(obj)
elif issubclass(type(obj), Enum) or issubclass(type(obj), IntEnum):
return obj.value
elif isinstance(obj, set):
return list(obj)
elif issubclass(type(obj), BaseType):
if not self._with_type:
return obj.to_dict()
else:
return obj.to_dict_with_type()
elif isinstance(obj, type):
return obj.__name__
else:
return json.JSONEncoder.default(self, obj)
def rag_uuid():
return uuid.uuid1().hex
def string_to_bytes(string):
return string if isinstance(string, bytes) else string.encode(encoding="utf-8")
def bytes_to_string(byte):
return byte.decode(encoding="utf-8")
def json_dumps(src, byte=False, indent=None, with_type=False):
dest = json.dumps(src, indent=indent, cls=CustomJSONEncoder, with_type=with_type)
if byte:
dest = string_to_bytes(dest)
return dest
def json_loads(src, object_hook=None, object_pairs_hook=None):
if isinstance(src, bytes):
src = bytes_to_string(src)
return json.loads(src, object_hook=object_hook, object_pairs_hook=object_pairs_hook)
def current_timestamp():
return int(time.time() * 1000)
def timestamp_to_date(timestamp, format_string="%Y-%m-%d %H:%M:%S"):
if not timestamp:
timestamp = time.time()
timestamp = int(timestamp) / 1000
time_array = time.localtime(timestamp)
str_date = time.strftime(format_string, time_array)
return str_date
def date_string_to_timestamp(time_str, format_string="%Y-%m-%d %H:%M:%S"):
time_array = time.strptime(time_str, format_string)
time_stamp = int(time.mktime(time_array) * 1000)
return time_stamp
def serialize_b64(src, to_str=False):
dest = base64.b64encode(pickle.dumps(src))
if not to_str:
return dest
else:
return bytes_to_string(dest)
def deserialize_b64(src):
src = base64.b64decode(string_to_bytes(src) if isinstance(src, str) else src)
if use_deserialize_safe_module:
return restricted_loads(src)
return pickle.loads(src)
safe_module = {
'numpy',
'rag_flow'
}
class RestrictedUnpickler(pickle.Unpickler):
def find_class(self, module, name):
import importlib
if module.split('.')[0] in safe_module:
_module = importlib.import_module(module)
return getattr(_module, name)
# Forbid everything else.
raise pickle.UnpicklingError("global '%s.%s' is forbidden" %
(module, name))
def restricted_loads(src):
"""Helper function analogous to pickle.loads()."""
return RestrictedUnpickler(io.BytesIO(src)).load()
def get_lan_ip():
if os.name != "nt":
import fcntl
import struct
def get_interface_ip(ifname):
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
return socket.inet_ntoa(
fcntl.ioctl(s.fileno(), 0x8915, struct.pack('256s', string_to_bytes(ifname[:15])))[20:24])
ip = socket.gethostbyname(socket.getfqdn())
if ip.startswith("127.") and os.name != "nt":
interfaces = [
"bond1",
"eth0",
"eth1",
"eth2",
"wlan0",
"wlan1",
"wifi0",
"ath0",
"ath1",
"ppp0",
]
for ifname in interfaces:
try:
ip = get_interface_ip(ifname)
break
except IOError as e:
pass
return ip or ''
def from_dict_hook(in_dict: dict):
if "type" in in_dict and "data" in in_dict:
if in_dict["module"] is None:
return in_dict["data"]
else:
return getattr(importlib.import_module(in_dict["module"]), in_dict["type"])(**in_dict["data"])
else:
return in_dict
def decrypt_database_password(password):
encrypt_password = get_base_config("encrypt_password", False)
encrypt_module = get_base_config("encrypt_module", False)
private_key = get_base_config("private_key", None)
if not password or not encrypt_password:
return password
if not private_key:
raise ValueError("No private key")
module_fun = encrypt_module.split("#")
pwdecrypt_fun = getattr(importlib.import_module(module_fun[0]), module_fun[1])
return pwdecrypt_fun(private_key, password)
def decrypt_database_config(database=None, passwd_key="passwd", name="database"):
if not database:
database = get_base_config(name, {})
database[passwd_key] = decrypt_database_password(database[passwd_key])
return database
def update_config(key, value, conf_name=SERVICE_CONF):
conf_path = conf_realpath(conf_name=conf_name)
if not os.path.isabs(conf_path):
conf_path = os.path.join(file_utils.get_project_base_directory(), conf_path)
with FileLock(os.path.join(os.path.dirname(conf_path), ".lock")):
config = file_utils.load_yaml_conf(conf_path=conf_path) or {}
config[key] = value
file_utils.rewrite_yaml_conf(conf_path=conf_path, config=config)
def get_uuid():
return uuid.uuid1().hex
def datetime_format(date_time: datetime.datetime) -> datetime.datetime:
return datetime.datetime(date_time.year, date_time.month, date_time.day, date_time.hour, date_time.minute, date_time.second)
def get_format_time() -> datetime.datetime:
return datetime_format(datetime.datetime.now())
def str2date(date_time: str):
return datetime.datetime.strptime(date_time, '%Y-%m-%d')
def elapsed2time(elapsed):
seconds = elapsed / 1000
minuter, second = divmod(seconds, 60)
hour, minuter = divmod(minuter, 60)
return '%02d:%02d:%02d' % (hour, minuter, second)
def decrypt(line):
file_path = os.path.join(file_utils.get_project_base_directory(), "conf", "private.pem")
rsa_key = RSA.importKey(open(file_path).read(), "Welcome")
cipher = Cipher_pkcs1_v1_5.new(rsa_key)
return cipher.decrypt(base64.b64decode(line), "Fail to decrypt password!").decode('utf-8')
def download_img(url):
if not url: return ""
response = requests.get(url)
return "data:" + \
response.headers.get('Content-Type', 'image/jpg') + ";" + \
"base64," + base64.b64encode(response.content).decode("utf-8")

212
api/utils/api_utils.py Normal file
View File

@ -0,0 +1,212 @@
#
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import random
import time
from functools import wraps
from io import BytesIO
from flask import (
Response, jsonify, send_file,make_response,
request as flask_request,
)
from werkzeug.http import HTTP_STATUS_CODES
from web_server.utils import json_dumps
from web_server.versions import get_rag_version
from web_server.settings import RetCode
from web_server.settings import (
REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC,
stat_logger,CLIENT_AUTHENTICATION, HTTP_APP_KEY, SECRET_KEY
)
import requests
import functools
from web_server.utils import CustomJSONEncoder
from uuid import uuid1
from base64 import b64encode
from hmac import HMAC
from urllib.parse import quote, urlencode
requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSONEncoder)
def request(**kwargs):
sess = requests.Session()
stream = kwargs.pop('stream', sess.stream)
timeout = kwargs.pop('timeout', None)
kwargs['headers'] = {k.replace('_', '-').upper(): v for k, v in kwargs.get('headers', {}).items()}
prepped = requests.Request(**kwargs).prepare()
if CLIENT_AUTHENTICATION and HTTP_APP_KEY and SECRET_KEY:
timestamp = str(round(time() * 1000))
nonce = str(uuid1())
signature = b64encode(HMAC(SECRET_KEY.encode('ascii'), b'\n'.join([
timestamp.encode('ascii'),
nonce.encode('ascii'),
HTTP_APP_KEY.encode('ascii'),
prepped.path_url.encode('ascii'),
prepped.body if kwargs.get('json') else b'',
urlencode(sorted(kwargs['data'].items()), quote_via=quote, safe='-._~').encode('ascii')
if kwargs.get('data') and isinstance(kwargs['data'], dict) else b'',
]), 'sha1').digest()).decode('ascii')
prepped.headers.update({
'TIMESTAMP': timestamp,
'NONCE': nonce,
'APP-KEY': HTTP_APP_KEY,
'SIGNATURE': signature,
})
return sess.send(prepped, stream=stream, timeout=timeout)
rag_version = get_rag_version() or ''
def get_exponential_backoff_interval(retries, full_jitter=False):
"""Calculate the exponential backoff wait time."""
# Will be zero if factor equals 0
countdown = min(REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC * (2 ** retries))
# Full jitter according to
# https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/
if full_jitter:
countdown = random.randrange(countdown + 1)
# Adjust according to maximum wait time and account for negative values.
return max(0, countdown)
def get_json_result(retcode=RetCode.SUCCESS, retmsg='success', data=None, job_id=None, meta=None):
import re
result_dict = {
"retcode": retcode,
"retmsg":retmsg,
# "retmsg": re.sub(r"rag", "seceum", retmsg, flags=re.IGNORECASE),
"data": data,
"jobId": job_id,
"meta": meta,
}
response = {}
for key, value in result_dict.items():
if value is None and key != "retcode":
continue
else:
response[key] = value
return jsonify(response)
def get_data_error_result(retcode=RetCode.DATA_ERROR, retmsg='Sorry! Data missing!'):
import re
result_dict = {"retcode": retcode, "retmsg": re.sub(r"rag", "seceum", retmsg, flags=re.IGNORECASE)}
response = {}
for key, value in result_dict.items():
if value is None and key != "retcode":
continue
else:
response[key] = value
return jsonify(response)
def server_error_response(e):
stat_logger.exception(e)
try:
if e.code==401:
return get_json_result(retcode=401, retmsg=repr(e))
except:
pass
if len(e.args) > 1:
return get_json_result(retcode=RetCode.EXCEPTION_ERROR, retmsg=repr(e.args[0]), data=e.args[1])
return get_json_result(retcode=RetCode.EXCEPTION_ERROR, retmsg=repr(e))
def error_response(response_code, retmsg=None):
if retmsg is None:
retmsg = HTTP_STATUS_CODES.get(response_code, 'Unknown Error')
return Response(json.dumps({
'retmsg': retmsg,
'retcode': response_code,
}), status=response_code, mimetype='application/json')
def validate_request(*args, **kwargs):
def wrapper(func):
@wraps(func)
def decorated_function(*_args, **_kwargs):
input_arguments = flask_request.json or flask_request.form.to_dict()
no_arguments = []
error_arguments = []
for arg in args:
if arg not in input_arguments:
no_arguments.append(arg)
for k, v in kwargs.items():
config_value = input_arguments.get(k, None)
if config_value is None:
no_arguments.append(k)
elif isinstance(v, (tuple, list)):
if config_value not in v:
error_arguments.append((k, set(v)))
elif config_value != v:
error_arguments.append((k, v))
if no_arguments or error_arguments:
error_string = ""
if no_arguments:
error_string += "required argument are missing: {}; ".format(",".join(no_arguments))
if error_arguments:
error_string += "required argument values: {}".format(",".join(["{}={}".format(a[0], a[1]) for a in error_arguments]))
return get_json_result(retcode=RetCode.ARGUMENT_ERROR, retmsg=error_string)
return func(*_args, **_kwargs)
return decorated_function
return wrapper
def is_localhost(ip):
return ip in {'127.0.0.1', '::1', '[::1]', 'localhost'}
def send_file_in_mem(data, filename):
if not isinstance(data, (str, bytes)):
data = json_dumps(data)
if isinstance(data, str):
data = data.encode('utf-8')
f = BytesIO()
f.write(data)
f.seek(0)
return send_file(f, as_attachment=True, attachment_filename=filename)
def get_json_result(retcode=RetCode.SUCCESS, retmsg='success', data=None):
response = {"retcode": retcode, "retmsg": retmsg, "data": data}
return jsonify(response)
def cors_reponse(retcode=RetCode.SUCCESS, retmsg='success', data=None, auth=None):
result_dict = {"retcode": retcode, "retmsg": retmsg, "data": data}
response_dict = {}
for key, value in result_dict.items():
if value is None and key != "retcode":
continue
else:
response_dict[key] = value
response = make_response(jsonify(response_dict))
if auth:
response.headers["Authorization"] = auth
response.headers["Access-Control-Allow-Origin"] = "*"
response.headers["Access-Control-Allow-Method"] = "*"
response.headers["Access-Control-Allow-Headers"] = "*"
response.headers["Access-Control-Allow-Headers"] = "*"
response.headers["Access-Control-Expose-Headers"] = "Authorization"
return response

153
api/utils/file_utils.py Normal file
View File

@ -0,0 +1,153 @@
#
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import os
import re
from cachetools import LRUCache, cached
from ruamel.yaml import YAML
from web_server.db import FileType
PROJECT_BASE = os.getenv("RAG_PROJECT_BASE") or os.getenv("RAG_DEPLOY_BASE")
RAG_BASE = os.getenv("RAG_BASE")
def get_project_base_directory(*args):
global PROJECT_BASE
if PROJECT_BASE is None:
PROJECT_BASE = os.path.abspath(
os.path.join(
os.path.dirname(os.path.realpath(__file__)),
os.pardir,
os.pardir,
)
)
if args:
return os.path.join(PROJECT_BASE, *args)
return PROJECT_BASE
def get_rag_directory(*args):
global RAG_BASE
if RAG_BASE is None:
RAG_BASE = os.path.abspath(
os.path.join(
os.path.dirname(os.path.realpath(__file__)),
os.pardir,
os.pardir,
os.pardir,
)
)
if args:
return os.path.join(RAG_BASE, *args)
return RAG_BASE
def get_rag_python_directory(*args):
return get_rag_directory("python", *args)
@cached(cache=LRUCache(maxsize=10))
def load_json_conf(conf_path):
if os.path.isabs(conf_path):
json_conf_path = conf_path
else:
json_conf_path = os.path.join(get_project_base_directory(), conf_path)
try:
with open(json_conf_path) as f:
return json.load(f)
except BaseException:
raise EnvironmentError(
"loading json file config from '{}' failed!".format(json_conf_path)
)
def dump_json_conf(config_data, conf_path):
if os.path.isabs(conf_path):
json_conf_path = conf_path
else:
json_conf_path = os.path.join(get_project_base_directory(), conf_path)
try:
with open(json_conf_path, "w") as f:
json.dump(config_data, f, indent=4)
except BaseException:
raise EnvironmentError(
"loading json file config from '{}' failed!".format(json_conf_path)
)
def load_json_conf_real_time(conf_path):
if os.path.isabs(conf_path):
json_conf_path = conf_path
else:
json_conf_path = os.path.join(get_project_base_directory(), conf_path)
try:
with open(json_conf_path) as f:
return json.load(f)
except BaseException:
raise EnvironmentError(
"loading json file config from '{}' failed!".format(json_conf_path)
)
def load_yaml_conf(conf_path):
if not os.path.isabs(conf_path):
conf_path = os.path.join(get_project_base_directory(), conf_path)
try:
with open(conf_path) as f:
yaml = YAML(typ='safe', pure=True)
return yaml.load(f)
except Exception as e:
raise EnvironmentError(
"loading yaml file config from {} failed:".format(conf_path), e
)
def rewrite_yaml_conf(conf_path, config):
if not os.path.isabs(conf_path):
conf_path = os.path.join(get_project_base_directory(), conf_path)
try:
with open(conf_path, "w") as f:
yaml = YAML(typ="safe")
yaml.dump(config, f)
except Exception as e:
raise EnvironmentError(
"rewrite yaml file config {} failed:".format(conf_path), e
)
def rewrite_json_file(filepath, json_data):
with open(filepath, "w") as f:
json.dump(json_data, f, indent=4, separators=(",", ": "))
f.close()
def filename_type(filename):
filename = filename.lower()
if re.match(r".*\.pdf$", filename):
return FileType.PDF.value
if re.match(r".*\.(doc|ppt|yml|xml|htm|json|csv|txt|ini|xsl|wps|rtf|hlp|pages|numbers|key|md)$", filename):
return FileType.DOC.value
if re.match(r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)$", filename):
return FileType.AURAL.value
if re.match(r".*\.(jpg|jpeg|png|tif|gif|pcx|tga|exif|fpx|svg|psd|cdr|pcd|dxf|ufo|eps|ai|raw|WMF|webp|avif|apng|icon|ico|mpg|mpeg|avi|rm|rmvb|mov|wmv|asf|dat|asx|wvx|mpe|mpa|mp4)$", filename):
return FileType.VISUAL

294
api/utils/log_utils.py Normal file
View File

@ -0,0 +1,294 @@
#
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import typing
import traceback
import logging
import inspect
from logging.handlers import TimedRotatingFileHandler
from threading import RLock
from web_server.utils import file_utils
class LoggerFactory(object):
TYPE = "FILE"
LOG_FORMAT = "[%(levelname)s] [%(asctime)s] [jobId] [%(process)s:%(thread)s] - [%(module)s.%(funcName)s] [line:%(lineno)d]: %(message)s"
LEVEL = logging.DEBUG
logger_dict = {}
global_handler_dict = {}
LOG_DIR = None
PARENT_LOG_DIR = None
log_share = True
append_to_parent_log = None
lock = RLock()
# CRITICAL = 50
# FATAL = CRITICAL
# ERROR = 40
# WARNING = 30
# WARN = WARNING
# INFO = 20
# DEBUG = 10
# NOTSET = 0
levels = (10, 20, 30, 40)
schedule_logger_dict = {}
@staticmethod
def set_directory(directory=None, parent_log_dir=None, append_to_parent_log=None, force=False):
if parent_log_dir:
LoggerFactory.PARENT_LOG_DIR = parent_log_dir
if append_to_parent_log:
LoggerFactory.append_to_parent_log = append_to_parent_log
with LoggerFactory.lock:
if not directory:
directory = file_utils.get_project_base_directory("logs")
if not LoggerFactory.LOG_DIR or force:
LoggerFactory.LOG_DIR = directory
if LoggerFactory.log_share:
oldmask = os.umask(000)
os.makedirs(LoggerFactory.LOG_DIR, exist_ok=True)
os.umask(oldmask)
else:
os.makedirs(LoggerFactory.LOG_DIR, exist_ok=True)
for loggerName, ghandler in LoggerFactory.global_handler_dict.items():
for className, (logger, handler) in LoggerFactory.logger_dict.items():
logger.removeHandler(ghandler)
ghandler.close()
LoggerFactory.global_handler_dict = {}
for className, (logger, handler) in LoggerFactory.logger_dict.items():
logger.removeHandler(handler)
_handler = None
if handler:
handler.close()
if className != "default":
_handler = LoggerFactory.get_handler(className)
logger.addHandler(_handler)
LoggerFactory.assemble_global_handler(logger)
LoggerFactory.logger_dict[className] = logger, _handler
@staticmethod
def new_logger(name):
logger = logging.getLogger(name)
logger.propagate = False
logger.setLevel(LoggerFactory.LEVEL)
return logger
@staticmethod
def get_logger(class_name=None):
with LoggerFactory.lock:
if class_name in LoggerFactory.logger_dict.keys():
logger, handler = LoggerFactory.logger_dict[class_name]
if not logger:
logger, handler = LoggerFactory.init_logger(class_name)
else:
logger, handler = LoggerFactory.init_logger(class_name)
return logger
@staticmethod
def get_global_handler(logger_name, level=None, log_dir=None):
if not LoggerFactory.LOG_DIR:
return logging.StreamHandler()
if log_dir:
logger_name_key = logger_name + "_" + log_dir
else:
logger_name_key = logger_name + "_" + LoggerFactory.LOG_DIR
# if loggerName not in LoggerFactory.globalHandlerDict:
if logger_name_key not in LoggerFactory.global_handler_dict:
with LoggerFactory.lock:
if logger_name_key not in LoggerFactory.global_handler_dict:
handler = LoggerFactory.get_handler(logger_name, level, log_dir)
LoggerFactory.global_handler_dict[logger_name_key] = handler
return LoggerFactory.global_handler_dict[logger_name_key]
@staticmethod
def get_handler(class_name, level=None, log_dir=None, log_type=None, job_id=None):
if not log_type:
if not LoggerFactory.LOG_DIR or not class_name:
return logging.StreamHandler()
# return Diy_StreamHandler()
if not log_dir:
log_file = os.path.join(LoggerFactory.LOG_DIR, "{}.log".format(class_name))
else:
log_file = os.path.join(log_dir, "{}.log".format(class_name))
else:
log_file = os.path.join(log_dir, "rag_flow_{}.log".format(
log_type) if level == LoggerFactory.LEVEL else 'rag_flow_{}_error.log'.format(log_type))
os.makedirs(os.path.dirname(log_file), exist_ok=True)
if LoggerFactory.log_share:
handler = ROpenHandler(log_file,
when='D',
interval=1,
backupCount=14,
delay=True)
else:
handler = TimedRotatingFileHandler(log_file,
when='D',
interval=1,
backupCount=14,
delay=True)
if level:
handler.level = level
return handler
@staticmethod
def init_logger(class_name):
with LoggerFactory.lock:
logger = LoggerFactory.new_logger(class_name)
handler = None
if class_name:
handler = LoggerFactory.get_handler(class_name)
logger.addHandler(handler)
LoggerFactory.logger_dict[class_name] = logger, handler
else:
LoggerFactory.logger_dict["default"] = logger, handler
LoggerFactory.assemble_global_handler(logger)
return logger, handler
@staticmethod
def assemble_global_handler(logger):
if LoggerFactory.LOG_DIR:
for level in LoggerFactory.levels:
if level >= LoggerFactory.LEVEL:
level_logger_name = logging._levelToName[level]
logger.addHandler(LoggerFactory.get_global_handler(level_logger_name, level))
if LoggerFactory.append_to_parent_log and LoggerFactory.PARENT_LOG_DIR:
for level in LoggerFactory.levels:
if level >= LoggerFactory.LEVEL:
level_logger_name = logging._levelToName[level]
logger.addHandler(
LoggerFactory.get_global_handler(level_logger_name, level, LoggerFactory.PARENT_LOG_DIR))
def setDirectory(directory=None):
LoggerFactory.set_directory(directory)
def setLevel(level):
LoggerFactory.LEVEL = level
def getLogger(className=None, useLevelFile=False):
if className is None:
frame = inspect.stack()[1]
module = inspect.getmodule(frame[0])
className = 'stat'
return LoggerFactory.get_logger(className)
def exception_to_trace_string(ex):
return "".join(traceback.TracebackException.from_exception(ex).format())
class ROpenHandler(TimedRotatingFileHandler):
def _open(self):
prevumask = os.umask(000)
rtv = TimedRotatingFileHandler._open(self)
os.umask(prevumask)
return rtv
def sql_logger(job_id='', log_type='sql'):
key = job_id + log_type
if key in LoggerFactory.schedule_logger_dict.keys():
return LoggerFactory.schedule_logger_dict[key]
return get_job_logger(job_id=job_id, log_type=log_type)
def ready_log(msg, job=None, task=None, role=None, party_id=None, detail=None):
prefix, suffix = base_msg(job, task, role, party_id, detail)
return f"{prefix}{msg} ready{suffix}"
def start_log(msg, job=None, task=None, role=None, party_id=None, detail=None):
prefix, suffix = base_msg(job, task, role, party_id, detail)
return f"{prefix}start to {msg}{suffix}"
def successful_log(msg, job=None, task=None, role=None, party_id=None, detail=None):
prefix, suffix = base_msg(job, task, role, party_id, detail)
return f"{prefix}{msg} successfully{suffix}"
def warning_log(msg, job=None, task=None, role=None, party_id=None, detail=None):
prefix, suffix = base_msg(job, task, role, party_id, detail)
return f"{prefix}{msg} is not effective{suffix}"
def failed_log(msg, job=None, task=None, role=None, party_id=None, detail=None):
prefix, suffix = base_msg(job, task, role, party_id, detail)
return f"{prefix}failed to {msg}{suffix}"
def base_msg(job=None, task=None, role: str = None, party_id: typing.Union[str, int] = None, detail=None):
if detail:
detail_msg = f" detail: \n{detail}"
else:
detail_msg = ""
if task is not None:
return f"task {task.f_task_id} {task.f_task_version} ", f" on {task.f_role} {task.f_party_id}{detail_msg}"
elif job is not None:
return "", f" on {job.f_role} {job.f_party_id}{detail_msg}"
elif role and party_id:
return "", f" on {role} {party_id}{detail_msg}"
else:
return "", f"{detail_msg}"
def exception_to_trace_string(ex):
return "".join(traceback.TracebackException.from_exception(ex).format())
def get_logger_base_dir():
job_log_dir = file_utils.get_rag_flow_directory('logs')
return job_log_dir
def get_job_logger(job_id, log_type):
rag_flow_log_dir = file_utils.get_rag_flow_directory('logs', 'rag_flow')
job_log_dir = file_utils.get_rag_flow_directory('logs', job_id)
if not job_id:
log_dirs = [rag_flow_log_dir]
else:
if log_type == 'audit':
log_dirs = [job_log_dir, rag_flow_log_dir]
else:
log_dirs = [job_log_dir]
if LoggerFactory.log_share:
oldmask = os.umask(000)
os.makedirs(job_log_dir, exist_ok=True)
os.makedirs(rag_flow_log_dir, exist_ok=True)
os.umask(oldmask)
else:
os.makedirs(job_log_dir, exist_ok=True)
os.makedirs(rag_flow_log_dir, exist_ok=True)
logger = LoggerFactory.new_logger(f"{job_id}_{log_type}")
for job_log_dir in log_dirs:
handler = LoggerFactory.get_handler(class_name=None, level=LoggerFactory.LEVEL,
log_dir=job_log_dir, log_type=log_type, job_id=job_id)
error_handler = LoggerFactory.get_handler(class_name=None, level=logging.ERROR, log_dir=job_log_dir, log_type=log_type, job_id=job_id)
logger.addHandler(handler)
logger.addHandler(error_handler)
with LoggerFactory.lock:
LoggerFactory.schedule_logger_dict[job_id + log_type] = logger
return logger

18
api/utils/t_crypt.py Normal file
View File

@ -0,0 +1,18 @@
import base64, os, sys
from Cryptodome.PublicKey import RSA
from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5
from web_server.utils import decrypt, file_utils
def crypt(line):
file_path = os.path.join(file_utils.get_project_base_directory(), "conf", "public.pem")
rsa_key = RSA.importKey(open(file_path).read())
cipher = Cipher_pkcs1_v1_5.new(rsa_key)
return base64.b64encode(cipher.encrypt(line.encode('utf-8'))).decode("utf-8")
if __name__ == "__main__":
pswd = crypt(sys.argv[1])
print(pswd)
print(decrypt(pswd))

30
api/versions.py Normal file
View File

@ -0,0 +1,30 @@
#
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import dotenv
import typing
from web_server.utils.file_utils import get_project_base_directory
def get_versions() -> typing.Mapping[str, typing.Any]:
return dotenv.dotenv_values(
dotenv_path=os.path.join(get_project_base_directory(), "rag.env")
)
def get_rag_version() -> typing.Optional[str]:
return get_versions().get("RAG")