llm configuation refine and trievalTest API refine (#40)

This commit is contained in:
KevinHuSh
2024-01-19 19:51:57 +08:00
committed by GitHub
parent f3dd131403
commit 484e5abc1f
39 changed files with 160 additions and 121 deletions

View File

@ -1,5 +1,5 @@
#
# Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -13,22 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import re
import tiktoken
from flask import request
from flask_login import login_required, current_user
from flask_login import login_required
from api.db.services.dialog_service import DialogService, ConversationService
from api.db import StatusEnum, LLMType
from api.db.services.kb_service import KnowledgebaseService
from api.db import LLMType
from api.db.services.llm_service import LLMService, TenantLLMService
from api.db.services.user_service import TenantService
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from api.utils import get_uuid
from api.utils.api_utils import get_json_result
from rag.llm import ChatModel
from rag.nlp import retrievaler
from rag.nlp.query import EsQueryer
from rag.utils import num_tokens_from_string, encoder
@ -142,6 +136,27 @@ def message_fit_in(msg, max_length=4000):
return max_length, msg
@manager.route('/completion', methods=['POST'])
@login_required
@validate_request("dialog_id", "messages")
def completion():
req = request.json
msg = []
for m in req["messages"]:
if m["role"] == "system":continue
if m["role"] == "assistant" and not msg:continue
msg.append({"role": m["role"], "content": m["content"]})
try:
e, dia = DialogService.get_by_id(req["dialog_id"])
if not e:
return get_data_error_result(retmsg="Dialog not found!")
del req["dialog_id"]
del req["messages"]
return get_json_result(data=chat(dia, msg, **req))
except Exception as e:
return server_error_response(e)
def chat(dialog, messages, **kwargs):
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
llm = LLMService.query(llm_name=dialog.llm_id)
@ -156,7 +171,7 @@ def chat(dialog, messages, **kwargs):
prompt_config["system"] = prompt_config["system"].replace("{%s}"%p["key"], " ")
model_config = TenantLLMService.get_api_key(dialog.tenant_id, LLMType.CHAT.value, dialog.llm_id)
if not model_config: raise LookupError("LLM(%s) API key not found"%dialog.llm_id)
if not model_config: raise LookupError("LLM({}) API key not found".format(dialog.llm_id))
question = messages[-1]["content"]
embd_mdl = TenantLLMService.model_instance(
@ -183,25 +198,4 @@ def chat(dialog, messages, **kwargs):
embd_mdl,
tkweight=1-dialog.vector_similarity_weight,
vtweight=dialog.vector_similarity_weight)
return {"answer": answer, "retrieval": kbinfos}
@manager.route('/completion', methods=['POST'])
@login_required
@validate_request("dialog_id", "messages")
def completion():
req = request.json
msg = []
for m in req["messages"]:
if m["role"] == "system":continue
if m["role"] == "assistant" and not msg:continue
msg.append({"role": m["role"], "content": m["content"]})
try:
e, dia = DialogService.get_by_id(req["dialog_id"])
if not e:
return get_data_error_result(retmsg="Dialog not found!")
del req["dialog_id"]
del req["messages"]
return get_json_result(data=chat(dia, msg, **req))
except Exception as e:
return server_error_response(e)
return {"answer": answer, "retrieval": kbinfos}