mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
refine admin initialization (#75)
This commit is contained in:
@ -20,7 +20,7 @@ from flask_login import login_required, current_user
|
||||
from elasticsearch_dsl import Q
|
||||
|
||||
from rag.app.qa import rmPrefix, beAdoc
|
||||
from rag.nlp import search, huqie, retrievaler
|
||||
from rag.nlp import search, huqie
|
||||
from rag.utils import ELASTICSEARCH, rmSpace
|
||||
from api.db import LLMType, ParserType
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
@ -28,7 +28,7 @@ from api.db.services.llm_service import TenantLLMService
|
||||
from api.db.services.user_service import UserTenantService
|
||||
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.settings import RetCode
|
||||
from api.settings import RetCode, retrievaler
|
||||
from api.utils.api_utils import get_json_result
|
||||
import hashlib
|
||||
import re
|
||||
|
||||
@ -21,13 +21,11 @@ from api.db.services.dialog_service import DialogService, ConversationService
|
||||
from api.db import LLMType
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMService, LLMBundle
|
||||
from api.settings import access_logger, stat_logger
|
||||
from api.settings import access_logger, stat_logger, retrievaler
|
||||
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
||||
from api.utils import get_uuid
|
||||
from api.utils.api_utils import get_json_result
|
||||
from rag.app.resume import forbidden_select_fields4resume
|
||||
from rag.llm import ChatModel
|
||||
from rag.nlp import retrievaler
|
||||
from rag.nlp.search import index_name
|
||||
from rag.utils import num_tokens_from_string, encoder, rmSpace
|
||||
|
||||
|
||||
@ -16,10 +16,12 @@
|
||||
import time
|
||||
import uuid
|
||||
|
||||
from api.db import LLMType
|
||||
from api.db import LLMType, UserTenantRole
|
||||
from api.db.db_models import init_database_tables as init_web_db
|
||||
from api.db.services import UserService
|
||||
from api.db.services.llm_service import LLMFactoriesService, LLMService
|
||||
from api.db.services.llm_service import LLMFactoriesService, LLMService, TenantLLMService, LLMBundle
|
||||
from api.db.services.user_service import TenantService, UserTenantService
|
||||
from api.settings import CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS, LLM_FACTORY, API_KEY
|
||||
|
||||
|
||||
def init_superuser():
|
||||
@ -32,8 +34,44 @@ def init_superuser():
|
||||
"creator": "system",
|
||||
"status": "1",
|
||||
}
|
||||
tenant = {
|
||||
"id": user_info["id"],
|
||||
"name": user_info["nickname"] + "‘s Kingdom",
|
||||
"llm_id": CHAT_MDL,
|
||||
"embd_id": EMBEDDING_MDL,
|
||||
"asr_id": ASR_MDL,
|
||||
"parser_ids": PARSERS,
|
||||
"img2txt_id": IMAGE2TEXT_MDL
|
||||
}
|
||||
usr_tenant = {
|
||||
"tenant_id": user_info["id"],
|
||||
"user_id": user_info["id"],
|
||||
"invited_by": user_info["id"],
|
||||
"role": UserTenantRole.OWNER
|
||||
}
|
||||
tenant_llm = []
|
||||
for llm in LLMService.query(fid=LLM_FACTORY):
|
||||
tenant_llm.append(
|
||||
{"tenant_id": user_info["id"], "llm_factory": LLM_FACTORY, "llm_name": llm.llm_name, "model_type": llm.model_type,
|
||||
"api_key": API_KEY})
|
||||
|
||||
if not UserService.save(**user_info):
|
||||
print("【ERROR】can't init admin.")
|
||||
return
|
||||
TenantService.save(**tenant)
|
||||
UserTenantService.save(**usr_tenant)
|
||||
TenantLLMService.insert_many(tenant_llm)
|
||||
UserService.save(**user_info)
|
||||
|
||||
chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"])
|
||||
msg = chat_mdl.chat(system="", history=[{"role": "user", "content": "Hello!"}], gen_conf={})
|
||||
if msg.find("ERROR: ") == 0:
|
||||
print("【ERROR】: '{}' dosen't work. {}".format(tenant["llm_id"]), msg)
|
||||
embd_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["embd_id"])
|
||||
v,c = embd_mdl.encode(["Hello!"])
|
||||
if c == 0:
|
||||
print("【ERROR】: '{}' dosen't work...".format(tenant["embd_id"]))
|
||||
|
||||
|
||||
def init_llm_factory():
|
||||
factory_infos = [{
|
||||
@ -171,10 +209,10 @@ def init_llm_factory():
|
||||
|
||||
def init_web_data():
|
||||
start_time = time.time()
|
||||
if not UserService.get_all().count():
|
||||
init_superuser()
|
||||
|
||||
if not LLMService.get_all().count():init_llm_factory()
|
||||
if not UserService.get_all().count():
|
||||
init_superuser()
|
||||
|
||||
print("init web data success:{}".format(time.time() - start_time))
|
||||
|
||||
|
||||
@ -21,8 +21,10 @@ from api.utils import get_base_config,decrypt_database_config
|
||||
from api.utils.file_utils import get_project_base_directory
|
||||
from api.utils.log_utils import LoggerFactory, getLogger
|
||||
|
||||
from rag.nlp import search
|
||||
from rag.utils import ELASTICSEARCH
|
||||
|
||||
|
||||
# Server
|
||||
API_VERSION = "v1"
|
||||
RAG_FLOW_SERVICE_NAME = "ragflow"
|
||||
SERVER_MODULE = "rag_flow_server.py"
|
||||
@ -116,6 +118,8 @@ AUTHENTICATION_DEFAULT_TIMEOUT = 30 * 24 * 60 * 60 # s
|
||||
PRIVILEGE_COMMAND_WHITELIST = []
|
||||
CHECK_NODES_IDENTITY = False
|
||||
|
||||
retrievaler = search.Dealer(ELASTICSEARCH)
|
||||
|
||||
class CustomEnum(Enum):
|
||||
@classmethod
|
||||
def valid(cls, value):
|
||||
|
||||
Reference in New Issue
Block a user