mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Test chat API and refine ppt chunker (#42)
This commit is contained in:
@ -14,12 +14,12 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
from api.db.services.user_service import TenantService
|
||||
from rag.llm import EmbeddingModel, CvModel
|
||||
from api.settings import database_logger
|
||||
from rag.llm import EmbeddingModel, CvModel, ChatModel
|
||||
from api.db import LLMType
|
||||
from api.db.db_models import DB, UserTenant
|
||||
from api.db.db_models import LLMFactories, LLM, TenantLLM
|
||||
from api.db.services.common_service import CommonService
|
||||
from api.db import StatusEnum
|
||||
|
||||
|
||||
class LLMFactoriesService(CommonService):
|
||||
@ -37,13 +37,19 @@ class TenantLLMService(CommonService):
|
||||
@DB.connection_context()
|
||||
def get_api_key(cls, tenant_id, model_name):
|
||||
objs = cls.query(tenant_id=tenant_id, llm_name=model_name)
|
||||
if not objs: return
|
||||
if not objs:
|
||||
return
|
||||
return objs[0]
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_my_llms(cls, tenant_id):
|
||||
fields = [cls.model.llm_factory, LLMFactories.logo, LLMFactories.tags, cls.model.model_type, cls.model.llm_name]
|
||||
fields = [
|
||||
cls.model.llm_factory,
|
||||
LLMFactories.logo,
|
||||
LLMFactories.tags,
|
||||
cls.model.model_type,
|
||||
cls.model.llm_name]
|
||||
objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory == LLMFactories.name)).where(
|
||||
cls.model.tenant_id == tenant_id).dicts()
|
||||
|
||||
@ -51,23 +57,96 @@ class TenantLLMService(CommonService):
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def model_instance(cls, tenant_id, llm_type):
|
||||
e,tenant = TenantService.get_by_id(tenant_id)
|
||||
if not e: raise LookupError("Tenant not found")
|
||||
def model_instance(cls, tenant_id, llm_type, llm_name=None):
|
||||
e, tenant = TenantService.get_by_id(tenant_id)
|
||||
if not e:
|
||||
raise LookupError("Tenant not found")
|
||||
|
||||
if llm_type == LLMType.EMBEDDING.value: mdlnm = tenant.embd_id
|
||||
elif llm_type == LLMType.SPEECH2TEXT.value: mdlnm = tenant.asr_id
|
||||
elif llm_type == LLMType.IMAGE2TEXT.value: mdlnm = tenant.img2txt_id
|
||||
elif llm_type == LLMType.CHAT.value: mdlnm = tenant.llm_id
|
||||
else: assert False, "LLM type error"
|
||||
if llm_type == LLMType.EMBEDDING.value:
|
||||
mdlnm = tenant.embd_id
|
||||
elif llm_type == LLMType.SPEECH2TEXT.value:
|
||||
mdlnm = tenant.asr_id
|
||||
elif llm_type == LLMType.IMAGE2TEXT.value:
|
||||
mdlnm = tenant.img2txt_id
|
||||
elif llm_type == LLMType.CHAT.value:
|
||||
mdlnm = tenant.llm_id if not llm_name else llm_name
|
||||
else:
|
||||
assert False, "LLM type error"
|
||||
|
||||
model_config = cls.get_api_key(tenant_id, mdlnm)
|
||||
if not model_config: raise LookupError("Model({}) not found".format(mdlnm))
|
||||
if not model_config:
|
||||
raise LookupError("Model({}) not found".format(mdlnm))
|
||||
model_config = model_config.to_dict()
|
||||
if llm_type == LLMType.EMBEDDING.value:
|
||||
if model_config["llm_factory"] not in EmbeddingModel: return
|
||||
return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"])
|
||||
if model_config["llm_factory"] not in EmbeddingModel:
|
||||
return
|
||||
return EmbeddingModel[model_config["llm_factory"]](
|
||||
model_config["api_key"], model_config["llm_name"])
|
||||
|
||||
if llm_type == LLMType.IMAGE2TEXT.value:
|
||||
if model_config["llm_factory"] not in CvModel: return
|
||||
return CvModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"])
|
||||
if model_config["llm_factory"] not in CvModel:
|
||||
return
|
||||
return CvModel[model_config["llm_factory"]](
|
||||
model_config["api_key"], model_config["llm_name"])
|
||||
|
||||
if llm_type == LLMType.CHAT.value:
|
||||
if model_config["llm_factory"] not in ChatModel:
|
||||
return
|
||||
return ChatModel[model_config["llm_factory"]](
|
||||
model_config["api_key"], model_config["llm_name"])
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def increase_usage(cls, tenant_id, llm_type, used_tokens, llm_name=None):
|
||||
e, tenant = TenantService.get_by_id(tenant_id)
|
||||
if not e:
|
||||
raise LookupError("Tenant not found")
|
||||
|
||||
if llm_type == LLMType.EMBEDDING.value:
|
||||
mdlnm = tenant.embd_id
|
||||
elif llm_type == LLMType.SPEECH2TEXT.value:
|
||||
mdlnm = tenant.asr_id
|
||||
elif llm_type == LLMType.IMAGE2TEXT.value:
|
||||
mdlnm = tenant.img2txt_id
|
||||
elif llm_type == LLMType.CHAT.value:
|
||||
mdlnm = tenant.llm_id if not llm_name else llm_name
|
||||
else:
|
||||
assert False, "LLM type error"
|
||||
|
||||
num = cls.model.update(used_tokens=cls.model.used_tokens + used_tokens)\
|
||||
.where(cls.model.tenant_id == tenant_id, cls.model.llm_name == mdlnm)\
|
||||
.execute()
|
||||
return num
|
||||
|
||||
|
||||
class LLMBundle(object):
|
||||
def __init__(self, tenant_id, llm_type, llm_name=None):
|
||||
self.tenant_id = tenant_id
|
||||
self.llm_type = llm_type
|
||||
self.llm_name = llm_name
|
||||
self.mdl = TenantLLMService.model_instance(tenant_id, llm_type, llm_name)
|
||||
assert self.mdl, "Can't find mole for {}/{}/{}".format(tenant_id, llm_type, llm_name)
|
||||
|
||||
def encode(self, texts: list, batch_size=32):
|
||||
emd, used_tokens = self.mdl.encode(texts, batch_size)
|
||||
if TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
|
||||
database_logger.error("Can't update token usage for {}/EMBEDDING".format(self.tenant_id))
|
||||
return emd, used_tokens
|
||||
|
||||
def encode_queries(self, query: str):
|
||||
emd, used_tokens = self.mdl.encode_queries(query)
|
||||
if TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
|
||||
database_logger.error("Can't update token usage for {}/EMBEDDING".format(self.tenant_id))
|
||||
return emd, used_tokens
|
||||
|
||||
def describe(self, image, max_tokens=300):
|
||||
txt, used_tokens = self.mdl.describe(image, max_tokens)
|
||||
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
|
||||
database_logger.error("Can't update token usage for {}/IMAGE2TEXT".format(self.tenant_id))
|
||||
return txt
|
||||
|
||||
def chat(self, system, history, gen_conf):
|
||||
txt, used_tokens = self.mdl.chat(system, history, gen_conf)
|
||||
if TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name):
|
||||
database_logger.error("Can't update token usage for {}/CHAT".format(self.tenant_id))
|
||||
return txt
|
||||
|
||||
Reference in New Issue
Block a user