Move settings initialization after module init phase (#3438)

### What problem does this PR solve?

1. Module init won't connect database any more.
2. Config in settings need to be used with settings.CONFIG_NAME

### Type of change

- [x] Refactoring

Signed-off-by: jinhai <haijin.chn@gmail.com>
This commit is contained in:
Jin Hai
2024-11-15 17:30:56 +08:00
committed by GitHub
parent ac033b62cf
commit 1e90a1bf36
33 changed files with 452 additions and 411 deletions

View File

@ -14,7 +14,7 @@
# limitations under the License.
#
from flask import request
from api.settings import RetCode
from api import settings
from api.db import StatusEnum
from api.db.services.dialog_service import DialogService
from api.db.services.knowledgebase_service import KnowledgebaseService
@ -44,7 +44,7 @@ def create(tenant_id):
kbs = KnowledgebaseService.get_by_ids(ids)
embd_count = list(set([kb.embd_id for kb in kbs]))
if len(embd_count) != 1:
return get_result(message='Datasets use different embedding models."',code=RetCode.AUTHENTICATION_ERROR)
return get_result(message='Datasets use different embedding models."',code=settings.RetCode.AUTHENTICATION_ERROR)
req["kb_ids"] = ids
# llm
llm = req.get("llm")
@ -173,7 +173,7 @@ def update(tenant_id,chat_id):
if len(embd_count) != 1 :
return get_result(
message='Datasets use different embedding models."',
code=RetCode.AUTHENTICATION_ERROR)
code=settings.RetCode.AUTHENTICATION_ERROR)
req["kb_ids"] = ids
llm = req.get("llm")
if llm:

View File

@ -23,7 +23,7 @@ from api.db.services.file_service import FileService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import TenantLLMService, LLMService
from api.db.services.user_service import TenantService
from api.settings import RetCode
from api import settings
from api.utils import get_uuid
from api.utils.api_utils import (
get_result,
@ -255,7 +255,7 @@ def delete(tenant_id):
File2DocumentService.delete_by_document_id(doc.id)
if not KnowledgebaseService.delete_by_id(id):
return get_error_data_result(message="Delete dataset error.(Database error)")
return get_result(code=RetCode.SUCCESS)
return get_result(code=settings.RetCode.SUCCESS)
@manager.route("/datasets/<dataset_id>", methods=["PUT"])
@ -424,7 +424,7 @@ def update(tenant_id, dataset_id):
)
if not KnowledgebaseService.update_by_id(kb.id, req):
return get_error_data_result(message="Update dataset error.(Database error)")
return get_result(code=RetCode.SUCCESS)
return get_result(code=settings.RetCode.SUCCESS)
@manager.route("/datasets", methods=["GET"])

View File

