From 3654ae61c13ea9ef039c5cf7c797f6a21c7ef91d Mon Sep 17 00:00:00 2001 From: Wanderson Pinto dos Santos <85753826+wanpdsantos@users.noreply.github.com> Date: Tue, 4 Nov 2025 23:47:50 -0300 Subject: [PATCH] feat: add allowed factories variable to allow admins to restrict llms users can add (#11003) ### What problem does this PR solve? Currently, if we want to restrict the allowed factories users can use we need to delete from the database table manually. The proposal of this PR is to include a variable to that, if set, will restrict the LLM factories the users can see and add. This allow us to not touch the llm_factories.json or the database if the LLM factory is already inserted. Obs.: All the lint changes were from the pre-commit hook which I did not change. ### Type of change - [X] New Feature (non-breaking change which adds functionality) --- api/apps/llm_app.py | 156 +++++++++++++++-------------------------- api/settings.py | 6 +- api/utils/api_utils.py | 93 ++++++++++++------------ docs/configurations.md | 82 +++++++++++----------- 4 files changed, 148 insertions(+), 189 deletions(-) diff --git a/api/apps/llm_app.py b/api/apps/llm_app.py index 4f0e3c91c..c49097fac 100644 --- a/api/apps/llm_app.py +++ b/api/apps/llm_app.py @@ -23,16 +23,16 @@ from api.db.services.llm_service import LLMService from api.utils.api_utils import server_error_response, get_data_error_result, validate_request from common.constants import StatusEnum, LLMType from api.db.db_models import TenantLLM -from api.utils.api_utils import get_json_result +from api.utils.api_utils import get_json_result, get_allowed_llm_factories from common.base64_image import test_image from rag.llm import EmbeddingModel, ChatModel, RerankModel, CvModel, TTSModel -@manager.route('/factories', methods=['GET']) # noqa: F821 +@manager.route("/factories", methods=["GET"]) # noqa: F821 @login_required def factories(): try: - fac = LLMFactoriesService.get_all() + fac = get_allowed_llm_factories() fac = [f.to_dict() for f in fac if f.name not in ["Youdao", "FastEmbed", "BAAI"]] llms = LLMService.get_all() mdl_types = {} @@ -43,14 +43,13 @@ def factories(): mdl_types[m.fid] = set([]) mdl_types[m.fid].add(m.model_type) for f in fac: - f["model_types"] = list(mdl_types.get(f["name"], [LLMType.CHAT, LLMType.EMBEDDING, LLMType.RERANK, - LLMType.IMAGE2TEXT, LLMType.SPEECH2TEXT, LLMType.TTS])) + f["model_types"] = list(mdl_types.get(f["name"], [LLMType.CHAT, LLMType.EMBEDDING, LLMType.RERANK, LLMType.IMAGE2TEXT, LLMType.SPEECH2TEXT, LLMType.TTS])) return get_json_result(data=fac) except Exception as e: return server_error_response(e) -@manager.route('/set_api_key', methods=['POST']) # noqa: F821 +@manager.route("/set_api_key", methods=["POST"]) # noqa: F821 @login_required @validate_request("llm_factory", "api_key") def set_api_key(): @@ -63,8 +62,7 @@ def set_api_key(): for llm in LLMService.query(fid=factory): if not embd_passed and llm.model_type == LLMType.EMBEDDING.value: assert factory in EmbeddingModel, f"Embedding model from {factory} is not supported yet." - mdl = EmbeddingModel[factory]( - req["api_key"], llm.llm_name, base_url=req.get("base_url")) + mdl = EmbeddingModel[factory](req["api_key"], llm.llm_name, base_url=req.get("base_url")) try: arr, tc = mdl.encode(["Test if the api key is available"]) if len(arr[0]) == 0: @@ -74,52 +72,40 @@ def set_api_key(): msg += f"\nFail to access embedding model({llm.llm_name}) using this api key." + str(e) elif not chat_passed and llm.model_type == LLMType.CHAT.value: assert factory in ChatModel, f"Chat model from {factory} is not supported yet." - mdl = ChatModel[factory]( - req["api_key"], llm.llm_name, base_url=req.get("base_url"), **extra) + mdl = ChatModel[factory](req["api_key"], llm.llm_name, base_url=req.get("base_url"), **extra) try: - m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], - {"temperature": 0.9, 'max_tokens': 50}) + m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9, "max_tokens": 50}) if m.find("**ERROR**") >= 0: raise Exception(m) chat_passed = True except Exception as e: - msg += f"\nFail to access model({llm.fid}/{llm.llm_name}) using this api key." + str( - e) + msg += f"\nFail to access model({llm.fid}/{llm.llm_name}) using this api key." + str(e) elif not rerank_passed and llm.model_type == LLMType.RERANK: assert factory in RerankModel, f"Re-rank model from {factory} is not supported yet." - mdl = RerankModel[factory]( - req["api_key"], llm.llm_name, base_url=req.get("base_url")) + mdl = RerankModel[factory](req["api_key"], llm.llm_name, base_url=req.get("base_url")) try: arr, tc = mdl.similarity("What's the weather?", ["Is it sunny today?"]) if len(arr) == 0 or tc == 0: raise Exception("Fail") rerank_passed = True - logging.debug(f'passed model rerank {llm.llm_name}') + logging.debug(f"passed model rerank {llm.llm_name}") except Exception as e: - msg += f"\nFail to access model({llm.fid}/{llm.llm_name}) using this api key." + str( - e) + msg += f"\nFail to access model({llm.fid}/{llm.llm_name}) using this api key." + str(e) if any([embd_passed, chat_passed, rerank_passed]): - msg = '' + msg = "" break if msg: return get_data_error_result(message=msg) - llm_config = { - "api_key": req["api_key"], - "api_base": req.get("base_url", "") - } + llm_config = {"api_key": req["api_key"], "api_base": req.get("base_url", "")} for n in ["model_type", "llm_name"]: if n in req: llm_config[n] = req[n] for llm in LLMService.query(fid=factory): - llm_config["max_tokens"]=llm.max_tokens - if not TenantLLMService.filter_update( - [TenantLLM.tenant_id == current_user.id, - TenantLLM.llm_factory == factory, - TenantLLM.llm_name == llm.llm_name], - llm_config): + llm_config["max_tokens"] = llm.max_tokens + if not TenantLLMService.filter_update([TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == factory, TenantLLM.llm_name == llm.llm_name], llm_config): TenantLLMService.save( tenant_id=current_user.id, llm_factory=factory, @@ -127,13 +113,13 @@ def set_api_key(): model_type=llm.model_type, api_key=llm_config["api_key"], api_base=llm_config["api_base"], - max_tokens=llm_config["max_tokens"] + max_tokens=llm_config["max_tokens"], ) return get_json_result(data=True) -@manager.route('/add_llm', methods=['POST']) # noqa: F821 +@manager.route("/add_llm", methods=["POST"]) # noqa: F821 @login_required @validate_request("llm_factory") def add_llm(): @@ -142,6 +128,9 @@ def add_llm(): api_key = req.get("api_key", "x") llm_name = req.get("llm_name") + if factory not in get_allowed_llm_factories(): + return get_data_error_result(message=f"LLM factory {factory} is not allowed") + def apikey_json(keys): nonlocal req return json.dumps({k: req.get(k, "") for k in keys}) @@ -204,7 +193,7 @@ def add_llm(): "llm_name": llm_name, "api_base": req.get("api_base", ""), "api_key": api_key, - "max_tokens": req.get("max_tokens") + "max_tokens": req.get("max_tokens"), } msg = "" @@ -212,10 +201,7 @@ def add_llm(): extra = {"provider": factory} if llm["model_type"] == LLMType.EMBEDDING.value: assert factory in EmbeddingModel, f"Embedding model from {factory} is not supported yet." - mdl = EmbeddingModel[factory]( - key=llm['api_key'], - model_name=mdl_nm, - base_url=llm["api_base"]) + mdl = EmbeddingModel[factory](key=llm["api_key"], model_name=mdl_nm, base_url=llm["api_base"]) try: arr, tc = mdl.encode(["Test if the api key is available"]) if len(arr[0]) == 0: @@ -225,42 +211,31 @@ def add_llm(): elif llm["model_type"] == LLMType.CHAT.value: assert factory in ChatModel, f"Chat model from {factory} is not supported yet." mdl = ChatModel[factory]( - key=llm['api_key'], + key=llm["api_key"], model_name=mdl_nm, base_url=llm["api_base"], **extra, ) try: - m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], { - "temperature": 0.9}) + m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9}) if not tc and m.find("**ERROR**:") >= 0: raise Exception(m) except Exception as e: - msg += f"\nFail to access model({factory}/{mdl_nm})." + str( - e) + msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e) elif llm["model_type"] == LLMType.RERANK: assert factory in RerankModel, f"RE-rank model from {factory} is not supported yet." try: - mdl = RerankModel[factory]( - key=llm["api_key"], - model_name=mdl_nm, - base_url=llm["api_base"] - ) + mdl = RerankModel[factory](key=llm["api_key"], model_name=mdl_nm, base_url=llm["api_base"]) arr, tc = mdl.similarity("Hello~ RAGFlower!", ["Hi, there!", "Ohh, my friend!"]) if len(arr) == 0: raise Exception("Not known.") except KeyError: msg += f"{factory} dose not support this model({factory}/{mdl_nm})" except Exception as e: - msg += f"\nFail to access model({factory}/{mdl_nm})." + str( - e) + msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e) elif llm["model_type"] == LLMType.IMAGE2TEXT.value: assert factory in CvModel, f"Image to text model from {factory} is not supported yet." - mdl = CvModel[factory]( - key=llm["api_key"], - model_name=mdl_nm, - base_url=llm["api_base"] - ) + mdl = CvModel[factory](key=llm["api_key"], model_name=mdl_nm, base_url=llm["api_base"]) try: image_data = test_image m, tc = mdl.describe(image_data) @@ -270,9 +245,7 @@ def add_llm(): msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e) elif llm["model_type"] == LLMType.TTS: assert factory in TTSModel, f"TTS model from {factory} is not supported yet." - mdl = TTSModel[factory]( - key=llm["api_key"], model_name=mdl_nm, base_url=llm["api_base"] - ) + mdl = TTSModel[factory](key=llm["api_key"], model_name=mdl_nm, base_url=llm["api_base"]) try: for resp in mdl.tts("Hello~ RAGFlower!"): pass @@ -285,51 +258,46 @@ def add_llm(): if msg: return get_data_error_result(message=msg) - if not TenantLLMService.filter_update( - [TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == factory, - TenantLLM.llm_name == llm["llm_name"]], llm): + if not TenantLLMService.filter_update([TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == factory, TenantLLM.llm_name == llm["llm_name"]], llm): TenantLLMService.save(**llm) return get_json_result(data=True) -@manager.route('/delete_llm', methods=['POST']) # noqa: F821 +@manager.route("/delete_llm", methods=["POST"]) # noqa: F821 @login_required @validate_request("llm_factory", "llm_name") def delete_llm(): req = request.json - TenantLLMService.filter_delete( - [TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"], - TenantLLM.llm_name == req["llm_name"]]) + TenantLLMService.filter_delete([TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"], TenantLLM.llm_name == req["llm_name"]]) return get_json_result(data=True) -@manager.route('/enable_llm', methods=['POST']) # noqa: F821 +@manager.route("/enable_llm", methods=["POST"]) # noqa: F821 @login_required @validate_request("llm_factory", "llm_name") def enable_llm(): req = request.json TenantLLMService.filter_update( - [TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"], - TenantLLM.llm_name == req["llm_name"]], {"status": str(req.get("status", "1"))}) + [TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"], TenantLLM.llm_name == req["llm_name"]], {"status": str(req.get("status", "1"))} + ) return get_json_result(data=True) -@manager.route('/delete_factory', methods=['POST']) # noqa: F821 +@manager.route("/delete_factory", methods=["POST"]) # noqa: F821 @login_required @validate_request("llm_factory") def delete_factory(): req = request.json - TenantLLMService.filter_delete( - [TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"]]) + TenantLLMService.filter_delete([TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"]]) return get_json_result(data=True) -@manager.route('/my_llms', methods=['GET']) # noqa: F821 +@manager.route("/my_llms", methods=["GET"]) # noqa: F821 @login_required def my_llms(): try: - include_details = request.args.get('include_details', 'false').lower() == 'true' + include_details = request.args.get("include_details", "false").lower() == "true" if include_details: res = {} @@ -345,40 +313,31 @@ def my_llms(): break if o_dict["llm_factory"] not in res: - res[o_dict["llm_factory"]] = { - "tags": factory_tags, - "llm": [] - } + res[o_dict["llm_factory"]] = {"tags": factory_tags, "llm": []} - res[o_dict["llm_factory"]]["llm"].append({ - "type": o_dict["model_type"], - "name": o_dict["llm_name"], - "used_token": o_dict["used_tokens"], - "api_base": o_dict["api_base"] or "", - "max_tokens": o_dict["max_tokens"] or 8192, - "status": o_dict["status"] or "1" - }) + res[o_dict["llm_factory"]]["llm"].append( + { + "type": o_dict["model_type"], + "name": o_dict["llm_name"], + "used_token": o_dict["used_tokens"], + "api_base": o_dict["api_base"] or "", + "max_tokens": o_dict["max_tokens"] or 8192, + "status": o_dict["status"] or "1", + } + ) else: res = {} for o in TenantLLMService.get_my_llms(current_user.id): if o["llm_factory"] not in res: - res[o["llm_factory"]] = { - "tags": o["tags"], - "llm": [] - } - res[o["llm_factory"]]["llm"].append({ - "type": o["model_type"], - "name": o["llm_name"], - "used_token": o["used_tokens"], - "status": o["status"] - }) + res[o["llm_factory"]] = {"tags": o["tags"], "llm": []} + res[o["llm_factory"]]["llm"].append({"type": o["model_type"], "name": o["llm_name"], "used_token": o["used_tokens"], "status": o["status"]}) return get_json_result(data=res) except Exception as e: return server_error_response(e) -@manager.route('/list', methods=['GET']) # noqa: F821 +@manager.route("/list", methods=["GET"]) # noqa: F821 @login_required def list_app(): self_deployed = ["FastEmbed", "Ollama", "Xinference", "LocalAI", "LM-Studio", "GPUStack"] @@ -386,14 +345,13 @@ def list_app(): model_type = request.args.get("model_type") try: objs = TenantLLMService.query(tenant_id=current_user.id) - facts = set([o.to_dict()["llm_factory"] for o in objs if o.api_key and o.status==StatusEnum.VALID.value]) + facts = set([o.to_dict()["llm_factory"] for o in objs if o.api_key and o.status == StatusEnum.VALID.value]) status = {(o.llm_name + "@" + o.llm_factory) for o in objs if o.status == StatusEnum.VALID.value} llms = LLMService.get_all() - llms = [m.to_dict() - for m in llms if m.status == StatusEnum.VALID.value and m.fid not in weighted and (m.llm_name + "@" + m.fid) in status] + llms = [m.to_dict() for m in llms if m.status == StatusEnum.VALID.value and m.fid not in weighted and (m.llm_name + "@" + m.fid) in status] for m in llms: m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" or m["fid"] in self_deployed - if "tei-" in os.getenv("COMPOSE_PROFILES", "") and m["model_type"]==LLMType.EMBEDDING and m["fid"]=="Builtin" and m["llm_name"]==os.getenv('TEI_MODEL', ''): + if "tei-" in os.getenv("COMPOSE_PROFILES", "") and m["model_type"] == LLMType.EMBEDDING and m["fid"] == "Builtin" and m["llm_name"] == os.getenv("TEI_MODEL", ""): m["available"] = True llm_set = set([m["llm_name"] + "@" + m["fid"] for m in llms]) diff --git a/api/settings.py b/api/settings.py index d44539b7c..02136fbef 100644 --- a/api/settings.py +++ b/api/settings.py @@ -46,6 +46,7 @@ HOST_IP = None HOST_PORT = None SECRET_KEY = None FACTORY_LLM_INFOS = None +ALLOWED_LLM_FACTORIES = None DATABASE_TYPE = os.getenv("DB_TYPE", "mysql") DATABASE = decrypt_database_config(name=DATABASE_TYPE) @@ -77,7 +78,7 @@ STRONG_TEST_COUNT = int(os.environ.get("STRONG_TEST_COUNT", "8")) SMTP_CONF = None MAIL_SERVER = "" MAIL_PORT = 000 -MAIL_USE_SSL= True +MAIL_USE_SSL = True MAIL_USE_TLS = False MAIL_USERNAME = "" MAIL_PASSWORD = "" @@ -104,13 +105,14 @@ def get_or_create_secret_key(): def init_settings(): - global LLM, LLM_FACTORY, LLM_BASE_URL, DATABASE_TYPE, DATABASE, FACTORY_LLM_INFOS, REGISTER_ENABLED + global LLM, LLM_FACTORY, LLM_BASE_URL, DATABASE_TYPE, DATABASE, FACTORY_LLM_INFOS, REGISTER_ENABLED, ALLOWED_LLM_FACTORIES DATABASE_TYPE = os.getenv("DB_TYPE", "mysql") DATABASE = decrypt_database_config(name=DATABASE_TYPE) LLM = get_base_config("user_default_llm", {}) or {} LLM_DEFAULT_MODELS = LLM.get("default_models", {}) or {} LLM_FACTORY = LLM.get("factory", "") or "" LLM_BASE_URL = LLM.get("base_url", "") or "" + ALLOWED_LLM_FACTORIES = LLM.get("allowed_factories", None) try: REGISTER_ENABLED = int(os.environ.get("REGISTER_ENABLED", "1")) except Exception: diff --git a/api/utils/api_utils.py b/api/utils/api_utils.py index 88cf7f5b9..630073788 100644 --- a/api/utils/api_utils.py +++ b/api/utils/api_utils.py @@ -39,6 +39,7 @@ from common.constants import ActiveEnum from api.db.db_models import APIToken from api.utils.json_encode import CustomJSONEncoder from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions +from api.db.services.tenant_llm_service import LLMFactoriesService from common.connection_utils import timeout from common.constants import RetCode @@ -51,16 +52,15 @@ def serialize_for_json(obj): Recursively serialize objects to make them JSON serializable. Handles ModelMetaclass and other non-serializable objects. """ - if hasattr(obj, '__dict__'): + if hasattr(obj, "__dict__"): # For objects with __dict__, try to serialize their attributes try: - return {key: serialize_for_json(value) for key, value in obj.__dict__.items() - if not key.startswith('_')} + return {key: serialize_for_json(value) for key, value in obj.__dict__.items() if not key.startswith("_")} except (AttributeError, TypeError): return str(obj) - elif hasattr(obj, '__name__'): + elif hasattr(obj, "__name__"): # For classes and metaclasses, return their name - return f"<{obj.__module__}.{obj.__name__}>" if hasattr(obj, '__module__') else f"<{obj.__name__}>" + return f"<{obj.__module__}.{obj.__name__}>" if hasattr(obj, "__module__") else f"<{obj.__name__}>" elif isinstance(obj, (list, tuple)): return [serialize_for_json(item) for item in obj] elif isinstance(obj, dict): @@ -71,6 +71,7 @@ def serialize_for_json(obj): # Fallback: convert to string representation return str(obj) + def get_data_error_result(code=RetCode.DATA_ERROR, message="Sorry! Data missing!"): logging.exception(Exception(message)) result_dict = {"code": code, "message": message} @@ -99,8 +100,7 @@ def server_error_response(e): except Exception: return get_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=None) if repr(e).find("index_not_found_exception") >= 0: - return get_json_result(code=RetCode.EXCEPTION_ERROR, - message="No chunk found, please upload file and parse it.") + return get_json_result(code=RetCode.EXCEPTION_ERROR, message="No chunk found, please upload file and parse it.") return get_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e)) @@ -129,8 +129,7 @@ def validate_request(*args, **kwargs): if no_arguments: error_string += "required argument are missing: {}; ".format(",".join(no_arguments)) if error_arguments: - error_string += "required argument values: {}".format( - ",".join(["{}={}".format(a[0], a[1]) for a in error_arguments])) + error_string += "required argument values: {}".format(",".join(["{}={}".format(a[0], a[1]) for a in error_arguments])) return get_json_result(code=RetCode.ARGUMENT_ERROR, message=error_string) return func(*_args, **_kwargs) @@ -145,8 +144,7 @@ def not_allowed_parameters(*params): input_arguments = flask_request.json or flask_request.form.to_dict() for param in params: if param in input_arguments: - return get_json_result(code=RetCode.ARGUMENT_ERROR, - message=f"Parameter {param} isn't allowed") + return get_json_result(code=RetCode.ARGUMENT_ERROR, message=f"Parameter {param} isn't allowed") return f(*args, **kwargs) return wrapper @@ -158,6 +156,7 @@ def active_required(f): @wraps(f) def wrapper(*args, **kwargs): from api.db.services import UserService + user_id = current_user.id usr = UserService.filter_by_id(user_id) # check is_active @@ -199,6 +198,7 @@ def construct_json_result(code: RetCode = RetCode.SUCCESS, message="success", da else: return jsonify({"code": code, "message": message, "data": data}) + def token_required(func): @wraps(func) def decorated_function(*args, **kwargs): @@ -213,8 +213,7 @@ def token_required(func): token = authorization_list[1] objs = APIToken.query(token=token) if not objs: - return get_json_result(data=False, message="Authentication error: API key is invalid!", - code=RetCode.AUTHENTICATION_ERROR) + return get_json_result(data=False, message="Authentication error: API key is invalid!", code=RetCode.AUTHENTICATION_ERROR) kwargs["tenant_id"] = objs[0].tenant_id return func(*args, **kwargs) @@ -243,9 +242,10 @@ def get_result(code=RetCode.SUCCESS, message="", data=None, total=None): return jsonify(response) + def get_error_data_result( - message="Sorry! Data missing!", - code=RetCode.DATA_ERROR, + message="Sorry! Data missing!", + code=RetCode.DATA_ERROR, ): result_dict = {"code": code, "message": message} response = {} @@ -271,6 +271,7 @@ def get_error_operating_result(message="Operating error"): def generate_confirmation_token(): import secrets + return "ragflow-" + secrets.token_urlsafe(32) @@ -345,18 +346,7 @@ def get_parser_config(chunk_method, parser_config): return merged_config -def get_data_openai( - id=None, - created=None, - model=None, - prompt_tokens=0, - completion_tokens=0, - content=None, - finish_reason=None, - object="chat.completion", - param=None, - stream=False -): +def get_data_openai(id=None, created=None, model=None, prompt_tokens=0, completion_tokens=0, content=None, finish_reason=None, object="chat.completion", param=None, stream=False): total_tokens = prompt_tokens + completion_tokens if stream: @@ -364,11 +354,13 @@ def get_data_openai( "id": f"{id}", "object": "chat.completion.chunk", "model": model, - "choices": [{ - "delta": {"content": content}, - "finish_reason": finish_reason, - "index": 0, - }], + "choices": [ + { + "delta": {"content": content}, + "finish_reason": finish_reason, + "index": 0, + } + ], } return { @@ -387,15 +379,14 @@ def get_data_openai( "rejected_prediction_tokens": 0, }, }, - "choices": [{ - "message": { - "role": "assistant", - "content": content - }, - "logprobs": None, - "finish_reason": finish_reason, - "index": 0, - }], + "choices": [ + { + "message": {"role": "assistant", "content": content}, + "logprobs": None, + "finish_reason": finish_reason, + "index": 0, + } + ], } @@ -431,6 +422,7 @@ def check_duplicate_ids(ids, id_type="item"): def verify_embedding_availability(embd_id: str, tenant_id: str) -> tuple[bool, Response | None]: from api.db.services.llm_service import LLMService from api.db.services.tenant_llm_service import TenantLLMService + """ Verifies availability of an embedding model for a specific tenant. @@ -469,11 +461,9 @@ def verify_embedding_availability(embd_id: str, tenant_id: str) -> tuple[bool, R in_llm_service = bool(LLMService.query(llm_name=llm_name, fid=llm_factory, model_type="embedding")) tenant_llms = TenantLLMService.get_my_llms(tenant_id=tenant_id) - is_tenant_model = any( - llm["llm_name"] == llm_name and llm["llm_factory"] == llm_factory and llm["model_type"] == "embedding" for - llm in tenant_llms) + is_tenant_model = any(llm["llm_name"] == llm_name and llm["llm_factory"] == llm_factory and llm["model_type"] == "embedding" for llm in tenant_llms) - is_builtin_model = llm_factory=='Builtin' + is_builtin_model = llm_factory == "Builtin" if not (is_builtin_model or is_tenant_model or in_llm_service): return False, get_error_argument_result(f"Unsupported model: <{embd_id}>") @@ -610,7 +600,6 @@ def get_mcp_tools(mcp_servers: list, timeout: float | int = 10) -> tuple[dict, s return {}, str(e) - async def is_strong_enough(chat_model, embedding_model): count = settings.STRONG_TEST_COUNT if not chat_model or not embedding_model: @@ -626,9 +615,7 @@ async def is_strong_enough(chat_model, embedding_model): _ = await trio.to_thread.run_sync(lambda: embedding_model.encode(["Are you strong enough!?"])) if chat_model: with trio.fail_after(30): - res = await trio.to_thread.run_sync(lambda: chat_model.chat("Nothing special.", [{"role": "user", - "content": "Are you strong enough!?"}], - {})) + res = await trio.to_thread.run_sync(lambda: chat_model.chat("Nothing special.", [{"role": "user", "content": "Are you strong enough!?"}], {})) if res.find("**ERROR**") >= 0: raise Exception(res) @@ -636,3 +623,11 @@ async def is_strong_enough(chat_model, embedding_model): async with trio.open_nursery() as nursery: for _ in range(count): nursery.start_soon(_is_strong_enough) + + +def get_allowed_llm_factories() -> list: + factories = LLMFactoriesService.get_all() + if settings.ALLOWED_LLM_FACTORIES is None: + return factories + + return [factory for factory in factories if factory.name in settings.ALLOWED_LLM_FACTORIES] diff --git a/docs/configurations.md b/docs/configurations.md index 76fe1a81f..7aba314b4 100644 --- a/docs/configurations.md +++ b/docs/configurations.md @@ -29,9 +29,9 @@ docker compose -f docker/docker-compose.yml up -d ## Docker Compose -- **docker-compose.yml** +- **docker-compose.yml** Sets up environment for RAGFlow and its dependencies. -- **docker-compose-base.yml** +- **docker-compose-base.yml** Sets up environment for RAGFlow's dependencies: Elasticsearch/[Infinity](https://github.com/infiniflow/infinity), MySQL, MinIO, and Redis. :::danger IMPORTANT @@ -44,97 +44,97 @@ The [.env](https://github.com/infiniflow/ragflow/blob/main/docker/.env) file con ### Elasticsearch -- `STACK_VERSION` +- `STACK_VERSION` The version of Elasticsearch. Defaults to `8.11.3` -- `ES_PORT` +- `ES_PORT` The port used to expose the Elasticsearch service to the host machine, allowing **external** access to the service running inside the Docker container. Defaults to `1200`. -- `ELASTIC_PASSWORD` +- `ELASTIC_PASSWORD` The password for Elasticsearch. ### Kibana -- `KIBANA_PORT` +- `KIBANA_PORT` The port used to expose the Kibana service to the host machine, allowing **external** access to the service running inside the Docker container. Defaults to `6601`. -- `KIBANA_USER` +- `KIBANA_USER` The username for Kibana. Defaults to `rag_flow`. -- `KIBANA_PASSWORD` +- `KIBANA_PASSWORD` The password for Kibana. Defaults to `infini_rag_flow`. ### Resource management -- `MEM_LIMIT` +- `MEM_LIMIT` The maximum amount of the memory, in bytes, that *a specific* Docker container can use while running. Defaults to `8073741824`. ### MySQL -- `MYSQL_PASSWORD` +- `MYSQL_PASSWORD` The password for MySQL. -- `MYSQL_PORT` +- `MYSQL_PORT` The port used to expose the MySQL service to the host machine, allowing **external** access to the MySQL database running inside the Docker container. Defaults to `5455`. ### MinIO RAGFlow utilizes MinIO as its object storage solution, leveraging its scalability to store and manage all uploaded files. -- `MINIO_CONSOLE_PORT` +- `MINIO_CONSOLE_PORT` The port used to expose the MinIO console interface to the host machine, allowing **external** access to the web-based console running inside the Docker container. Defaults to `9001` -- `MINIO_PORT` +- `MINIO_PORT` The port used to expose the MinIO API service to the host machine, allowing **external** access to the MinIO object storage service running inside the Docker container. Defaults to `9000`. -- `MINIO_USER` +- `MINIO_USER` The username for MinIO. -- `MINIO_PASSWORD` +- `MINIO_PASSWORD` The password for MinIO. ### Redis -- `REDIS_PORT` +- `REDIS_PORT` The port used to expose the Redis service to the host machine, allowing **external** access to the Redis service running inside the Docker container. Defaults to `6379`. -- `REDIS_PASSWORD` +- `REDIS_PASSWORD` The password for Redis. ### RAGFlow -- `SVR_HTTP_PORT` +- `SVR_HTTP_PORT` The port used to expose RAGFlow's HTTP API service to the host machine, allowing **external** access to the service running inside the Docker container. Defaults to `9380`. -- `RAGFLOW-IMAGE` - The Docker image edition. Available editions: - - - `infiniflow/ragflow:v0.21.1-slim` (default): The RAGFlow Docker image without embedding models. +- `RAGFLOW-IMAGE` + The Docker image edition. Available editions: + + - `infiniflow/ragflow:v0.21.1-slim` (default): The RAGFlow Docker image without embedding models. - `infiniflow/ragflow:v0.21.1`: The RAGFlow Docker image with embedding models including: - Built-in embedding models: - - `BAAI/bge-large-zh-v1.5` + - `BAAI/bge-large-zh-v1.5` - `maidalun1020/bce-embedding-base_v1` -:::tip NOTE -If you cannot download the RAGFlow Docker image, try the following mirrors. +:::tip NOTE +If you cannot download the RAGFlow Docker image, try the following mirrors. -- For the `nightly` edition: +- For the `nightly` edition: - `RAGFLOW_IMAGE=swr.cn-north-4.myhuaweicloud.com/infiniflow/ragflow:nightly` or, - `RAGFLOW_IMAGE=registry.cn-hangzhou.aliyuncs.com/infiniflow/ragflow:nightly`. ::: ### Embedding service -- `TEI_MODEL` +- `TEI_MODEL` The embedding model which text-embeddings-inference serves. Allowed values are one of `Qwen/Qwen3-Embedding-0.6B`(default), `BAAI/bge-m3`, and `BAAI/bge-small-en-v1.5`. -- `TEI_PORT` +- `TEI_PORT` The port used to expose the text-embeddings-inference service to the host machine, allowing **external** access to the text-embeddings-inference service running inside the Docker container. Defaults to `6380`. ### Timezone -- `TZ` +- `TZ` The local time zone. Defaults to `Asia/Shanghai`. ### Hugging Face mirror site -- `HF_ENDPOINT` +- `HF_ENDPOINT` The mirror site for huggingface.co. It is disabled by default. You can uncomment this line if you have limited access to the primary Hugging Face domain. ### MacOS -- `MACOS` +- `MACOS` Optimizations for macOS. It is disabled by default. You can uncomment this line if your OS is macOS. ### User registration @@ -153,7 +153,7 @@ If you cannot download the RAGFlow Docker image, try the following mirrors. - `port`: The API server's serving port inside the Docker container. Defaults to `9380`. ### `mysql` - + - `name`: The MySQL database name. Defaults to `rag_flow`. - `user`: The username for MySQL. - `password`: The password for MySQL. @@ -162,12 +162,12 @@ If you cannot download the RAGFlow Docker image, try the following mirrors. - `stale_timeout`: Timeout in seconds. ### `minio` - + - `user`: The username for MinIO. - `password`: The password for MinIO. - `host`: The MinIO serving IP *and* port inside the Docker container. Defaults to `minio:9000`. -### `oauth` +### `oauth` The OAuth configuration for signing up or signing in to RAGFlow using a third-party account. @@ -184,7 +184,7 @@ The OAuth configuration for signing up or signing in to RAGFlow using a third-pa - `scope`: Requested permission scope, a space-separated string. For example, `openid profile email`. - `redirect_uri`: Required, URI to which the authorization server redirects during the authentication flow to return results. Must match the callback URI registered with the authentication server. Format: `https://your-app.com/v1/user/oauth/callback/`. For local configuration, you can directly use `http://127.0.0.1:80/v1/user/oauth/callback/`. -:::tip NOTE +:::tip NOTE The following are best practices for configuring various third-party authentication methods. You can configure one or multiple third-party authentication methods for Ragflow: ```yaml oauth: @@ -216,9 +216,9 @@ oauth: ``` ::: -### `user_default_llm` +### `user_default_llm` -The default LLM to use for a new RAGFlow user. It is disabled by default. To enable this feature, uncomment the corresponding lines in **service_conf.yaml.template**. +The default LLM to use for a new RAGFlow user. It is disabled by default. To enable this feature, uncomment the corresponding lines in **service_conf.yaml.template**. - `factory`: The LLM supplier. Available options: - `"OpenAI"` @@ -228,7 +228,11 @@ The default LLM to use for a new RAGFlow user. It is disabled by default. To ena - `"VolcEngine"` - `"ZHIPU-AI"` - `api_key`: The API key for the specified LLM. You will need to apply for your model API key online. +- `allowed_factories`: If this is set, the users will be allowed to add only the factories in this list. + - `"OpenAI"` + - `"DeepSeek"` + - `"Moonshot"` -:::tip NOTE +:::tip NOTE If you do not set the default LLM here, configure the default LLM on the **Settings** page in the RAGFlow UI. -::: \ No newline at end of file +:::