refine admin initialization (#75)

This commit is contained in:
KevinHuSh
2024-02-27 14:57:34 +08:00
committed by GitHub
parent d1c600d5d3
commit 4568a4b2cb
13 changed files with 91 additions and 34 deletions

View File

@ -16,10 +16,12 @@
import time
import uuid
from api.db import LLMType
from api.db import LLMType, UserTenantRole
from api.db.db_models import init_database_tables as init_web_db
from api.db.services import UserService
from api.db.services.llm_service import LLMFactoriesService, LLMService
from api.db.services.llm_service import LLMFactoriesService, LLMService, TenantLLMService, LLMBundle
from api.db.services.user_service import TenantService, UserTenantService
from api.settings import CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS, LLM_FACTORY, API_KEY
def init_superuser():
@ -32,8 +34,44 @@ def init_superuser():
"creator": "system",
"status": "1",
}
tenant = {
"id": user_info["id"],
"name": user_info["nickname"] + "s Kingdom",
"llm_id": CHAT_MDL,
"embd_id": EMBEDDING_MDL,
"asr_id": ASR_MDL,
"parser_ids": PARSERS,
"img2txt_id": IMAGE2TEXT_MDL
}
usr_tenant = {
"tenant_id": user_info["id"],
"user_id": user_info["id"],
"invited_by": user_info["id"],
"role": UserTenantRole.OWNER
}
tenant_llm = []
for llm in LLMService.query(fid=LLM_FACTORY):
tenant_llm.append(
{"tenant_id": user_info["id"], "llm_factory": LLM_FACTORY, "llm_name": llm.llm_name, "model_type": llm.model_type,
"api_key": API_KEY})
if not UserService.save(**user_info):
print("【ERROR】can't init admin.")
return
TenantService.save(**tenant)
UserTenantService.save(**usr_tenant)
TenantLLMService.insert_many(tenant_llm)
UserService.save(**user_info)
chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"])
msg = chat_mdl.chat(system="", history=[{"role": "user", "content": "Hello!"}], gen_conf={})
if msg.find("ERROR: ") == 0:
print("【ERROR】: '{}' dosen't work. {}".format(tenant["llm_id"]), msg)
embd_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["embd_id"])
v,c = embd_mdl.encode(["Hello!"])
if c == 0:
print("【ERROR】: '{}' dosen't work...".format(tenant["embd_id"]))
def init_llm_factory():
factory_infos = [{
@ -171,10 +209,10 @@ def init_llm_factory():
def init_web_data():
start_time = time.time()
if not UserService.get_all().count():
init_superuser()
if not LLMService.get_all().count():init_llm_factory()
if not UserService.get_all().count():
init_superuser()
print("init web data success:{}".format(time.time() - start_time))