mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Feat: allows setting multiple types of default models in service config (#9404)
### What problem does this PR solve? Allows set multiple types of default models in service config. ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -620,18 +620,35 @@ def user_register(user_id, user):
|
|||||||
"location": "",
|
"location": "",
|
||||||
}
|
}
|
||||||
tenant_llm = []
|
tenant_llm = []
|
||||||
for llm in LLMService.query(fid=settings.LLM_FACTORY):
|
|
||||||
tenant_llm.append(
|
seen = set()
|
||||||
{
|
factory_configs = []
|
||||||
"tenant_id": user_id,
|
for factory_config in [
|
||||||
"llm_factory": settings.LLM_FACTORY,
|
settings.CHAT_CFG,
|
||||||
"llm_name": llm.llm_name,
|
settings.EMBEDDING_CFG,
|
||||||
"model_type": llm.model_type,
|
settings.ASR_CFG,
|
||||||
"api_key": settings.API_KEY,
|
settings.IMAGE2TEXT_CFG,
|
||||||
"api_base": settings.LLM_BASE_URL,
|
settings.RERANK_CFG,
|
||||||
"max_tokens": llm.max_tokens if llm.max_tokens else 8192,
|
]:
|
||||||
}
|
factory_name = factory_config["factory"]
|
||||||
)
|
if factory_name not in seen:
|
||||||
|
seen.add(factory_name)
|
||||||
|
factory_configs.append(factory_config)
|
||||||
|
|
||||||
|
for factory_config in factory_configs:
|
||||||
|
for llm in LLMService.query(fid=factory_config["factory"]):
|
||||||
|
tenant_llm.append(
|
||||||
|
{
|
||||||
|
"tenant_id": user_id,
|
||||||
|
"llm_factory": factory_config["factory"],
|
||||||
|
"llm_name": llm.llm_name,
|
||||||
|
"model_type": llm.model_type,
|
||||||
|
"api_key": factory_config["api_key"],
|
||||||
|
"api_base": factory_config["base_url"],
|
||||||
|
"max_tokens": llm.max_tokens if llm.max_tokens else 8192,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
if settings.LIGHTEN != 1:
|
if settings.LIGHTEN != 1:
|
||||||
for buildin_embedding_model in settings.BUILTIN_EMBEDDING_MODELS:
|
for buildin_embedding_model in settings.BUILTIN_EMBEDDING_MODELS:
|
||||||
mdlnm, fid = TenantLLMService.split_model_name_and_factory(buildin_embedding_model)
|
mdlnm, fid = TenantLLMService.split_model_name_and_factory(buildin_embedding_model)
|
||||||
@ -647,6 +664,13 @@ def user_register(user_id, user):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
unique = {}
|
||||||
|
for item in tenant_llm:
|
||||||
|
key = (item["tenant_id"], item["llm_factory"], item["llm_name"])
|
||||||
|
if key not in unique:
|
||||||
|
unique[key] = item
|
||||||
|
tenant_llm = list(unique.values())
|
||||||
|
|
||||||
if not UserService.save(**user):
|
if not UserService.save(**user):
|
||||||
return
|
return
|
||||||
TenantService.insert(**tenant)
|
TenantService.insert(**tenant)
|
||||||
|
|||||||
@ -63,12 +63,44 @@ def init_superuser():
|
|||||||
"invited_by": user_info["id"],
|
"invited_by": user_info["id"],
|
||||||
"role": UserTenantRole.OWNER
|
"role": UserTenantRole.OWNER
|
||||||
}
|
}
|
||||||
|
|
||||||
|
user_id = user_info
|
||||||
tenant_llm = []
|
tenant_llm = []
|
||||||
for llm in LLMService.query(fid=settings.LLM_FACTORY):
|
|
||||||
tenant_llm.append(
|
seen = set()
|
||||||
{"tenant_id": user_info["id"], "llm_factory": settings.LLM_FACTORY, "llm_name": llm.llm_name,
|
factory_configs = []
|
||||||
"model_type": llm.model_type,
|
for factory_config in [
|
||||||
"api_key": settings.API_KEY, "api_base": settings.LLM_BASE_URL})
|
settings.CHAT_CFG["factory"],
|
||||||
|
settings.EMBEDDING_CFG["factory"],
|
||||||
|
settings.ASR_CFG["factory"],
|
||||||
|
settings.IMAGE2TEXT_CFG["factory"],
|
||||||
|
settings.RERANK_CFG["factory"],
|
||||||
|
]:
|
||||||
|
factory_name = factory_config["factory"]
|
||||||
|
if factory_name not in seen:
|
||||||
|
seen.add(factory_name)
|
||||||
|
factory_configs.append(factory_config)
|
||||||
|
|
||||||
|
for factory_config in factory_configs:
|
||||||
|
for llm in LLMService.query(fid=factory_config["factory"]):
|
||||||
|
tenant_llm.append(
|
||||||
|
{
|
||||||
|
"tenant_id": user_id,
|
||||||
|
"llm_factory": factory_config["factory"],
|
||||||
|
"llm_name": llm.llm_name,
|
||||||
|
"model_type": llm.model_type,
|
||||||
|
"api_key": factory_config["api_key"],
|
||||||
|
"api_base": factory_config["base_url"],
|
||||||
|
"max_tokens": llm.max_tokens if llm.max_tokens else 8192,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
unique = {}
|
||||||
|
for item in tenant_llm:
|
||||||
|
key = (item["tenant_id"], item["llm_factory"], item["llm_name"])
|
||||||
|
if key not in unique:
|
||||||
|
unique[key] = item
|
||||||
|
tenant_llm = list(unique.values())
|
||||||
|
|
||||||
if not UserService.save(**user_info):
|
if not UserService.save(**user_info):
|
||||||
logging.error("can't init admin.")
|
logging.error("can't init admin.")
|
||||||
|
|||||||
@ -38,6 +38,11 @@ EMBEDDING_MDL = ""
|
|||||||
RERANK_MDL = ""
|
RERANK_MDL = ""
|
||||||
ASR_MDL = ""
|
ASR_MDL = ""
|
||||||
IMAGE2TEXT_MDL = ""
|
IMAGE2TEXT_MDL = ""
|
||||||
|
CHAT_CFG = ""
|
||||||
|
EMBEDDING_CFG = ""
|
||||||
|
RERANK_CFG = ""
|
||||||
|
ASR_CFG = ""
|
||||||
|
IMAGE2TEXT_CFG = ""
|
||||||
API_KEY = None
|
API_KEY = None
|
||||||
PARSERS = None
|
PARSERS = None
|
||||||
HOST_IP = None
|
HOST_IP = None
|
||||||
@ -74,6 +79,7 @@ STRONG_TEST_COUNT = int(os.environ.get("STRONG_TEST_COUNT", "8"))
|
|||||||
|
|
||||||
BUILTIN_EMBEDDING_MODELS = ["BAAI/bge-large-zh-v1.5@BAAI", "maidalun1020/bce-embedding-base_v1@Youdao"]
|
BUILTIN_EMBEDDING_MODELS = ["BAAI/bge-large-zh-v1.5@BAAI", "maidalun1020/bce-embedding-base_v1@Youdao"]
|
||||||
|
|
||||||
|
|
||||||
def get_or_create_secret_key():
|
def get_or_create_secret_key():
|
||||||
secret_key = os.environ.get("RAGFLOW_SECRET_KEY")
|
secret_key = os.environ.get("RAGFLOW_SECRET_KEY")
|
||||||
if secret_key and len(secret_key) >= 32:
|
if secret_key and len(secret_key) >= 32:
|
||||||
@ -86,11 +92,9 @@ def get_or_create_secret_key():
|
|||||||
|
|
||||||
# Generate a new secure key and warn about it
|
# Generate a new secure key and warn about it
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
new_key = secrets.token_hex(32)
|
new_key = secrets.token_hex(32)
|
||||||
logging.warning(
|
logging.warning(f"SECURITY WARNING: Using auto-generated SECRET_KEY. Generated key: {new_key}")
|
||||||
"SECURITY WARNING: Using auto-generated SECRET_KEY. "
|
|
||||||
f"Generated key: {new_key}"
|
|
||||||
)
|
|
||||||
return new_key
|
return new_key
|
||||||
|
|
||||||
|
|
||||||
@ -99,10 +103,10 @@ def init_settings():
|
|||||||
LIGHTEN = int(os.environ.get("LIGHTEN", "0"))
|
LIGHTEN = int(os.environ.get("LIGHTEN", "0"))
|
||||||
DATABASE_TYPE = os.getenv("DB_TYPE", "mysql")
|
DATABASE_TYPE = os.getenv("DB_TYPE", "mysql")
|
||||||
DATABASE = decrypt_database_config(name=DATABASE_TYPE)
|
DATABASE = decrypt_database_config(name=DATABASE_TYPE)
|
||||||
LLM = get_base_config("user_default_llm", {})
|
LLM = get_base_config("user_default_llm", {}) or {}
|
||||||
LLM_DEFAULT_MODELS = LLM.get("default_models", {})
|
LLM_DEFAULT_MODELS = LLM.get("default_models", {}) or {}
|
||||||
LLM_FACTORY = LLM.get("factory")
|
LLM_FACTORY = LLM.get("factory", "") or ""
|
||||||
LLM_BASE_URL = LLM.get("base_url")
|
LLM_BASE_URL = LLM.get("base_url", "") or ""
|
||||||
try:
|
try:
|
||||||
REGISTER_ENABLED = int(os.environ.get("REGISTER_ENABLED", "1"))
|
REGISTER_ENABLED = int(os.environ.get("REGISTER_ENABLED", "1"))
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -115,29 +119,34 @@ def init_settings():
|
|||||||
FACTORY_LLM_INFOS = []
|
FACTORY_LLM_INFOS = []
|
||||||
|
|
||||||
global CHAT_MDL, EMBEDDING_MDL, RERANK_MDL, ASR_MDL, IMAGE2TEXT_MDL
|
global CHAT_MDL, EMBEDDING_MDL, RERANK_MDL, ASR_MDL, IMAGE2TEXT_MDL
|
||||||
|
global CHAT_CFG, EMBEDDING_CFG, RERANK_CFG, ASR_CFG, IMAGE2TEXT_CFG
|
||||||
if not LIGHTEN:
|
if not LIGHTEN:
|
||||||
EMBEDDING_MDL = BUILTIN_EMBEDDING_MODELS[0]
|
EMBEDDING_MDL = BUILTIN_EMBEDDING_MODELS[0]
|
||||||
|
|
||||||
if LLM_DEFAULT_MODELS:
|
|
||||||
CHAT_MDL = LLM_DEFAULT_MODELS.get("chat_model", CHAT_MDL)
|
|
||||||
EMBEDDING_MDL = LLM_DEFAULT_MODELS.get("embedding_model", EMBEDDING_MDL)
|
|
||||||
RERANK_MDL = LLM_DEFAULT_MODELS.get("rerank_model", RERANK_MDL)
|
|
||||||
ASR_MDL = LLM_DEFAULT_MODELS.get("asr_model", ASR_MDL)
|
|
||||||
IMAGE2TEXT_MDL = LLM_DEFAULT_MODELS.get("image2text_model", IMAGE2TEXT_MDL)
|
|
||||||
|
|
||||||
# factory can be specified in the config name with "@". LLM_FACTORY will be used if not specified
|
|
||||||
CHAT_MDL = CHAT_MDL + (f"@{LLM_FACTORY}" if "@" not in CHAT_MDL and CHAT_MDL != "" else "")
|
|
||||||
EMBEDDING_MDL = EMBEDDING_MDL + (f"@{LLM_FACTORY}" if "@" not in EMBEDDING_MDL and EMBEDDING_MDL != "" else "")
|
|
||||||
RERANK_MDL = RERANK_MDL + (f"@{LLM_FACTORY}" if "@" not in RERANK_MDL and RERANK_MDL != "" else "")
|
|
||||||
ASR_MDL = ASR_MDL + (f"@{LLM_FACTORY}" if "@" not in ASR_MDL and ASR_MDL != "" else "")
|
|
||||||
IMAGE2TEXT_MDL = IMAGE2TEXT_MDL + (f"@{LLM_FACTORY}" if "@" not in IMAGE2TEXT_MDL and IMAGE2TEXT_MDL != "" else "")
|
|
||||||
|
|
||||||
global API_KEY, PARSERS, HOST_IP, HOST_PORT, SECRET_KEY
|
global API_KEY, PARSERS, HOST_IP, HOST_PORT, SECRET_KEY
|
||||||
API_KEY = LLM.get("api_key")
|
API_KEY = LLM.get("api_key")
|
||||||
PARSERS = LLM.get(
|
PARSERS = LLM.get(
|
||||||
"parsers", "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,email:Email,tag:Tag"
|
"parsers", "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,email:Email,tag:Tag"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
chat_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("chat_model", CHAT_MDL))
|
||||||
|
embedding_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("embedding_model", EMBEDDING_MDL))
|
||||||
|
rerank_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("rerank_model", RERANK_MDL))
|
||||||
|
asr_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("asr_model", ASR_MDL))
|
||||||
|
image2text_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("image2text_model", IMAGE2TEXT_MDL))
|
||||||
|
|
||||||
|
CHAT_CFG = _resolve_per_model_config(chat_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL)
|
||||||
|
EMBEDDING_CFG = _resolve_per_model_config(embedding_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL)
|
||||||
|
RERANK_CFG = _resolve_per_model_config(rerank_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL)
|
||||||
|
ASR_CFG = _resolve_per_model_config(asr_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL)
|
||||||
|
IMAGE2TEXT_CFG = _resolve_per_model_config(image2text_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL)
|
||||||
|
|
||||||
|
CHAT_MDL = CHAT_CFG.get("model", "") or ""
|
||||||
|
EMBEDDING_MDL = EMBEDDING_CFG.get("model", "") or ""
|
||||||
|
RERANK_MDL = RERANK_CFG.get("model", "") or ""
|
||||||
|
ASR_MDL = ASR_CFG.get("model", "") or ""
|
||||||
|
IMAGE2TEXT_MDL = IMAGE2TEXT_CFG.get("model", "") or ""
|
||||||
|
|
||||||
HOST_IP = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("host", "127.0.0.1")
|
HOST_IP = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("host", "127.0.0.1")
|
||||||
HOST_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("http_port")
|
HOST_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("http_port")
|
||||||
|
|
||||||
@ -170,6 +179,7 @@ def init_settings():
|
|||||||
|
|
||||||
retrievaler = search.Dealer(docStoreConn)
|
retrievaler = search.Dealer(docStoreConn)
|
||||||
from graphrag import search as kg_search
|
from graphrag import search as kg_search
|
||||||
|
|
||||||
kg_retrievaler = kg_search.KGSearch(docStoreConn)
|
kg_retrievaler = kg_search.KGSearch(docStoreConn)
|
||||||
|
|
||||||
if int(os.environ.get("SANDBOX_ENABLED", "0")):
|
if int(os.environ.get("SANDBOX_ENABLED", "0")):
|
||||||
@ -210,3 +220,34 @@ class RetCode(IntEnum, CustomEnum):
|
|||||||
SERVER_ERROR = 500
|
SERVER_ERROR = 500
|
||||||
FORBIDDEN = 403
|
FORBIDDEN = 403
|
||||||
NOT_FOUND = 404
|
NOT_FOUND = 404
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_model_entry(entry):
|
||||||
|
if isinstance(entry, str):
|
||||||
|
return {"name": entry, "factory": None, "api_key": None, "base_url": None}
|
||||||
|
if isinstance(entry, dict):
|
||||||
|
name = entry.get("name") or entry.get("model") or ""
|
||||||
|
return {
|
||||||
|
"name": name,
|
||||||
|
"factory": entry.get("factory"),
|
||||||
|
"api_key": entry.get("api_key"),
|
||||||
|
"base_url": entry.get("base_url"),
|
||||||
|
}
|
||||||
|
return {"name": "", "factory": None, "api_key": None, "base_url": None}
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_per_model_config(entry_dict, backup_factory, backup_api_key, backup_base_url):
|
||||||
|
name = (entry_dict.get("name") or "").strip()
|
||||||
|
m_factory = entry_dict.get("factory") or backup_factory or ""
|
||||||
|
m_api_key = entry_dict.get("api_key") or backup_api_key or ""
|
||||||
|
m_base_url = entry_dict.get("base_url") or backup_base_url or ""
|
||||||
|
|
||||||
|
if name and "@" not in name and m_factory:
|
||||||
|
name = f"{name}@{m_factory}"
|
||||||
|
|
||||||
|
return {
|
||||||
|
"model": name,
|
||||||
|
"factory": m_factory,
|
||||||
|
"api_key": m_api_key,
|
||||||
|
"base_url": m_base_url,
|
||||||
|
}
|
||||||
|
|||||||
@ -64,9 +64,21 @@ redis:
|
|||||||
# config:
|
# config:
|
||||||
# oss_table: 'opendal_storage'
|
# oss_table: 'opendal_storage'
|
||||||
# user_default_llm:
|
# user_default_llm:
|
||||||
# factory: 'Tongyi-Qianwen'
|
# factory: 'BAAI'
|
||||||
# api_key: 'sk-xxxxxxxxxxxxx'
|
# api_key: 'backup'
|
||||||
# base_url: ''
|
# base_url: 'backup_base_url'
|
||||||
|
# default_models:
|
||||||
|
# chat_model:
|
||||||
|
# name: 'qwen2.5-7b-instruct'
|
||||||
|
# factory: 'xxxx'
|
||||||
|
# api_key: 'xxxx'
|
||||||
|
# base_url: 'https://api.xx.com'
|
||||||
|
# embedding_model:
|
||||||
|
# name: 'bge-m3'
|
||||||
|
# rerank_model: 'bge-reranker-v2'
|
||||||
|
# asr_model:
|
||||||
|
# model: 'whisper-large-v3' # alias of name
|
||||||
|
# image2text_model: ''
|
||||||
# oauth:
|
# oauth:
|
||||||
# oauth2:
|
# oauth2:
|
||||||
# display_name: "OAuth2"
|
# display_name: "OAuth2"
|
||||||
|
|||||||
Reference in New Issue
Block a user