refine admin initialization (#75)

This commit is contained in:
KevinHuSh
2024-02-27 14:57:34 +08:00
committed by GitHub
parent d1c600d5d3
commit 4568a4b2cb
13 changed files with 91 additions and 34 deletions

View File

@ -20,7 +20,7 @@ from flask_login import login_required, current_user
from elasticsearch_dsl import Q
from rag.app.qa import rmPrefix, beAdoc
from rag.nlp import search, huqie, retrievaler
from rag.nlp import search, huqie
from rag.utils import ELASTICSEARCH, rmSpace
from api.db import LLMType, ParserType
from api.db.services.knowledgebase_service import KnowledgebaseService
@ -28,7 +28,7 @@ from api.db.services.llm_service import TenantLLMService
from api.db.services.user_service import UserTenantService
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from api.db.services.document_service import DocumentService
from api.settings import RetCode
from api.settings import RetCode, retrievaler
from api.utils.api_utils import get_json_result
import hashlib
import re

View File

@ -21,13 +21,11 @@ from api.db.services.dialog_service import DialogService, ConversationService
from api.db import LLMType
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMService, LLMBundle
from api.settings import access_logger, stat_logger
from api.settings import access_logger, stat_logger, retrievaler
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from api.utils import get_uuid
from api.utils.api_utils import get_json_result
from rag.app.resume import forbidden_select_fields4resume
from rag.llm import ChatModel
from rag.nlp import retrievaler
from rag.nlp.search import index_name
from rag.utils import num_tokens_from_string, encoder, rmSpace

View File

@ -16,10 +16,12 @@
import time
import uuid
from api.db import LLMType
from api.db import LLMType, UserTenantRole
from api.db.db_models import init_database_tables as init_web_db
from api.db.services import UserService
from api.db.services.llm_service import LLMFactoriesService, LLMService
from api.db.services.llm_service import LLMFactoriesService, LLMService, TenantLLMService, LLMBundle
from api.db.services.user_service import TenantService, UserTenantService
from api.settings import CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS, LLM_FACTORY, API_KEY
def init_superuser():
@ -32,8 +34,44 @@ def init_superuser():
"creator": "system",
"status": "1",
}
tenant = {
"id": user_info["id"],
"name": user_info["nickname"] + "s Kingdom",
"llm_id": CHAT_MDL,
"embd_id": EMBEDDING_MDL,
"asr_id": ASR_MDL,
"parser_ids": PARSERS,
"img2txt_id": IMAGE2TEXT_MDL
}
usr_tenant = {
"tenant_id": user_info["id"],
"user_id": user_info["id"],
"invited_by": user_info["id"],
"role": UserTenantRole.OWNER
}
tenant_llm = []
for llm in LLMService.query(fid=LLM_FACTORY):
tenant_llm.append(
{"tenant_id": user_info["id"], "llm_factory": LLM_FACTORY, "llm_name": llm.llm_name, "model_type": llm.model_type,
"api_key": API_KEY})
if not UserService.save(**user_info):
print("【ERROR】can't init admin.")
return
TenantService.save(**tenant)
UserTenantService.save(**usr_tenant)
TenantLLMService.insert_many(tenant_llm)
UserService.save(**user_info)
chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"])
msg = chat_mdl.chat(system="", history=[{"role": "user", "content": "Hello!"}], gen_conf={})
if msg.find("ERROR: ") == 0:
print("【ERROR】: '{}' dosen't work. {}".format(tenant["llm_id"]), msg)
embd_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["embd_id"])
v,c = embd_mdl.encode(["Hello!"])
if c == 0:
print("【ERROR】: '{}' dosen't work...".format(tenant["embd_id"]))
def init_llm_factory():
factory_infos = [{
@ -171,10 +209,10 @@ def init_llm_factory():
def init_web_data():
start_time = time.time()
if not UserService.get_all().count():
init_superuser()
if not LLMService.get_all().count():init_llm_factory()
if not UserService.get_all().count():
init_superuser()
print("init web data success:{}".format(time.time() - start_time))

View File

@ -21,8 +21,10 @@ from api.utils import get_base_config,decrypt_database_config
from api.utils.file_utils import get_project_base_directory
from api.utils.log_utils import LoggerFactory, getLogger
from rag.nlp import search
from rag.utils import ELASTICSEARCH
# Server
API_VERSION = "v1"
RAG_FLOW_SERVICE_NAME = "ragflow"
SERVER_MODULE = "rag_flow_server.py"
@ -116,6 +118,8 @@ AUTHENTICATION_DEFAULT_TIMEOUT = 30 * 24 * 60 * 60 # s
PRIVILEGE_COMMAND_WHITELIST = []
CHECK_NODES_IDENTITY = False
retrievaler = search.Dealer(ELASTICSEARCH)
class CustomEnum(Enum):
@classmethod
def valid(cls, value):