Test APIs and fix bugs (#41)

This commit is contained in:
KevinHuSh
2024-01-22 19:51:38 +08:00
committed by GitHub
parent 484e5abc1f
commit 34b2ab3b2f
11 changed files with 46 additions and 27 deletions

View File

@ -214,7 +214,7 @@ def retrieval_test():
question = req["question"]
kb_id = req["kb_id"]
doc_ids = req.get("doc_ids", [])
similarity_threshold = float(req.get("similarity_threshold", 0.4))
similarity_threshold = float(req.get("similarity_threshold", 0.2))
vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
top = int(req.get("top", 1024))
try:

View File

@ -170,7 +170,7 @@ def chat(dialog, messages, **kwargs):
if p["key"] not in kwargs:
prompt_config["system"] = prompt_config["system"].replace("{%s}"%p["key"], " ")
model_config = TenantLLMService.get_api_key(dialog.tenant_id, LLMType.CHAT.value, dialog.llm_id)
model_config = TenantLLMService.get_api_key(dialog.tenant_id, dialog.llm_id)
if not model_config: raise LookupError("LLM({}) API key not found".format(dialog.llm_id))
question = messages[-1]["content"]
@ -186,10 +186,10 @@ def chat(dialog, messages, **kwargs):
kwargs["knowledge"] = "\n".join(knowledges)
gen_conf = dialog.llm_setting[dialog.llm_setting_type]
msg = [{"role": m["role"], "content": m["content"]} for m in messages if m["role"] != "system"]
used_token_count = message_fit_in(msg, int(llm.max_tokens * 0.97))
used_token_count, msg = message_fit_in(msg, int(llm.max_tokens * 0.97))
if "max_tokens" in gen_conf:
gen_conf["max_tokens"] = min(gen_conf["max_tokens"], llm.max_tokens - used_token_count)
mdl = ChatModel[model_config.llm_factory](model_config["api_key"], dialog.llm_id)
mdl = ChatModel[model_config.llm_factory](model_config.api_key, dialog.llm_id)
answer = mdl.chat(prompt_config["system"].format(**kwargs), msg, gen_conf)
answer = retrievaler.insert_citations(answer,
@ -198,4 +198,6 @@ def chat(dialog, messages, **kwargs):
embd_mdl,
tkweight=1-dialog.vector_similarity_weight,
vtweight=dialog.vector_similarity_weight)
for c in kbinfos["chunks"]:
if c.get("vector"):del c["vector"]
return {"answer": answer, "retrieval": kbinfos}

View File

@ -11,7 +11,8 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License
#
#
import base64
import pathlib
@ -65,7 +66,7 @@ def upload():
while MINIO.obj_exist(kb_id, location):
location += "_"
blob = request.files['file'].read()
MINIO.put(kb_id, filename, blob)
MINIO.put(kb_id, location, blob)
doc = DocumentService.insert({
"id": get_uuid(),
"kb_id": kb.id,
@ -188,7 +189,10 @@ def rm():
e, doc = DocumentService.get_by_id(req["doc_id"])
if not e:
return get_data_error_result(retmsg="Document not found!")
ELASTICSEARCH.deleteByQuery(Q("match", doc_id=doc.id), idxnm=search.index_name(doc.kb_id))
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
if not tenant_id:
return get_data_error_result(retmsg="Tenant not found!")
ELASTICSEARCH.deleteByQuery(Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id))
DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num*-1, doc.chunk_num*-1, 0)
if not DocumentService.delete_by_id(req["doc_id"]):

View File

@ -75,7 +75,7 @@ def list():
llms = LLMService.get_all()
llms = [m.to_dict() for m in llms if m.status == StatusEnum.VALID.value]
for m in llms:
m["available"] = m.llm_name in mdlnms
m["available"] = m["llm_name"] in mdlnms
res = {}
for m in llms:

View File

@ -469,7 +469,7 @@ class Knowledgebase(DataBaseModel):
doc_num = IntegerField(default=0)
token_num = IntegerField(default=0)
chunk_num = IntegerField(default=0)
similarity_threshold = FloatField(default=0.4)
similarity_threshold = FloatField(default=0.2)
vector_similarity_weight = FloatField(default=0.3)
parser_id = CharField(max_length=32, null=False, help_text="default parser ID")
@ -521,7 +521,7 @@ class Dialog(DataBaseModel):
prompt_config = JSONField(null=False, default={"system": "", "prologue": "您好我是您的助手小樱长得可爱又善良can I help you?",
"parameters": [], "empty_response": "Sorry! 知识库中未找到相关内容!"})
similarity_threshold = FloatField(default=0.4)
similarity_threshold = FloatField(default=0.2)
vector_similarity_weight = FloatField(default=0.3)
top_n = IntegerField(default=6)

View File

@ -63,7 +63,7 @@ class TenantLLMService(CommonService):
model_config = cls.get_api_key(tenant_id, mdlnm)
if not model_config: raise LookupError("Model({}) not found".format(mdlnm))
model_config = model_config[0].to_dict()
model_config = model_config.to_dict()
if llm_type == LLMType.EMBEDDING.value:
if model_config["llm_factory"] not in EmbeddingModel: return
return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"])

View File

@ -143,7 +143,7 @@ def filename_type(filename):
if re.match(r".*\.pdf$", filename):
return FileType.PDF.value
if re.match(r".*\.(doc|ppt|yml|xml|htm|json|csv|txt|ini|xsl|wps|rtf|hlp|pages|numbers|key|md)$", filename):
if re.match(r".*\.(docx|doc|ppt|yml|xml|htm|json|csv|txt|ini|xsl|wps|rtf|hlp|pages|numbers|key|md)$", filename):
return FileType.DOC.value
if re.match(r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)$", filename):