Test chat API and refine ppt chunker (#42)

This commit is contained in:
KevinHuSh
2024-01-23 19:45:36 +08:00
committed by GitHub
parent 34b2ab3b2f
commit e32ef75e99
10 changed files with 226 additions and 91 deletions

View File

@ -14,12 +14,12 @@
# limitations under the License.
#
from api.db.services.user_service import TenantService
from rag.llm import EmbeddingModel, CvModel
from api.settings import database_logger
from rag.llm import EmbeddingModel, CvModel, ChatModel
from api.db import LLMType
from api.db.db_models import DB, UserTenant
from api.db.db_models import LLMFactories, LLM, TenantLLM
from api.db.services.common_service import CommonService
from api.db import StatusEnum
class LLMFactoriesService(CommonService):
@ -37,13 +37,19 @@ class TenantLLMService(CommonService):
@DB.connection_context()
def get_api_key(cls, tenant_id, model_name):
objs = cls.query(tenant_id=tenant_id, llm_name=model_name)
if not objs: return
if not objs:
return
return objs[0]
@classmethod
@DB.connection_context()
def get_my_llms(cls, tenant_id):
fields = [cls.model.llm_factory, LLMFactories.logo, LLMFactories.tags, cls.model.model_type, cls.model.llm_name]
fields = [
cls.model.llm_factory,
LLMFactories.logo,
LLMFactories.tags,
cls.model.model_type,
cls.model.llm_name]
objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory == LLMFactories.name)).where(
cls.model.tenant_id == tenant_id).dicts()
@ -51,23 +57,96 @@ class TenantLLMService(CommonService):
@classmethod
@DB.connection_context()
def model_instance(cls, tenant_id, llm_type):
e,tenant = TenantService.get_by_id(tenant_id)
if not e: raise LookupError("Tenant not found")
def model_instance(cls, tenant_id, llm_type, llm_name=None):
e, tenant = TenantService.get_by_id(tenant_id)
if not e:
raise LookupError("Tenant not found")
if llm_type == LLMType.EMBEDDING.value: mdlnm = tenant.embd_id
elif llm_type == LLMType.SPEECH2TEXT.value: mdlnm = tenant.asr_id
elif llm_type == LLMType.IMAGE2TEXT.value: mdlnm = tenant.img2txt_id
elif llm_type == LLMType.CHAT.value: mdlnm = tenant.llm_id
else: assert False, "LLM type error"
if llm_type == LLMType.EMBEDDING.value:
mdlnm = tenant.embd_id
elif llm_type == LLMType.SPEECH2TEXT.value:
mdlnm = tenant.asr_id
elif llm_type == LLMType.IMAGE2TEXT.value:
mdlnm = tenant.img2txt_id
elif llm_type == LLMType.CHAT.value:
mdlnm = tenant.llm_id if not llm_name else llm_name
else:
assert False, "LLM type error"
model_config = cls.get_api_key(tenant_id, mdlnm)
if not model_config: raise LookupError("Model({}) not found".format(mdlnm))
if not model_config:
raise LookupError("Model({}) not found".format(mdlnm))
model_config = model_config.to_dict()
if llm_type == LLMType.EMBEDDING.value:
if model_config["llm_factory"] not in EmbeddingModel: return
return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"])
if model_config["llm_factory"] not in EmbeddingModel:
return
return EmbeddingModel[model_config["llm_factory"]](
model_config["api_key"], model_config["llm_name"])
if llm_type == LLMType.IMAGE2TEXT.value:
if model_config["llm_factory"] not in CvModel: return
return CvModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"])
if model_config["llm_factory"] not in CvModel:
return
return CvModel[model_config["llm_factory"]](
model_config["api_key"], model_config["llm_name"])
if llm_type == LLMType.CHAT.value:
if model_config["llm_factory"] not in ChatModel:
return
return ChatModel[model_config["llm_factory"]](
model_config["api_key"], model_config["llm_name"])
@classmethod
@DB.connection_context()
def increase_usage(cls, tenant_id, llm_type, used_tokens, llm_name=None):
e, tenant = TenantService.get_by_id(tenant_id)
if not e:
raise LookupError("Tenant not found")
if llm_type == LLMType.EMBEDDING.value:
mdlnm = tenant.embd_id
elif llm_type == LLMType.SPEECH2TEXT.value:
mdlnm = tenant.asr_id
elif llm_type == LLMType.IMAGE2TEXT.value:
mdlnm = tenant.img2txt_id
elif llm_type == LLMType.CHAT.value:
mdlnm = tenant.llm_id if not llm_name else llm_name
else:
assert False, "LLM type error"
num = cls.model.update(used_tokens=cls.model.used_tokens + used_tokens)\
.where(cls.model.tenant_id == tenant_id, cls.model.llm_name == mdlnm)\
.execute()
return num
class LLMBundle(object):
def __init__(self, tenant_id, llm_type, llm_name=None):
self.tenant_id = tenant_id
self.llm_type = llm_type
self.llm_name = llm_name
self.mdl = TenantLLMService.model_instance(tenant_id, llm_type, llm_name)
assert self.mdl, "Can't find mole for {}/{}/{}".format(tenant_id, llm_type, llm_name)
def encode(self, texts: list, batch_size=32):
emd, used_tokens = self.mdl.encode(texts, batch_size)
if TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
database_logger.error("Can't update token usage for {}/EMBEDDING".format(self.tenant_id))
return emd, used_tokens
def encode_queries(self, query: str):
emd, used_tokens = self.mdl.encode_queries(query)
if TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
database_logger.error("Can't update token usage for {}/EMBEDDING".format(self.tenant_id))
return emd, used_tokens
def describe(self, image, max_tokens=300):
txt, used_tokens = self.mdl.describe(image, max_tokens)
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
database_logger.error("Can't update token usage for {}/IMAGE2TEXT".format(self.tenant_id))
return txt
def chat(self, system, history, gen_conf):
txt, used_tokens = self.mdl.chat(system, history, gen_conf)
if TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name):
database_logger.error("Can't update token usage for {}/CHAT".format(self.tenant_id))
return txt