feat: add allowed factories variable to allow admins to restrict llms users can add (#11003)

### What problem does this PR solve?

Currently, if we want to restrict the allowed factories users can use we
need to delete from the database table manually. The proposal of this PR
is to include a variable to that, if set, will restrict the LLM
factories the users can see and add. This allow us to not touch the
llm_factories.json or the database if the LLM factory is already
inserted.

Obs.: All the lint changes were from the pre-commit hook which I did not
change.

### Type of change

- [X] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Wanderson Pinto dos Santos
2025-11-04 23:47:50 -03:00
committed by GitHub
parent bab3fce136
commit 3654ae61c1
4 changed files with 148 additions and 189 deletions

View File

@ -23,16 +23,16 @@ from api.db.services.llm_service import LLMService
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from common.constants import StatusEnum, LLMType from common.constants import StatusEnum, LLMType
from api.db.db_models import TenantLLM from api.db.db_models import TenantLLM
from api.utils.api_utils import get_json_result from api.utils.api_utils import get_json_result, get_allowed_llm_factories
from common.base64_image import test_image from common.base64_image import test_image
from rag.llm import EmbeddingModel, ChatModel, RerankModel, CvModel, TTSModel from rag.llm import EmbeddingModel, ChatModel, RerankModel, CvModel, TTSModel
@manager.route('/factories', methods=['GET']) # noqa: F821 @manager.route("/factories", methods=["GET"]) # noqa: F821
@login_required @login_required
def factories(): def factories():
try: try:
fac = LLMFactoriesService.get_all() fac = get_allowed_llm_factories()
fac = [f.to_dict() for f in fac if f.name not in ["Youdao", "FastEmbed", "BAAI"]] fac = [f.to_dict() for f in fac if f.name not in ["Youdao", "FastEmbed", "BAAI"]]
llms = LLMService.get_all() llms = LLMService.get_all()
mdl_types = {} mdl_types = {}
@ -43,14 +43,13 @@ def factories():
mdl_types[m.fid] = set([]) mdl_types[m.fid] = set([])
mdl_types[m.fid].add(m.model_type) mdl_types[m.fid].add(m.model_type)
for f in fac: for f in fac:
f["model_types"] = list(mdl_types.get(f["name"], [LLMType.CHAT, LLMType.EMBEDDING, LLMType.RERANK, f["model_types"] = list(mdl_types.get(f["name"], [LLMType.CHAT, LLMType.EMBEDDING, LLMType.RERANK, LLMType.IMAGE2TEXT, LLMType.SPEECH2TEXT, LLMType.TTS]))
LLMType.IMAGE2TEXT, LLMType.SPEECH2TEXT, LLMType.TTS]))
return get_json_result(data=fac) return get_json_result(data=fac)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@manager.route('/set_api_key', methods=['POST']) # noqa: F821 @manager.route("/set_api_key", methods=["POST"]) # noqa: F821
@login_required @login_required
@validate_request("llm_factory", "api_key") @validate_request("llm_factory", "api_key")
def set_api_key(): def set_api_key():
@ -63,8 +62,7 @@ def set_api_key():
for llm in LLMService.query(fid=factory): for llm in LLMService.query(fid=factory):
if not embd_passed and llm.model_type == LLMType.EMBEDDING.value: if not embd_passed and llm.model_type == LLMType.EMBEDDING.value:
assert factory in EmbeddingModel, f"Embedding model from {factory} is not supported yet." assert factory in EmbeddingModel, f"Embedding model from {factory} is not supported yet."
mdl = EmbeddingModel[factory]( mdl = EmbeddingModel[factory](req["api_key"], llm.llm_name, base_url=req.get("base_url"))
req["api_key"], llm.llm_name, base_url=req.get("base_url"))
try: try:
arr, tc = mdl.encode(["Test if the api key is available"]) arr, tc = mdl.encode(["Test if the api key is available"])
if len(arr[0]) == 0: if len(arr[0]) == 0:
@ -74,52 +72,40 @@ def set_api_key():
msg += f"\nFail to access embedding model({llm.llm_name}) using this api key." + str(e) msg += f"\nFail to access embedding model({llm.llm_name}) using this api key." + str(e)
elif not chat_passed and llm.model_type == LLMType.CHAT.value: elif not chat_passed and llm.model_type == LLMType.CHAT.value:
assert factory in ChatModel, f"Chat model from {factory} is not supported yet." assert factory in ChatModel, f"Chat model from {factory} is not supported yet."
mdl = ChatModel[factory]( mdl = ChatModel[factory](req["api_key"], llm.llm_name, base_url=req.get("base_url"), **extra)
req["api_key"], llm.llm_name, base_url=req.get("base_url"), **extra)
try: try:
m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9, "max_tokens": 50})
{"temperature": 0.9, 'max_tokens': 50})
if m.find("**ERROR**") >= 0: if m.find("**ERROR**") >= 0:
raise Exception(m) raise Exception(m)
chat_passed = True chat_passed = True
except Exception as e: except Exception as e:
msg += f"\nFail to access model({llm.fid}/{llm.llm_name}) using this api key." + str( msg += f"\nFail to access model({llm.fid}/{llm.llm_name}) using this api key." + str(e)
e)
elif not rerank_passed and llm.model_type == LLMType.RERANK: elif not rerank_passed and llm.model_type == LLMType.RERANK:
assert factory in RerankModel, f"Re-rank model from {factory} is not supported yet." assert factory in RerankModel, f"Re-rank model from {factory} is not supported yet."
mdl = RerankModel[factory]( mdl = RerankModel[factory](req["api_key"], llm.llm_name, base_url=req.get("base_url"))
req["api_key"], llm.llm_name, base_url=req.get("base_url"))
try: try:
arr, tc = mdl.similarity("What's the weather?", ["Is it sunny today?"]) arr, tc = mdl.similarity("What's the weather?", ["Is it sunny today?"])
if len(arr) == 0 or tc == 0: if len(arr) == 0 or tc == 0:
raise Exception("Fail") raise Exception("Fail")
rerank_passed = True rerank_passed = True
logging.debug(f'passed model rerank {llm.llm_name}') logging.debug(f"passed model rerank {llm.llm_name}")
except Exception as e: except Exception as e:
msg += f"\nFail to access model({llm.fid}/{llm.llm_name}) using this api key." + str( msg += f"\nFail to access model({llm.fid}/{llm.llm_name}) using this api key." + str(e)
e)
if any([embd_passed, chat_passed, rerank_passed]): if any([embd_passed, chat_passed, rerank_passed]):
msg = '' msg = ""
break break
if msg: if msg:
return get_data_error_result(message=msg) return get_data_error_result(message=msg)
llm_config = { llm_config = {"api_key": req["api_key"], "api_base": req.get("base_url", "")}
"api_key": req["api_key"],
"api_base": req.get("base_url", "")
}
for n in ["model_type", "llm_name"]: for n in ["model_type", "llm_name"]:
if n in req: if n in req:
llm_config[n] = req[n] llm_config[n] = req[n]
for llm in LLMService.query(fid=factory): for llm in LLMService.query(fid=factory):
llm_config["max_tokens"] = llm.max_tokens llm_config["max_tokens"] = llm.max_tokens
if not TenantLLMService.filter_update( if not TenantLLMService.filter_update([TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == factory, TenantLLM.llm_name == llm.llm_name], llm_config):
[TenantLLM.tenant_id == current_user.id,
TenantLLM.llm_factory == factory,
TenantLLM.llm_name == llm.llm_name],
llm_config):
TenantLLMService.save( TenantLLMService.save(
tenant_id=current_user.id, tenant_id=current_user.id,
llm_factory=factory, llm_factory=factory,
@ -127,13 +113,13 @@ def set_api_key():
model_type=llm.model_type, model_type=llm.model_type,
api_key=llm_config["api_key"], api_key=llm_config["api_key"],
api_base=llm_config["api_base"], api_base=llm_config["api_base"],
max_tokens=llm_config["max_tokens"] max_tokens=llm_config["max_tokens"],
) )
return get_json_result(data=True) return get_json_result(data=True)
@manager.route('/add_llm', methods=['POST']) # noqa: F821 @manager.route("/add_llm", methods=["POST"]) # noqa: F821
@login_required @login_required
@validate_request("llm_factory") @validate_request("llm_factory")
def add_llm(): def add_llm():
@ -142,6 +128,9 @@ def add_llm():
api_key = req.get("api_key", "x") api_key = req.get("api_key", "x")
llm_name = req.get("llm_name") llm_name = req.get("llm_name")
if factory not in get_allowed_llm_factories():
return get_data_error_result(message=f"LLM factory {factory} is not allowed")
def apikey_json(keys): def apikey_json(keys):
nonlocal req nonlocal req
return json.dumps({k: req.get(k, "") for k in keys}) return json.dumps({k: req.get(k, "") for k in keys})
@ -204,7 +193,7 @@ def add_llm():
"llm_name": llm_name, "llm_name": llm_name,
"api_base": req.get("api_base", ""), "api_base": req.get("api_base", ""),
"api_key": api_key, "api_key": api_key,
"max_tokens": req.get("max_tokens") "max_tokens": req.get("max_tokens"),
} }
msg = "" msg = ""
@ -212,10 +201,7 @@ def add_llm():
extra = {"provider": factory} extra = {"provider": factory}
if llm["model_type"] == LLMType.EMBEDDING.value: if llm["model_type"] == LLMType.EMBEDDING.value:
assert factory in EmbeddingModel, f"Embedding model from {factory} is not supported yet." assert factory in EmbeddingModel, f"Embedding model from {factory} is not supported yet."
mdl = EmbeddingModel[factory]( mdl = EmbeddingModel[factory](key=llm["api_key"], model_name=mdl_nm, base_url=llm["api_base"])
key=llm['api_key'],
model_name=mdl_nm,
base_url=llm["api_base"])
try: try:
arr, tc = mdl.encode(["Test if the api key is available"]) arr, tc = mdl.encode(["Test if the api key is available"])
if len(arr[0]) == 0: if len(arr[0]) == 0:
@ -225,42 +211,31 @@ def add_llm():
elif llm["model_type"] == LLMType.CHAT.value: elif llm["model_type"] == LLMType.CHAT.value:
assert factory in ChatModel, f"Chat model from {factory} is not supported yet." assert factory in ChatModel, f"Chat model from {factory} is not supported yet."
mdl = ChatModel[factory]( mdl = ChatModel[factory](
key=llm['api_key'], key=llm["api_key"],
model_name=mdl_nm, model_name=mdl_nm,
base_url=llm["api_base"], base_url=llm["api_base"],
**extra, **extra,
) )
try: try:
m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], { m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9})
"temperature": 0.9})
if not tc and m.find("**ERROR**:") >= 0: if not tc and m.find("**ERROR**:") >= 0:
raise Exception(m) raise Exception(m)
except Exception as e: except Exception as e:
msg += f"\nFail to access model({factory}/{mdl_nm})." + str( msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
e)
elif llm["model_type"] == LLMType.RERANK: elif llm["model_type"] == LLMType.RERANK:
assert factory in RerankModel, f"RE-rank model from {factory} is not supported yet." assert factory in RerankModel, f"RE-rank model from {factory} is not supported yet."
try: try:
mdl = RerankModel[factory]( mdl = RerankModel[factory](key=llm["api_key"], model_name=mdl_nm, base_url=llm["api_base"])
key=llm["api_key"],
model_name=mdl_nm,
base_url=llm["api_base"]
)
arr, tc = mdl.similarity("Hello~ RAGFlower!", ["Hi, there!", "Ohh, my friend!"]) arr, tc = mdl.similarity("Hello~ RAGFlower!", ["Hi, there!", "Ohh, my friend!"])
if len(arr) == 0: if len(arr) == 0:
raise Exception("Not known.") raise Exception("Not known.")
except KeyError: except KeyError:
msg += f"{factory} dose not support this model({factory}/{mdl_nm})" msg += f"{factory} dose not support this model({factory}/{mdl_nm})"
except Exception as e: except Exception as e:
msg += f"\nFail to access model({factory}/{mdl_nm})." + str( msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
e)
elif llm["model_type"] == LLMType.IMAGE2TEXT.value: elif llm["model_type"] == LLMType.IMAGE2TEXT.value:
assert factory in CvModel, f"Image to text model from {factory} is not supported yet." assert factory in CvModel, f"Image to text model from {factory} is not supported yet."
mdl = CvModel[factory]( mdl = CvModel[factory](key=llm["api_key"], model_name=mdl_nm, base_url=llm["api_base"])
key=llm["api_key"],
model_name=mdl_nm,
base_url=llm["api_base"]
)
try: try:
image_data = test_image image_data = test_image
m, tc = mdl.describe(image_data) m, tc = mdl.describe(image_data)
@ -270,9 +245,7 @@ def add_llm():
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e) msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
elif llm["model_type"] == LLMType.TTS: elif llm["model_type"] == LLMType.TTS:
assert factory in TTSModel, f"TTS model from {factory} is not supported yet." assert factory in TTSModel, f"TTS model from {factory} is not supported yet."
mdl = TTSModel[factory]( mdl = TTSModel[factory](key=llm["api_key"], model_name=mdl_nm, base_url=llm["api_base"])
key=llm["api_key"], model_name=mdl_nm, base_url=llm["api_base"]
)
try: try:
for resp in mdl.tts("Hello~ RAGFlower!"): for resp in mdl.tts("Hello~ RAGFlower!"):
pass pass
@ -285,51 +258,46 @@ def add_llm():
if msg: if msg:
return get_data_error_result(message=msg) return get_data_error_result(message=msg)
if not TenantLLMService.filter_update( if not TenantLLMService.filter_update([TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == factory, TenantLLM.llm_name == llm["llm_name"]], llm):
[TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == factory,
TenantLLM.llm_name == llm["llm_name"]], llm):
TenantLLMService.save(**llm) TenantLLMService.save(**llm)
return get_json_result(data=True) return get_json_result(data=True)
@manager.route('/delete_llm', methods=['POST']) # noqa: F821 @manager.route("/delete_llm", methods=["POST"]) # noqa: F821
@login_required @login_required
@validate_request("llm_factory", "llm_name") @validate_request("llm_factory", "llm_name")
def delete_llm(): def delete_llm():
req = request.json req = request.json
TenantLLMService.filter_delete( TenantLLMService.filter_delete([TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"], TenantLLM.llm_name == req["llm_name"]])
[TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"],
TenantLLM.llm_name == req["llm_name"]])
return get_json_result(data=True) return get_json_result(data=True)
@manager.route('/enable_llm', methods=['POST']) # noqa: F821 @manager.route("/enable_llm", methods=["POST"]) # noqa: F821
@login_required @login_required
@validate_request("llm_factory", "llm_name") @validate_request("llm_factory", "llm_name")
def enable_llm(): def enable_llm():
req = request.json req = request.json
TenantLLMService.filter_update( TenantLLMService.filter_update(
[TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"], [TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"], TenantLLM.llm_name == req["llm_name"]], {"status": str(req.get("status", "1"))}
TenantLLM.llm_name == req["llm_name"]], {"status": str(req.get("status", "1"))}) )
return get_json_result(data=True) return get_json_result(data=True)
@manager.route('/delete_factory', methods=['POST']) # noqa: F821 @manager.route("/delete_factory", methods=["POST"]) # noqa: F821
@login_required @login_required
@validate_request("llm_factory") @validate_request("llm_factory")
def delete_factory(): def delete_factory():
req = request.json req = request.json
TenantLLMService.filter_delete( TenantLLMService.filter_delete([TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"]])
[TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"]])
return get_json_result(data=True) return get_json_result(data=True)
@manager.route('/my_llms', methods=['GET']) # noqa: F821 @manager.route("/my_llms", methods=["GET"]) # noqa: F821
@login_required @login_required
def my_llms(): def my_llms():
try: try:
include_details = request.args.get('include_details', 'false').lower() == 'true' include_details = request.args.get("include_details", "false").lower() == "true"
if include_details: if include_details:
res = {} res = {}
@ -345,40 +313,31 @@ def my_llms():
break break
if o_dict["llm_factory"] not in res: if o_dict["llm_factory"] not in res:
res[o_dict["llm_factory"]] = { res[o_dict["llm_factory"]] = {"tags": factory_tags, "llm": []}
"tags": factory_tags,
"llm": []
}
res[o_dict["llm_factory"]]["llm"].append({ res[o_dict["llm_factory"]]["llm"].append(
{
"type": o_dict["model_type"], "type": o_dict["model_type"],
"name": o_dict["llm_name"], "name": o_dict["llm_name"],
"used_token": o_dict["used_tokens"], "used_token": o_dict["used_tokens"],
"api_base": o_dict["api_base"] or "", "api_base": o_dict["api_base"] or "",
"max_tokens": o_dict["max_tokens"] or 8192, "max_tokens": o_dict["max_tokens"] or 8192,
"status": o_dict["status"] or "1" "status": o_dict["status"] or "1",
}) }
)
else: else:
res = {} res = {}
for o in TenantLLMService.get_my_llms(current_user.id): for o in TenantLLMService.get_my_llms(current_user.id):
if o["llm_factory"] not in res: if o["llm_factory"] not in res:
res[o["llm_factory"]] = { res[o["llm_factory"]] = {"tags": o["tags"], "llm": []}
"tags": o["tags"], res[o["llm_factory"]]["llm"].append({"type": o["model_type"], "name": o["llm_name"], "used_token": o["used_tokens"], "status": o["status"]})
"llm": []
}
res[o["llm_factory"]]["llm"].append({
"type": o["model_type"],
"name": o["llm_name"],
"used_token": o["used_tokens"],
"status": o["status"]
})
return get_json_result(data=res) return get_json_result(data=res)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@manager.route('/list', methods=['GET']) # noqa: F821 @manager.route("/list", methods=["GET"]) # noqa: F821
@login_required @login_required
def list_app(): def list_app():
self_deployed = ["FastEmbed", "Ollama", "Xinference", "LocalAI", "LM-Studio", "GPUStack"] self_deployed = ["FastEmbed", "Ollama", "Xinference", "LocalAI", "LM-Studio", "GPUStack"]
@ -389,11 +348,10 @@ def list_app():
facts = set([o.to_dict()["llm_factory"] for o in objs if o.api_key and o.status == StatusEnum.VALID.value]) facts = set([o.to_dict()["llm_factory"] for o in objs if o.api_key and o.status == StatusEnum.VALID.value])
status = {(o.llm_name + "@" + o.llm_factory) for o in objs if o.status == StatusEnum.VALID.value} status = {(o.llm_name + "@" + o.llm_factory) for o in objs if o.status == StatusEnum.VALID.value}
llms = LLMService.get_all() llms = LLMService.get_all()
llms = [m.to_dict() llms = [m.to_dict() for m in llms if m.status == StatusEnum.VALID.value and m.fid not in weighted and (m.llm_name + "@" + m.fid) in status]
for m in llms if m.status == StatusEnum.VALID.value and m.fid not in weighted and (m.llm_name + "@" + m.fid) in status]
for m in llms: for m in llms:
m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" or m["fid"] in self_deployed m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" or m["fid"] in self_deployed
if "tei-" in os.getenv("COMPOSE_PROFILES", "") and m["model_type"]==LLMType.EMBEDDING and m["fid"]=="Builtin" and m["llm_name"]==os.getenv('TEI_MODEL', ''): if "tei-" in os.getenv("COMPOSE_PROFILES", "") and m["model_type"] == LLMType.EMBEDDING and m["fid"] == "Builtin" and m["llm_name"] == os.getenv("TEI_MODEL", ""):
m["available"] = True m["available"] = True
llm_set = set([m["llm_name"] + "@" + m["fid"] for m in llms]) llm_set = set([m["llm_name"] + "@" + m["fid"] for m in llms])

View File

@ -46,6 +46,7 @@ HOST_IP = None
HOST_PORT = None HOST_PORT = None
SECRET_KEY = None SECRET_KEY = None
FACTORY_LLM_INFOS = None FACTORY_LLM_INFOS = None
ALLOWED_LLM_FACTORIES = None
DATABASE_TYPE = os.getenv("DB_TYPE", "mysql") DATABASE_TYPE = os.getenv("DB_TYPE", "mysql")
DATABASE = decrypt_database_config(name=DATABASE_TYPE) DATABASE = decrypt_database_config(name=DATABASE_TYPE)
@ -104,13 +105,14 @@ def get_or_create_secret_key():
def init_settings(): def init_settings():
global LLM, LLM_FACTORY, LLM_BASE_URL, DATABASE_TYPE, DATABASE, FACTORY_LLM_INFOS, REGISTER_ENABLED global LLM, LLM_FACTORY, LLM_BASE_URL, DATABASE_TYPE, DATABASE, FACTORY_LLM_INFOS, REGISTER_ENABLED, ALLOWED_LLM_FACTORIES
DATABASE_TYPE = os.getenv("DB_TYPE", "mysql") DATABASE_TYPE = os.getenv("DB_TYPE", "mysql")
DATABASE = decrypt_database_config(name=DATABASE_TYPE) DATABASE = decrypt_database_config(name=DATABASE_TYPE)
LLM = get_base_config("user_default_llm", {}) or {} LLM = get_base_config("user_default_llm", {}) or {}
LLM_DEFAULT_MODELS = LLM.get("default_models", {}) or {} LLM_DEFAULT_MODELS = LLM.get("default_models", {}) or {}
LLM_FACTORY = LLM.get("factory", "") or "" LLM_FACTORY = LLM.get("factory", "") or ""
LLM_BASE_URL = LLM.get("base_url", "") or "" LLM_BASE_URL = LLM.get("base_url", "") or ""
ALLOWED_LLM_FACTORIES = LLM.get("allowed_factories", None)
try: try:
REGISTER_ENABLED = int(os.environ.get("REGISTER_ENABLED", "1")) REGISTER_ENABLED = int(os.environ.get("REGISTER_ENABLED", "1"))
except Exception: except Exception:

View File

@ -39,6 +39,7 @@ from common.constants import ActiveEnum
from api.db.db_models import APIToken from api.db.db_models import APIToken
from api.utils.json_encode import CustomJSONEncoder from api.utils.json_encode import CustomJSONEncoder
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
from api.db.services.tenant_llm_service import LLMFactoriesService
from common.connection_utils import timeout from common.connection_utils import timeout
from common.constants import RetCode from common.constants import RetCode
@ -51,16 +52,15 @@ def serialize_for_json(obj):
Recursively serialize objects to make them JSON serializable. Recursively serialize objects to make them JSON serializable.
Handles ModelMetaclass and other non-serializable objects. Handles ModelMetaclass and other non-serializable objects.
""" """
if hasattr(obj, '__dict__'): if hasattr(obj, "__dict__"):
# For objects with __dict__, try to serialize their attributes # For objects with __dict__, try to serialize their attributes
try: try:
return {key: serialize_for_json(value) for key, value in obj.__dict__.items() return {key: serialize_for_json(value) for key, value in obj.__dict__.items() if not key.startswith("_")}
if not key.startswith('_')}
except (AttributeError, TypeError): except (AttributeError, TypeError):
return str(obj) return str(obj)
elif hasattr(obj, '__name__'): elif hasattr(obj, "__name__"):
# For classes and metaclasses, return their name # For classes and metaclasses, return their name
return f"<{obj.__module__}.{obj.__name__}>" if hasattr(obj, '__module__') else f"<{obj.__name__}>" return f"<{obj.__module__}.{obj.__name__}>" if hasattr(obj, "__module__") else f"<{obj.__name__}>"
elif isinstance(obj, (list, tuple)): elif isinstance(obj, (list, tuple)):
return [serialize_for_json(item) for item in obj] return [serialize_for_json(item) for item in obj]
elif isinstance(obj, dict): elif isinstance(obj, dict):
@ -71,6 +71,7 @@ def serialize_for_json(obj):
# Fallback: convert to string representation # Fallback: convert to string representation
return str(obj) return str(obj)
def get_data_error_result(code=RetCode.DATA_ERROR, message="Sorry! Data missing!"): def get_data_error_result(code=RetCode.DATA_ERROR, message="Sorry! Data missing!"):
logging.exception(Exception(message)) logging.exception(Exception(message))
result_dict = {"code": code, "message": message} result_dict = {"code": code, "message": message}
@ -99,8 +100,7 @@ def server_error_response(e):
except Exception: except Exception:
return get_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=None) return get_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=None)
if repr(e).find("index_not_found_exception") >= 0: if repr(e).find("index_not_found_exception") >= 0:
return get_json_result(code=RetCode.EXCEPTION_ERROR, return get_json_result(code=RetCode.EXCEPTION_ERROR, message="No chunk found, please upload file and parse it.")
message="No chunk found, please upload file and parse it.")
return get_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e)) return get_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e))
@ -129,8 +129,7 @@ def validate_request(*args, **kwargs):
if no_arguments: if no_arguments:
error_string += "required argument are missing: {}; ".format(",".join(no_arguments)) error_string += "required argument are missing: {}; ".format(",".join(no_arguments))
if error_arguments: if error_arguments:
error_string += "required argument values: {}".format( error_string += "required argument values: {}".format(",".join(["{}={}".format(a[0], a[1]) for a in error_arguments]))
",".join(["{}={}".format(a[0], a[1]) for a in error_arguments]))
return get_json_result(code=RetCode.ARGUMENT_ERROR, message=error_string) return get_json_result(code=RetCode.ARGUMENT_ERROR, message=error_string)
return func(*_args, **_kwargs) return func(*_args, **_kwargs)
@ -145,8 +144,7 @@ def not_allowed_parameters(*params):
input_arguments = flask_request.json or flask_request.form.to_dict() input_arguments = flask_request.json or flask_request.form.to_dict()
for param in params: for param in params:
if param in input_arguments: if param in input_arguments:
return get_json_result(code=RetCode.ARGUMENT_ERROR, return get_json_result(code=RetCode.ARGUMENT_ERROR, message=f"Parameter {param} isn't allowed")
message=f"Parameter {param} isn't allowed")
return f(*args, **kwargs) return f(*args, **kwargs)
return wrapper return wrapper
@ -158,6 +156,7 @@ def active_required(f):
@wraps(f) @wraps(f)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
from api.db.services import UserService from api.db.services import UserService
user_id = current_user.id user_id = current_user.id
usr = UserService.filter_by_id(user_id) usr = UserService.filter_by_id(user_id)
# check is_active # check is_active
@ -199,6 +198,7 @@ def construct_json_result(code: RetCode = RetCode.SUCCESS, message="success", da
else: else:
return jsonify({"code": code, "message": message, "data": data}) return jsonify({"code": code, "message": message, "data": data})
def token_required(func): def token_required(func):
@wraps(func) @wraps(func)
def decorated_function(*args, **kwargs): def decorated_function(*args, **kwargs):
@ -213,8 +213,7 @@ def token_required(func):
token = authorization_list[1] token = authorization_list[1]
objs = APIToken.query(token=token) objs = APIToken.query(token=token)
if not objs: if not objs:
return get_json_result(data=False, message="Authentication error: API key is invalid!", return get_json_result(data=False, message="Authentication error: API key is invalid!", code=RetCode.AUTHENTICATION_ERROR)
code=RetCode.AUTHENTICATION_ERROR)
kwargs["tenant_id"] = objs[0].tenant_id kwargs["tenant_id"] = objs[0].tenant_id
return func(*args, **kwargs) return func(*args, **kwargs)
@ -243,6 +242,7 @@ def get_result(code=RetCode.SUCCESS, message="", data=None, total=None):
return jsonify(response) return jsonify(response)
def get_error_data_result( def get_error_data_result(
message="Sorry! Data missing!", message="Sorry! Data missing!",
code=RetCode.DATA_ERROR, code=RetCode.DATA_ERROR,
@ -271,6 +271,7 @@ def get_error_operating_result(message="Operating error"):
def generate_confirmation_token(): def generate_confirmation_token():
import secrets import secrets
return "ragflow-" + secrets.token_urlsafe(32) return "ragflow-" + secrets.token_urlsafe(32)
@ -345,18 +346,7 @@ def get_parser_config(chunk_method, parser_config):
return merged_config return merged_config
def get_data_openai( def get_data_openai(id=None, created=None, model=None, prompt_tokens=0, completion_tokens=0, content=None, finish_reason=None, object="chat.completion", param=None, stream=False):
id=None,
created=None,
model=None,
prompt_tokens=0,
completion_tokens=0,
content=None,
finish_reason=None,
object="chat.completion",
param=None,
stream=False
):
total_tokens = prompt_tokens + completion_tokens total_tokens = prompt_tokens + completion_tokens
if stream: if stream:
@ -364,11 +354,13 @@ def get_data_openai(
"id": f"{id}", "id": f"{id}",
"object": "chat.completion.chunk", "object": "chat.completion.chunk",
"model": model, "model": model,
"choices": [{ "choices": [
{
"delta": {"content": content}, "delta": {"content": content},
"finish_reason": finish_reason, "finish_reason": finish_reason,
"index": 0, "index": 0,
}], }
],
} }
return { return {
@ -387,15 +379,14 @@ def get_data_openai(
"rejected_prediction_tokens": 0, "rejected_prediction_tokens": 0,
}, },
}, },
"choices": [{ "choices": [
"message": { {
"role": "assistant", "message": {"role": "assistant", "content": content},
"content": content
},
"logprobs": None, "logprobs": None,
"finish_reason": finish_reason, "finish_reason": finish_reason,
"index": 0, "index": 0,
}], }
],
} }
@ -431,6 +422,7 @@ def check_duplicate_ids(ids, id_type="item"):
def verify_embedding_availability(embd_id: str, tenant_id: str) -> tuple[bool, Response | None]: def verify_embedding_availability(embd_id: str, tenant_id: str) -> tuple[bool, Response | None]:
from api.db.services.llm_service import LLMService from api.db.services.llm_service import LLMService
from api.db.services.tenant_llm_service import TenantLLMService from api.db.services.tenant_llm_service import TenantLLMService
""" """
Verifies availability of an embedding model for a specific tenant. Verifies availability of an embedding model for a specific tenant.
@ -469,11 +461,9 @@ def verify_embedding_availability(embd_id: str, tenant_id: str) -> tuple[bool, R
in_llm_service = bool(LLMService.query(llm_name=llm_name, fid=llm_factory, model_type="embedding")) in_llm_service = bool(LLMService.query(llm_name=llm_name, fid=llm_factory, model_type="embedding"))
tenant_llms = TenantLLMService.get_my_llms(tenant_id=tenant_id) tenant_llms = TenantLLMService.get_my_llms(tenant_id=tenant_id)
is_tenant_model = any( is_tenant_model = any(llm["llm_name"] == llm_name and llm["llm_factory"] == llm_factory and llm["model_type"] == "embedding" for llm in tenant_llms)
llm["llm_name"] == llm_name and llm["llm_factory"] == llm_factory and llm["model_type"] == "embedding" for
llm in tenant_llms)
is_builtin_model = llm_factory=='Builtin' is_builtin_model = llm_factory == "Builtin"
if not (is_builtin_model or is_tenant_model or in_llm_service): if not (is_builtin_model or is_tenant_model or in_llm_service):
return False, get_error_argument_result(f"Unsupported model: <{embd_id}>") return False, get_error_argument_result(f"Unsupported model: <{embd_id}>")
@ -610,7 +600,6 @@ def get_mcp_tools(mcp_servers: list, timeout: float | int = 10) -> tuple[dict, s
return {}, str(e) return {}, str(e)
async def is_strong_enough(chat_model, embedding_model): async def is_strong_enough(chat_model, embedding_model):
count = settings.STRONG_TEST_COUNT count = settings.STRONG_TEST_COUNT
if not chat_model or not embedding_model: if not chat_model or not embedding_model:
@ -626,9 +615,7 @@ async def is_strong_enough(chat_model, embedding_model):
_ = await trio.to_thread.run_sync(lambda: embedding_model.encode(["Are you strong enough!?"])) _ = await trio.to_thread.run_sync(lambda: embedding_model.encode(["Are you strong enough!?"]))
if chat_model: if chat_model:
with trio.fail_after(30): with trio.fail_after(30):
res = await trio.to_thread.run_sync(lambda: chat_model.chat("Nothing special.", [{"role": "user", res = await trio.to_thread.run_sync(lambda: chat_model.chat("Nothing special.", [{"role": "user", "content": "Are you strong enough!?"}], {}))
"content": "Are you strong enough!?"}],
{}))
if res.find("**ERROR**") >= 0: if res.find("**ERROR**") >= 0:
raise Exception(res) raise Exception(res)
@ -636,3 +623,11 @@ async def is_strong_enough(chat_model, embedding_model):
async with trio.open_nursery() as nursery: async with trio.open_nursery() as nursery:
for _ in range(count): for _ in range(count):
nursery.start_soon(_is_strong_enough) nursery.start_soon(_is_strong_enough)
def get_allowed_llm_factories() -> list:
factories = LLMFactoriesService.get_all()
if settings.ALLOWED_LLM_FACTORIES is None:
return factories
return [factory for factory in factories if factory.name in settings.ALLOWED_LLM_FACTORIES]

View File

@ -228,6 +228,10 @@ The default LLM to use for a new RAGFlow user. It is disabled by default. To ena
- `"VolcEngine"` - `"VolcEngine"`
- `"ZHIPU-AI"` - `"ZHIPU-AI"`
- `api_key`: The API key for the specified LLM. You will need to apply for your model API key online. - `api_key`: The API key for the specified LLM. You will need to apply for your model API key online.
- `allowed_factories`: If this is set, the users will be allowed to add only the factories in this list.
- `"OpenAI"`
- `"DeepSeek"`
- `"Moonshot"`
:::tip NOTE :::tip NOTE
If you do not set the default LLM here, configure the default LLM on the **Settings** page in the RAGFlow UI. If you do not set the default LLM here, configure the default LLM on the **Settings** page in the RAGFlow UI.