Support new feature about Ollama (#262)

### What problem does this PR solve?

Issue link:#221

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
KevinHuSh
2024-04-08 19:59:31 +08:00
committed by GitHub
parent 3708b97db9
commit b6887a20f8
4 changed files with 17 additions and 8 deletions

View File

@ -20,7 +20,7 @@ from flask_login import login_required
from api.db.services.dialog_service import DialogService, ConversationService
from api.db import LLMType
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMService, LLMBundle
from api.db.services.llm_service import LLMService, LLMBundle, TenantLLMService
from api.settings import access_logger, stat_logger, retrievaler, chat_logger
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from api.utils import get_uuid
@ -184,8 +184,11 @@ def chat(dialog, messages, **kwargs):
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
llm = LLMService.query(llm_name=dialog.llm_id)
if not llm:
raise LookupError("LLM(%s) not found" % dialog.llm_id)
llm = llm[0]
llm = TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=dialog.llm_id)
if not llm:
raise LookupError("LLM(%s) not found" % dialog.llm_id)
max_tokens = 1024
else: max_tokens = llm[0].max_tokens
questions = [m["content"] for m in messages if m["role"] == "user"]
embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING)
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
@ -227,11 +230,11 @@ def chat(dialog, messages, **kwargs):
gen_conf = dialog.llm_setting
msg = [{"role": m["role"], "content": m["content"]}
for m in messages if m["role"] != "system"]
used_token_count, msg = message_fit_in(msg, int(llm.max_tokens * 0.97))
used_token_count, msg = message_fit_in(msg, int(max_tokens * 0.97))
if "max_tokens" in gen_conf:
gen_conf["max_tokens"] = min(
gen_conf["max_tokens"],
llm.max_tokens - used_token_count)
max_tokens - used_token_count)
answer = chat_mdl.chat(
prompt_config["system"].format(
**kwargs), msg, gen_conf)