Support Ollama (#261)

### What problem does this PR solve?

Issue link:#221

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
KevinHuSh
2024-04-08 19:20:57 +08:00
committed by GitHub
parent 265a7a283a
commit 3708b97db9
15 changed files with 234 additions and 43 deletions

View File

@ -126,7 +126,7 @@ def message_fit_in(msg, max_length=4000):
if c < max_length:
return c, msg
msg_ = [m for m in msg[:-1] if m.role == "system"]
msg_ = [m for m in msg[:-1] if m["role"] == "system"]
msg_.append(msg[-1])
msg = msg_
c = count()

View File

@ -81,7 +81,7 @@ def upload():
"parser_id": kb.parser_id,
"parser_config": kb.parser_config,
"created_by": current_user.id,
"type": filename_type(filename),
"type": filetype,
"name": filename,
"location": location,
"size": len(blob),

View File

@ -91,6 +91,57 @@ def set_api_key():
return get_json_result(data=True)
@manager.route('/add_llm', methods=['POST'])
@login_required
@validate_request("llm_factory", "llm_name", "model_type")
def add_llm():
req = request.json
llm = {
"tenant_id": current_user.id,
"llm_factory": req["llm_factory"],
"model_type": req["model_type"],
"llm_name": req["llm_name"],
"api_base": req.get("api_base", ""),
"api_key": "xxxxxxxxxxxxxxx"
}
factory = req["llm_factory"]
msg = ""
if llm["model_type"] == LLMType.EMBEDDING.value:
mdl = EmbeddingModel[factory](
key=None, model_name=llm["llm_name"], base_url=llm["api_base"])
try:
arr, tc = mdl.encode(["Test if the api key is available"])
if len(arr[0]) == 0 or tc == 0:
raise Exception("Fail")
except Exception as e:
msg += f"\nFail to access embedding model({llm['llm_name']})." + str(e)
elif llm["model_type"] == LLMType.CHAT.value:
mdl = ChatModel[factory](
key=None, model_name=llm["llm_name"], base_url=llm["api_base"])
try:
m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {
"temperature": 0.9})
if not tc:
raise Exception(m)
except Exception as e:
msg += f"\nFail to access model({llm['llm_name']})." + str(
e)
else:
# TODO: check other type of models
pass
if msg:
return get_data_error_result(retmsg=msg)
if not TenantLLMService.filter_update(
[TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == factory, TenantLLM.llm_name == llm["llm_name"]], llm):
TenantLLMService.save(**llm)
return get_json_result(data=True)
@manager.route('/my_llms', methods=['GET'])
@login_required
def my_llms():
@ -125,6 +176,12 @@ def list():
for m in llms:
m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding"
llm_set = set([m["llm_name"] for m in llms])
for o in objs:
if not o.api_key:continue
if o.llm_name in llm_set:continue
llms.append({"llm_name": o.llm_name, "model_type": o.model_type, "fid": o.llm_factory, "available": True})
res = {}
for m in llms:
if model_type and m["model_type"] != model_type:

View File

@ -181,6 +181,10 @@ def user_info():
def rollback_user_registration(user_id):
try:
UserService.delete_by_id(user_id)
except Exception as e:
pass
try:
TenantService.delete_by_id(user_id)
except Exception as e:

View File

@ -18,7 +18,7 @@ import time
import uuid
from api.db import LLMType, UserTenantRole
from api.db.db_models import init_database_tables as init_web_db
from api.db.db_models import init_database_tables as init_web_db, LLMFactories, LLM
from api.db.services import UserService
from api.db.services.llm_service import LLMFactoriesService, LLMService, TenantLLMService, LLMBundle
from api.db.services.user_service import TenantService, UserTenantService
@ -100,16 +100,16 @@ factory_infos = [{
"status": "1",
},
{
"name": "Local",
"name": "Ollama",
"logo": "",
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
"status": "1",
}, {
"name": "Moonshot",
"name": "Moonshot",
"logo": "",
"tags": "LLM,TEXT EMBEDDING",
"status": "1",
}
},
# {
# "name": "文心一言",
# "logo": "",
@ -230,20 +230,6 @@ def init_llm_factory():
"max_tokens": 512,
"model_type": LLMType.EMBEDDING.value
},
# ---------------------- 本地 ----------------------
{
"fid": factory_infos[3]["name"],
"llm_name": "qwen-14B-chat",
"tags": "LLM,CHAT,",
"max_tokens": 4096,
"model_type": LLMType.CHAT.value
}, {
"fid": factory_infos[3]["name"],
"llm_name": "flag-embedding",
"tags": "TEXT EMBEDDING,",
"max_tokens": 128 * 1000,
"model_type": LLMType.EMBEDDING.value
},
# ------------------------ Moonshot -----------------------
{
"fid": factory_infos[4]["name"],
@ -282,6 +268,9 @@ def init_llm_factory():
except Exception as e:
pass
LLMFactoriesService.filter_delete([LLMFactories.name=="Local"])
LLMService.filter_delete([LLM.fid=="Local"])
"""
drop table llm;
drop table llm_factories;
@ -295,8 +284,7 @@ def init_llm_factory():
def init_web_data():
start_time = time.time()
if LLMFactoriesService.get_all().count() != len(factory_infos):
init_llm_factory()
init_llm_factory()
if not UserService.get_all().count():
init_superuser()