diff --git a/api/apps/llm_app.py b/api/apps/llm_app.py index 4f0e3c91c..c49097fac 100644 --- a/api/apps/llm_app.py +++ b/api/apps/llm_app.py @@ -23,16 +23,16 @@ from api.db.services.llm_service import LLMService from api.utils.api_utils import server_error_response, get_data_error_result, validate_request from common.constants import StatusEnum, LLMType from api.db.db_models import TenantLLM -from api.utils.api_utils import get_json_result +from api.utils.api_utils import get_json_result, get_allowed_llm_factories from common.base64_image import test_image from rag.llm import EmbeddingModel, ChatModel, RerankModel, CvModel, TTSModel -@manager.route('/factories', methods=['GET']) # noqa: F821 +@manager.route("/factories", methods=["GET"]) # noqa: F821 @login_required def factories(): try: - fac = LLMFactoriesService.get_all() + fac = get_allowed_llm_factories() fac = [f.to_dict() for f in fac if f.name not in ["Youdao", "FastEmbed", "BAAI"]] llms = LLMService.get_all() mdl_types = {} @@ -43,14 +43,13 @@ def factories(): mdl_types[m.fid] = set([]) mdl_types[m.fid].add(m.model_type) for f in fac: - f["model_types"] = list(mdl_types.get(f["name"], [LLMType.CHAT, LLMType.EMBEDDING, LLMType.RERANK, - LLMType.IMAGE2TEXT, LLMType.SPEECH2TEXT, LLMType.TTS])) + f["model_types"] = list(mdl_types.get(f["name"], [LLMType.CHAT, LLMType.EMBEDDING, LLMType.RERANK, LLMType.IMAGE2TEXT, LLMType.SPEECH2TEXT, LLMType.TTS])) return get_json_result(data=fac) except Exception as e: return server_error_response(e) -@manager.route('/set_api_key', methods=['POST']) # noqa: F821 +@manager.route("/set_api_key", methods=["POST"]) # noqa: F821 @login_required @validate_request("llm_factory", "api_key") def set_api_key(): @@ -63,8 +62,7 @@ def set_api_key(): for llm in LLMService.query(fid=factory): if not embd_passed and llm.model_type == LLMType.EMBEDDING.value: assert factory in EmbeddingModel, f"Embedding model from {factory} is not supported yet." - mdl = EmbeddingModel[factory]( - req["api_key"], llm.llm_name, base_url=req.get("base_url")) + mdl = EmbeddingModel[factory](req["api_key"], llm.llm_name, base_url=req.get("base_url")) try: arr, tc = mdl.encode(["Test if the api key is available"]) if len(arr[0]) == 0: @@ -74,52 +72,40 @@ def set_api_key(): msg += f"\nFail to access embedding model({llm.llm_name}) using this api key." + str(e) elif not chat_passed and llm.model_type == LLMType.CHAT.value: assert factory in ChatModel, f"Chat model from {factory} is not supported yet." - mdl = ChatModel[factory]( - req["api_key"], llm.llm_name, base_url=req.get("base_url"), **extra) + mdl = ChatModel[factory](req["api_key"], llm.llm_name, base_url=req.get("base_url"), **extra) try: - m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], - {"temperature": 0.9, 'max_tokens': 50}) + m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9, "max_tokens": 50}) if m.find("**ERROR**") >= 0: raise Exception(m) chat_passed = True except Exception as e: - msg += f"\nFail to access model({llm.fid}/{llm.llm_name}) using this api key." + str( - e) + msg += f"\nFail to access model({llm.fid}/{llm.llm_name}) using this api key." + str(e) elif not rerank_passed and llm.model_type == LLMType.RERANK: assert factory in RerankModel, f"Re-rank model from {factory} is not supported yet." - mdl = RerankModel[factory]( - req["api_key"], llm.llm_name, base_url=req.get("base_url")) + mdl = RerankModel[factory](req["api_key"], llm.llm_name, base_url=req.get("base_url")) try: arr, tc = mdl.similarity("What's the weather?", ["Is it sunny today?"]) if len(arr) == 0 or tc == 0: raise Exception("Fail") rerank_passed = True - logging.debug(f'passed model rerank {llm.llm_name}') + logging.debug(f"passed model rerank {llm.llm_name}") except Exception as e: - msg += f"\nFail to access model({llm.fid}/{llm.llm_name}) using this api key." + str( - e) + msg += f"\nFail to access model({llm.fid}/{llm.llm_name}) using this api key." + str(e) if any([embd_passed, chat_passed, rerank_passed]): - msg = '' + msg = "" break if msg: return get_data_error_result(message=msg) - llm_config = { - "api_key": req["api_key"], - "api_base": req.get("base_url", "") - } + llm_config = {"api_key": req["api_key"], "api_base": req.get("base_url", "")} for n in ["model_type", "llm_name"]: if n in req: llm_config[n] = req[n] for llm in LLMService.query(fid=factory): - llm_config["max_tokens"]=llm.max_tokens - if not TenantLLMService.filter_update( - [TenantLLM.tenant_id == current_user.id, - TenantLLM.llm_factory == factory, - TenantLLM.llm_name == llm.llm_name], - llm_config): + llm_config["max_tokens"] = llm.max_tokens + if not TenantLLMService.filter_update([TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == factory, TenantLLM.llm_name == llm.llm_name], llm_config): TenantLLMService.save( tenant_id=current_user.id, llm_factory=factory, @@ -127,13 +113,13 @@ def set_api_key(): model_type=llm.model_type, api_key=llm_config["api_key"], api_base=llm_config["api_base"], - max_tokens=llm_config["max_tokens"] + max_tokens=llm_config["max_tokens"], ) return get_json_result(data=True) -@manager.route('/add_llm', methods=['POST']) # noqa: F821 +@manager.route("/add_llm", methods=["POST"]) # noqa: F821 @login_required @validate_request("llm_factory") def add_llm(): @@ -142,6 +128,9 @@ def add_llm(): api_key = req.get("api_key", "x") llm_name = req.get("llm_name") + if factory not in get_allowed_llm_factories(): + return get_data_error_result(message=f"LLM factory {factory} is not allowed") + def apikey_json(keys): nonlocal req return json.dumps({k: req.get(k, "") for k in keys}) @@ -204,7 +193,7 @@ def add_llm(): "llm_name": llm_name, "api_base": req.get("api_base", ""), "api_key": api_key, - "max_tokens": req.get("max_tokens") + "max_tokens": req.get("max_tokens"), } msg = "" @@ -212,10 +201,7 @@ def add_llm(): extra = {"provider": factory} if llm["model_type"] == LLMType.EMBEDDING.value: assert factory in EmbeddingModel, f"Embedding model from {factory} is not supported yet." - mdl = EmbeddingModel[factory]( - key=llm['api_key'], - model_name=mdl_nm, - base_url=llm["api_base"]) + mdl = EmbeddingModel[factory](key=llm["api_key"], model_name=mdl_nm, base_url=llm["api_base"]) try: arr, tc = mdl.encode(["Test if the api key is available"]) if len(arr[0]) == 0: @@ -225,42 +211,31 @@ def add_llm(): elif llm["model_type"] == LLMType.CHAT.value: assert factory in ChatModel, f"Chat model from {factory} is not supported yet." mdl = ChatModel[factory]( - key=llm['api_key'], + key=llm["api_key"], model_name=mdl_nm, base_url=llm["api_base"], **extra, ) try: - m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], { - "temperature": 0.9}) + m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9}) if not tc and m.find("**ERROR**:") >= 0: raise Exception(m) except Exception as e: - msg += f"\nFail to access model({factory}/{mdl_nm})." + str( - e) + msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e) elif llm["model_type"] == LLMType.RERANK: assert factory in RerankModel, f"RE-rank model from {factory} is not supported yet." try: - mdl = RerankModel[factory]( - key=llm["api_key"], - model_name=mdl_nm, - base_url=llm["api_base"] - ) + mdl = RerankModel[factory](key=llm["api_key"], model_name=mdl_nm, base_url=llm["api_base"]) arr, tc = mdl.similarity("Hello~ RAGFlower!", ["Hi, there!", "Ohh, my friend!"]) if len(arr) == 0: raise Exception("Not known.") except KeyError: msg += f"{factory} dose not support this model({factory}/{mdl_nm})" except Exception as e: - msg += f"\nFail to access model({factory}/{mdl_nm})." + str( - e) + msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e) elif llm["model_type"] == LLMType.IMAGE2TEXT.value: assert factory in CvModel, f"Image to text model from {factory} is not supported yet." - mdl = CvModel[factory]( - key=llm["api_key"], - model_name=mdl_nm, - base_url=llm["api_base"] - ) + mdl = CvModel[factory](key=llm["api_key"], model_name=mdl_nm, base_url=llm["api_base"]) try: image_data = test_image m, tc = mdl.describe(image_data) @@ -270,9 +245,7 @@ def add_llm(): msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e) elif llm["model_type"] == LLMType.TTS: assert factory in TTSModel, f"TTS model from {factory} is not supported yet." - mdl = TTSModel[factory]( - key=llm["api_key"], model_name=mdl_nm, base_url=llm["api_base"] - ) + mdl = TTSModel[factory](key=llm["api_key"], model_name=mdl_nm, base_url=llm["api_base"]) try: for resp in mdl.tts("Hello~ RAGFlower!"): pass @@ -285,51 +258,46 @@ def add_llm(): if msg: return get_data_error_result(message=msg) - if not TenantLLMService.filter_update( - [TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == factory, - TenantLLM.llm_name == llm["llm_name"]], llm): + if not TenantLLMService.filter_update([TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == factory, TenantLLM.llm_name == llm["llm_name"]], llm): TenantLLMService.save(**llm) return get_json_result(data=True) -@manager.route('/delete_llm', methods=['POST']) # noqa: F821 +@manager.route("/delete_llm", methods=["POST"]) # noqa: F821 @login_required @validate_request("llm_factory", "llm_name") def delete_llm(): req = request.json - TenantLLMService.filter_delete( - [TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"], - TenantLLM.llm_name == req["llm_name"]]) + TenantLLMService.filter_delete([TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"], TenantLLM.llm_name == req["llm_name"]]) return get_json_result(data=True) -@manager.route('/enable_llm', methods=['POST']) # noqa: F821 +@manager.route("/enable_llm", methods=["POST"]) # noqa: F821 @login_required @validate_request("llm_factory", "llm_name") def enable_llm(): req = request.json TenantLLMService.filter_update( - [TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"], - TenantLLM.llm_name == req["llm_name"]], {"status": str(req.get("status", "1"))}) + [TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"], TenantLLM.llm_name == req["llm_name"]], {"status": str(req.get("status", "1"))} + ) return get_json_result(data=True) -@manager.route('/delete_factory', methods=['POST']) # noqa: F821 +@manager.route("/delete_factory", methods=["POST"]) # noqa: F821 @login_required @validate_request("llm_factory") def delete_factory(): req = request.json - TenantLLMService.filter_delete( - [TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"]]) + TenantLLMService.filter_delete([TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"]]) return get_json_result(data=True) -@manager.route('/my_llms', methods=['GET']) # noqa: F821 +@manager.route("/my_llms", methods=["GET"]) # noqa: F821 @login_required def my_llms(): try: - include_details = request.args.get('include_details', 'false').lower() == 'true' + include_details = request.args.get("include_details", "false").lower() == "true" if include_details: res = {} @@ -345,40 +313,31 @@ def my_llms(): break if o_dict["llm_factory"] not in res: - res[o_dict["llm_factory"]] = { - "tags": factory_tags, - "llm": [] - } + res[o_dict["llm_factory"]] = {"tags": factory_tags, "llm": []} - res[o_dict["llm_factory"]]["llm"].append({ - "type": o_dict["model_type"], - "name": o_dict["llm_name"], - "used_token": o_dict["used_tokens"], - "api_base": o_dict["api_base"] or "", - "max_tokens": o_dict["max_tokens"] or 8192, - "status": o_dict["status"] or "1" - }) + res[o_dict["llm_factory"]]["llm"].append( + { + "type": o_dict["model_type"], + "name": o_dict["llm_name"], + "used_token": o_dict["used_tokens"], + "api_base": o_dict["api_base"] or "", + "max_tokens": o_dict["max_tokens"] or 8192, + "status": o_dict["status"] or "1", + } + ) else: res = {} for o in TenantLLMService.get_my_llms(current_user.id): if o["llm_factory"] not in res: - res[o["llm_factory"]] = { - "tags": o["tags"], - "llm": [] - } - res[o["llm_factory"]]["llm"].append({ - "type": o["model_type"], - "name": o["llm_name"], - "used_token": o["used_tokens"], - "status": o["status"] - }) + res[o["llm_factory"]] = {"tags": o["tags"], "llm": []} + res[o["llm_factory"]]["llm"].append({"type": o["model_type"], "name": o["llm_name"], "used_token": o["used_tokens"], "status": o["status"]}) return get_json_result(data=res) except Exception as e: return server_error_response(e) -@manager.route('/list', methods=['GET']) # noqa: F821 +@manager.route("/list", methods=["GET"]) # noqa: F821 @login_required def list_app(): self_deployed = ["FastEmbed", "Ollama", "Xinference", "LocalAI", "LM-Studio", "GPUStack"] @@ -386,14 +345,13 @@ def list_app(): model_type = request.args.get("model_type") try: objs = TenantLLMService.query(tenant_id=current_user.id) - facts = set([o.to_dict()["llm_factory"] for o in objs if o.api_key and o.status==StatusEnum.VALID.value]) + facts = set([o.to_dict()["llm_factory"] for o in objs if o.api_key and o.status == StatusEnum.VALID.value]) status = {(o.llm_name + "@" + o.llm_factory) for o in objs if o.status == StatusEnum.VALID.value} llms = LLMService.get_all() - llms = [m.to_dict() - for m in llms if m.status == StatusEnum.VALID.value and m.fid not in weighted and (m.llm_name + "@" + m.fid) in status] + llms = [m.to_dict() for m in llms if m.status == StatusEnum.VALID.value and m.fid not in weighted and (m.llm_name + "@" + m.fid) in status] for m in llms: m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" or m["fid"] in self_deployed - if "tei-" in os.getenv("COMPOSE_PROFILES", "") and m["model_type"]==LLMType.EMBEDDING and m["fid"]=="Builtin" and m["llm_name"]==os.getenv('TEI_MODEL', ''): + if "tei-" in os.getenv("COMPOSE_PROFILES", "") and m["model_type"] == LLMType.EMBEDDING and m["fid"] == "Builtin" and m["llm_name"] == os.getenv("TEI_MODEL", ""): m["available"] = True llm_set = set([m["llm_name"] + "@" + m["fid"] for m in llms]) diff --git a/api/settings.py b/api/settings.py index d44539b7c..02136fbef 100644 --- a/api/settings.py +++ b/api/settings.py @@ -46,6 +46,7 @@ HOST_IP = None HOST_PORT = None SECRET_KEY = None FACTORY_LLM_INFOS = None +ALLOWED_LLM_FACTORIES = None DATABASE_TYPE = os.getenv("DB_TYPE", "mysql") DATABASE = decrypt_database_config(name=DATABASE_TYPE) @@ -77,7 +78,7 @@ STRONG_TEST_COUNT = int(os.environ.get("STRONG_TEST_COUNT", "8")) SMTP_CONF = None MAIL_SERVER = "" MAIL_PORT = 000 -MAIL_USE_SSL= True +MAIL_USE_SSL = True MAIL_USE_TLS = False MAIL_USERNAME = "" MAIL_PASSWORD = "" @@ -104,13 +105,14 @@ def get_or_create_secret_key(): def init_settings(): - global LLM, LLM_FACTORY, LLM_BASE_URL, DATABASE_TYPE, DATABASE, FACTORY_LLM_INFOS, REGISTER_ENABLED + global LLM, LLM_FACTORY, LLM_BASE_URL, DATABASE_TYPE, DATABASE, FACTORY_LLM_INFOS, REGISTER_ENABLED, ALLOWED_LLM_FACTORIES DATABASE_TYPE = os.getenv("DB_TYPE", "mysql") DATABASE = decrypt_database_config(name=DATABASE_TYPE) LLM = get_base_config("user_default_llm", {}) or {} LLM_DEFAULT_MODELS = LLM.get("default_models", {}) or {} LLM_FACTORY = LLM.get("factory", "") or "" LLM_BASE_URL = LLM.get("base_url", "") or "" + ALLOWED_LLM_FACTORIES = LLM.get("allowed_factories", None) try: REGISTER_ENABLED = int(os.environ.get("REGISTER_ENABLED", "1")) except Exception: diff --git a/api/utils/api_utils.py b/api/utils/api_utils.py index 88cf7f5b9..630073788 100644 --- a/api/utils/api_utils.py +++ b/api/utils/api_utils.py @@ -39,6 +39,7 @@ from common.constants import ActiveEnum from api.db.db_models import APIToken from api.utils.json_encode import CustomJSONEncoder from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions +from api.db.services.tenant_llm_service import LLMFactoriesService from common.connection_utils import timeout from common.constants import RetCode @@ -51,16 +52,15 @@ def serialize_for_json(obj): Recursively serialize objects to make them JSON serializable. Handles ModelMetaclass and other non-serializable objects. """ - if hasattr(obj, '__dict__'): + if hasattr(obj, "__dict__"): # For objects with __dict__, try to serialize their attributes try: - return {key: serialize_for_json(value) for key, value in obj.__dict__.items() - if not key.startswith('_')} + return {key: serialize_for_json(value) for key, value in obj.__dict__.items() if not key.startswith("_")} except (AttributeError, TypeError): return str(obj) - elif hasattr(obj, '__name__'): + elif hasattr(obj, "__name__"): # For classes and metaclasses, return their name - return f"<{obj.__module__}.{obj.__name__}>" if hasattr(obj, '__module__') else f"<{obj.__name__}>" + return f"<{obj.__module__}.{obj.__name__}>" if hasattr(obj, "__module__") else f"<{obj.__name__}>" elif isinstance(obj, (list, tuple)): return [serialize_for_json(item) for item in obj] elif isinstance(obj, dict): @@ -71,6 +71,7 @@ def serialize_for_json(obj): # Fallback: convert to string representation return str(obj) + def get_data_error_result(code=RetCode.DATA_ERROR, message="Sorry! Data missing!"): logging.exception(Exception(message)) result_dict = {"code": code, "message": message} @@ -99,8 +100,7 @@ def server_error_response(e): except Exception: return get_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=None) if repr(e).find("index_not_found_exception") >= 0: - return get_json_result(code=RetCode.EXCEPTION_ERROR, - message="No chunk found, please upload file and parse it.") + return get_json_result(code=RetCode.EXCEPTION_ERROR, message="No chunk found, please upload file and parse it.") return get_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e)) @@ -129,8 +129,7 @@ def validate_request(*args, **kwargs): if no_arguments: error_string += "required argument are missing: {}; ".format(",".join(no_arguments)) if error_arguments: - error_string += "required argument values: {}".format( - ",".join(["{}={}".format(a[0], a[1]) for a in error_arguments])) + error_string += "required argument values: {}".format(",".join(["{}={}".format(a[0], a[1]) for a in error_arguments])) return get_json_result(code=RetCode.ARGUMENT_ERROR, message=error_string) return func(*_args, **_kwargs) @@ -145,8 +144,7 @@ def not_allowed_parameters(*params): input_arguments = flask_request.json or flask_request.form.to_dict() for param in params: if param in input_arguments: - return get_json_result(code=RetCode.ARGUMENT_ERROR, - message=f"Parameter {param} isn't allowed") + return get_json_result(code=RetCode.ARGUMENT_ERROR, message=f"Parameter {param} isn't allowed") return f(*args, **kwargs) return wrapper @@ -158,6 +156,7 @@ def active_required(f): @wraps(f) def wrapper(*args, **kwargs): from api.db.services import UserService + user_id = current_user.id usr = UserService.filter_by_id(user_id) # check is_active @@ -199,6 +198,7 @@ def construct_json_result(code: RetCode = RetCode.SUCCESS, message="success", da else: return jsonify({"code": code, "message": message, "data": data}) + def token_required(func): @wraps(func) def decorated_function(*args, **kwargs): @@ -213,8 +213,7 @@ def token_required(func): token = authorization_list[1] objs = APIToken.query(token=token) if not objs: - return get_json_result(data=False, message="Authentication error: API key is invalid!", - code=RetCode.AUTHENTICATION_ERROR) + return get_json_result(data=False, message="Authentication error: API key is invalid!", code=RetCode.AUTHENTICATION_ERROR) kwargs["tenant_id"] = objs[0].tenant_id return func(*args, **kwargs) @@ -243,9 +242,10 @@ def get_result(code=RetCode.SUCCESS, message="", data=None, total=None): return jsonify(response) + def get_error_data_result( - message="Sorry! Data missing!", - code=RetCode.DATA_ERROR, + message="Sorry! Data missing!", + code=RetCode.DATA_ERROR, ): result_dict = {"code": code, "message": message} response = {} @@ -271,6 +271,7 @@ def get_error_operating_result(message="Operating error"): def generate_confirmation_token(): import secrets + return "ragflow-" + secrets.token_urlsafe(32) @@ -345,18 +346,7 @@ def get_parser_config(chunk_method, parser_config): return merged_config -def get_data_openai( - id=None, - created=None, - model=None, - prompt_tokens=0, - completion_tokens=0, - content=None, - finish_reason=None, - object="chat.completion", - param=None, - stream=False -): +def get_data_openai(id=None, created=None, model=None, prompt_tokens=0, completion_tokens=0, content=None, finish_reason=None, object="chat.completion", param=None, stream=False): total_tokens = prompt_tokens + completion_tokens if stream: @@ -364,11 +354,13 @@ def get_data_openai( "id": f"{id}", "object": "chat.completion.chunk", "model": model, - "choices": [{ - "delta": {"content": content}, - "finish_reason": finish_reason, - "index": 0, - }], + "choices": [ + { + "delta": {"content": content}, + "finish_reason": finish_reason, + "index": 0, + } + ], } return { @@ -387,15 +379,14 @@ def get_data_openai( "rejected_prediction_tokens": 0, }, }, - "choices": [{ - "message": { - "role": "assistant", - "content": content - }, - "logprobs": None, - "finish_reason": finish_reason, - "index": 0, - }], + "choices": [ + { + "message": {"role": "assistant", "content": content}, + "logprobs": None, + "finish_reason": finish_reason, + "index": 0, + } + ], } @@ -431,6 +422,7 @@ def check_duplicate_ids(ids, id_type="item"): def verify_embedding_availability(embd_id: str, tenant_id: str) -> tuple[bool, Response | None]: from api.db.services.llm_service import LLMService from api.db.services.tenant_llm_service import TenantLLMService + """ Verifies availability of an embedding model for a specific tenant. @@ -469,11 +461,9 @@ def verify_embedding_availability(embd_id: str, tenant_id: str) -> tuple[bool, R in_llm_service = bool(LLMService.query(llm_name=llm_name, fid=llm_factory, model_type="embedding")) tenant_llms = TenantLLMService.get_my_llms(tenant_id=tenant_id) - is_tenant_model = any( - llm["llm_name"] == llm_name and llm["llm_factory"] == llm_factory and llm["model_type"] == "embedding" for - llm in tenant_llms) + is_tenant_model = any(llm["llm_name"] == llm_name and llm["llm_factory"] == llm_factory and llm["model_type"] == "embedding" for llm in tenant_llms) - is_builtin_model = llm_factory=='Builtin' + is_builtin_model = llm_factory == "Builtin" if not (is_builtin_model or is_tenant_model or in_llm_service): return False, get_error_argument_result(f"Unsupported model: <{embd_id}>") @@ -610,7 +600,6 @@ def get_mcp_tools(mcp_servers: list, timeout: float | int = 10) -> tuple[dict, s return {}, str(e) - async def is_strong_enough(chat_model, embedding_model): count = settings.STRONG_TEST_COUNT if not chat_model or not embedding_model: @@ -626,9 +615,7 @@ async def is_strong_enough(chat_model, embedding_model): _ = await trio.to_thread.run_sync(lambda: embedding_model.encode(["Are you strong enough!?"])) if chat_model: with trio.fail_after(30): - res = await trio.to_thread.run_sync(lambda: chat_model.chat("Nothing special.", [{"role": "user", - "content": "Are you strong enough!?"}], - {})) + res = await trio.to_thread.run_sync(lambda: chat_model.chat("Nothing special.", [{"role": "user", "content": "Are you strong enough!?"}], {})) if res.find("**ERROR**") >= 0: raise Exception(res) @@ -636,3 +623,11 @@ async def is_strong_enough(chat_model, embedding_model): async with trio.open_nursery() as nursery: for _ in range(count): nursery.start_soon(_is_strong_enough) + + +def get_allowed_llm_factories() -> list: + factories = LLMFactoriesService.get_all() + if settings.ALLOWED_LLM_FACTORIES is None: + return factories + + return [factory for factory in factories if factory.name in settings.ALLOWED_LLM_FACTORIES] diff --git a/docs/configurations.md b/docs/configurations.md index 76fe1a81f..7aba314b4 100644 --- a/docs/configurations.md +++ b/docs/configurations.md @@ -29,9 +29,9 @@ docker compose -f docker/docker-compose.yml up -d ## Docker Compose -- **docker-compose.yml** +- **docker-compose.yml** Sets up environment for RAGFlow and its dependencies. -- **docker-compose-base.yml** +- **docker-compose-base.yml** Sets up environment for RAGFlow's dependencies: Elasticsearch/[Infinity](https://github.com/infiniflow/infinity), MySQL, MinIO, and Redis. :::danger IMPORTANT @@ -44,97 +44,97 @@ The [.env](https://github.com/infiniflow/ragflow/blob/main/docker/.env) file con ### Elasticsearch -- `STACK_VERSION` +- `STACK_VERSION` The version of Elasticsearch. Defaults to `8.11.3` -- `ES_PORT` +- `ES_PORT` The port used to expose the Elasticsearch service to the host machine, allowing **external** access to the service running inside the Docker container. Defaults to `1200`. -- `ELASTIC_PASSWORD` +- `ELASTIC_PASSWORD` The password for Elasticsearch. ### Kibana -- `KIBANA_PORT` +- `KIBANA_PORT` The port used to expose the Kibana service to the host machine, allowing **external** access to the service running inside the Docker container. Defaults to `6601`. -- `KIBANA_USER` +- `KIBANA_USER` The username for Kibana. Defaults to `rag_flow`. -- `KIBANA_PASSWORD` +- `KIBANA_PASSWORD` The password for Kibana. Defaults to `infini_rag_flow`. ### Resource management -- `MEM_LIMIT` +- `MEM_LIMIT` The maximum amount of the memory, in bytes, that *a specific* Docker container can use while running. Defaults to `8073741824`. ### MySQL -- `MYSQL_PASSWORD` +- `MYSQL_PASSWORD` The password for MySQL. -- `MYSQL_PORT` +- `MYSQL_PORT` The port used to expose the MySQL service to the host machine, allowing **external** access to the MySQL database running inside the Docker container. Defaults to `5455`. ### MinIO RAGFlow utilizes MinIO as its object storage solution, leveraging its scalability to store and manage all uploaded files. -- `MINIO_CONSOLE_PORT` +- `MINIO_CONSOLE_PORT` The port used to expose the MinIO console interface to the host machine, allowing **external** access to the web-based console running inside the Docker container. Defaults to `9001` -- `MINIO_PORT` +- `MINIO_PORT` The port used to expose the MinIO API service to the host machine, allowing **external** access to the MinIO object storage service running inside the Docker container. Defaults to `9000`. -- `MINIO_USER` +- `MINIO_USER` The username for MinIO. -- `MINIO_PASSWORD` +- `MINIO_PASSWORD` The password for MinIO. ### Redis -- `REDIS_PORT` +- `REDIS_PORT` The port used to expose the Redis service to the host machine, allowing **external** access to the Redis service running inside the Docker container. Defaults to `6379`. -- `REDIS_PASSWORD` +- `REDIS_PASSWORD` The password for Redis. ### RAGFlow -- `SVR_HTTP_PORT` +- `SVR_HTTP_PORT` The port used to expose RAGFlow's HTTP API service to the host machine, allowing **external** access to the service running inside the Docker container. Defaults to `9380`. -- `RAGFLOW-IMAGE` - The Docker image edition. Available editions: - - - `infiniflow/ragflow:v0.21.1-slim` (default): The RAGFlow Docker image without embedding models. +- `RAGFLOW-IMAGE` + The Docker image edition. Available editions: + + - `infiniflow/ragflow:v0.21.1-slim` (default): The RAGFlow Docker image without embedding models. - `infiniflow/ragflow:v0.21.1`: The RAGFlow Docker image with embedding models including: - Built-in embedding models: - - `BAAI/bge-large-zh-v1.5` + - `BAAI/bge-large-zh-v1.5` - `maidalun1020/bce-embedding-base_v1` -:::tip NOTE -If you cannot download the RAGFlow Docker image, try the following mirrors. +:::tip NOTE +If you cannot download the RAGFlow Docker image, try the following mirrors. -- For the `nightly` edition: +- For the `nightly` edition: - `RAGFLOW_IMAGE=swr.cn-north-4.myhuaweicloud.com/infiniflow/ragflow:nightly` or, - `RAGFLOW_IMAGE=registry.cn-hangzhou.aliyuncs.com/infiniflow/ragflow:nightly`. ::: ### Embedding service -- `TEI_MODEL` +- `TEI_MODEL` The embedding model which text-embeddings-inference serves. Allowed values are one of `Qwen/Qwen3-Embedding-0.6B`(default), `BAAI/bge-m3`, and `BAAI/bge-small-en-v1.5`. -- `TEI_PORT` +- `TEI_PORT` The port used to expose the text-embeddings-inference service to the host machine, allowing **external** access to the text-embeddings-inference service running inside the Docker container. Defaults to `6380`. ### Timezone -- `TZ` +- `TZ` The local time zone. Defaults to `Asia/Shanghai`. ### Hugging Face mirror site -- `HF_ENDPOINT` +- `HF_ENDPOINT` The mirror site for huggingface.co. It is disabled by default. You can uncomment this line if you have limited access to the primary Hugging Face domain. ### MacOS -- `MACOS` +- `MACOS` Optimizations for macOS. It is disabled by default. You can uncomment this line if your OS is macOS. ### User registration @@ -153,7 +153,7 @@ If you cannot download the RAGFlow Docker image, try the following mirrors. - `port`: The API server's serving port inside the Docker container. Defaults to `9380`. ### `mysql` - + - `name`: The MySQL database name. Defaults to `rag_flow`. - `user`: The username for MySQL. - `password`: The password for MySQL. @@ -162,12 +162,12 @@ If you cannot download the RAGFlow Docker image, try the following mirrors. - `stale_timeout`: Timeout in seconds. ### `minio` - + - `user`: The username for MinIO. - `password`: The password for MinIO. - `host`: The MinIO serving IP *and* port inside the Docker container. Defaults to `minio:9000`. -### `oauth` +### `oauth` The OAuth configuration for signing up or signing in to RAGFlow using a third-party account. @@ -184,7 +184,7 @@ The OAuth configuration for signing up or signing in to RAGFlow using a third-pa - `scope`: Requested permission scope, a space-separated string. For example, `openid profile email`. - `redirect_uri`: Required, URI to which the authorization server redirects during the authentication flow to return results. Must match the callback URI registered with the authentication server. Format: `https://your-app.com/v1/user/oauth/callback/`. For local configuration, you can directly use `http://127.0.0.1:80/v1/user/oauth/callback/`. -:::tip NOTE +:::tip NOTE The following are best practices for configuring various third-party authentication methods. You can configure one or multiple third-party authentication methods for Ragflow: ```yaml oauth: @@ -216,9 +216,9 @@ oauth: ``` ::: -### `user_default_llm` +### `user_default_llm` -The default LLM to use for a new RAGFlow user. It is disabled by default. To enable this feature, uncomment the corresponding lines in **service_conf.yaml.template**. +The default LLM to use for a new RAGFlow user. It is disabled by default. To enable this feature, uncomment the corresponding lines in **service_conf.yaml.template**. - `factory`: The LLM supplier. Available options: - `"OpenAI"` @@ -228,7 +228,11 @@ The default LLM to use for a new RAGFlow user. It is disabled by default. To ena - `"VolcEngine"` - `"ZHIPU-AI"` - `api_key`: The API key for the specified LLM. You will need to apply for your model API key online. +- `allowed_factories`: If this is set, the users will be allowed to add only the factories in this list. + - `"OpenAI"` + - `"DeepSeek"` + - `"Moonshot"` -:::tip NOTE +:::tip NOTE If you do not set the default LLM here, configure the default LLM on the **Settings** page in the RAGFlow UI. -::: \ No newline at end of file +:::