mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
feat: add allowed factories variable to allow admins to restrict llms users can add (#11003)
### What problem does this PR solve? Currently, if we want to restrict the allowed factories users can use we need to delete from the database table manually. The proposal of this PR is to include a variable to that, if set, will restrict the LLM factories the users can see and add. This allow us to not touch the llm_factories.json or the database if the LLM factory is already inserted. Obs.: All the lint changes were from the pre-commit hook which I did not change. ### Type of change - [X] New Feature (non-breaking change which adds functionality)
This commit is contained in:
committed by
GitHub
parent
bab3fce136
commit
3654ae61c1
@ -39,6 +39,7 @@ from common.constants import ActiveEnum
|
||||
from api.db.db_models import APIToken
|
||||
from api.utils.json_encode import CustomJSONEncoder
|
||||
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
|
||||
from api.db.services.tenant_llm_service import LLMFactoriesService
|
||||
from common.connection_utils import timeout
|
||||
from common.constants import RetCode
|
||||
|
||||
@ -51,16 +52,15 @@ def serialize_for_json(obj):
|
||||
Recursively serialize objects to make them JSON serializable.
|
||||
Handles ModelMetaclass and other non-serializable objects.
|
||||
"""
|
||||
if hasattr(obj, '__dict__'):
|
||||
if hasattr(obj, "__dict__"):
|
||||
# For objects with __dict__, try to serialize their attributes
|
||||
try:
|
||||
return {key: serialize_for_json(value) for key, value in obj.__dict__.items()
|
||||
if not key.startswith('_')}
|
||||
return {key: serialize_for_json(value) for key, value in obj.__dict__.items() if not key.startswith("_")}
|
||||
except (AttributeError, TypeError):
|
||||
return str(obj)
|
||||
elif hasattr(obj, '__name__'):
|
||||
elif hasattr(obj, "__name__"):
|
||||
# For classes and metaclasses, return their name
|
||||
return f"<{obj.__module__}.{obj.__name__}>" if hasattr(obj, '__module__') else f"<{obj.__name__}>"
|
||||
return f"<{obj.__module__}.{obj.__name__}>" if hasattr(obj, "__module__") else f"<{obj.__name__}>"
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
return [serialize_for_json(item) for item in obj]
|
||||
elif isinstance(obj, dict):
|
||||
@ -71,6 +71,7 @@ def serialize_for_json(obj):
|
||||
# Fallback: convert to string representation
|
||||
return str(obj)
|
||||
|
||||
|
||||
def get_data_error_result(code=RetCode.DATA_ERROR, message="Sorry! Data missing!"):
|
||||
logging.exception(Exception(message))
|
||||
result_dict = {"code": code, "message": message}
|
||||
@ -99,8 +100,7 @@ def server_error_response(e):
|
||||
except Exception:
|
||||
return get_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=None)
|
||||
if repr(e).find("index_not_found_exception") >= 0:
|
||||
return get_json_result(code=RetCode.EXCEPTION_ERROR,
|
||||
message="No chunk found, please upload file and parse it.")
|
||||
return get_json_result(code=RetCode.EXCEPTION_ERROR, message="No chunk found, please upload file and parse it.")
|
||||
|
||||
return get_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e))
|
||||
|
||||
@ -129,8 +129,7 @@ def validate_request(*args, **kwargs):
|
||||
if no_arguments:
|
||||
error_string += "required argument are missing: {}; ".format(",".join(no_arguments))
|
||||
if error_arguments:
|
||||
error_string += "required argument values: {}".format(
|
||||
",".join(["{}={}".format(a[0], a[1]) for a in error_arguments]))
|
||||
error_string += "required argument values: {}".format(",".join(["{}={}".format(a[0], a[1]) for a in error_arguments]))
|
||||
return get_json_result(code=RetCode.ARGUMENT_ERROR, message=error_string)
|
||||
return func(*_args, **_kwargs)
|
||||
|
||||
@ -145,8 +144,7 @@ def not_allowed_parameters(*params):
|
||||
input_arguments = flask_request.json or flask_request.form.to_dict()
|
||||
for param in params:
|
||||
if param in input_arguments:
|
||||
return get_json_result(code=RetCode.ARGUMENT_ERROR,
|
||||
message=f"Parameter {param} isn't allowed")
|
||||
return get_json_result(code=RetCode.ARGUMENT_ERROR, message=f"Parameter {param} isn't allowed")
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
@ -158,6 +156,7 @@ def active_required(f):
|
||||
@wraps(f)
|
||||
def wrapper(*args, **kwargs):
|
||||
from api.db.services import UserService
|
||||
|
||||
user_id = current_user.id
|
||||
usr = UserService.filter_by_id(user_id)
|
||||
# check is_active
|
||||
@ -199,6 +198,7 @@ def construct_json_result(code: RetCode = RetCode.SUCCESS, message="success", da
|
||||
else:
|
||||
return jsonify({"code": code, "message": message, "data": data})
|
||||
|
||||
|
||||
def token_required(func):
|
||||
@wraps(func)
|
||||
def decorated_function(*args, **kwargs):
|
||||
@ -213,8 +213,7 @@ def token_required(func):
|
||||
token = authorization_list[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(data=False, message="Authentication error: API key is invalid!",
|
||||
code=RetCode.AUTHENTICATION_ERROR)
|
||||
return get_json_result(data=False, message="Authentication error: API key is invalid!", code=RetCode.AUTHENTICATION_ERROR)
|
||||
kwargs["tenant_id"] = objs[0].tenant_id
|
||||
return func(*args, **kwargs)
|
||||
|
||||
@ -243,9 +242,10 @@ def get_result(code=RetCode.SUCCESS, message="", data=None, total=None):
|
||||
|
||||
return jsonify(response)
|
||||
|
||||
|
||||
def get_error_data_result(
|
||||
message="Sorry! Data missing!",
|
||||
code=RetCode.DATA_ERROR,
|
||||
message="Sorry! Data missing!",
|
||||
code=RetCode.DATA_ERROR,
|
||||
):
|
||||
result_dict = {"code": code, "message": message}
|
||||
response = {}
|
||||
@ -271,6 +271,7 @@ def get_error_operating_result(message="Operating error"):
|
||||
|
||||
def generate_confirmation_token():
|
||||
import secrets
|
||||
|
||||
return "ragflow-" + secrets.token_urlsafe(32)
|
||||
|
||||
|
||||
@ -345,18 +346,7 @@ def get_parser_config(chunk_method, parser_config):
|
||||
return merged_config
|
||||
|
||||
|
||||
def get_data_openai(
|
||||
id=None,
|
||||
created=None,
|
||||
model=None,
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
content=None,
|
||||
finish_reason=None,
|
||||
object="chat.completion",
|
||||
param=None,
|
||||
stream=False
|
||||
):
|
||||
def get_data_openai(id=None, created=None, model=None, prompt_tokens=0, completion_tokens=0, content=None, finish_reason=None, object="chat.completion", param=None, stream=False):
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
|
||||
if stream:
|
||||
@ -364,11 +354,13 @@ def get_data_openai(
|
||||
"id": f"{id}",
|
||||
"object": "chat.completion.chunk",
|
||||
"model": model,
|
||||
"choices": [{
|
||||
"delta": {"content": content},
|
||||
"finish_reason": finish_reason,
|
||||
"index": 0,
|
||||
}],
|
||||
"choices": [
|
||||
{
|
||||
"delta": {"content": content},
|
||||
"finish_reason": finish_reason,
|
||||
"index": 0,
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
return {
|
||||
@ -387,15 +379,14 @@ def get_data_openai(
|
||||
"rejected_prediction_tokens": 0,
|
||||
},
|
||||
},
|
||||
"choices": [{
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": content
|
||||
},
|
||||
"logprobs": None,
|
||||
"finish_reason": finish_reason,
|
||||
"index": 0,
|
||||
}],
|
||||
"choices": [
|
||||
{
|
||||
"message": {"role": "assistant", "content": content},
|
||||
"logprobs": None,
|
||||
"finish_reason": finish_reason,
|
||||
"index": 0,
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@ -431,6 +422,7 @@ def check_duplicate_ids(ids, id_type="item"):
|
||||
def verify_embedding_availability(embd_id: str, tenant_id: str) -> tuple[bool, Response | None]:
|
||||
from api.db.services.llm_service import LLMService
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
|
||||
"""
|
||||
Verifies availability of an embedding model for a specific tenant.
|
||||
|
||||
@ -469,11 +461,9 @@ def verify_embedding_availability(embd_id: str, tenant_id: str) -> tuple[bool, R
|
||||
in_llm_service = bool(LLMService.query(llm_name=llm_name, fid=llm_factory, model_type="embedding"))
|
||||
|
||||
tenant_llms = TenantLLMService.get_my_llms(tenant_id=tenant_id)
|
||||
is_tenant_model = any(
|
||||
llm["llm_name"] == llm_name and llm["llm_factory"] == llm_factory and llm["model_type"] == "embedding" for
|
||||
llm in tenant_llms)
|
||||
is_tenant_model = any(llm["llm_name"] == llm_name and llm["llm_factory"] == llm_factory and llm["model_type"] == "embedding" for llm in tenant_llms)
|
||||
|
||||
is_builtin_model = llm_factory=='Builtin'
|
||||
is_builtin_model = llm_factory == "Builtin"
|
||||
if not (is_builtin_model or is_tenant_model or in_llm_service):
|
||||
return False, get_error_argument_result(f"Unsupported model: <{embd_id}>")
|
||||
|
||||
@ -610,7 +600,6 @@ def get_mcp_tools(mcp_servers: list, timeout: float | int = 10) -> tuple[dict, s
|
||||
return {}, str(e)
|
||||
|
||||
|
||||
|
||||
async def is_strong_enough(chat_model, embedding_model):
|
||||
count = settings.STRONG_TEST_COUNT
|
||||
if not chat_model or not embedding_model:
|
||||
@ -626,9 +615,7 @@ async def is_strong_enough(chat_model, embedding_model):
|
||||
_ = await trio.to_thread.run_sync(lambda: embedding_model.encode(["Are you strong enough!?"]))
|
||||
if chat_model:
|
||||
with trio.fail_after(30):
|
||||
res = await trio.to_thread.run_sync(lambda: chat_model.chat("Nothing special.", [{"role": "user",
|
||||
"content": "Are you strong enough!?"}],
|
||||
{}))
|
||||
res = await trio.to_thread.run_sync(lambda: chat_model.chat("Nothing special.", [{"role": "user", "content": "Are you strong enough!?"}], {}))
|
||||
if res.find("**ERROR**") >= 0:
|
||||
raise Exception(res)
|
||||
|
||||
@ -636,3 +623,11 @@ async def is_strong_enough(chat_model, embedding_model):
|
||||
async with trio.open_nursery() as nursery:
|
||||
for _ in range(count):
|
||||
nursery.start_soon(_is_strong_enough)
|
||||
|
||||
|
||||
def get_allowed_llm_factories() -> list:
|
||||
factories = LLMFactoriesService.get_all()
|
||||
if settings.ALLOWED_LLM_FACTORIES is None:
|
||||
return factories
|
||||
|
||||
return [factory for factory in factories if factory.name in settings.ALLOWED_LLM_FACTORIES]
|
||||
|
||||
Reference in New Issue
Block a user