Test chat API and refine ppt chunker (#42)

This commit is contained in:
KevinHuSh
2024-01-23 19:45:36 +08:00
committed by GitHub
parent 34b2ab3b2f
commit e32ef75e99
10 changed files with 226 additions and 91 deletions

View File

@ -17,7 +17,7 @@ from flask import request
from flask_login import login_required
from api.db.services.dialog_service import DialogService, ConversationService
from api.db import LLMType
from api.db.services.llm_service import LLMService, TenantLLMService
from api.db.services.llm_service import LLMService, TenantLLMService, LLMBundle
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from api.utils import get_uuid
from api.utils.api_utils import get_json_result
@ -170,12 +170,9 @@ def chat(dialog, messages, **kwargs):
if p["key"] not in kwargs:
prompt_config["system"] = prompt_config["system"].replace("{%s}"%p["key"], " ")
model_config = TenantLLMService.get_api_key(dialog.tenant_id, dialog.llm_id)
if not model_config: raise LookupError("LLM({}) API key not found".format(dialog.llm_id))
question = messages[-1]["content"]
embd_mdl = TenantLLMService.model_instance(
dialog.tenant_id, LLMType.EMBEDDING.value)
embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING)
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
kbinfos = retrievaler.retrieval(question, embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n, dialog.similarity_threshold,
dialog.vector_similarity_weight, top=1024, aggs=False)
knowledges = [ck["content_ltks"] for ck in kbinfos["chunks"]]
@ -189,8 +186,7 @@ def chat(dialog, messages, **kwargs):
used_token_count, msg = message_fit_in(msg, int(llm.max_tokens * 0.97))
if "max_tokens" in gen_conf:
gen_conf["max_tokens"] = min(gen_conf["max_tokens"], llm.max_tokens - used_token_count)
mdl = ChatModel[model_config.llm_factory](model_config.api_key, dialog.llm_id)
answer = mdl.chat(prompt_config["system"].format(**kwargs), msg, gen_conf)
answer = chat_mdl.chat(prompt_config["system"].format(**kwargs), msg, gen_conf)
answer = retrievaler.insert_citations(answer,
[ck["content_ltks"] for ck in kbinfos["chunks"]],

View File

@ -524,6 +524,7 @@ class Dialog(DataBaseModel):
similarity_threshold = FloatField(default=0.2)
vector_similarity_weight = FloatField(default=0.3)
top_n = IntegerField(default=6)
do_refer = CharField(max_length=1, null=False, help_text="it needs to insert reference index into answer or not", default="1")
kb_ids = JSONField(null=False, default=[])
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1")

View File

@ -14,12 +14,12 @@
# limitations under the License.
#
from api.db.services.user_service import TenantService
from rag.llm import EmbeddingModel, CvModel
from api.settings import database_logger
from rag.llm import EmbeddingModel, CvModel, ChatModel
from api.db import LLMType
from api.db.db_models import DB, UserTenant
from api.db.db_models import LLMFactories, LLM, TenantLLM
from api.db.services.common_service import CommonService
from api.db import StatusEnum
class LLMFactoriesService(CommonService):
@ -37,13 +37,19 @@ class TenantLLMService(CommonService):
@DB.connection_context()
def get_api_key(cls, tenant_id, model_name):
objs = cls.query(tenant_id=tenant_id, llm_name=model_name)
if not objs: return
if not objs:
return
return objs[0]
@classmethod
@DB.connection_context()
def get_my_llms(cls, tenant_id):
fields = [cls.model.llm_factory, LLMFactories.logo, LLMFactories.tags, cls.model.model_type, cls.model.llm_name]
fields = [
cls.model.llm_factory,
LLMFactories.logo,
LLMFactories.tags,
cls.model.model_type,
cls.model.llm_name]
objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory == LLMFactories.name)).where(
cls.model.tenant_id == tenant_id).dicts()
@ -51,23 +57,96 @@ class TenantLLMService(CommonService):
@classmethod
@DB.connection_context()
def model_instance(cls, tenant_id, llm_type):
e,tenant = TenantService.get_by_id(tenant_id)
if not e: raise LookupError("Tenant not found")
def model_instance(cls, tenant_id, llm_type, llm_name=None):
e, tenant = TenantService.get_by_id(tenant_id)
if not e:
raise LookupError("Tenant not found")
if llm_type == LLMType.EMBEDDING.value: mdlnm = tenant.embd_id
elif llm_type == LLMType.SPEECH2TEXT.value: mdlnm = tenant.asr_id
elif llm_type == LLMType.IMAGE2TEXT.value: mdlnm = tenant.img2txt_id
elif llm_type == LLMType.CHAT.value: mdlnm = tenant.llm_id
else: assert False, "LLM type error"
if llm_type == LLMType.EMBEDDING.value:
mdlnm = tenant.embd_id
elif llm_type == LLMType.SPEECH2TEXT.value:
mdlnm = tenant.asr_id
elif llm_type == LLMType.IMAGE2TEXT.value:
mdlnm = tenant.img2txt_id
elif llm_type == LLMType.CHAT.value:
mdlnm = tenant.llm_id if not llm_name else llm_name
else:
assert False, "LLM type error"
model_config = cls.get_api_key(tenant_id, mdlnm)
if not model_config: raise LookupError("Model({}) not found".format(mdlnm))
if not model_config:
raise LookupError("Model({}) not found".format(mdlnm))
model_config = model_config.to_dict()
if llm_type == LLMType.EMBEDDING.value:
if model_config["llm_factory"] not in EmbeddingModel: return
return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"])
if model_config["llm_factory"] not in EmbeddingModel:
return
return EmbeddingModel[model_config["llm_factory"]](
model_config["api_key"], model_config["llm_name"])
if llm_type == LLMType.IMAGE2TEXT.value:
if model_config["llm_factory"] not in CvModel: return
return CvModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"])
if model_config["llm_factory"] not in CvModel:
return
return CvModel[model_config["llm_factory"]](
model_config["api_key"], model_config["llm_name"])
if llm_type == LLMType.CHAT.value:
if model_config["llm_factory"] not in ChatModel:
return
return ChatModel[model_config["llm_factory"]](
model_config["api_key"], model_config["llm_name"])
@classmethod
@DB.connection_context()
def increase_usage(cls, tenant_id, llm_type, used_tokens, llm_name=None):
e, tenant = TenantService.get_by_id(tenant_id)
if not e:
raise LookupError("Tenant not found")
if llm_type == LLMType.EMBEDDING.value:
mdlnm = tenant.embd_id
elif llm_type == LLMType.SPEECH2TEXT.value:
mdlnm = tenant.asr_id
elif llm_type == LLMType.IMAGE2TEXT.value:
mdlnm = tenant.img2txt_id
elif llm_type == LLMType.CHAT.value:
mdlnm = tenant.llm_id if not llm_name else llm_name
else:
assert False, "LLM type error"
num = cls.model.update(used_tokens=cls.model.used_tokens + used_tokens)\
.where(cls.model.tenant_id == tenant_id, cls.model.llm_name == mdlnm)\
.execute()
return num
class LLMBundle(object):
def __init__(self, tenant_id, llm_type, llm_name=None):
self.tenant_id = tenant_id
self.llm_type = llm_type
self.llm_name = llm_name
self.mdl = TenantLLMService.model_instance(tenant_id, llm_type, llm_name)
assert self.mdl, "Can't find mole for {}/{}/{}".format(tenant_id, llm_type, llm_name)
def encode(self, texts: list, batch_size=32):
emd, used_tokens = self.mdl.encode(texts, batch_size)
if TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
database_logger.error("Can't update token usage for {}/EMBEDDING".format(self.tenant_id))
return emd, used_tokens
def encode_queries(self, query: str):
emd, used_tokens = self.mdl.encode_queries(query)
if TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
database_logger.error("Can't update token usage for {}/EMBEDDING".format(self.tenant_id))
return emd, used_tokens
def describe(self, image, max_tokens=300):
txt, used_tokens = self.mdl.describe(image, max_tokens)
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
database_logger.error("Can't update token usage for {}/IMAGE2TEXT".format(self.tenant_id))
return txt
def chat(self, system, history, gen_conf):
txt, used_tokens = self.mdl.chat(system, history, gen_conf)
if TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name):
database_logger.error("Can't update token usage for {}/CHAT".format(self.tenant_id))
return txt

View File

@ -143,11 +143,11 @@ def filename_type(filename):
if re.match(r".*\.pdf$", filename):
return FileType.PDF.value
if re.match(r".*\.(docx|doc|ppt|yml|xml|htm|json|csv|txt|ini|xsl|wps|rtf|hlp|pages|numbers|key|md)$", filename):
if re.match(r".*\.(docx|doc|ppt|pptx|yml|xml|htm|json|csv|txt|ini|xsl|wps|rtf|hlp|pages|numbers|key|md)$", filename):
return FileType.DOC.value
if re.match(r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)$", filename):
return FileType.AURAL.value
if re.match(r".*\.(jpg|jpeg|png|tif|gif|pcx|tga|exif|fpx|svg|psd|cdr|pcd|dxf|ufo|eps|ai|raw|WMF|webp|avif|apng|icon|ico|mpg|mpeg|avi|rm|rmvb|mov|wmv|asf|dat|asx|wvx|mpe|mpa|mp4)$", filename):
return FileType.VISUAL
return FileType.VISUAL