Test chat API and refine ppt chunker (#42)

This commit is contained in:
KevinHuSh
2024-01-23 19:45:36 +08:00
committed by GitHub
parent 34b2ab3b2f
commit e32ef75e99
10 changed files with 226 additions and 91 deletions

View File

@ -17,7 +17,7 @@ from flask import request
from flask_login import login_required
from api.db.services.dialog_service import DialogService, ConversationService
from api.db import LLMType
from api.db.services.llm_service import LLMService, TenantLLMService
from api.db.services.llm_service import LLMService, TenantLLMService, LLMBundle
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from api.utils import get_uuid
from api.utils.api_utils import get_json_result
@ -170,12 +170,9 @@ def chat(dialog, messages, **kwargs):
if p["key"] not in kwargs:
prompt_config["system"] = prompt_config["system"].replace("{%s}"%p["key"], " ")
model_config = TenantLLMService.get_api_key(dialog.tenant_id, dialog.llm_id)
if not model_config: raise LookupError("LLM({}) API key not found".format(dialog.llm_id))
question = messages[-1]["content"]
embd_mdl = TenantLLMService.model_instance(
dialog.tenant_id, LLMType.EMBEDDING.value)
embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING)
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
kbinfos = retrievaler.retrieval(question, embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n, dialog.similarity_threshold,
dialog.vector_similarity_weight, top=1024, aggs=False)
knowledges = [ck["content_ltks"] for ck in kbinfos["chunks"]]
@ -189,8 +186,7 @@ def chat(dialog, messages, **kwargs):
used_token_count, msg = message_fit_in(msg, int(llm.max_tokens * 0.97))
if "max_tokens" in gen_conf:
gen_conf["max_tokens"] = min(gen_conf["max_tokens"], llm.max_tokens - used_token_count)
mdl = ChatModel[model_config.llm_factory](model_config.api_key, dialog.llm_id)
answer = mdl.chat(prompt_config["system"].format(**kwargs), msg, gen_conf)
answer = chat_mdl.chat(prompt_config["system"].format(**kwargs), msg, gen_conf)
answer = retrievaler.insert_citations(answer,
[ck["content_ltks"] for ck in kbinfos["chunks"]],