mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
feat: add allowed factories variable to allow admins to restrict llms users can add (#11003)
### What problem does this PR solve? Currently, if we want to restrict the allowed factories users can use we need to delete from the database table manually. The proposal of this PR is to include a variable to that, if set, will restrict the LLM factories the users can see and add. This allow us to not touch the llm_factories.json or the database if the LLM factory is already inserted. Obs.: All the lint changes were from the pre-commit hook which I did not change. ### Type of change - [X] New Feature (non-breaking change which adds functionality)
This commit is contained in:
committed by
GitHub
parent
bab3fce136
commit
3654ae61c1
@ -23,16 +23,16 @@ from api.db.services.llm_service import LLMService
|
||||
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
||||
from common.constants import StatusEnum, LLMType
|
||||
from api.db.db_models import TenantLLM
|
||||
from api.utils.api_utils import get_json_result
|
||||
from api.utils.api_utils import get_json_result, get_allowed_llm_factories
|
||||
from common.base64_image import test_image
|
||||
from rag.llm import EmbeddingModel, ChatModel, RerankModel, CvModel, TTSModel
|
||||
|
||||
|
||||
@manager.route('/factories', methods=['GET']) # noqa: F821
|
||||
@manager.route("/factories", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def factories():
|
||||
try:
|
||||
fac = LLMFactoriesService.get_all()
|
||||
fac = get_allowed_llm_factories()
|
||||
fac = [f.to_dict() for f in fac if f.name not in ["Youdao", "FastEmbed", "BAAI"]]
|
||||
llms = LLMService.get_all()
|
||||
mdl_types = {}
|
||||
@ -43,14 +43,13 @@ def factories():
|
||||
mdl_types[m.fid] = set([])
|
||||
mdl_types[m.fid].add(m.model_type)
|
||||
for f in fac:
|
||||
f["model_types"] = list(mdl_types.get(f["name"], [LLMType.CHAT, LLMType.EMBEDDING, LLMType.RERANK,
|
||||
LLMType.IMAGE2TEXT, LLMType.SPEECH2TEXT, LLMType.TTS]))
|
||||
f["model_types"] = list(mdl_types.get(f["name"], [LLMType.CHAT, LLMType.EMBEDDING, LLMType.RERANK, LLMType.IMAGE2TEXT, LLMType.SPEECH2TEXT, LLMType.TTS]))
|
||||
return get_json_result(data=fac)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/set_api_key', methods=['POST']) # noqa: F821
|
||||
@manager.route("/set_api_key", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("llm_factory", "api_key")
|
||||
def set_api_key():
|
||||
@ -63,8 +62,7 @@ def set_api_key():
|
||||
for llm in LLMService.query(fid=factory):
|
||||
if not embd_passed and llm.model_type == LLMType.EMBEDDING.value:
|
||||
assert factory in EmbeddingModel, f"Embedding model from {factory} is not supported yet."
|
||||
mdl = EmbeddingModel[factory](
|
||||
req["api_key"], llm.llm_name, base_url=req.get("base_url"))
|
||||
mdl = EmbeddingModel[factory](req["api_key"], llm.llm_name, base_url=req.get("base_url"))
|
||||
try:
|
||||
arr, tc = mdl.encode(["Test if the api key is available"])
|
||||
if len(arr[0]) == 0:
|
||||
@ -74,52 +72,40 @@ def set_api_key():
|
||||
msg += f"\nFail to access embedding model({llm.llm_name}) using this api key." + str(e)
|
||||
elif not chat_passed and llm.model_type == LLMType.CHAT.value:
|
||||
assert factory in ChatModel, f"Chat model from {factory} is not supported yet."
|
||||
mdl = ChatModel[factory](
|
||||
req["api_key"], llm.llm_name, base_url=req.get("base_url"), **extra)
|
||||
mdl = ChatModel[factory](req["api_key"], llm.llm_name, base_url=req.get("base_url"), **extra)
|
||||
try:
|
||||
m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}],
|
||||
{"temperature": 0.9, 'max_tokens': 50})
|
||||
m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9, "max_tokens": 50})
|
||||
if m.find("**ERROR**") >= 0:
|
||||
raise Exception(m)
|
||||
chat_passed = True
|
||||
except Exception as e:
|
||||
msg += f"\nFail to access model({llm.fid}/{llm.llm_name}) using this api key." + str(
|
||||
e)
|
||||
msg += f"\nFail to access model({llm.fid}/{llm.llm_name}) using this api key." + str(e)
|
||||
elif not rerank_passed and llm.model_type == LLMType.RERANK:
|
||||
assert factory in RerankModel, f"Re-rank model from {factory} is not supported yet."
|
||||
mdl = RerankModel[factory](
|
||||
req["api_key"], llm.llm_name, base_url=req.get("base_url"))
|
||||
mdl = RerankModel[factory](req["api_key"], llm.llm_name, base_url=req.get("base_url"))
|
||||
try:
|
||||
arr, tc = mdl.similarity("What's the weather?", ["Is it sunny today?"])
|
||||
if len(arr) == 0 or tc == 0:
|
||||
raise Exception("Fail")
|
||||
rerank_passed = True
|
||||
logging.debug(f'passed model rerank {llm.llm_name}')
|
||||
logging.debug(f"passed model rerank {llm.llm_name}")
|
||||
except Exception as e:
|
||||
msg += f"\nFail to access model({llm.fid}/{llm.llm_name}) using this api key." + str(
|
||||
e)
|
||||
msg += f"\nFail to access model({llm.fid}/{llm.llm_name}) using this api key." + str(e)
|
||||
if any([embd_passed, chat_passed, rerank_passed]):
|
||||
msg = ''
|
||||
msg = ""
|
||||
break
|
||||
|
||||
if msg:
|
||||
return get_data_error_result(message=msg)
|
||||
|
||||
llm_config = {
|
||||
"api_key": req["api_key"],
|
||||
"api_base": req.get("base_url", "")
|
||||
}
|
||||
llm_config = {"api_key": req["api_key"], "api_base": req.get("base_url", "")}
|
||||
for n in ["model_type", "llm_name"]:
|
||||
if n in req:
|
||||
llm_config[n] = req[n]
|
||||
|
||||
for llm in LLMService.query(fid=factory):
|
||||
llm_config["max_tokens"]=llm.max_tokens
|
||||
if not TenantLLMService.filter_update(
|
||||
[TenantLLM.tenant_id == current_user.id,
|
||||
TenantLLM.llm_factory == factory,
|
||||
TenantLLM.llm_name == llm.llm_name],
|
||||
llm_config):
|
||||
llm_config["max_tokens"] = llm.max_tokens
|
||||
if not TenantLLMService.filter_update([TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == factory, TenantLLM.llm_name == llm.llm_name], llm_config):
|
||||
TenantLLMService.save(
|
||||
tenant_id=current_user.id,
|
||||
llm_factory=factory,
|
||||
@ -127,13 +113,13 @@ def set_api_key():
|
||||
model_type=llm.model_type,
|
||||
api_key=llm_config["api_key"],
|
||||
api_base=llm_config["api_base"],
|
||||
max_tokens=llm_config["max_tokens"]
|
||||
max_tokens=llm_config["max_tokens"],
|
||||
)
|
||||
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
@manager.route('/add_llm', methods=['POST']) # noqa: F821
|
||||
@manager.route("/add_llm", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("llm_factory")
|
||||
def add_llm():
|
||||
@ -142,6 +128,9 @@ def add_llm():
|
||||
api_key = req.get("api_key", "x")
|
||||
llm_name = req.get("llm_name")
|
||||
|
||||
if factory not in get_allowed_llm_factories():
|
||||
return get_data_error_result(message=f"LLM factory {factory} is not allowed")
|
||||
|
||||
def apikey_json(keys):
|
||||
nonlocal req
|
||||
return json.dumps({k: req.get(k, "") for k in keys})
|
||||
@ -204,7 +193,7 @@ def add_llm():
|
||||
"llm_name": llm_name,
|
||||
"api_base": req.get("api_base", ""),
|
||||
"api_key": api_key,
|
||||
"max_tokens": req.get("max_tokens")
|
||||
"max_tokens": req.get("max_tokens"),
|
||||
}
|
||||
|
||||
msg = ""
|
||||
@ -212,10 +201,7 @@ def add_llm():
|
||||
extra = {"provider": factory}
|
||||
if llm["model_type"] == LLMType.EMBEDDING.value:
|
||||
assert factory in EmbeddingModel, f"Embedding model from {factory} is not supported yet."
|
||||
mdl = EmbeddingModel[factory](
|
||||
key=llm['api_key'],
|
||||
model_name=mdl_nm,
|
||||
base_url=llm["api_base"])
|
||||
mdl = EmbeddingModel[factory](key=llm["api_key"], model_name=mdl_nm, base_url=llm["api_base"])
|
||||
try:
|
||||
arr, tc = mdl.encode(["Test if the api key is available"])
|
||||
if len(arr[0]) == 0:
|
||||
@ -225,42 +211,31 @@ def add_llm():
|
||||
elif llm["model_type"] == LLMType.CHAT.value:
|
||||
assert factory in ChatModel, f"Chat model from {factory} is not supported yet."
|
||||
mdl = ChatModel[factory](
|
||||
key=llm['api_key'],
|
||||
key=llm["api_key"],
|
||||
model_name=mdl_nm,
|
||||
base_url=llm["api_base"],
|
||||
**extra,
|
||||
)
|
||||
try:
|
||||
m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {
|
||||
"temperature": 0.9})
|
||||
m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9})
|
||||
if not tc and m.find("**ERROR**:") >= 0:
|
||||
raise Exception(m)
|
||||
except Exception as e:
|
||||
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(
|
||||
e)
|
||||
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
|
||||
elif llm["model_type"] == LLMType.RERANK:
|
||||
assert factory in RerankModel, f"RE-rank model from {factory} is not supported yet."
|
||||
try:
|
||||
mdl = RerankModel[factory](
|
||||
key=llm["api_key"],
|
||||
model_name=mdl_nm,
|
||||
base_url=llm["api_base"]
|
||||
)
|
||||
mdl = RerankModel[factory](key=llm["api_key"], model_name=mdl_nm, base_url=llm["api_base"])
|
||||
arr, tc = mdl.similarity("Hello~ RAGFlower!", ["Hi, there!", "Ohh, my friend!"])
|
||||
if len(arr) == 0:
|
||||
raise Exception("Not known.")
|
||||
except KeyError:
|
||||
msg += f"{factory} dose not support this model({factory}/{mdl_nm})"
|
||||
except Exception as e:
|
||||
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(
|
||||
e)
|
||||
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
|
||||
elif llm["model_type"] == LLMType.IMAGE2TEXT.value:
|
||||
assert factory in CvModel, f"Image to text model from {factory} is not supported yet."
|
||||
mdl = CvModel[factory](
|
||||
key=llm["api_key"],
|
||||
model_name=mdl_nm,
|
||||
base_url=llm["api_base"]
|
||||
)
|
||||
mdl = CvModel[factory](key=llm["api_key"], model_name=mdl_nm, base_url=llm["api_base"])
|
||||
try:
|
||||
image_data = test_image
|
||||
m, tc = mdl.describe(image_data)
|
||||
@ -270,9 +245,7 @@ def add_llm():
|
||||
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
|
||||
elif llm["model_type"] == LLMType.TTS:
|
||||
assert factory in TTSModel, f"TTS model from {factory} is not supported yet."
|
||||
mdl = TTSModel[factory](
|
||||
key=llm["api_key"], model_name=mdl_nm, base_url=llm["api_base"]
|
||||
)
|
||||
mdl = TTSModel[factory](key=llm["api_key"], model_name=mdl_nm, base_url=llm["api_base"])
|
||||
try:
|
||||
for resp in mdl.tts("Hello~ RAGFlower!"):
|
||||
pass
|
||||
@ -285,51 +258,46 @@ def add_llm():
|
||||
if msg:
|
||||
return get_data_error_result(message=msg)
|
||||
|
||||
if not TenantLLMService.filter_update(
|
||||
[TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == factory,
|
||||
TenantLLM.llm_name == llm["llm_name"]], llm):
|
||||
if not TenantLLMService.filter_update([TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == factory, TenantLLM.llm_name == llm["llm_name"]], llm):
|
||||
TenantLLMService.save(**llm)
|
||||
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
@manager.route('/delete_llm', methods=['POST']) # noqa: F821
|
||||
@manager.route("/delete_llm", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("llm_factory", "llm_name")
|
||||
def delete_llm():
|
||||
req = request.json
|
||||
TenantLLMService.filter_delete(
|
||||
[TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"],
|
||||
TenantLLM.llm_name == req["llm_name"]])
|
||||
TenantLLMService.filter_delete([TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"], TenantLLM.llm_name == req["llm_name"]])
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
@manager.route('/enable_llm', methods=['POST']) # noqa: F821
|
||||
@manager.route("/enable_llm", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("llm_factory", "llm_name")
|
||||
def enable_llm():
|
||||
req = request.json
|
||||
TenantLLMService.filter_update(
|
||||
[TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"],
|
||||
TenantLLM.llm_name == req["llm_name"]], {"status": str(req.get("status", "1"))})
|
||||
[TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"], TenantLLM.llm_name == req["llm_name"]], {"status": str(req.get("status", "1"))}
|
||||
)
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
@manager.route('/delete_factory', methods=['POST']) # noqa: F821
|
||||
@manager.route("/delete_factory", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("llm_factory")
|
||||
def delete_factory():
|
||||
req = request.json
|
||||
TenantLLMService.filter_delete(
|
||||
[TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"]])
|
||||
TenantLLMService.filter_delete([TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"]])
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
@manager.route('/my_llms', methods=['GET']) # noqa: F821
|
||||
@manager.route("/my_llms", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def my_llms():
|
||||
try:
|
||||
include_details = request.args.get('include_details', 'false').lower() == 'true'
|
||||
include_details = request.args.get("include_details", "false").lower() == "true"
|
||||
|
||||
if include_details:
|
||||
res = {}
|
||||
@ -345,40 +313,31 @@ def my_llms():
|
||||
break
|
||||
|
||||
if o_dict["llm_factory"] not in res:
|
||||
res[o_dict["llm_factory"]] = {
|
||||
"tags": factory_tags,
|
||||
"llm": []
|
||||
}
|
||||
res[o_dict["llm_factory"]] = {"tags": factory_tags, "llm": []}
|
||||
|
||||
res[o_dict["llm_factory"]]["llm"].append({
|
||||
"type": o_dict["model_type"],
|
||||
"name": o_dict["llm_name"],
|
||||
"used_token": o_dict["used_tokens"],
|
||||
"api_base": o_dict["api_base"] or "",
|
||||
"max_tokens": o_dict["max_tokens"] or 8192,
|
||||
"status": o_dict["status"] or "1"
|
||||
})
|
||||
res[o_dict["llm_factory"]]["llm"].append(
|
||||
{
|
||||
"type": o_dict["model_type"],
|
||||
"name": o_dict["llm_name"],
|
||||
"used_token": o_dict["used_tokens"],
|
||||
"api_base": o_dict["api_base"] or "",
|
||||
"max_tokens": o_dict["max_tokens"] or 8192,
|
||||
"status": o_dict["status"] or "1",
|
||||
}
|
||||
)
|
||||
else:
|
||||
res = {}
|
||||
for o in TenantLLMService.get_my_llms(current_user.id):
|
||||
if o["llm_factory"] not in res:
|
||||
res[o["llm_factory"]] = {
|
||||
"tags": o["tags"],
|
||||
"llm": []
|
||||
}
|
||||
res[o["llm_factory"]]["llm"].append({
|
||||
"type": o["model_type"],
|
||||
"name": o["llm_name"],
|
||||
"used_token": o["used_tokens"],
|
||||
"status": o["status"]
|
||||
})
|
||||
res[o["llm_factory"]] = {"tags": o["tags"], "llm": []}
|
||||
res[o["llm_factory"]]["llm"].append({"type": o["model_type"], "name": o["llm_name"], "used_token": o["used_tokens"], "status": o["status"]})
|
||||
|
||||
return get_json_result(data=res)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/list', methods=['GET']) # noqa: F821
|
||||
@manager.route("/list", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def list_app():
|
||||
self_deployed = ["FastEmbed", "Ollama", "Xinference", "LocalAI", "LM-Studio", "GPUStack"]
|
||||
@ -386,14 +345,13 @@ def list_app():
|
||||
model_type = request.args.get("model_type")
|
||||
try:
|
||||
objs = TenantLLMService.query(tenant_id=current_user.id)
|
||||
facts = set([o.to_dict()["llm_factory"] for o in objs if o.api_key and o.status==StatusEnum.VALID.value])
|
||||
facts = set([o.to_dict()["llm_factory"] for o in objs if o.api_key and o.status == StatusEnum.VALID.value])
|
||||
status = {(o.llm_name + "@" + o.llm_factory) for o in objs if o.status == StatusEnum.VALID.value}
|
||||
llms = LLMService.get_all()
|
||||
llms = [m.to_dict()
|
||||
for m in llms if m.status == StatusEnum.VALID.value and m.fid not in weighted and (m.llm_name + "@" + m.fid) in status]
|
||||
llms = [m.to_dict() for m in llms if m.status == StatusEnum.VALID.value and m.fid not in weighted and (m.llm_name + "@" + m.fid) in status]
|
||||
for m in llms:
|
||||
m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" or m["fid"] in self_deployed
|
||||
if "tei-" in os.getenv("COMPOSE_PROFILES", "") and m["model_type"]==LLMType.EMBEDDING and m["fid"]=="Builtin" and m["llm_name"]==os.getenv('TEI_MODEL', ''):
|
||||
if "tei-" in os.getenv("COMPOSE_PROFILES", "") and m["model_type"] == LLMType.EMBEDDING and m["fid"] == "Builtin" and m["llm_name"] == os.getenv("TEI_MODEL", ""):
|
||||
m["available"] = True
|
||||
|
||||
llm_set = set([m["llm_name"] + "@" + m["fid"] for m in llms])
|
||||
|
||||
Reference in New Issue
Block a user