mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
add conversation API (#35)
This commit is contained in:
@ -34,7 +34,7 @@ class TenantLLMService(CommonService):
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_api_key(cls, tenant_id, model_type):
|
||||
def get_api_key(cls, tenant_id, model_type, model_name=""):
|
||||
objs = cls.query(tenant_id=tenant_id, model_type=model_type)
|
||||
if objs and len(objs)>0 and objs[0].llm_name:
|
||||
return objs[0]
|
||||
@ -42,7 +42,7 @@ class TenantLLMService(CommonService):
|
||||
fields = [LLM.llm_name, cls.model.llm_factory, cls.model.api_key]
|
||||
objs = cls.model.select(*fields).join(LLM, on=(LLM.fid == cls.model.llm_factory)).where(
|
||||
(cls.model.tenant_id == tenant_id),
|
||||
(cls.model.model_type == model_type),
|
||||
((cls.model.model_type == model_type) | (cls.model.llm_name == model_name)),
|
||||
(LLM.status == StatusEnum.VALID)
|
||||
)
|
||||
|
||||
@ -60,7 +60,7 @@ class TenantLLMService(CommonService):
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def model_instance(cls, tenant_id, llm_type):
|
||||
model_config = cls.get_api_key(tenant_id, model_type=LLMType.EMBEDDING)
|
||||
model_config = cls.get_api_key(tenant_id, model_type=LLMType.EMBEDDING.value)
|
||||
if not model_config:
|
||||
model_config = {"llm_factory": "local", "api_key": "", "llm_name": ""}
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user