From 421657f64beb59b142a84b8792296db46ca337a4 Mon Sep 17 00:00:00 2001 From: Yongteng Lei Date: Wed, 13 Aug 2025 09:46:05 +0800 Subject: [PATCH] Feat: allows setting multiple types of default models in service config (#9404) ### What problem does this PR solve? Allows set multiple types of default models in service config. ### Type of change - [x] New Feature (non-breaking change which adds functionality) --- api/apps/user_app.py | 48 +++++++++++++++++------ api/db/init_data.py | 44 ++++++++++++++++++--- api/settings.py | 89 ++++++++++++++++++++++++++++++------------ conf/service_conf.yaml | 18 +++++++-- 4 files changed, 154 insertions(+), 45 deletions(-) diff --git a/api/apps/user_app.py b/api/apps/user_app.py index b8d66ecba..93b07615c 100644 --- a/api/apps/user_app.py +++ b/api/apps/user_app.py @@ -620,18 +620,35 @@ def user_register(user_id, user): "location": "", } tenant_llm = [] - for llm in LLMService.query(fid=settings.LLM_FACTORY): - tenant_llm.append( - { - "tenant_id": user_id, - "llm_factory": settings.LLM_FACTORY, - "llm_name": llm.llm_name, - "model_type": llm.model_type, - "api_key": settings.API_KEY, - "api_base": settings.LLM_BASE_URL, - "max_tokens": llm.max_tokens if llm.max_tokens else 8192, - } - ) + + seen = set() + factory_configs = [] + for factory_config in [ + settings.CHAT_CFG, + settings.EMBEDDING_CFG, + settings.ASR_CFG, + settings.IMAGE2TEXT_CFG, + settings.RERANK_CFG, + ]: + factory_name = factory_config["factory"] + if factory_name not in seen: + seen.add(factory_name) + factory_configs.append(factory_config) + + for factory_config in factory_configs: + for llm in LLMService.query(fid=factory_config["factory"]): + tenant_llm.append( + { + "tenant_id": user_id, + "llm_factory": factory_config["factory"], + "llm_name": llm.llm_name, + "model_type": llm.model_type, + "api_key": factory_config["api_key"], + "api_base": factory_config["base_url"], + "max_tokens": llm.max_tokens if llm.max_tokens else 8192, + } + ) + if settings.LIGHTEN != 1: for buildin_embedding_model in settings.BUILTIN_EMBEDDING_MODELS: mdlnm, fid = TenantLLMService.split_model_name_and_factory(buildin_embedding_model) @@ -647,6 +664,13 @@ def user_register(user_id, user): } ) + unique = {} + for item in tenant_llm: + key = (item["tenant_id"], item["llm_factory"], item["llm_name"]) + if key not in unique: + unique[key] = item + tenant_llm = list(unique.values()) + if not UserService.save(**user): return TenantService.insert(**tenant) diff --git a/api/db/init_data.py b/api/db/init_data.py index 390bce48e..83456035d 100644 --- a/api/db/init_data.py +++ b/api/db/init_data.py @@ -63,12 +63,44 @@ def init_superuser(): "invited_by": user_info["id"], "role": UserTenantRole.OWNER } + + user_id = user_info tenant_llm = [] - for llm in LLMService.query(fid=settings.LLM_FACTORY): - tenant_llm.append( - {"tenant_id": user_info["id"], "llm_factory": settings.LLM_FACTORY, "llm_name": llm.llm_name, - "model_type": llm.model_type, - "api_key": settings.API_KEY, "api_base": settings.LLM_BASE_URL}) + + seen = set() + factory_configs = [] + for factory_config in [ + settings.CHAT_CFG["factory"], + settings.EMBEDDING_CFG["factory"], + settings.ASR_CFG["factory"], + settings.IMAGE2TEXT_CFG["factory"], + settings.RERANK_CFG["factory"], + ]: + factory_name = factory_config["factory"] + if factory_name not in seen: + seen.add(factory_name) + factory_configs.append(factory_config) + + for factory_config in factory_configs: + for llm in LLMService.query(fid=factory_config["factory"]): + tenant_llm.append( + { + "tenant_id": user_id, + "llm_factory": factory_config["factory"], + "llm_name": llm.llm_name, + "model_type": llm.model_type, + "api_key": factory_config["api_key"], + "api_base": factory_config["base_url"], + "max_tokens": llm.max_tokens if llm.max_tokens else 8192, + } + ) + + unique = {} + for item in tenant_llm: + key = (item["tenant_id"], item["llm_factory"], item["llm_name"]) + if key not in unique: + unique[key] = item + tenant_llm = list(unique.values()) if not UserService.save(**user_info): logging.error("can't init admin.") @@ -103,7 +135,7 @@ def init_llm_factory(): except Exception: pass - factory_llm_infos = settings.FACTORY_LLM_INFOS + factory_llm_infos = settings.FACTORY_LLM_INFOS for factory_llm_info in factory_llm_infos: info = deepcopy(factory_llm_info) llm_infos = info.pop("llm") diff --git a/api/settings.py b/api/settings.py index f5577361b..f45483651 100644 --- a/api/settings.py +++ b/api/settings.py @@ -38,6 +38,11 @@ EMBEDDING_MDL = "" RERANK_MDL = "" ASR_MDL = "" IMAGE2TEXT_MDL = "" +CHAT_CFG = "" +EMBEDDING_CFG = "" +RERANK_CFG = "" +ASR_CFG = "" +IMAGE2TEXT_CFG = "" API_KEY = None PARSERS = None HOST_IP = None @@ -74,23 +79,22 @@ STRONG_TEST_COUNT = int(os.environ.get("STRONG_TEST_COUNT", "8")) BUILTIN_EMBEDDING_MODELS = ["BAAI/bge-large-zh-v1.5@BAAI", "maidalun1020/bce-embedding-base_v1@Youdao"] + def get_or_create_secret_key(): secret_key = os.environ.get("RAGFLOW_SECRET_KEY") if secret_key and len(secret_key) >= 32: return secret_key - + # Check if there's a configured secret key configured_key = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("secret_key") if configured_key and configured_key != str(date.today()) and len(configured_key) >= 32: return configured_key - + # Generate a new secure key and warn about it import logging + new_key = secrets.token_hex(32) - logging.warning( - "SECURITY WARNING: Using auto-generated SECRET_KEY. " - f"Generated key: {new_key}" - ) + logging.warning(f"SECURITY WARNING: Using auto-generated SECRET_KEY. Generated key: {new_key}") return new_key @@ -99,10 +103,10 @@ def init_settings(): LIGHTEN = int(os.environ.get("LIGHTEN", "0")) DATABASE_TYPE = os.getenv("DB_TYPE", "mysql") DATABASE = decrypt_database_config(name=DATABASE_TYPE) - LLM = get_base_config("user_default_llm", {}) - LLM_DEFAULT_MODELS = LLM.get("default_models", {}) - LLM_FACTORY = LLM.get("factory") - LLM_BASE_URL = LLM.get("base_url") + LLM = get_base_config("user_default_llm", {}) or {} + LLM_DEFAULT_MODELS = LLM.get("default_models", {}) or {} + LLM_FACTORY = LLM.get("factory", "") or "" + LLM_BASE_URL = LLM.get("base_url", "") or "" try: REGISTER_ENABLED = int(os.environ.get("REGISTER_ENABLED", "1")) except Exception: @@ -115,29 +119,34 @@ def init_settings(): FACTORY_LLM_INFOS = [] global CHAT_MDL, EMBEDDING_MDL, RERANK_MDL, ASR_MDL, IMAGE2TEXT_MDL + global CHAT_CFG, EMBEDDING_CFG, RERANK_CFG, ASR_CFG, IMAGE2TEXT_CFG if not LIGHTEN: EMBEDDING_MDL = BUILTIN_EMBEDDING_MODELS[0] - if LLM_DEFAULT_MODELS: - CHAT_MDL = LLM_DEFAULT_MODELS.get("chat_model", CHAT_MDL) - EMBEDDING_MDL = LLM_DEFAULT_MODELS.get("embedding_model", EMBEDDING_MDL) - RERANK_MDL = LLM_DEFAULT_MODELS.get("rerank_model", RERANK_MDL) - ASR_MDL = LLM_DEFAULT_MODELS.get("asr_model", ASR_MDL) - IMAGE2TEXT_MDL = LLM_DEFAULT_MODELS.get("image2text_model", IMAGE2TEXT_MDL) - - # factory can be specified in the config name with "@". LLM_FACTORY will be used if not specified - CHAT_MDL = CHAT_MDL + (f"@{LLM_FACTORY}" if "@" not in CHAT_MDL and CHAT_MDL != "" else "") - EMBEDDING_MDL = EMBEDDING_MDL + (f"@{LLM_FACTORY}" if "@" not in EMBEDDING_MDL and EMBEDDING_MDL != "" else "") - RERANK_MDL = RERANK_MDL + (f"@{LLM_FACTORY}" if "@" not in RERANK_MDL and RERANK_MDL != "" else "") - ASR_MDL = ASR_MDL + (f"@{LLM_FACTORY}" if "@" not in ASR_MDL and ASR_MDL != "" else "") - IMAGE2TEXT_MDL = IMAGE2TEXT_MDL + (f"@{LLM_FACTORY}" if "@" not in IMAGE2TEXT_MDL and IMAGE2TEXT_MDL != "" else "") - global API_KEY, PARSERS, HOST_IP, HOST_PORT, SECRET_KEY API_KEY = LLM.get("api_key") PARSERS = LLM.get( "parsers", "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,email:Email,tag:Tag" ) + chat_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("chat_model", CHAT_MDL)) + embedding_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("embedding_model", EMBEDDING_MDL)) + rerank_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("rerank_model", RERANK_MDL)) + asr_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("asr_model", ASR_MDL)) + image2text_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("image2text_model", IMAGE2TEXT_MDL)) + + CHAT_CFG = _resolve_per_model_config(chat_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL) + EMBEDDING_CFG = _resolve_per_model_config(embedding_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL) + RERANK_CFG = _resolve_per_model_config(rerank_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL) + ASR_CFG = _resolve_per_model_config(asr_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL) + IMAGE2TEXT_CFG = _resolve_per_model_config(image2text_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL) + + CHAT_MDL = CHAT_CFG.get("model", "") or "" + EMBEDDING_MDL = EMBEDDING_CFG.get("model", "") or "" + RERANK_MDL = RERANK_CFG.get("model", "") or "" + ASR_MDL = ASR_CFG.get("model", "") or "" + IMAGE2TEXT_MDL = IMAGE2TEXT_CFG.get("model", "") or "" + HOST_IP = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("host", "127.0.0.1") HOST_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("http_port") @@ -170,6 +179,7 @@ def init_settings(): retrievaler = search.Dealer(docStoreConn) from graphrag import search as kg_search + kg_retrievaler = kg_search.KGSearch(docStoreConn) if int(os.environ.get("SANDBOX_ENABLED", "0")): @@ -210,3 +220,34 @@ class RetCode(IntEnum, CustomEnum): SERVER_ERROR = 500 FORBIDDEN = 403 NOT_FOUND = 404 + + +def _parse_model_entry(entry): + if isinstance(entry, str): + return {"name": entry, "factory": None, "api_key": None, "base_url": None} + if isinstance(entry, dict): + name = entry.get("name") or entry.get("model") or "" + return { + "name": name, + "factory": entry.get("factory"), + "api_key": entry.get("api_key"), + "base_url": entry.get("base_url"), + } + return {"name": "", "factory": None, "api_key": None, "base_url": None} + + +def _resolve_per_model_config(entry_dict, backup_factory, backup_api_key, backup_base_url): + name = (entry_dict.get("name") or "").strip() + m_factory = entry_dict.get("factory") or backup_factory or "" + m_api_key = entry_dict.get("api_key") or backup_api_key or "" + m_base_url = entry_dict.get("base_url") or backup_base_url or "" + + if name and "@" not in name and m_factory: + name = f"{name}@{m_factory}" + + return { + "model": name, + "factory": m_factory, + "api_key": m_api_key, + "base_url": m_base_url, + } diff --git a/conf/service_conf.yaml b/conf/service_conf.yaml index 7a94a2a1d..9a995c2b2 100644 --- a/conf/service_conf.yaml +++ b/conf/service_conf.yaml @@ -64,9 +64,21 @@ redis: # config: # oss_table: 'opendal_storage' # user_default_llm: -# factory: 'Tongyi-Qianwen' -# api_key: 'sk-xxxxxxxxxxxxx' -# base_url: '' +# factory: 'BAAI' +# api_key: 'backup' +# base_url: 'backup_base_url' +# default_models: +# chat_model: +# name: 'qwen2.5-7b-instruct' +# factory: 'xxxx' +# api_key: 'xxxx' +# base_url: 'https://api.xx.com' +# embedding_model: +# name: 'bge-m3' +# rerank_model: 'bge-reranker-v2' +# asr_model: +# model: 'whisper-large-v3' # alias of name +# image2text_model: '' # oauth: # oauth2: # display_name: "OAuth2"