mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 12:32:30 +08:00
Feat/admin service (#10233)
### What problem does this PR solve? - Admin client support show user and create user command. - Admin client support alter user password and active status. - Admin client support list user datasets. issue: #10241 ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -1,5 +1,7 @@
|
||||
import argparse
|
||||
import base64
|
||||
from Cryptodome.PublicKey import RSA
|
||||
from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5
|
||||
from typing import Dict, List, Any
|
||||
from lark import Lark, Transformer, Tree
|
||||
import requests
|
||||
@ -19,6 +21,8 @@ sql_command: list_services
|
||||
| show_user
|
||||
| drop_user
|
||||
| alter_user
|
||||
| create_user
|
||||
| activate_user
|
||||
| list_datasets
|
||||
| list_agents
|
||||
|
||||
@ -35,6 +39,7 @@ meta_arg: /[^\\s"']+/ | quoted_string
|
||||
LIST: "LIST"i
|
||||
SERVICES: "SERVICES"i
|
||||
SHOW: "SHOW"i
|
||||
CREATE: "CREATE"i
|
||||
SERVICE: "SERVICE"i
|
||||
SHUTDOWN: "SHUTDOWN"i
|
||||
STARTUP: "STARTUP"i
|
||||
@ -43,6 +48,7 @@ USERS: "USERS"i
|
||||
DROP: "DROP"i
|
||||
USER: "USER"i
|
||||
ALTER: "ALTER"i
|
||||
ACTIVE: "ACTIVE"i
|
||||
PASSWORD: "PASSWORD"i
|
||||
DATASETS: "DATASETS"i
|
||||
OF: "OF"i
|
||||
@ -58,12 +64,15 @@ list_users: LIST USERS ";"
|
||||
drop_user: DROP USER quoted_string ";"
|
||||
alter_user: ALTER USER PASSWORD quoted_string quoted_string ";"
|
||||
show_user: SHOW USER quoted_string ";"
|
||||
create_user: CREATE USER quoted_string quoted_string ";"
|
||||
activate_user: ALTER USER ACTIVE quoted_string status ";"
|
||||
|
||||
list_datasets: LIST DATASETS OF quoted_string ";"
|
||||
list_agents: LIST AGENTS OF quoted_string ";"
|
||||
|
||||
identifier: WORD
|
||||
quoted_string: QUOTED_STRING
|
||||
status: WORD
|
||||
|
||||
QUOTED_STRING: /'[^']+'/ | /"[^"]+"/
|
||||
WORD: /[a-zA-Z0-9_\-\.]+/
|
||||
@ -118,6 +127,16 @@ class AdminTransformer(Transformer):
|
||||
new_password = items[4]
|
||||
return {"type": "alter_user", "username": user_name, "password": new_password}
|
||||
|
||||
def create_user(self, items):
|
||||
user_name = items[2]
|
||||
password = items[3]
|
||||
return {"type": "create_user", "username": user_name, "password": password, "role": "user"}
|
||||
|
||||
def activate_user(self, items):
|
||||
user_name = items[3]
|
||||
activate_status = items[4]
|
||||
return {"type": "activate_user", "activate_status": activate_status, "username": user_name}
|
||||
|
||||
def list_datasets(self, items):
|
||||
user_name = items[3]
|
||||
return {"type": "list_datasets", "username": user_name}
|
||||
@ -152,6 +171,14 @@ def encode_to_base64(input_string):
|
||||
return base64_encoded.decode('utf-8')
|
||||
|
||||
|
||||
def encrypt(input_string):
|
||||
pub = '-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEArq9XTUSeYr2+N1h3Afl/z8Dse/2yD0ZGrKwx+EEEcdsBLca9Ynmx3nIB5obmLlSfmskLpBo0UACBmB5rEjBp2Q2f3AG3Hjd4B+gNCG6BDaawuDlgANIhGnaTLrIqWrrcm4EMzJOnAOI1fgzJRsOOUEfaS318Eq9OVO3apEyCCt0lOQK6PuksduOjVxtltDav+guVAA068NrPYmRNabVKRNLJpL8w4D44sfth5RvZ3q9t+6RTArpEtc5sh5ChzvqPOzKGMXW83C95TxmXqpbK6olN4RevSfVjEAgCydH6HN6OhtOQEcnrU97r9H0iZOWwbw3pVrZiUkuRD1R56Wzs2wIDAQAB\n-----END PUBLIC KEY-----'
|
||||
pub_key = RSA.importKey(pub)
|
||||
cipher = Cipher_pkcs1_v1_5.new(pub_key)
|
||||
cipher_text = cipher.encrypt(base64.b64encode(input_string.encode('utf-8')))
|
||||
return base64.b64encode(cipher_text).decode("utf-8")
|
||||
|
||||
|
||||
class AdminCommandParser:
|
||||
def __init__(self):
|
||||
self.parser = Lark(GRAMMAR, start='start', parser='lalr', transformer=AdminTransformer())
|
||||
@ -220,6 +247,9 @@ class AdminCLI:
|
||||
if not data:
|
||||
print("No data to print")
|
||||
return
|
||||
if isinstance(data, dict):
|
||||
# handle single row data
|
||||
data = [data]
|
||||
|
||||
columns = list(data[0].keys())
|
||||
col_widths = {}
|
||||
@ -335,6 +365,10 @@ class AdminCLI:
|
||||
self._handle_drop_user(command_dict)
|
||||
case 'alter_user':
|
||||
self._handle_alter_user(command_dict)
|
||||
case 'create_user':
|
||||
self._handle_create_user(command_dict)
|
||||
case 'activate_user':
|
||||
self._handle_activate_user(command_dict)
|
||||
case 'list_datasets':
|
||||
self._handle_list_datasets(command_dict)
|
||||
case 'list_agents':
|
||||
@ -349,9 +383,8 @@ class AdminCLI:
|
||||
|
||||
url = f'http://{self.host}:{self.port}/api/v1/admin/services'
|
||||
response = requests.get(url, auth=HTTPBasicAuth(self.admin_account, self.admin_password))
|
||||
res_json = dict
|
||||
res_json = response.json()
|
||||
if response.status_code == 200:
|
||||
res_json = response.json()
|
||||
self._print_table_simple(res_json['data'])
|
||||
else:
|
||||
print(f"Fail to get all users, code: {res_json['code']}, message: {res_json['message']}")
|
||||
@ -377,9 +410,8 @@ class AdminCLI:
|
||||
|
||||
url = f'http://{self.host}:{self.port}/api/v1/admin/users'
|
||||
response = requests.get(url, auth=HTTPBasicAuth(self.admin_account, self.admin_password))
|
||||
res_json = dict
|
||||
res_json = response.json()
|
||||
if response.status_code == 200:
|
||||
res_json = response.json()
|
||||
self._print_table_simple(res_json['data'])
|
||||
else:
|
||||
print(f"Fail to get all users, code: {res_json['code']}, message: {res_json['message']}")
|
||||
@ -388,6 +420,13 @@ class AdminCLI:
|
||||
username_tree: Tree = command['username']
|
||||
username: str = username_tree.children[0].strip("'\"")
|
||||
print(f"Showing user: {username}")
|
||||
url = f'http://{self.host}:{self.port}/api/v1/admin/users/{username}'
|
||||
response = requests.get(url, auth=HTTPBasicAuth(self.admin_account, self.admin_password))
|
||||
res_json = response.json()
|
||||
if response.status_code == 200:
|
||||
self._print_table_simple(res_json['data'])
|
||||
else:
|
||||
print(f"Fail to get user {username}, code: {res_json['code']}, message: {res_json['message']}")
|
||||
|
||||
def _handle_drop_user(self, command):
|
||||
username_tree: Tree = command['username']
|
||||
@ -400,16 +439,73 @@ class AdminCLI:
|
||||
password_tree: Tree = command['password']
|
||||
password: str = password_tree.children[0].strip("'\"")
|
||||
print(f"Alter user: {username}, password: {password}")
|
||||
url = f'http://{self.host}:{self.port}/api/v1/admin/users/{username}/password'
|
||||
response = requests.put(url, auth=HTTPBasicAuth(self.admin_account, self.admin_password), json={'new_password': encrypt(password)})
|
||||
res_json = response.json()
|
||||
if response.status_code == 200:
|
||||
print(res_json["message"])
|
||||
else:
|
||||
print(f"Fail to alter password, code: {res_json['code']}, message: {res_json['message']}")
|
||||
|
||||
def _handle_create_user(self, command):
|
||||
username_tree: Tree = command['username']
|
||||
username: str = username_tree.children[0].strip("'\"")
|
||||
password_tree: Tree = command['password']
|
||||
password: str = password_tree.children[0].strip("'\"")
|
||||
role: str = command['role']
|
||||
print(f"Create user: {username}, password: {password}, role: {role}")
|
||||
url = f'http://{self.host}:{self.port}/api/v1/admin/users'
|
||||
response = requests.post(
|
||||
url,
|
||||
auth=HTTPBasicAuth(self.admin_account, self.admin_password),
|
||||
json={'username': username, 'password': encrypt(password), 'role': role}
|
||||
)
|
||||
res_json = response.json()
|
||||
if response.status_code == 200:
|
||||
self._print_table_simple(res_json['data'])
|
||||
else:
|
||||
print(f"Fail to create user {username}, code: {res_json['code']}, message: {res_json['message']}")
|
||||
|
||||
def _handle_activate_user(self, command):
|
||||
username_tree: Tree = command['username']
|
||||
username: str = username_tree.children[0].strip("'\"")
|
||||
activate_tree: Tree = command['activate_status']
|
||||
activate_status: str = activate_tree.children[0].strip("'\"")
|
||||
if activate_status.lower() in ['on', 'off']:
|
||||
print(f"Alter user {username} activate status, turn {activate_status.lower()}.")
|
||||
url = f'http://{self.host}:{self.port}/api/v1/admin/users/{username}/activate'
|
||||
response = requests.put(url, auth=HTTPBasicAuth(self.admin_account, self.admin_password), json={'activate_status': activate_status})
|
||||
res_json = response.json()
|
||||
if response.status_code == 200:
|
||||
print(res_json["message"])
|
||||
else:
|
||||
print(f"Fail to alter activate status, code: {res_json['code']}, message: {res_json['message']}")
|
||||
else:
|
||||
print(f"Unknown activate status: {activate_status}.")
|
||||
|
||||
def _handle_list_datasets(self, command):
|
||||
username_tree: Tree = command['username']
|
||||
username: str = username_tree.children[0].strip("'\"")
|
||||
print(f"Listing all datasets of user: {username}")
|
||||
url = f'http://{self.host}:{self.port}/api/v1/admin/users/{username}/datasets'
|
||||
response = requests.get(url, auth=HTTPBasicAuth(self.admin_account, self.admin_password))
|
||||
res_json = response.json()
|
||||
if response.status_code == 200:
|
||||
self._print_table_simple(res_json['data'])
|
||||
else:
|
||||
print(f"Fail to get all datasets of {username}, code: {res_json['code']}, message: {res_json['message']}")
|
||||
|
||||
def _handle_list_agents(self, command):
|
||||
username_tree: Tree = command['username']
|
||||
username: str = username_tree.children[0].strip("'\"")
|
||||
print(f"Listing all agents of user: {username}")
|
||||
url = f'http://{self.host}:{self.port}/api/v1/admin/users/{username}/agents'
|
||||
response = requests.get(url, auth=HTTPBasicAuth(self.admin_account, self.admin_password))
|
||||
res_json = response.json()
|
||||
if response.status_code == 200:
|
||||
self._print_table_simple(res_json['data'])
|
||||
else:
|
||||
print(f"Fail to get all agents of {username}, code: {res_json['code']}, message: {res_json['message']}")
|
||||
|
||||
def _handle_meta_command(self, command):
|
||||
meta_command = command['command']
|
||||
|
||||
@ -10,6 +10,7 @@ from flask import Flask
|
||||
from routes import admin_bp
|
||||
from api.utils.log_utils import init_root_logger
|
||||
from api.constants import SERVICE_CONF
|
||||
from api import settings
|
||||
from config import load_configurations, SERVICE_CONFIGS
|
||||
|
||||
stop_event = threading.Event()
|
||||
@ -26,7 +27,7 @@ if __name__ == '__main__':
|
||||
|
||||
app = Flask(__name__)
|
||||
app.register_blueprint(admin_bp)
|
||||
|
||||
settings.init_settings()
|
||||
SERVICE_CONFIGS.configs = load_configurations(SERVICE_CONF)
|
||||
|
||||
try:
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
from flask import Blueprint, request
|
||||
|
||||
from auth import login_verify
|
||||
from responses import success_response, error_response
|
||||
from services import UserMgr, ServiceMgr
|
||||
from services import UserMgr, ServiceMgr, UserServiceMgr
|
||||
from exceptions import AdminException
|
||||
|
||||
admin_bp = Blueprint('admin', __name__, url_prefix='/api/v1/admin')
|
||||
@ -38,13 +39,18 @@ def create_user():
|
||||
password = data['password']
|
||||
role = data.get('role', 'user')
|
||||
|
||||
user = UserMgr.create_user(username, password, role)
|
||||
return success_response(user, "User created successfully", 201)
|
||||
res = UserMgr.create_user(username, password, role)
|
||||
if res["success"]:
|
||||
user_info = res["user_info"]
|
||||
user_info.pop("password") # do not return password
|
||||
return success_response(user_info, "User created successfully")
|
||||
else:
|
||||
return error_response("create user failed")
|
||||
|
||||
except AdminException as e:
|
||||
return error_response(e.message, e.code)
|
||||
except Exception as e:
|
||||
return error_response(str(e), 500)
|
||||
return error_response(str(e))
|
||||
|
||||
|
||||
@admin_bp.route('/users/<username>', methods=['DELETE'])
|
||||
@ -69,8 +75,8 @@ def change_password(username):
|
||||
return error_response("New password is required", 400)
|
||||
|
||||
new_password = data['new_password']
|
||||
UserMgr.update_user_password(username, new_password)
|
||||
return success_response(None, "Password updated successfully")
|
||||
msg = UserMgr.update_user_password(username, new_password)
|
||||
return success_response(None, msg)
|
||||
|
||||
except AdminException as e:
|
||||
return error_response(e.message, e.code)
|
||||
@ -78,6 +84,21 @@ def change_password(username):
|
||||
return error_response(str(e), 500)
|
||||
|
||||
|
||||
@admin_bp.route('/users/<username>/activate', methods=['PUT'])
|
||||
@login_verify
|
||||
def alter_user_activate_status(username):
|
||||
try:
|
||||
data = request.get_json()
|
||||
if not data or 'activate_status' not in data:
|
||||
return error_response("Activation status is required", 400)
|
||||
activate_status = data['activate_status']
|
||||
msg = UserMgr.update_user_activate_status(username, activate_status)
|
||||
return success_response(None, msg)
|
||||
except AdminException as e:
|
||||
return error_response(e.message, e.code)
|
||||
except Exception as e:
|
||||
return error_response(str(e), 500)
|
||||
|
||||
@admin_bp.route('/users/<username>', methods=['GET'])
|
||||
@login_verify
|
||||
def get_user_details(username):
|
||||
@ -90,6 +111,31 @@ def get_user_details(username):
|
||||
except Exception as e:
|
||||
return error_response(str(e), 500)
|
||||
|
||||
@admin_bp.route('/users/<username>/datasets', methods=['GET'])
|
||||
@login_verify
|
||||
def get_user_datasets(username):
|
||||
try:
|
||||
datasets_list = UserServiceMgr.get_user_datasets(username)
|
||||
return success_response(datasets_list)
|
||||
|
||||
except AdminException as e:
|
||||
return error_response(e.message, e.code)
|
||||
except Exception as e:
|
||||
return error_response(str(e), 500)
|
||||
|
||||
|
||||
@admin_bp.route('/users/<username>/agents', methods=['GET'])
|
||||
@login_verify
|
||||
def get_user_agents(username):
|
||||
try:
|
||||
agents_list = UserServiceMgr.get_user_agents(username)
|
||||
return success_response(agents_list)
|
||||
|
||||
except AdminException as e:
|
||||
return error_response(e.message, e.code)
|
||||
except Exception as e:
|
||||
return error_response(str(e), 500)
|
||||
|
||||
|
||||
@admin_bp.route('/services', methods=['GET'])
|
||||
@login_verify
|
||||
|
||||
@ -1,5 +1,13 @@
|
||||
import re
|
||||
from werkzeug.security import check_password_hash
|
||||
from api.db import ActiveEnum
|
||||
from api.db.services import UserService
|
||||
from exceptions import AdminException
|
||||
from api.db.joint_services.user_account_service import create_new_user
|
||||
from api.db.services.canvas_service import UserCanvasService
|
||||
from api.db.services.user_service import TenantService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.utils.crypt import decrypt
|
||||
from exceptions import AdminException, UserAlreadyExistsError, UserNotFoundError
|
||||
from config import SERVICE_CONFIGS
|
||||
|
||||
class UserMgr:
|
||||
@ -13,19 +21,120 @@ class UserMgr:
|
||||
|
||||
@staticmethod
|
||||
def get_user_details(username):
|
||||
raise AdminException("get_user_details: not implemented")
|
||||
# use email to query
|
||||
users = UserService.query_user_by_email(username)
|
||||
result = []
|
||||
for user in users:
|
||||
result.append({
|
||||
'email': user.email,
|
||||
'language': user.language,
|
||||
'last_login_time': user.last_login_time,
|
||||
'is_authenticated': user.is_authenticated,
|
||||
'is_active': user.is_active,
|
||||
'is_anonymous': user.is_anonymous,
|
||||
'login_channel': user.login_channel,
|
||||
'status': user.status,
|
||||
'is_superuser': user.is_superuser,
|
||||
'create_date': user.create_date,
|
||||
'update_date': user.update_date
|
||||
})
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def create_user(username, password, role="user"):
|
||||
raise AdminException("create_user: not implemented")
|
||||
def create_user(username, password, role="user") -> dict:
|
||||
# Validate the email address
|
||||
if not re.match(r"^[\w\._-]+@([\w_-]+\.)+[\w-]{2,}$", username):
|
||||
raise AdminException(f"Invalid email address: {username}!")
|
||||
# Check if the email address is already used
|
||||
if UserService.query(email=username):
|
||||
raise UserAlreadyExistsError(username)
|
||||
# Construct user info data
|
||||
user_info_dict = {
|
||||
"email": username,
|
||||
"nickname": "", # ask user to edit it manually in settings.
|
||||
"password": decrypt(password),
|
||||
"login_channel": "password",
|
||||
"is_superuser": role == "admin",
|
||||
}
|
||||
return create_new_user(user_info_dict)
|
||||
|
||||
@staticmethod
|
||||
def delete_user(username):
|
||||
# use email to delete
|
||||
raise AdminException("delete_user: not implemented")
|
||||
|
||||
@staticmethod
|
||||
def update_user_password(username, new_password):
|
||||
raise AdminException("update_user_password: not implemented")
|
||||
def update_user_password(username, new_password) -> str:
|
||||
# use email to find user. check exist and unique.
|
||||
user_list = UserService.query_user_by_email(username)
|
||||
if not user_list:
|
||||
raise UserNotFoundError(username)
|
||||
elif len(user_list) > 1:
|
||||
raise AdminException(f"Exist more than 1 user: {username}!")
|
||||
# check new_password different from old.
|
||||
usr = user_list[0]
|
||||
psw = decrypt(new_password)
|
||||
if check_password_hash(usr.password, psw):
|
||||
return "Same password, no need to update!"
|
||||
# update password
|
||||
UserService.update_user_password(usr.id, psw)
|
||||
return "Password updated successfully!"
|
||||
|
||||
@staticmethod
|
||||
def update_user_activate_status(username, activate_status: str):
|
||||
# use email to find user. check exist and unique.
|
||||
user_list = UserService.query_user_by_email(username)
|
||||
if not user_list:
|
||||
raise UserNotFoundError(username)
|
||||
elif len(user_list) > 1:
|
||||
raise AdminException(f"Exist more than 1 user: {username}!")
|
||||
# check activate status different from new
|
||||
usr = user_list[0]
|
||||
# format activate_status before handle
|
||||
_activate_status = activate_status.lower()
|
||||
target_status = {
|
||||
'on': ActiveEnum.ACTIVE.value,
|
||||
'off': ActiveEnum.INACTIVE.value,
|
||||
}.get(_activate_status)
|
||||
if not target_status:
|
||||
raise AdminException(f"Invalid activate_status: {activate_status}")
|
||||
if target_status == usr.is_active:
|
||||
return f"User activate status is already {_activate_status}!"
|
||||
# update is_active
|
||||
UserService.update_user(usr.id, {"is_active": target_status})
|
||||
return f"Turn {_activate_status} user activate status successfully!"
|
||||
|
||||
class UserServiceMgr:
|
||||
|
||||
@staticmethod
|
||||
def get_user_datasets(username):
|
||||
# use email to find user.
|
||||
user_list = UserService.query_user_by_email(username)
|
||||
if not user_list:
|
||||
raise UserNotFoundError(username)
|
||||
elif len(user_list) > 1:
|
||||
raise AdminException(f"Exist more than 1 user: {username}!")
|
||||
# find tenants
|
||||
usr = user_list[0]
|
||||
tenants = TenantService.get_joined_tenants_by_user_id(usr.id)
|
||||
tenant_ids = [m["tenant_id"] for m in tenants]
|
||||
# filter permitted kb and owned kb
|
||||
return KnowledgebaseService.get_all_kb_by_tenant_ids(tenant_ids, usr.id)
|
||||
|
||||
@staticmethod
|
||||
def get_user_agents(username):
|
||||
# use email to find user.
|
||||
user_list = UserService.query_user_by_email(username)
|
||||
if not user_list:
|
||||
raise UserNotFoundError(username)
|
||||
elif len(user_list) > 1:
|
||||
raise AdminException(f"Exist more than 1 user: {username}!")
|
||||
# find tenants
|
||||
usr = user_list[0]
|
||||
tenants = TenantService.get_joined_tenants_by_user_id(usr.id)
|
||||
tenant_ids = [m["tenant_id"] for m in tenants]
|
||||
# filter permitted agents and owned agents
|
||||
return UserCanvasService.get_all_agents_by_tenant_ids(tenant_ids, usr.id)
|
||||
|
||||
class ServiceMgr:
|
||||
|
||||
|
||||
@ -23,7 +23,7 @@ from api.db.services.document_service import DocumentService
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.user_service import TenantService, UserTenantService
|
||||
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request, not_allowed_parameters
|
||||
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request, not_allowed_parameters, active_required
|
||||
from api.utils import get_uuid
|
||||
from api.db import StatusEnum, FileSource
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
@ -38,6 +38,7 @@ from rag.utils.storage_factory import STORAGE_IMPL
|
||||
|
||||
@manager.route('/create', methods=['post']) # noqa: F821
|
||||
@login_required
|
||||
@active_required
|
||||
@validate_request("name")
|
||||
def create():
|
||||
req = request.json
|
||||
|
||||
@ -23,6 +23,11 @@ class StatusEnum(Enum):
|
||||
INVALID = "0"
|
||||
|
||||
|
||||
class ActiveEnum(Enum):
|
||||
ACTIVE = "1"
|
||||
INACTIVE = "0"
|
||||
|
||||
|
||||
class UserTenantRole(StrEnum):
|
||||
OWNER = 'owner'
|
||||
ADMIN = 'admin'
|
||||
|
||||
0
api/db/joint_services/__init__.py
Normal file
0
api/db/joint_services/__init__.py
Normal file
120
api/db/joint_services/user_account_service.py
Normal file
120
api/db/joint_services/user_account_service.py
Normal file
@ -0,0 +1,120 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
from api import settings
|
||||
from api.db import FileType, UserTenantRole
|
||||
from api.db.db_models import TenantLLM
|
||||
from api.db.services.llm_service import get_init_tenant_llm
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from api.db.services.user_service import TenantService, UserService, UserTenantService
|
||||
|
||||
|
||||
|
||||
def create_new_user(user_info: dict) -> dict:
|
||||
"""
|
||||
Add a new user, and create tenant, tenant llm, file folder for new user.
|
||||
:param user_info: {
|
||||
"email": <example@example.com>,
|
||||
"nickname": <str, "name">,
|
||||
"password": <decrypted password>,
|
||||
"login_channel": <enum, "password">,
|
||||
"is_superuser": <bool, role == "admin">,
|
||||
}
|
||||
:return: {
|
||||
"success": <bool>,
|
||||
"user_info": <dict>, # if true, return user_info
|
||||
}
|
||||
"""
|
||||
# generate user_id and access_token for user
|
||||
user_id = uuid.uuid1().hex
|
||||
user_info['id'] = user_id
|
||||
user_info['access_token'] = uuid.uuid1().hex
|
||||
# construct tenant info
|
||||
tenant = {
|
||||
"id": user_id,
|
||||
"name": user_info["nickname"] + "‘s Kingdom",
|
||||
"llm_id": settings.CHAT_MDL,
|
||||
"embd_id": settings.EMBEDDING_MDL,
|
||||
"asr_id": settings.ASR_MDL,
|
||||
"parser_ids": settings.PARSERS,
|
||||
"img2txt_id": settings.IMAGE2TEXT_MDL,
|
||||
"rerank_id": settings.RERANK_MDL,
|
||||
}
|
||||
usr_tenant = {
|
||||
"tenant_id": user_id,
|
||||
"user_id": user_id,
|
||||
"invited_by": user_id,
|
||||
"role": UserTenantRole.OWNER,
|
||||
}
|
||||
# construct file folder info
|
||||
file_id = uuid.uuid1().hex
|
||||
file = {
|
||||
"id": file_id,
|
||||
"parent_id": file_id,
|
||||
"tenant_id": user_id,
|
||||
"created_by": user_id,
|
||||
"name": "/",
|
||||
"type": FileType.FOLDER.value,
|
||||
"size": 0,
|
||||
"location": "",
|
||||
}
|
||||
try:
|
||||
tenant_llm = get_init_tenant_llm(user_id)
|
||||
|
||||
if not UserService.save(**user_info):
|
||||
return {"success": False}
|
||||
|
||||
TenantService.insert(**tenant)
|
||||
UserTenantService.insert(**usr_tenant)
|
||||
TenantLLMService.insert_many(tenant_llm)
|
||||
FileService.insert(file)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"user_info": user_info,
|
||||
}
|
||||
|
||||
except Exception as create_error:
|
||||
logging.exception(create_error)
|
||||
# rollback
|
||||
try:
|
||||
TenantService.delete_by_id(user_id)
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
try:
|
||||
u = UserTenantService.query(tenant_id=user_id)
|
||||
if u:
|
||||
UserTenantService.delete_by_id(u[0].id)
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
try:
|
||||
TenantLLM.delete().where(TenantLLM.tenant_id == user_id).execute()
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
try:
|
||||
FileService.delete_by_id(file["id"])
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
# delete user row finally
|
||||
try:
|
||||
UserService.delete_by_id(user_id)
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
# reraise
|
||||
raise create_error
|
||||
@ -61,6 +61,36 @@ class UserCanvasService(CommonService):
|
||||
|
||||
return list(agents.dicts())
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_all_agents_by_tenant_ids(cls, tenant_ids, user_id):
|
||||
# will get all permitted agents, be cautious
|
||||
fields = [
|
||||
cls.model.title,
|
||||
cls.model.permission,
|
||||
cls.model.canvas_type,
|
||||
cls.model.canvas_category
|
||||
]
|
||||
# find team agents and owned agents
|
||||
agents = cls.model.select(*fields).where(
|
||||
(cls.model.user_id.in_(tenant_ids) & (cls.model.permission == TenantPermission.TEAM.value)) | (
|
||||
cls.model.user_id == user_id
|
||||
)
|
||||
)
|
||||
# sort by create_time, asc
|
||||
agents.order_by(cls.model.create_time.asc())
|
||||
# maybe cause slow query by deep paginate, optimize later
|
||||
offset, limit = 0, 50
|
||||
res = []
|
||||
while True:
|
||||
ag_batch = agents.offset(offset).limit(limit)
|
||||
_temp = list(ag_batch.dicts())
|
||||
if not _temp:
|
||||
break
|
||||
res.extend(_temp)
|
||||
offset += limit
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_by_tenant_id(cls, pid):
|
||||
|
||||
@ -190,6 +190,41 @@ class KnowledgebaseService(CommonService):
|
||||
|
||||
return list(kbs.dicts()), count
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_all_kb_by_tenant_ids(cls, tenant_ids, user_id):
|
||||
# will get all permitted kb, be cautious.
|
||||
fields = [
|
||||
cls.model.name,
|
||||
cls.model.language,
|
||||
cls.model.permission,
|
||||
cls.model.doc_num,
|
||||
cls.model.token_num,
|
||||
cls.model.chunk_num,
|
||||
cls.model.status,
|
||||
cls.model.create_date,
|
||||
cls.model.update_date
|
||||
]
|
||||
# find team kb and owned kb
|
||||
kbs = cls.model.select(*fields).where(
|
||||
(cls.model.tenant_id.in_(tenant_ids) & (cls.model.permission ==TenantPermission.TEAM.value)) | (
|
||||
cls.model.tenant_id == user_id
|
||||
)
|
||||
)
|
||||
# sort by create_time asc
|
||||
kbs.order_by(cls.model.create_time.asc())
|
||||
# maybe cause slow query by deep paginate, optimize later.
|
||||
offset, limit = 0, 50
|
||||
res = []
|
||||
while True:
|
||||
kb_batch = kbs.offset(offset).limit(limit)
|
||||
_temp = list(kb_batch.dicts())
|
||||
if not _temp:
|
||||
break
|
||||
res.extend(_temp)
|
||||
offset += limit
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_kb_ids(cls, tenant_id):
|
||||
|
||||
@ -100,6 +100,12 @@ class UserService(CommonService):
|
||||
else:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def query_user_by_email(cls, email):
|
||||
users = cls.model.select().where((cls.model.email == email))
|
||||
return list(users)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def save(cls, **kwargs):
|
||||
@ -133,6 +139,17 @@ class UserService(CommonService):
|
||||
cls.model.update(user_dict).where(
|
||||
cls.model.id == user_id).execute()
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def update_user_password(cls, user_id, new_password):
|
||||
with DB.atomic():
|
||||
update_dict = {
|
||||
"password": generate_password_hash(str(new_password)),
|
||||
"update_time": current_timestamp(),
|
||||
"update_date": datetime_format(datetime.now())
|
||||
}
|
||||
cls.model.update(update_dict).where(cls.model.id == user_id).execute()
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def is_admin(cls, user_id):
|
||||
|
||||
@ -39,6 +39,7 @@ from flask import (
|
||||
make_response,
|
||||
send_file,
|
||||
)
|
||||
from flask_login import current_user
|
||||
from flask import (
|
||||
request as flask_request,
|
||||
)
|
||||
@ -48,7 +49,9 @@ from werkzeug.http import HTTP_STATUS_CODES
|
||||
|
||||
from api import settings
|
||||
from api.constants import REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC
|
||||
from api.db import ActiveEnum
|
||||
from api.db.db_models import APIToken
|
||||
from api.db.services import UserService
|
||||
from api.db.services.llm_service import LLMService
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from api.utils import CustomJSONEncoder, get_uuid, json_dumps
|
||||
@ -226,6 +229,18 @@ def not_allowed_parameters(*params):
|
||||
return decorator
|
||||
|
||||
|
||||
def active_required(f):
|
||||
@wraps(f)
|
||||
def wrapper(*args, **kwargs):
|
||||
user_id = current_user.id
|
||||
usr = UserService.filter_by_id(user_id)
|
||||
# check is_active
|
||||
if not usr or not usr.is_active == ActiveEnum.ACTIVE.value:
|
||||
return get_json_result(code=settings.RetCode.FORBIDDEN, message="User isn't active, please activate first.")
|
||||
return f(*args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
|
||||
def is_localhost(ip):
|
||||
return ip in {"127.0.0.1", "::1", "[::1]", "localhost"}
|
||||
|
||||
|
||||
@ -23,6 +23,9 @@ from api.utils import file_utils
|
||||
|
||||
|
||||
def crypt(line):
|
||||
"""
|
||||
decrypt(crypt(input_string)) == base64(input_string), which frontend and admin_client use.
|
||||
"""
|
||||
file_path = os.path.join(file_utils.get_project_base_directory(), "conf", "public.pem")
|
||||
rsa_key = RSA.importKey(open(file_path).read(), "Welcome")
|
||||
cipher = Cipher_pkcs1_v1_5.new(rsa_key)
|
||||
|
||||
@ -949,7 +949,8 @@ export default {
|
||||
multimodalModels: 'Мультимодальные модели',
|
||||
textOnlyModels: 'Только текстовые модели',
|
||||
allModels: 'Все модели',
|
||||
codeExecDescription: 'Напишите свою пользовательскую логику на Python или Javascript.',
|
||||
codeExecDescription:
|
||||
'Напишите свою пользовательскую логику на Python или Javascript.',
|
||||
stringTransformDescription:
|
||||
'Изменяет текстовое содержимое. В настоящее время поддерживает: разделение или объединение текста.',
|
||||
foundation: 'Основа',
|
||||
|
||||
Reference in New Issue
Block a user