Move settings initialization after module init phase (#3438)

### What problem does this PR solve?

1. Module init won't connect database any more.
2. Config in settings need to be used with settings.CONFIG_NAME

### Type of change

- [x] Refactoring

Signed-off-by: jinhai <haijin.chn@gmail.com>
This commit is contained in:
Jin Hai
2024-11-15 17:30:56 +08:00
committed by GitHub
parent ac033b62cf
commit 1e90a1bf36
33 changed files with 452 additions and 411 deletions

View File

@ -32,7 +32,7 @@ from api.db.services.file_service import FileService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.task_service import queue_tasks, TaskService
from api.db.services.user_service import UserTenantService
from api.settings import RetCode, retrievaler
from api import settings
from api.utils import get_uuid, current_timestamp, datetime_format
from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result, validate_request, \
generate_confirmation_token
@ -141,7 +141,7 @@ def set_conversation():
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
req = request.json
try:
if objs[0].source == "agent":
@ -183,7 +183,7 @@ def completion():
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
req = request.json
e, conv = API4ConversationService.get_by_id(req["conversation_id"])
if not e:
@ -290,8 +290,8 @@ def completion():
API4ConversationService.append_message(conv.id, conv.to_dict())
rename_field(result)
return get_json_result(data=result)
#******************For dialog******************
# ******************For dialog******************
conv.message.append(msg[-1])
e, dia = DialogService.get_by_id(conv.dialog_id)
if not e:
@ -326,7 +326,7 @@ def completion():
resp.headers.add_header("X-Accel-Buffering", "no")
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
return resp
answer = None
for ans in chat(dia, msg, **req):
answer = ans
@ -347,8 +347,8 @@ def get(conversation_id):
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
try:
e, conv = API4ConversationService.get_by_id(conversation_id)
if not e:
@ -357,8 +357,8 @@ def get(conversation_id):
conv = conv.to_dict()
if token != APIToken.query(dialog_id=conv['dialog_id'])[0].token:
return get_json_result(data=False, message='Token is not valid for this conversation_id!"',
code=RetCode.AUTHENTICATION_ERROR)
code=settings.RetCode.AUTHENTICATION_ERROR)
for referenct_i in conv['reference']:
if referenct_i is None or len(referenct_i) == 0:
continue
@ -378,7 +378,7 @@ def upload():
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
kb_name = request.form.get("kb_name").strip()
tenant_id = objs[0].tenant_id
@ -394,12 +394,12 @@ def upload():
if 'file' not in request.files:
return get_json_result(
data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR)
data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR)
file = request.files['file']
if file.filename == '':
return get_json_result(
data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR)
data=False, message='No file selected!', code=settings.RetCode.ARGUMENT_ERROR)
root_folder = FileService.get_root_folder(tenant_id)
pf_id = root_folder["id"]
@ -490,17 +490,17 @@ def upload_parse():
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
if 'file' not in request.files:
return get_json_result(
data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR)
data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR)
file_objs = request.files.getlist('file')
for file_obj in file_objs:
if file_obj.filename == '':
return get_json_result(
data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR)
data=False, message='No file selected!', code=settings.RetCode.ARGUMENT_ERROR)
doc_ids = doc_upload_and_parse(request.form.get("conversation_id"), file_objs, objs[0].tenant_id)
return get_json_result(data=doc_ids)
@ -513,7 +513,7 @@ def list_chunks():
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
req = request.json
@ -531,7 +531,7 @@ def list_chunks():
)
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
res = retrievaler.chunk_list(doc_id, tenant_id, kb_ids)
res = settings.retrievaler.chunk_list(doc_id, tenant_id, kb_ids)
res = [
{
"content": res_item["content_with_weight"],
@ -553,7 +553,7 @@ def list_kb_docs():
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
req = request.json
tenant_id = objs[0].tenant_id
@ -585,6 +585,7 @@ def list_kb_docs():
except Exception as e:
return server_error_response(e)
@manager.route('/document/infos', methods=['POST'])
@validate_request("doc_ids")
def docinfos():
@ -592,7 +593,7 @@ def docinfos():
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
req = request.json
doc_ids = req["doc_ids"]
docs = DocumentService.get_by_ids(doc_ids)
@ -606,7 +607,7 @@ def document_rm():
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
tenant_id = objs[0].tenant_id
req = request.json
@ -653,7 +654,7 @@ def document_rm():
errors += str(e)
if errors:
return get_json_result(data=False, message=errors, code=RetCode.SERVER_ERROR)
return get_json_result(data=False, message=errors, code=settings.RetCode.SERVER_ERROR)
return get_json_result(data=True)
@ -668,7 +669,7 @@ def completion_faq():
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
e, conv = API4ConversationService.get_by_id(req["conversation_id"])
if not e:
@ -805,10 +806,10 @@ def retrieval():
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
req = request.json
kb_ids = req.get("kb_id",[])
kb_ids = req.get("kb_id", [])
doc_ids = req.get("doc_ids", [])
question = req.get("question")
page = int(req.get("page", 1))
@ -822,20 +823,21 @@ def retrieval():
embd_nms = list(set([kb.embd_id for kb in kbs]))
if len(embd_nms) != 1:
return get_json_result(
data=False, message='Knowledge bases use different embedding models or does not exist."', code=RetCode.AUTHENTICATION_ERROR)
data=False, message='Knowledge bases use different embedding models or does not exist."',
code=settings.RetCode.AUTHENTICATION_ERROR)
embd_mdl = TenantLLMService.model_instance(
kbs[0].tenant_id, LLMType.EMBEDDING.value, llm_name=kbs[0].embd_id)
rerank_mdl = None
if req.get("rerank_id"):
rerank_mdl = TenantLLMService.model_instance(
kbs[0].tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"])
kbs[0].tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"])
if req.get("keyword", False):
chat_mdl = TenantLLMService.model_instance(kbs[0].tenant_id, LLMType.CHAT)
question += keyword_extraction(chat_mdl, question)
ranks = retrievaler.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size,
similarity_threshold, vector_similarity_weight, top,
doc_ids, rerank_mdl=rerank_mdl)
ranks = settings.retrievaler.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size,
similarity_threshold, vector_similarity_weight, top,
doc_ids, rerank_mdl=rerank_mdl)
for c in ranks["chunks"]:
if "vector" in c:
del c["vector"]
@ -843,5 +845,5 @@ def retrieval():
except Exception as e:
if str(e).find("not_found") > 0:
return get_json_result(data=False, message='No chunk found! Check the chunk status please!',
code=RetCode.DATA_ERROR)
code=settings.RetCode.DATA_ERROR)
return server_error_response(e)