add conversation API (#35)

This commit is contained in:
KevinHuSh
2024-01-18 19:28:37 +08:00
committed by GitHub
parent fad2ec7cf3
commit 4a858d33b6
13 changed files with 425 additions and 153 deletions

View File

@ -34,7 +34,7 @@ class TenantLLMService(CommonService):
@classmethod
@DB.connection_context()
def get_api_key(cls, tenant_id, model_type):
def get_api_key(cls, tenant_id, model_type, model_name=""):
objs = cls.query(tenant_id=tenant_id, model_type=model_type)
if objs and len(objs)>0 and objs[0].llm_name:
return objs[0]
@ -42,7 +42,7 @@ class TenantLLMService(CommonService):
fields = [LLM.llm_name, cls.model.llm_factory, cls.model.api_key]
objs = cls.model.select(*fields).join(LLM, on=(LLM.fid == cls.model.llm_factory)).where(
(cls.model.tenant_id == tenant_id),
(cls.model.model_type == model_type),
((cls.model.model_type == model_type) | (cls.model.llm_name == model_name)),
(LLM.status == StatusEnum.VALID)
)
@ -60,7 +60,7 @@ class TenantLLMService(CommonService):
@classmethod
@DB.connection_context()
def model_instance(cls, tenant_id, llm_type):
model_config = cls.get_api_key(tenant_id, model_type=LLMType.EMBEDDING)
model_config = cls.get_api_key(tenant_id, model_type=LLMType.EMBEDDING.value)
if not model_config:
model_config = {"llm_factory": "local", "api_key": "", "llm_name": ""}
else: