llm configuation refine and trievalTest API refine (#40)

This commit is contained in:
KevinHuSh
2024-01-19 19:51:57 +08:00
committed by GitHub
parent f3dd131403
commit 484e5abc1f
39 changed files with 160 additions and 121 deletions

View File

@ -1,5 +1,5 @@
#
# Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from api.db.services.user_service import TenantService
from rag.llm import EmbeddingModel, CvModel
from api.db import LLMType
from api.db.db_models import DB, UserTenant
@ -34,40 +35,39 @@ class TenantLLMService(CommonService):
@classmethod
@DB.connection_context()
def get_api_key(cls, tenant_id, model_type, model_name=""):
objs = cls.query(tenant_id=tenant_id, model_type=model_type)
if objs and len(objs)>0 and objs[0].llm_name:
return objs[0]
fields = [LLM.llm_name, cls.model.llm_factory, cls.model.api_key]
objs = cls.model.select(*fields).join(LLM, on=(LLM.fid == cls.model.llm_factory)).where(
(cls.model.tenant_id == tenant_id),
((cls.model.model_type == model_type) | (cls.model.llm_name == model_name)),
(LLM.status == StatusEnum.VALID)
)
if not objs:return
def get_api_key(cls, tenant_id, model_name):
objs = cls.query(tenant_id=tenant_id, llm_name=model_name)
if not objs: return
return objs[0]
@classmethod
@DB.connection_context()
def get_my_llms(cls, tenant_id):
fields = [cls.model.llm_factory, LLMFactories.logo, LLMFactories.tags, cls.model.model_type, cls.model.llm_name]
objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory==LLMFactories.name)).where(cls.model.tenant_id==tenant_id).dicts()
objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory == LLMFactories.name)).where(
cls.model.tenant_id == tenant_id).dicts()
return list(objs)
@classmethod
@DB.connection_context()
def model_instance(cls, tenant_id, llm_type):
model_config = cls.get_api_key(tenant_id, model_type=LLMType.EMBEDDING.value)
if not model_config:
model_config = {"llm_factory": "local", "api_key": "", "llm_name": ""}
else:
model_config = model_config[0].to_dict()
if llm_type == LLMType.EMBEDDING:
e,tenant = TenantService.get_by_id(tenant_id)
if not e: raise LookupError("Tenant not found")
if llm_type == LLMType.EMBEDDING.value: mdlnm = tenant.embd_id
elif llm_type == LLMType.SPEECH2TEXT.value: mdlnm = tenant.asr_id
elif llm_type == LLMType.IMAGE2TEXT.value: mdlnm = tenant.img2txt_id
elif llm_type == LLMType.CHAT.value: mdlnm = tenant.llm_id
else: assert False, "LLM type error"
model_config = cls.get_api_key(tenant_id, mdlnm)
if not model_config: raise LookupError("Model({}) not found".format(mdlnm))
model_config = model_config[0].to_dict()
if llm_type == LLMType.EMBEDDING.value:
if model_config["llm_factory"] not in EmbeddingModel: return
return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"])
if llm_type == LLMType.IMAGE2TEXT:
if llm_type == LLMType.IMAGE2TEXT.value:
if model_config["llm_factory"] not in CvModel: return
return CvModel[model_config.llm_factory](model_config["api_key"], model_config["llm_name"])
return CvModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"])