@ -18,7 +18,7 @@ from flask import request, jsonify
from api.db import LLMType, ParserType
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
from api.settings import retrievaler, kg_retrievaler, RetCode
from api import settings
from api.utils.api_utils import validate_request, build_error_result, apikey_required
@ -37,14 +37,14 @@ def retrieval(tenant_id):
e, kb = KnowledgebaseService.get_by_id(kb_id)
if not e:
return build_error_result(message="Knowledgebase not found!", code=RetCode.NOT_FOUND)
return build_error_result(message="Knowledgebase not found!", code=settings.RetCode.NOT_FOUND)
if kb.tenant_id != tenant_id:
return build_error_result(message="Knowledgebase not found!", code=RetCode.NOT_FOUND)
return build_error_result(message="Knowledgebase not found!", code=settings.RetCode.NOT_FOUND)
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler
retr = settings.retrievaler if kb.parser_id != ParserType.KG else settings.kg_retrievaler
ranks = retr.retrieval(
question,
embd_mdl,
@ -72,6 +72,6 @@ def retrieval(tenant_id):
if str(e).find("not_found") > 0:
return build_error_result(
message='No chunk found! Check the chunk status please!',
code=RetCode.NOT_FOUND
code=settings.RetCode.NOT_FOUND
)
return build_error_result(message=str(e), code=RetCode.SERVER_ERROR)
return build_error_result(message=str(e), code=settings.RetCode.SERVER_ERROR)

View File

@ -21,7 +21,7 @@ from rag.app.qa import rmPrefix, beAdoc
from rag.nlp import rag_tokenizer
from api.db import LLMType, ParserType
from api.db.services.llm_service import TenantLLMService
from api.settings import kg_retrievaler
from api import settings
import hashlib
import re
from api.utils.api_utils import token_required
@ -37,11 +37,10 @@ from api.db.services.document_service import DocumentService
from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.settings import RetCode, retrievaler
from api import settings
from api.utils.api_utils import construct_json_result, get_parser_config
from rag.nlp import search
from rag.utils import rmSpace
from api.settings import docStoreConn
from rag.utils.storage_factory import STORAGE_IMPL
import os
@ -109,13 +108,13 @@ def upload(dataset_id, tenant_id):
"""
if "file" not in request.files:
return get_error_data_result(
message="No file part!", code=RetCode.ARGUMENT_ERROR
message="No file part!", code=settings.RetCode.ARGUMENT_ERROR
)
file_objs = request.files.getlist("file")
for file_obj in file_objs:
if file_obj.filename == "":
return get_result(
message="No file selected!", code=RetCode.ARGUMENT_ERROR
message="No file selected!", code=settings.RetCode.ARGUMENT_ERROR
)
# total size
total_size = 0
@ -127,14 +126,14 @@ def upload(dataset_id, tenant_id):
if total_size > MAX_TOTAL_FILE_SIZE:
return get_result(
message=f"Total file size exceeds 10MB limit! ({total_size / (1024 * 1024):.2f} MB)",
code=RetCode.ARGUMENT_ERROR,
code=settings.RetCode.ARGUMENT_ERROR,
)
e, kb = KnowledgebaseService.get_by_id(dataset_id)
if not e:
raise LookupError(f"Can't find the dataset with ID {dataset_id}!")
err, files = FileService.upload_document(kb, file_objs, tenant_id)
if err:
return get_result(message="\n".join(err), code=RetCode.SERVER_ERROR)
return get_result(message="\n".join(err), code=settings.RetCode.SERVER_ERROR)
# rename key's name
renamed_doc_list = []
for file in files:
@ -221,12 +220,12 @@ def update_doc(tenant_id, dataset_id, document_id):
if "name" in req and req["name"] != doc.name:
if (
pathlib.Path(req["name"].lower()).suffix
!= pathlib.Path(doc.name.lower()).suffix
pathlib.Path(req["name"].lower()).suffix
!= pathlib.Path(doc.name.lower()).suffix
):
return get_result(
message="The extension of file can't be changed",
code=RetCode.ARGUMENT_ERROR,
code=settings.RetCode.ARGUMENT_ERROR,
)
for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id):
if d.name == req["name"]:
@ -292,7 +291,7 @@ def update_doc(tenant_id, dataset_id, document_id):
)
if not e:
return get_error_data_result(message="Document not found!")
docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id)
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id)
return get_result()
@ -349,7 +348,7 @@ def download(tenant_id, dataset_id, document_id):
file_stream = STORAGE_IMPL.get(doc_id, doc_location)
if not file_stream:
return construct_json_result(
message="This file is empty.", code=RetCode.DATA_ERROR
message="This file is empty.", code=settings.RetCode.DATA_ERROR
)
file = BytesIO(file_stream)
# Use send_file with a proper filename and MIME type
@ -582,7 +581,7 @@ def delete(tenant_id, dataset_id):
errors += str(e)
if errors:
return get_result(message=errors, code=RetCode.SERVER_ERROR)
return get_result(message=errors, code=settings.RetCode.SERVER_ERROR)
return get_result()
@ -644,7 +643,7 @@ def parse(tenant_id, dataset_id):
info["chunk_num"] = 0
info["token_num"] = 0
DocumentService.update_by_id(id, info)
docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), dataset_id)
settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), dataset_id)
TaskService.filter_delete([Task.doc_id == id])
e, doc = DocumentService.get_by_id(id)
doc = doc.to_dict()
@ -708,7 +707,7 @@ def stop_parsing(tenant_id, dataset_id):
)
info = {"run": "2", "progress": 0, "chunk_num": 0}
DocumentService.update_by_id(id, info)
docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id)
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id)
return get_result()
@ -828,8 +827,9 @@ def list_chunks(tenant_id, dataset_id, document_id):
res = {"total": 0, "chunks": [], "doc": renamed_doc}
origin_chunks = []
if docStoreConn.indexExist(search.index_name(tenant_id), dataset_id):
sres = retrievaler.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None, highlight=True)
if settings.docStoreConn.indexExist(search.index_name(tenant_id), dataset_id):
sres = settings.retrievaler.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None,
highlight=True)
res["total"] = sres.total
sign = 0
for id in sres.ids:
@ -1003,7 +1003,7 @@ def add_chunk(tenant_id, dataset_id, document_id):
v, c = embd_mdl.encode([doc.name, req["content"]])
v = 0.1 * v[0] + 0.9 * v[1]
d["q_%d_vec" % len(v)] = v.tolist()
docStoreConn.insert([d], search.index_name(tenant_id), dataset_id)
settings.docStoreConn.insert([d], search.index_name(tenant_id), dataset_id)
DocumentService.increment_chunk_num(doc.id, doc.kb_id, c, 1, 0)
# rename keys
@ -1078,7 +1078,7 @@ def rm_chunk(tenant_id, dataset_id, document_id):
condition = {"doc_id": document_id}
if "chunk_ids" in req:
condition["id"] = req["chunk_ids"]
chunk_number = docStoreConn.delete(condition, search.index_name(tenant_id), dataset_id)
chunk_number = settings.docStoreConn.delete(condition, search.index_name(tenant_id), dataset_id)
if chunk_number != 0:
DocumentService.decrement_chunk_num(document_id, dataset_id, 1, chunk_number, 0)
if "chunk_ids" in req and chunk_number != len(req["chunk_ids"]):
@ -1143,7 +1143,7 @@ def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
schema:
type: object
"""
chunk = docStoreConn.get(chunk_id, search.index_name(tenant_id), [dataset_id])
chunk = settings.docStoreConn.get(chunk_id, search.index_name(tenant_id), [dataset_id])
if chunk is None:
return get_error_data_result(f"Can't find this chunk {chunk_id}")
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
@ -1187,7 +1187,7 @@ def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
v, c = embd_mdl.encode([doc.name, d["content_with_weight"]])
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
d["q_%d_vec" % len(v)] = v.tolist()
docStoreConn.update({"id": chunk_id}, d, search.index_name(tenant_id), dataset_id)
settings.docStoreConn.update({"id": chunk_id}, d, search.index_name(tenant_id), dataset_id)
return get_result()
@ -1285,7 +1285,7 @@ def retrieval_test(tenant_id):
if len(embd_nms) != 1:
return get_result(
message='Datasets use different embedding models."',
code=RetCode.AUTHENTICATION_ERROR,
code=settings.RetCode.AUTHENTICATION_ERROR,
)
if "question" not in req:
return get_error_data_result("`question` is required.")
@ -1326,7 +1326,7 @@ def retrieval_test(tenant_id):
chat_mdl = TenantLLMService.model_instance(kb.tenant_id, LLMType.CHAT)
question += keyword_extraction(chat_mdl, question)
retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler
retr = settings.retrievaler if kb.parser_id != ParserType.KG else settings.kg_retrievaler
ranks = retr.retrieval(
question,
embd_mdl,
@ -1366,6 +1366,6 @@ def retrieval_test(tenant_id):
if str(e).find("not_found") > 0:
return get_result(
message="No chunk found! Check the chunk status please!",
code=RetCode.DATA_ERROR,
code=settings.RetCode.DATA_ERROR,
)
return server_error_response(e)
return server_error_response(e)