mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-02-03 17:15:08 +08:00
Compare commits
5 Commits
4aa1abd8e5
...
b0b866c8fd
| Author | SHA1 | Date | |
|---|---|---|---|
| b0b866c8fd | |||
| 3a831d0c28 | |||
| 9e323a9351 | |||
| 7ac95b759b | |||
| daea357940 |
@ -1,5 +1,7 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import base64
|
import base64
|
||||||
|
from Cryptodome.PublicKey import RSA
|
||||||
|
from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5
|
||||||
from typing import Dict, List, Any
|
from typing import Dict, List, Any
|
||||||
from lark import Lark, Transformer, Tree
|
from lark import Lark, Transformer, Tree
|
||||||
import requests
|
import requests
|
||||||
@ -19,6 +21,8 @@ sql_command: list_services
|
|||||||
| show_user
|
| show_user
|
||||||
| drop_user
|
| drop_user
|
||||||
| alter_user
|
| alter_user
|
||||||
|
| create_user
|
||||||
|
| activate_user
|
||||||
| list_datasets
|
| list_datasets
|
||||||
| list_agents
|
| list_agents
|
||||||
|
|
||||||
@ -35,6 +39,7 @@ meta_arg: /[^\\s"']+/ | quoted_string
|
|||||||
LIST: "LIST"i
|
LIST: "LIST"i
|
||||||
SERVICES: "SERVICES"i
|
SERVICES: "SERVICES"i
|
||||||
SHOW: "SHOW"i
|
SHOW: "SHOW"i
|
||||||
|
CREATE: "CREATE"i
|
||||||
SERVICE: "SERVICE"i
|
SERVICE: "SERVICE"i
|
||||||
SHUTDOWN: "SHUTDOWN"i
|
SHUTDOWN: "SHUTDOWN"i
|
||||||
STARTUP: "STARTUP"i
|
STARTUP: "STARTUP"i
|
||||||
@ -43,6 +48,7 @@ USERS: "USERS"i
|
|||||||
DROP: "DROP"i
|
DROP: "DROP"i
|
||||||
USER: "USER"i
|
USER: "USER"i
|
||||||
ALTER: "ALTER"i
|
ALTER: "ALTER"i
|
||||||
|
ACTIVE: "ACTIVE"i
|
||||||
PASSWORD: "PASSWORD"i
|
PASSWORD: "PASSWORD"i
|
||||||
DATASETS: "DATASETS"i
|
DATASETS: "DATASETS"i
|
||||||
OF: "OF"i
|
OF: "OF"i
|
||||||
@ -58,12 +64,15 @@ list_users: LIST USERS ";"
|
|||||||
drop_user: DROP USER quoted_string ";"
|
drop_user: DROP USER quoted_string ";"
|
||||||
alter_user: ALTER USER PASSWORD quoted_string quoted_string ";"
|
alter_user: ALTER USER PASSWORD quoted_string quoted_string ";"
|
||||||
show_user: SHOW USER quoted_string ";"
|
show_user: SHOW USER quoted_string ";"
|
||||||
|
create_user: CREATE USER quoted_string quoted_string ";"
|
||||||
|
activate_user: ALTER USER ACTIVE quoted_string status ";"
|
||||||
|
|
||||||
list_datasets: LIST DATASETS OF quoted_string ";"
|
list_datasets: LIST DATASETS OF quoted_string ";"
|
||||||
list_agents: LIST AGENTS OF quoted_string ";"
|
list_agents: LIST AGENTS OF quoted_string ";"
|
||||||
|
|
||||||
identifier: WORD
|
identifier: WORD
|
||||||
quoted_string: QUOTED_STRING
|
quoted_string: QUOTED_STRING
|
||||||
|
status: WORD
|
||||||
|
|
||||||
QUOTED_STRING: /'[^']+'/ | /"[^"]+"/
|
QUOTED_STRING: /'[^']+'/ | /"[^"]+"/
|
||||||
WORD: /[a-zA-Z0-9_\-\.]+/
|
WORD: /[a-zA-Z0-9_\-\.]+/
|
||||||
@ -118,6 +127,16 @@ class AdminTransformer(Transformer):
|
|||||||
new_password = items[4]
|
new_password = items[4]
|
||||||
return {"type": "alter_user", "username": user_name, "password": new_password}
|
return {"type": "alter_user", "username": user_name, "password": new_password}
|
||||||
|
|
||||||
|
def create_user(self, items):
|
||||||
|
user_name = items[2]
|
||||||
|
password = items[3]
|
||||||
|
return {"type": "create_user", "username": user_name, "password": password, "role": "user"}
|
||||||
|
|
||||||
|
def activate_user(self, items):
|
||||||
|
user_name = items[3]
|
||||||
|
activate_status = items[4]
|
||||||
|
return {"type": "activate_user", "activate_status": activate_status, "username": user_name}
|
||||||
|
|
||||||
def list_datasets(self, items):
|
def list_datasets(self, items):
|
||||||
user_name = items[3]
|
user_name = items[3]
|
||||||
return {"type": "list_datasets", "username": user_name}
|
return {"type": "list_datasets", "username": user_name}
|
||||||
@ -152,6 +171,14 @@ def encode_to_base64(input_string):
|
|||||||
return base64_encoded.decode('utf-8')
|
return base64_encoded.decode('utf-8')
|
||||||
|
|
||||||
|
|
||||||
|
def encrypt(input_string):
|
||||||
|
pub = '-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEArq9XTUSeYr2+N1h3Afl/z8Dse/2yD0ZGrKwx+EEEcdsBLca9Ynmx3nIB5obmLlSfmskLpBo0UACBmB5rEjBp2Q2f3AG3Hjd4B+gNCG6BDaawuDlgANIhGnaTLrIqWrrcm4EMzJOnAOI1fgzJRsOOUEfaS318Eq9OVO3apEyCCt0lOQK6PuksduOjVxtltDav+guVAA068NrPYmRNabVKRNLJpL8w4D44sfth5RvZ3q9t+6RTArpEtc5sh5ChzvqPOzKGMXW83C95TxmXqpbK6olN4RevSfVjEAgCydH6HN6OhtOQEcnrU97r9H0iZOWwbw3pVrZiUkuRD1R56Wzs2wIDAQAB\n-----END PUBLIC KEY-----'
|
||||||
|
pub_key = RSA.importKey(pub)
|
||||||
|
cipher = Cipher_pkcs1_v1_5.new(pub_key)
|
||||||
|
cipher_text = cipher.encrypt(base64.b64encode(input_string.encode('utf-8')))
|
||||||
|
return base64.b64encode(cipher_text).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
class AdminCommandParser:
|
class AdminCommandParser:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.parser = Lark(GRAMMAR, start='start', parser='lalr', transformer=AdminTransformer())
|
self.parser = Lark(GRAMMAR, start='start', parser='lalr', transformer=AdminTransformer())
|
||||||
@ -220,6 +247,9 @@ class AdminCLI:
|
|||||||
if not data:
|
if not data:
|
||||||
print("No data to print")
|
print("No data to print")
|
||||||
return
|
return
|
||||||
|
if isinstance(data, dict):
|
||||||
|
# handle single row data
|
||||||
|
data = [data]
|
||||||
|
|
||||||
columns = list(data[0].keys())
|
columns = list(data[0].keys())
|
||||||
col_widths = {}
|
col_widths = {}
|
||||||
@ -335,6 +365,10 @@ class AdminCLI:
|
|||||||
self._handle_drop_user(command_dict)
|
self._handle_drop_user(command_dict)
|
||||||
case 'alter_user':
|
case 'alter_user':
|
||||||
self._handle_alter_user(command_dict)
|
self._handle_alter_user(command_dict)
|
||||||
|
case 'create_user':
|
||||||
|
self._handle_create_user(command_dict)
|
||||||
|
case 'activate_user':
|
||||||
|
self._handle_activate_user(command_dict)
|
||||||
case 'list_datasets':
|
case 'list_datasets':
|
||||||
self._handle_list_datasets(command_dict)
|
self._handle_list_datasets(command_dict)
|
||||||
case 'list_agents':
|
case 'list_agents':
|
||||||
@ -349,9 +383,8 @@ class AdminCLI:
|
|||||||
|
|
||||||
url = f'http://{self.host}:{self.port}/api/v1/admin/services'
|
url = f'http://{self.host}:{self.port}/api/v1/admin/services'
|
||||||
response = requests.get(url, auth=HTTPBasicAuth(self.admin_account, self.admin_password))
|
response = requests.get(url, auth=HTTPBasicAuth(self.admin_account, self.admin_password))
|
||||||
res_json = dict
|
res_json = response.json()
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
res_json = response.json()
|
|
||||||
self._print_table_simple(res_json['data'])
|
self._print_table_simple(res_json['data'])
|
||||||
else:
|
else:
|
||||||
print(f"Fail to get all users, code: {res_json['code']}, message: {res_json['message']}")
|
print(f"Fail to get all users, code: {res_json['code']}, message: {res_json['message']}")
|
||||||
@ -377,9 +410,8 @@ class AdminCLI:
|
|||||||
|
|
||||||
url = f'http://{self.host}:{self.port}/api/v1/admin/users'
|
url = f'http://{self.host}:{self.port}/api/v1/admin/users'
|
||||||
response = requests.get(url, auth=HTTPBasicAuth(self.admin_account, self.admin_password))
|
response = requests.get(url, auth=HTTPBasicAuth(self.admin_account, self.admin_password))
|
||||||
res_json = dict
|
res_json = response.json()
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
res_json = response.json()
|
|
||||||
self._print_table_simple(res_json['data'])
|
self._print_table_simple(res_json['data'])
|
||||||
else:
|
else:
|
||||||
print(f"Fail to get all users, code: {res_json['code']}, message: {res_json['message']}")
|
print(f"Fail to get all users, code: {res_json['code']}, message: {res_json['message']}")
|
||||||
@ -388,6 +420,13 @@ class AdminCLI:
|
|||||||
username_tree: Tree = command['username']
|
username_tree: Tree = command['username']
|
||||||
username: str = username_tree.children[0].strip("'\"")
|
username: str = username_tree.children[0].strip("'\"")
|
||||||
print(f"Showing user: {username}")
|
print(f"Showing user: {username}")
|
||||||
|
url = f'http://{self.host}:{self.port}/api/v1/admin/users/{username}'
|
||||||
|
response = requests.get(url, auth=HTTPBasicAuth(self.admin_account, self.admin_password))
|
||||||
|
res_json = response.json()
|
||||||
|
if response.status_code == 200:
|
||||||
|
self._print_table_simple(res_json['data'])
|
||||||
|
else:
|
||||||
|
print(f"Fail to get user {username}, code: {res_json['code']}, message: {res_json['message']}")
|
||||||
|
|
||||||
def _handle_drop_user(self, command):
|
def _handle_drop_user(self, command):
|
||||||
username_tree: Tree = command['username']
|
username_tree: Tree = command['username']
|
||||||
@ -400,16 +439,73 @@ class AdminCLI:
|
|||||||
password_tree: Tree = command['password']
|
password_tree: Tree = command['password']
|
||||||
password: str = password_tree.children[0].strip("'\"")
|
password: str = password_tree.children[0].strip("'\"")
|
||||||
print(f"Alter user: {username}, password: {password}")
|
print(f"Alter user: {username}, password: {password}")
|
||||||
|
url = f'http://{self.host}:{self.port}/api/v1/admin/users/{username}/password'
|
||||||
|
response = requests.put(url, auth=HTTPBasicAuth(self.admin_account, self.admin_password), json={'new_password': encrypt(password)})
|
||||||
|
res_json = response.json()
|
||||||
|
if response.status_code == 200:
|
||||||
|
print(res_json["message"])
|
||||||
|
else:
|
||||||
|
print(f"Fail to alter password, code: {res_json['code']}, message: {res_json['message']}")
|
||||||
|
|
||||||
|
def _handle_create_user(self, command):
|
||||||
|
username_tree: Tree = command['username']
|
||||||
|
username: str = username_tree.children[0].strip("'\"")
|
||||||
|
password_tree: Tree = command['password']
|
||||||
|
password: str = password_tree.children[0].strip("'\"")
|
||||||
|
role: str = command['role']
|
||||||
|
print(f"Create user: {username}, password: {password}, role: {role}")
|
||||||
|
url = f'http://{self.host}:{self.port}/api/v1/admin/users'
|
||||||
|
response = requests.post(
|
||||||
|
url,
|
||||||
|
auth=HTTPBasicAuth(self.admin_account, self.admin_password),
|
||||||
|
json={'username': username, 'password': encrypt(password), 'role': role}
|
||||||
|
)
|
||||||
|
res_json = response.json()
|
||||||
|
if response.status_code == 200:
|
||||||
|
self._print_table_simple(res_json['data'])
|
||||||
|
else:
|
||||||
|
print(f"Fail to create user {username}, code: {res_json['code']}, message: {res_json['message']}")
|
||||||
|
|
||||||
|
def _handle_activate_user(self, command):
|
||||||
|
username_tree: Tree = command['username']
|
||||||
|
username: str = username_tree.children[0].strip("'\"")
|
||||||
|
activate_tree: Tree = command['activate_status']
|
||||||
|
activate_status: str = activate_tree.children[0].strip("'\"")
|
||||||
|
if activate_status.lower() in ['on', 'off']:
|
||||||
|
print(f"Alter user {username} activate status, turn {activate_status.lower()}.")
|
||||||
|
url = f'http://{self.host}:{self.port}/api/v1/admin/users/{username}/activate'
|
||||||
|
response = requests.put(url, auth=HTTPBasicAuth(self.admin_account, self.admin_password), json={'activate_status': activate_status})
|
||||||
|
res_json = response.json()
|
||||||
|
if response.status_code == 200:
|
||||||
|
print(res_json["message"])
|
||||||
|
else:
|
||||||
|
print(f"Fail to alter activate status, code: {res_json['code']}, message: {res_json['message']}")
|
||||||
|
else:
|
||||||
|
print(f"Unknown activate status: {activate_status}.")
|
||||||
|
|
||||||
def _handle_list_datasets(self, command):
|
def _handle_list_datasets(self, command):
|
||||||
username_tree: Tree = command['username']
|
username_tree: Tree = command['username']
|
||||||
username: str = username_tree.children[0].strip("'\"")
|
username: str = username_tree.children[0].strip("'\"")
|
||||||
print(f"Listing all datasets of user: {username}")
|
print(f"Listing all datasets of user: {username}")
|
||||||
|
url = f'http://{self.host}:{self.port}/api/v1/admin/users/{username}/datasets'
|
||||||
|
response = requests.get(url, auth=HTTPBasicAuth(self.admin_account, self.admin_password))
|
||||||
|
res_json = response.json()
|
||||||
|
if response.status_code == 200:
|
||||||
|
self._print_table_simple(res_json['data'])
|
||||||
|
else:
|
||||||
|
print(f"Fail to get all datasets of {username}, code: {res_json['code']}, message: {res_json['message']}")
|
||||||
|
|
||||||
def _handle_list_agents(self, command):
|
def _handle_list_agents(self, command):
|
||||||
username_tree: Tree = command['username']
|
username_tree: Tree = command['username']
|
||||||
username: str = username_tree.children[0].strip("'\"")
|
username: str = username_tree.children[0].strip("'\"")
|
||||||
print(f"Listing all agents of user: {username}")
|
print(f"Listing all agents of user: {username}")
|
||||||
|
url = f'http://{self.host}:{self.port}/api/v1/admin/users/{username}/agents'
|
||||||
|
response = requests.get(url, auth=HTTPBasicAuth(self.admin_account, self.admin_password))
|
||||||
|
res_json = response.json()
|
||||||
|
if response.status_code == 200:
|
||||||
|
self._print_table_simple(res_json['data'])
|
||||||
|
else:
|
||||||
|
print(f"Fail to get all agents of {username}, code: {res_json['code']}, message: {res_json['message']}")
|
||||||
|
|
||||||
def _handle_meta_command(self, command):
|
def _handle_meta_command(self, command):
|
||||||
meta_command = command['command']
|
meta_command = command['command']
|
||||||
|
|||||||
@ -10,6 +10,7 @@ from flask import Flask
|
|||||||
from routes import admin_bp
|
from routes import admin_bp
|
||||||
from api.utils.log_utils import init_root_logger
|
from api.utils.log_utils import init_root_logger
|
||||||
from api.constants import SERVICE_CONF
|
from api.constants import SERVICE_CONF
|
||||||
|
from api import settings
|
||||||
from config import load_configurations, SERVICE_CONFIGS
|
from config import load_configurations, SERVICE_CONFIGS
|
||||||
|
|
||||||
stop_event = threading.Event()
|
stop_event = threading.Event()
|
||||||
@ -26,7 +27,7 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
app.register_blueprint(admin_bp)
|
app.register_blueprint(admin_bp)
|
||||||
|
settings.init_settings()
|
||||||
SERVICE_CONFIGS.configs = load_configurations(SERVICE_CONF)
|
SERVICE_CONFIGS.configs = load_configurations(SERVICE_CONF)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -1,7 +1,8 @@
|
|||||||
from flask import Blueprint, request
|
from flask import Blueprint, request
|
||||||
|
|
||||||
from auth import login_verify
|
from auth import login_verify
|
||||||
from responses import success_response, error_response
|
from responses import success_response, error_response
|
||||||
from services import UserMgr, ServiceMgr
|
from services import UserMgr, ServiceMgr, UserServiceMgr
|
||||||
from exceptions import AdminException
|
from exceptions import AdminException
|
||||||
|
|
||||||
admin_bp = Blueprint('admin', __name__, url_prefix='/api/v1/admin')
|
admin_bp = Blueprint('admin', __name__, url_prefix='/api/v1/admin')
|
||||||
@ -38,13 +39,18 @@ def create_user():
|
|||||||
password = data['password']
|
password = data['password']
|
||||||
role = data.get('role', 'user')
|
role = data.get('role', 'user')
|
||||||
|
|
||||||
user = UserMgr.create_user(username, password, role)
|
res = UserMgr.create_user(username, password, role)
|
||||||
return success_response(user, "User created successfully", 201)
|
if res["success"]:
|
||||||
|
user_info = res["user_info"]
|
||||||
|
user_info.pop("password") # do not return password
|
||||||
|
return success_response(user_info, "User created successfully")
|
||||||
|
else:
|
||||||
|
return error_response("create user failed")
|
||||||
|
|
||||||
except AdminException as e:
|
except AdminException as e:
|
||||||
return error_response(e.message, e.code)
|
return error_response(e.message, e.code)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return error_response(str(e), 500)
|
return error_response(str(e))
|
||||||
|
|
||||||
|
|
||||||
@admin_bp.route('/users/<username>', methods=['DELETE'])
|
@admin_bp.route('/users/<username>', methods=['DELETE'])
|
||||||
@ -69,8 +75,8 @@ def change_password(username):
|
|||||||
return error_response("New password is required", 400)
|
return error_response("New password is required", 400)
|
||||||
|
|
||||||
new_password = data['new_password']
|
new_password = data['new_password']
|
||||||
UserMgr.update_user_password(username, new_password)
|
msg = UserMgr.update_user_password(username, new_password)
|
||||||
return success_response(None, "Password updated successfully")
|
return success_response(None, msg)
|
||||||
|
|
||||||
except AdminException as e:
|
except AdminException as e:
|
||||||
return error_response(e.message, e.code)
|
return error_response(e.message, e.code)
|
||||||
@ -78,6 +84,21 @@ def change_password(username):
|
|||||||
return error_response(str(e), 500)
|
return error_response(str(e), 500)
|
||||||
|
|
||||||
|
|
||||||
|
@admin_bp.route('/users/<username>/activate', methods=['PUT'])
|
||||||
|
@login_verify
|
||||||
|
def alter_user_activate_status(username):
|
||||||
|
try:
|
||||||
|
data = request.get_json()
|
||||||
|
if not data or 'activate_status' not in data:
|
||||||
|
return error_response("Activation status is required", 400)
|
||||||
|
activate_status = data['activate_status']
|
||||||
|
msg = UserMgr.update_user_activate_status(username, activate_status)
|
||||||
|
return success_response(None, msg)
|
||||||
|
except AdminException as e:
|
||||||
|
return error_response(e.message, e.code)
|
||||||
|
except Exception as e:
|
||||||
|
return error_response(str(e), 500)
|
||||||
|
|
||||||
@admin_bp.route('/users/<username>', methods=['GET'])
|
@admin_bp.route('/users/<username>', methods=['GET'])
|
||||||
@login_verify
|
@login_verify
|
||||||
def get_user_details(username):
|
def get_user_details(username):
|
||||||
@ -90,6 +111,31 @@ def get_user_details(username):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
return error_response(str(e), 500)
|
return error_response(str(e), 500)
|
||||||
|
|
||||||
|
@admin_bp.route('/users/<username>/datasets', methods=['GET'])
|
||||||
|
@login_verify
|
||||||
|
def get_user_datasets(username):
|
||||||
|
try:
|
||||||
|
datasets_list = UserServiceMgr.get_user_datasets(username)
|
||||||
|
return success_response(datasets_list)
|
||||||
|
|
||||||
|
except AdminException as e:
|
||||||
|
return error_response(e.message, e.code)
|
||||||
|
except Exception as e:
|
||||||
|
return error_response(str(e), 500)
|
||||||
|
|
||||||
|
|
||||||
|
@admin_bp.route('/users/<username>/agents', methods=['GET'])
|
||||||
|
@login_verify
|
||||||
|
def get_user_agents(username):
|
||||||
|
try:
|
||||||
|
agents_list = UserServiceMgr.get_user_agents(username)
|
||||||
|
return success_response(agents_list)
|
||||||
|
|
||||||
|
except AdminException as e:
|
||||||
|
return error_response(e.message, e.code)
|
||||||
|
except Exception as e:
|
||||||
|
return error_response(str(e), 500)
|
||||||
|
|
||||||
|
|
||||||
@admin_bp.route('/services', methods=['GET'])
|
@admin_bp.route('/services', methods=['GET'])
|
||||||
@login_verify
|
@login_verify
|
||||||
|
|||||||
@ -1,5 +1,13 @@
|
|||||||
|
import re
|
||||||
|
from werkzeug.security import check_password_hash
|
||||||
|
from api.db import ActiveEnum
|
||||||
from api.db.services import UserService
|
from api.db.services import UserService
|
||||||
from exceptions import AdminException
|
from api.db.joint_services.user_account_service import create_new_user
|
||||||
|
from api.db.services.canvas_service import UserCanvasService
|
||||||
|
from api.db.services.user_service import TenantService
|
||||||
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
|
from api.utils.crypt import decrypt
|
||||||
|
from exceptions import AdminException, UserAlreadyExistsError, UserNotFoundError
|
||||||
from config import SERVICE_CONFIGS
|
from config import SERVICE_CONFIGS
|
||||||
|
|
||||||
class UserMgr:
|
class UserMgr:
|
||||||
@ -13,19 +21,120 @@ class UserMgr:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_user_details(username):
|
def get_user_details(username):
|
||||||
raise AdminException("get_user_details: not implemented")
|
# use email to query
|
||||||
|
users = UserService.query_user_by_email(username)
|
||||||
|
result = []
|
||||||
|
for user in users:
|
||||||
|
result.append({
|
||||||
|
'email': user.email,
|
||||||
|
'language': user.language,
|
||||||
|
'last_login_time': user.last_login_time,
|
||||||
|
'is_authenticated': user.is_authenticated,
|
||||||
|
'is_active': user.is_active,
|
||||||
|
'is_anonymous': user.is_anonymous,
|
||||||
|
'login_channel': user.login_channel,
|
||||||
|
'status': user.status,
|
||||||
|
'is_superuser': user.is_superuser,
|
||||||
|
'create_date': user.create_date,
|
||||||
|
'update_date': user.update_date
|
||||||
|
})
|
||||||
|
return result
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_user(username, password, role="user"):
|
def create_user(username, password, role="user") -> dict:
|
||||||
raise AdminException("create_user: not implemented")
|
# Validate the email address
|
||||||
|
if not re.match(r"^[\w\._-]+@([\w_-]+\.)+[\w-]{2,}$", username):
|
||||||
|
raise AdminException(f"Invalid email address: {username}!")
|
||||||
|
# Check if the email address is already used
|
||||||
|
if UserService.query(email=username):
|
||||||
|
raise UserAlreadyExistsError(username)
|
||||||
|
# Construct user info data
|
||||||
|
user_info_dict = {
|
||||||
|
"email": username,
|
||||||
|
"nickname": "", # ask user to edit it manually in settings.
|
||||||
|
"password": decrypt(password),
|
||||||
|
"login_channel": "password",
|
||||||
|
"is_superuser": role == "admin",
|
||||||
|
}
|
||||||
|
return create_new_user(user_info_dict)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def delete_user(username):
|
def delete_user(username):
|
||||||
|
# use email to delete
|
||||||
raise AdminException("delete_user: not implemented")
|
raise AdminException("delete_user: not implemented")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def update_user_password(username, new_password):
|
def update_user_password(username, new_password) -> str:
|
||||||
raise AdminException("update_user_password: not implemented")
|
# use email to find user. check exist and unique.
|
||||||
|
user_list = UserService.query_user_by_email(username)
|
||||||
|
if not user_list:
|
||||||
|
raise UserNotFoundError(username)
|
||||||
|
elif len(user_list) > 1:
|
||||||
|
raise AdminException(f"Exist more than 1 user: {username}!")
|
||||||
|
# check new_password different from old.
|
||||||
|
usr = user_list[0]
|
||||||
|
psw = decrypt(new_password)
|
||||||
|
if check_password_hash(usr.password, psw):
|
||||||
|
return "Same password, no need to update!"
|
||||||
|
# update password
|
||||||
|
UserService.update_user_password(usr.id, psw)
|
||||||
|
return "Password updated successfully!"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def update_user_activate_status(username, activate_status: str):
|
||||||
|
# use email to find user. check exist and unique.
|
||||||
|
user_list = UserService.query_user_by_email(username)
|
||||||
|
if not user_list:
|
||||||
|
raise UserNotFoundError(username)
|
||||||
|
elif len(user_list) > 1:
|
||||||
|
raise AdminException(f"Exist more than 1 user: {username}!")
|
||||||
|
# check activate status different from new
|
||||||
|
usr = user_list[0]
|
||||||
|
# format activate_status before handle
|
||||||
|
_activate_status = activate_status.lower()
|
||||||
|
target_status = {
|
||||||
|
'on': ActiveEnum.ACTIVE.value,
|
||||||
|
'off': ActiveEnum.INACTIVE.value,
|
||||||
|
}.get(_activate_status)
|
||||||
|
if not target_status:
|
||||||
|
raise AdminException(f"Invalid activate_status: {activate_status}")
|
||||||
|
if target_status == usr.is_active:
|
||||||
|
return f"User activate status is already {_activate_status}!"
|
||||||
|
# update is_active
|
||||||
|
UserService.update_user(usr.id, {"is_active": target_status})
|
||||||
|
return f"Turn {_activate_status} user activate status successfully!"
|
||||||
|
|
||||||
|
class UserServiceMgr:
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_user_datasets(username):
|
||||||
|
# use email to find user.
|
||||||
|
user_list = UserService.query_user_by_email(username)
|
||||||
|
if not user_list:
|
||||||
|
raise UserNotFoundError(username)
|
||||||
|
elif len(user_list) > 1:
|
||||||
|
raise AdminException(f"Exist more than 1 user: {username}!")
|
||||||
|
# find tenants
|
||||||
|
usr = user_list[0]
|
||||||
|
tenants = TenantService.get_joined_tenants_by_user_id(usr.id)
|
||||||
|
tenant_ids = [m["tenant_id"] for m in tenants]
|
||||||
|
# filter permitted kb and owned kb
|
||||||
|
return KnowledgebaseService.get_all_kb_by_tenant_ids(tenant_ids, usr.id)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_user_agents(username):
|
||||||
|
# use email to find user.
|
||||||
|
user_list = UserService.query_user_by_email(username)
|
||||||
|
if not user_list:
|
||||||
|
raise UserNotFoundError(username)
|
||||||
|
elif len(user_list) > 1:
|
||||||
|
raise AdminException(f"Exist more than 1 user: {username}!")
|
||||||
|
# find tenants
|
||||||
|
usr = user_list[0]
|
||||||
|
tenants = TenantService.get_joined_tenants_by_user_id(usr.id)
|
||||||
|
tenant_ids = [m["tenant_id"] for m in tenants]
|
||||||
|
# filter permitted agents and owned agents
|
||||||
|
return UserCanvasService.get_all_agents_by_tenant_ids(tenant_ids, usr.id)
|
||||||
|
|
||||||
class ServiceMgr:
|
class ServiceMgr:
|
||||||
|
|
||||||
|
|||||||
@ -80,7 +80,7 @@ Here's description of each category:
|
|||||||
- Prioritize the most specific applicable category
|
- Prioritize the most specific applicable category
|
||||||
- Return only the category name without explanations
|
- Return only the category name without explanations
|
||||||
- Use "Other" only when no other category fits
|
- Use "Other" only when no other category fits
|
||||||
|
|
||||||
""".format(
|
""".format(
|
||||||
"\n - ".join(list(self.category_description.keys())),
|
"\n - ".join(list(self.category_description.keys())),
|
||||||
"\n".join(descriptions)
|
"\n".join(descriptions)
|
||||||
@ -96,7 +96,7 @@ Here's description of each category:
|
|||||||
class Categorize(LLM, ABC):
|
class Categorize(LLM, ABC):
|
||||||
component_name = "Categorize"
|
component_name = "Categorize"
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
msg = self._canvas.get_history(self._param.message_history_window_size)
|
msg = self._canvas.get_history(self._param.message_history_window_size)
|
||||||
if not msg:
|
if not msg:
|
||||||
@ -112,7 +112,7 @@ class Categorize(LLM, ABC):
|
|||||||
|
|
||||||
user_prompt = """
|
user_prompt = """
|
||||||
---- Real Data ----
|
---- Real Data ----
|
||||||
{} →
|
{} →
|
||||||
""".format(" | ".join(["{}: \"{}\"".format(c["role"].upper(), re.sub(r"\n", "", c["content"], flags=re.DOTALL)) for c in msg]))
|
""".format(" | ".join(["{}: \"{}\"".format(c["role"].upper(), re.sub(r"\n", "", c["content"], flags=re.DOTALL)) for c in msg]))
|
||||||
ans = chat_mdl.chat(self._param.sys_prompt, [{"role": "user", "content": user_prompt}], self._param.gen_conf())
|
ans = chat_mdl.chat(self._param.sys_prompt, [{"role": "user", "content": user_prompt}], self._param.gen_conf())
|
||||||
logging.info(f"input: {user_prompt}, answer: {str(ans)}")
|
logging.info(f"input: {user_prompt}, answer: {str(ans)}")
|
||||||
@ -134,4 +134,4 @@ class Categorize(LLM, ABC):
|
|||||||
self.set_output("_next", cpn_ids)
|
self.set_output("_next", cpn_ids)
|
||||||
|
|
||||||
def thoughts(self) -> str:
|
def thoughts(self) -> str:
|
||||||
return "Which should it falls into {}? ...".format(",".join([f"`{c}`" for c, _ in self._param.category_description.items()]))
|
return "Which should it falls into {}? ...".format(",".join([f"`{c}`" for c, _ in self._param.category_description.items()]))
|
||||||
|
|||||||
@ -56,7 +56,7 @@ class StringTransform(Message, ABC):
|
|||||||
"type": "line"
|
"type": "line"
|
||||||
} for k, o in self.get_input_elements_from_text(self._param.script).items()}
|
} for k, o in self.get_input_elements_from_text(self._param.script).items()}
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
if self._param.method == "split":
|
if self._param.method == "split":
|
||||||
self._split(kwargs.get("line"))
|
self._split(kwargs.get("line"))
|
||||||
|
|||||||
@ -61,7 +61,7 @@ class ArXivParam(ToolParamBase):
|
|||||||
class ArXiv(ToolBase, ABC):
|
class ArXiv(ToolBase, ABC):
|
||||||
component_name = "ArXiv"
|
component_name = "ArXiv"
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
if not kwargs.get("query"):
|
if not kwargs.get("query"):
|
||||||
self.set_output("formalized_content", "")
|
self.set_output("formalized_content", "")
|
||||||
@ -97,6 +97,6 @@ class ArXiv(ToolBase, ABC):
|
|||||||
|
|
||||||
def thoughts(self) -> str:
|
def thoughts(self) -> str:
|
||||||
return """
|
return """
|
||||||
Keywords: {}
|
Keywords: {}
|
||||||
Looking for the most relevant articles.
|
Looking for the most relevant articles.
|
||||||
""".format(self.get_input().get("query", "-_-!"))
|
""".format(self.get_input().get("query", "-_-!"))
|
||||||
|
|||||||
@ -129,7 +129,7 @@ module.exports = { main };
|
|||||||
class CodeExec(ToolBase, ABC):
|
class CodeExec(ToolBase, ABC):
|
||||||
component_name = "CodeExec"
|
component_name = "CodeExec"
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
lang = kwargs.get("lang", self._param.lang)
|
lang = kwargs.get("lang", self._param.lang)
|
||||||
script = kwargs.get("script", self._param.script)
|
script = kwargs.get("script", self._param.script)
|
||||||
|
|||||||
@ -73,7 +73,7 @@ class DuckDuckGoParam(ToolParamBase):
|
|||||||
class DuckDuckGo(ToolBase, ABC):
|
class DuckDuckGo(ToolBase, ABC):
|
||||||
component_name = "DuckDuckGo"
|
component_name = "DuckDuckGo"
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
if not kwargs.get("query"):
|
if not kwargs.get("query"):
|
||||||
self.set_output("formalized_content", "")
|
self.set_output("formalized_content", "")
|
||||||
@ -115,6 +115,6 @@ class DuckDuckGo(ToolBase, ABC):
|
|||||||
|
|
||||||
def thoughts(self) -> str:
|
def thoughts(self) -> str:
|
||||||
return """
|
return """
|
||||||
Keywords: {}
|
Keywords: {}
|
||||||
Looking for the most relevant articles.
|
Looking for the most relevant articles.
|
||||||
""".format(self.get_input().get("query", "-_-!"))
|
""".format(self.get_input().get("query", "-_-!"))
|
||||||
|
|||||||
@ -98,8 +98,8 @@ class EmailParam(ToolParamBase):
|
|||||||
|
|
||||||
class Email(ToolBase, ABC):
|
class Email(ToolBase, ABC):
|
||||||
component_name = "Email"
|
component_name = "Email"
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
if not kwargs.get("to_email"):
|
if not kwargs.get("to_email"):
|
||||||
self.set_output("success", False)
|
self.set_output("success", False)
|
||||||
@ -212,4 +212,4 @@ class Email(ToolBase, ABC):
|
|||||||
To: {}
|
To: {}
|
||||||
Subject: {}
|
Subject: {}
|
||||||
Your email is on its way—sit tight!
|
Your email is on its way—sit tight!
|
||||||
""".format(inputs.get("to_email", "-_-!"), inputs.get("subject", "-_-!"))
|
""".format(inputs.get("to_email", "-_-!"), inputs.get("subject", "-_-!"))
|
||||||
|
|||||||
@ -78,7 +78,7 @@ class ExeSQLParam(ToolParamBase):
|
|||||||
class ExeSQL(ToolBase, ABC):
|
class ExeSQL(ToolBase, ABC):
|
||||||
component_name = "ExeSQL"
|
component_name = "ExeSQL"
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
|
|
||||||
def convert_decimals(obj):
|
def convert_decimals(obj):
|
||||||
|
|||||||
@ -57,7 +57,7 @@ class GitHubParam(ToolParamBase):
|
|||||||
class GitHub(ToolBase, ABC):
|
class GitHub(ToolBase, ABC):
|
||||||
component_name = "GitHub"
|
component_name = "GitHub"
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
if not kwargs.get("query"):
|
if not kwargs.get("query"):
|
||||||
self.set_output("formalized_content", "")
|
self.set_output("formalized_content", "")
|
||||||
@ -88,4 +88,4 @@ class GitHub(ToolBase, ABC):
|
|||||||
assert False, self.output()
|
assert False, self.output()
|
||||||
|
|
||||||
def thoughts(self) -> str:
|
def thoughts(self) -> str:
|
||||||
return "Scanning GitHub repos related to `{}`.".format(self.get_input().get("query", "-_-!"))
|
return "Scanning GitHub repos related to `{}`.".format(self.get_input().get("query", "-_-!"))
|
||||||
|
|||||||
@ -116,7 +116,7 @@ class GoogleParam(ToolParamBase):
|
|||||||
class Google(ToolBase, ABC):
|
class Google(ToolBase, ABC):
|
||||||
component_name = "Google"
|
component_name = "Google"
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
if not kwargs.get("q"):
|
if not kwargs.get("q"):
|
||||||
self.set_output("formalized_content", "")
|
self.set_output("formalized_content", "")
|
||||||
@ -154,6 +154,6 @@ class Google(ToolBase, ABC):
|
|||||||
|
|
||||||
def thoughts(self) -> str:
|
def thoughts(self) -> str:
|
||||||
return """
|
return """
|
||||||
Keywords: {}
|
Keywords: {}
|
||||||
Looking for the most relevant articles.
|
Looking for the most relevant articles.
|
||||||
""".format(self.get_input().get("query", "-_-!"))
|
""".format(self.get_input().get("query", "-_-!"))
|
||||||
|
|||||||
@ -63,7 +63,7 @@ class GoogleScholarParam(ToolParamBase):
|
|||||||
class GoogleScholar(ToolBase, ABC):
|
class GoogleScholar(ToolBase, ABC):
|
||||||
component_name = "GoogleScholar"
|
component_name = "GoogleScholar"
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
if not kwargs.get("query"):
|
if not kwargs.get("query"):
|
||||||
self.set_output("formalized_content", "")
|
self.set_output("formalized_content", "")
|
||||||
@ -93,4 +93,4 @@ class GoogleScholar(ToolBase, ABC):
|
|||||||
assert False, self.output()
|
assert False, self.output()
|
||||||
|
|
||||||
def thoughts(self) -> str:
|
def thoughts(self) -> str:
|
||||||
return "Looking for scholarly papers on `{}`,” prioritising reputable sources.".format(self.get_input().get("query", "-_-!"))
|
return "Looking for scholarly papers on `{}`,” prioritising reputable sources.".format(self.get_input().get("query", "-_-!"))
|
||||||
|
|||||||
@ -33,7 +33,7 @@ class PubMedParam(ToolParamBase):
|
|||||||
self.meta:ToolMeta = {
|
self.meta:ToolMeta = {
|
||||||
"name": "pubmed_search",
|
"name": "pubmed_search",
|
||||||
"description": """
|
"description": """
|
||||||
PubMed is an openly accessible, free database which includes primarily the MEDLINE database of references and abstracts on life sciences and biomedical topics.
|
PubMed is an openly accessible, free database which includes primarily the MEDLINE database of references and abstracts on life sciences and biomedical topics.
|
||||||
In addition to MEDLINE, PubMed provides access to:
|
In addition to MEDLINE, PubMed provides access to:
|
||||||
- older references from the print version of Index Medicus, back to 1951 and earlier
|
- older references from the print version of Index Medicus, back to 1951 and earlier
|
||||||
- references to some journals before they were indexed in Index Medicus and MEDLINE, for instance Science, BMJ, and Annals of Surgery
|
- references to some journals before they were indexed in Index Medicus and MEDLINE, for instance Science, BMJ, and Annals of Surgery
|
||||||
@ -69,7 +69,7 @@ In addition to MEDLINE, PubMed provides access to:
|
|||||||
class PubMed(ToolBase, ABC):
|
class PubMed(ToolBase, ABC):
|
||||||
component_name = "PubMed"
|
component_name = "PubMed"
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
if not kwargs.get("query"):
|
if not kwargs.get("query"):
|
||||||
self.set_output("formalized_content", "")
|
self.set_output("formalized_content", "")
|
||||||
@ -105,4 +105,4 @@ class PubMed(ToolBase, ABC):
|
|||||||
assert False, self.output()
|
assert False, self.output()
|
||||||
|
|
||||||
def thoughts(self) -> str:
|
def thoughts(self) -> str:
|
||||||
return "Looking for scholarly papers on `{}`,” prioritising reputable sources.".format(self.get_input().get("query", "-_-!"))
|
return "Looking for scholarly papers on `{}`,” prioritising reputable sources.".format(self.get_input().get("query", "-_-!"))
|
||||||
|
|||||||
@ -74,7 +74,7 @@ class RetrievalParam(ToolParamBase):
|
|||||||
class Retrieval(ToolBase, ABC):
|
class Retrieval(ToolBase, ABC):
|
||||||
component_name = "Retrieval"
|
component_name = "Retrieval"
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
if not kwargs.get("query"):
|
if not kwargs.get("query"):
|
||||||
self.set_output("formalized_content", self._param.empty_response)
|
self.set_output("formalized_content", self._param.empty_response)
|
||||||
@ -164,18 +164,18 @@ class Retrieval(ToolBase, ABC):
|
|||||||
|
|
||||||
# Format the chunks for JSON output (similar to how other tools do it)
|
# Format the chunks for JSON output (similar to how other tools do it)
|
||||||
json_output = kbinfos["chunks"].copy()
|
json_output = kbinfos["chunks"].copy()
|
||||||
|
|
||||||
self._canvas.add_reference(kbinfos["chunks"], kbinfos["doc_aggs"])
|
self._canvas.add_reference(kbinfos["chunks"], kbinfos["doc_aggs"])
|
||||||
form_cnt = "\n".join(kb_prompt(kbinfos, 200000, True))
|
form_cnt = "\n".join(kb_prompt(kbinfos, 200000, True))
|
||||||
|
|
||||||
# Set both formalized content and JSON output
|
# Set both formalized content and JSON output
|
||||||
self.set_output("formalized_content", form_cnt)
|
self.set_output("formalized_content", form_cnt)
|
||||||
self.set_output("json", json_output)
|
self.set_output("json", json_output)
|
||||||
|
|
||||||
return form_cnt
|
return form_cnt
|
||||||
|
|
||||||
def thoughts(self) -> str:
|
def thoughts(self) -> str:
|
||||||
return """
|
return """
|
||||||
Keywords: {}
|
Keywords: {}
|
||||||
Looking for the most relevant articles.
|
Looking for the most relevant articles.
|
||||||
""".format(self.get_input().get("query", "-_-!"))
|
""".format(self.get_input().get("query", "-_-!"))
|
||||||
|
|||||||
@ -77,7 +77,7 @@ class SearXNGParam(ToolParamBase):
|
|||||||
class SearXNG(ToolBase, ABC):
|
class SearXNG(ToolBase, ABC):
|
||||||
component_name = "SearXNG"
|
component_name = "SearXNG"
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
# Gracefully handle try-run without inputs
|
# Gracefully handle try-run without inputs
|
||||||
query = kwargs.get("query")
|
query = kwargs.get("query")
|
||||||
@ -94,7 +94,6 @@ class SearXNG(ToolBase, ABC):
|
|||||||
last_e = ""
|
last_e = ""
|
||||||
for _ in range(self._param.max_retries+1):
|
for _ in range(self._param.max_retries+1):
|
||||||
try:
|
try:
|
||||||
# 构建搜索参数
|
|
||||||
search_params = {
|
search_params = {
|
||||||
'q': query,
|
'q': query,
|
||||||
'format': 'json',
|
'format': 'json',
|
||||||
@ -104,33 +103,29 @@ class SearXNG(ToolBase, ABC):
|
|||||||
'pageno': 1
|
'pageno': 1
|
||||||
}
|
}
|
||||||
|
|
||||||
# 发送搜索请求
|
|
||||||
response = requests.get(
|
response = requests.get(
|
||||||
f"{searxng_url}/search",
|
f"{searxng_url}/search",
|
||||||
params=search_params,
|
params=search_params,
|
||||||
timeout=10
|
timeout=10
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
data = response.json()
|
data = response.json()
|
||||||
|
|
||||||
# 验证响应数据
|
|
||||||
if not data or not isinstance(data, dict):
|
if not data or not isinstance(data, dict):
|
||||||
raise ValueError("Invalid response from SearXNG")
|
raise ValueError("Invalid response from SearXNG")
|
||||||
|
|
||||||
results = data.get("results", [])
|
results = data.get("results", [])
|
||||||
if not isinstance(results, list):
|
if not isinstance(results, list):
|
||||||
raise ValueError("Invalid results format from SearXNG")
|
raise ValueError("Invalid results format from SearXNG")
|
||||||
|
|
||||||
# 限制结果数量
|
|
||||||
results = results[:self._param.top_n]
|
results = results[:self._param.top_n]
|
||||||
|
|
||||||
# 处理搜索结果
|
|
||||||
self._retrieve_chunks(results,
|
self._retrieve_chunks(results,
|
||||||
get_title=lambda r: r.get("title", ""),
|
get_title=lambda r: r.get("title", ""),
|
||||||
get_url=lambda r: r.get("url", ""),
|
get_url=lambda r: r.get("url", ""),
|
||||||
get_content=lambda r: r.get("content", ""))
|
get_content=lambda r: r.get("content", ""))
|
||||||
|
|
||||||
self.set_output("json", results)
|
self.set_output("json", results)
|
||||||
return self.output("formalized_content")
|
return self.output("formalized_content")
|
||||||
|
|
||||||
@ -151,6 +146,6 @@ class SearXNG(ToolBase, ABC):
|
|||||||
|
|
||||||
def thoughts(self) -> str:
|
def thoughts(self) -> str:
|
||||||
return """
|
return """
|
||||||
Keywords: {}
|
Keywords: {}
|
||||||
Searching with SearXNG for relevant results...
|
Searching with SearXNG for relevant results...
|
||||||
""".format(self.get_input().get("query", "-_-!"))
|
""".format(self.get_input().get("query", "-_-!"))
|
||||||
|
|||||||
@ -31,7 +31,7 @@ class TavilySearchParam(ToolParamBase):
|
|||||||
self.meta:ToolMeta = {
|
self.meta:ToolMeta = {
|
||||||
"name": "tavily_search",
|
"name": "tavily_search",
|
||||||
"description": """
|
"description": """
|
||||||
Tavily is a search engine optimized for LLMs, aimed at efficient, quick and persistent search results.
|
Tavily is a search engine optimized for LLMs, aimed at efficient, quick and persistent search results.
|
||||||
When searching:
|
When searching:
|
||||||
- Start with specific query which should focus on just a single aspect.
|
- Start with specific query which should focus on just a single aspect.
|
||||||
- Number of keywords in query should be less than 5.
|
- Number of keywords in query should be less than 5.
|
||||||
@ -101,7 +101,7 @@ When searching:
|
|||||||
class TavilySearch(ToolBase, ABC):
|
class TavilySearch(ToolBase, ABC):
|
||||||
component_name = "TavilySearch"
|
component_name = "TavilySearch"
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
if not kwargs.get("query"):
|
if not kwargs.get("query"):
|
||||||
self.set_output("formalized_content", "")
|
self.set_output("formalized_content", "")
|
||||||
@ -136,7 +136,7 @@ class TavilySearch(ToolBase, ABC):
|
|||||||
|
|
||||||
def thoughts(self) -> str:
|
def thoughts(self) -> str:
|
||||||
return """
|
return """
|
||||||
Keywords: {}
|
Keywords: {}
|
||||||
Looking for the most relevant articles.
|
Looking for the most relevant articles.
|
||||||
""".format(self.get_input().get("query", "-_-!"))
|
""".format(self.get_input().get("query", "-_-!"))
|
||||||
|
|
||||||
@ -199,7 +199,7 @@ class TavilyExtractParam(ToolParamBase):
|
|||||||
class TavilyExtract(ToolBase, ABC):
|
class TavilyExtract(ToolBase, ABC):
|
||||||
component_name = "TavilyExtract"
|
component_name = "TavilyExtract"
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
self.tavily_client = TavilyClient(api_key=self._param.api_key)
|
self.tavily_client = TavilyClient(api_key=self._param.api_key)
|
||||||
last_e = None
|
last_e = None
|
||||||
@ -224,4 +224,4 @@ class TavilyExtract(ToolBase, ABC):
|
|||||||
assert False, self.output()
|
assert False, self.output()
|
||||||
|
|
||||||
def thoughts(self) -> str:
|
def thoughts(self) -> str:
|
||||||
return "Opened {}—pulling out the main text…".format(self.get_input().get("urls", "-_-!"))
|
return "Opened {}—pulling out the main text…".format(self.get_input().get("urls", "-_-!"))
|
||||||
|
|||||||
@ -68,7 +68,7 @@ fund selection platform: through AI technology, is committed to providing excell
|
|||||||
class WenCai(ToolBase, ABC):
|
class WenCai(ToolBase, ABC):
|
||||||
component_name = "WenCai"
|
component_name = "WenCai"
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
if not kwargs.get("query"):
|
if not kwargs.get("query"):
|
||||||
self.set_output("report", "")
|
self.set_output("report", "")
|
||||||
@ -111,4 +111,4 @@ class WenCai(ToolBase, ABC):
|
|||||||
assert False, self.output()
|
assert False, self.output()
|
||||||
|
|
||||||
def thoughts(self) -> str:
|
def thoughts(self) -> str:
|
||||||
return "Pulling live financial data for `{}`.".format(self.get_input().get("query", "-_-!"))
|
return "Pulling live financial data for `{}`.".format(self.get_input().get("query", "-_-!"))
|
||||||
|
|||||||
@ -64,7 +64,7 @@ class WikipediaParam(ToolParamBase):
|
|||||||
class Wikipedia(ToolBase, ABC):
|
class Wikipedia(ToolBase, ABC):
|
||||||
component_name = "Wikipedia"
|
component_name = "Wikipedia"
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
if not kwargs.get("query"):
|
if not kwargs.get("query"):
|
||||||
self.set_output("formalized_content", "")
|
self.set_output("formalized_content", "")
|
||||||
@ -99,6 +99,6 @@ class Wikipedia(ToolBase, ABC):
|
|||||||
|
|
||||||
def thoughts(self) -> str:
|
def thoughts(self) -> str:
|
||||||
return """
|
return """
|
||||||
Keywords: {}
|
Keywords: {}
|
||||||
Looking for the most relevant articles.
|
Looking for the most relevant articles.
|
||||||
""".format(self.get_input().get("query", "-_-!"))
|
""".format(self.get_input().get("query", "-_-!"))
|
||||||
|
|||||||
@ -72,7 +72,7 @@ class YahooFinanceParam(ToolParamBase):
|
|||||||
class YahooFinance(ToolBase, ABC):
|
class YahooFinance(ToolBase, ABC):
|
||||||
component_name = "YahooFinance"
|
component_name = "YahooFinance"
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
if not kwargs.get("stock_code"):
|
if not kwargs.get("stock_code"):
|
||||||
self.set_output("report", "")
|
self.set_output("report", "")
|
||||||
@ -111,4 +111,4 @@ class YahooFinance(ToolBase, ABC):
|
|||||||
assert False, self.output()
|
assert False, self.output()
|
||||||
|
|
||||||
def thoughts(self) -> str:
|
def thoughts(self) -> str:
|
||||||
return "Pulling live financial data for `{}`.".format(self.get_input().get("stock_code", "-_-!"))
|
return "Pulling live financial data for `{}`.".format(self.get_input().get("stock_code", "-_-!"))
|
||||||
|
|||||||
@ -27,7 +27,8 @@ from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
|||||||
from api.db import StatusEnum
|
from api.db import StatusEnum
|
||||||
from api.db.db_models import close_connection
|
from api.db.db_models import close_connection
|
||||||
from api.db.services import UserService
|
from api.db.services import UserService
|
||||||
from api.utils import CustomJSONEncoder, commands
|
from api.utils.json import CustomJSONEncoder
|
||||||
|
from api.utils import commands
|
||||||
|
|
||||||
from flask_mail import Mail
|
from flask_mail import Mail
|
||||||
from flask_session import Session
|
from flask_session import Session
|
||||||
|
|||||||
@ -23,7 +23,7 @@ from api.db.services.document_service import DocumentService
|
|||||||
from api.db.services.file2document_service import File2DocumentService
|
from api.db.services.file2document_service import File2DocumentService
|
||||||
from api.db.services.file_service import FileService
|
from api.db.services.file_service import FileService
|
||||||
from api.db.services.user_service import TenantService, UserTenantService
|
from api.db.services.user_service import TenantService, UserTenantService
|
||||||
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request, not_allowed_parameters
|
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request, not_allowed_parameters, active_required
|
||||||
from api.utils import get_uuid
|
from api.utils import get_uuid
|
||||||
from api.db import StatusEnum, FileSource
|
from api.db import StatusEnum, FileSource
|
||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
@ -38,6 +38,7 @@ from rag.utils.storage_factory import STORAGE_IMPL
|
|||||||
|
|
||||||
@manager.route('/create', methods=['post']) # noqa: F821
|
@manager.route('/create', methods=['post']) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
|
@active_required
|
||||||
@validate_request("name")
|
@validate_request("name")
|
||||||
def create():
|
def create():
|
||||||
req = request.json
|
req = request.json
|
||||||
|
|||||||
@ -23,6 +23,11 @@ class StatusEnum(Enum):
|
|||||||
INVALID = "0"
|
INVALID = "0"
|
||||||
|
|
||||||
|
|
||||||
|
class ActiveEnum(Enum):
|
||||||
|
ACTIVE = "1"
|
||||||
|
INACTIVE = "0"
|
||||||
|
|
||||||
|
|
||||||
class UserTenantRole(StrEnum):
|
class UserTenantRole(StrEnum):
|
||||||
OWNER = 'owner'
|
OWNER = 'owner'
|
||||||
ADMIN = 'admin'
|
ADMIN = 'admin'
|
||||||
|
|||||||
@ -26,12 +26,14 @@ from functools import wraps
|
|||||||
|
|
||||||
from flask_login import UserMixin
|
from flask_login import UserMixin
|
||||||
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
||||||
from peewee import BigIntegerField, BooleanField, CharField, CompositeKey, DateTimeField, Field, FloatField, IntegerField, Metadata, Model, TextField
|
from peewee import InterfaceError, OperationalError, BigIntegerField, BooleanField, CharField, CompositeKey, DateTimeField, Field, FloatField, IntegerField, Metadata, Model, TextField
|
||||||
from playhouse.migrate import MySQLMigrator, PostgresqlMigrator, migrate
|
from playhouse.migrate import MySQLMigrator, PostgresqlMigrator, migrate
|
||||||
from playhouse.pool import PooledMySQLDatabase, PooledPostgresqlDatabase
|
from playhouse.pool import PooledMySQLDatabase, PooledPostgresqlDatabase
|
||||||
|
|
||||||
from api import settings, utils
|
from api import settings, utils
|
||||||
from api.db import ParserType, SerializedType
|
from api.db import ParserType, SerializedType
|
||||||
|
from api.utils.json import json_dumps, json_loads
|
||||||
|
from api.utils.configs import deserialize_b64, serialize_b64
|
||||||
|
|
||||||
|
|
||||||
def singleton(cls, *args, **kw):
|
def singleton(cls, *args, **kw):
|
||||||
@ -70,12 +72,12 @@ class JSONField(LongTextField):
|
|||||||
def db_value(self, value):
|
def db_value(self, value):
|
||||||
if value is None:
|
if value is None:
|
||||||
value = self.default_value
|
value = self.default_value
|
||||||
return utils.json_dumps(value)
|
return json_dumps(value)
|
||||||
|
|
||||||
def python_value(self, value):
|
def python_value(self, value):
|
||||||
if not value:
|
if not value:
|
||||||
return self.default_value
|
return self.default_value
|
||||||
return utils.json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
|
return json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
|
||||||
|
|
||||||
|
|
||||||
class ListField(JSONField):
|
class ListField(JSONField):
|
||||||
@ -91,21 +93,21 @@ class SerializedField(LongTextField):
|
|||||||
|
|
||||||
def db_value(self, value):
|
def db_value(self, value):
|
||||||
if self._serialized_type == SerializedType.PICKLE:
|
if self._serialized_type == SerializedType.PICKLE:
|
||||||
return utils.serialize_b64(value, to_str=True)
|
return serialize_b64(value, to_str=True)
|
||||||
elif self._serialized_type == SerializedType.JSON:
|
elif self._serialized_type == SerializedType.JSON:
|
||||||
if value is None:
|
if value is None:
|
||||||
return None
|
return None
|
||||||
return utils.json_dumps(value, with_type=True)
|
return json_dumps(value, with_type=True)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"the serialized type {self._serialized_type} is not supported")
|
raise ValueError(f"the serialized type {self._serialized_type} is not supported")
|
||||||
|
|
||||||
def python_value(self, value):
|
def python_value(self, value):
|
||||||
if self._serialized_type == SerializedType.PICKLE:
|
if self._serialized_type == SerializedType.PICKLE:
|
||||||
return utils.deserialize_b64(value)
|
return deserialize_b64(value)
|
||||||
elif self._serialized_type == SerializedType.JSON:
|
elif self._serialized_type == SerializedType.JSON:
|
||||||
if value is None:
|
if value is None:
|
||||||
return {}
|
return {}
|
||||||
return utils.json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
|
return json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"the serialized type {self._serialized_type} is not supported")
|
raise ValueError(f"the serialized type {self._serialized_type} is not supported")
|
||||||
|
|
||||||
@ -250,36 +252,63 @@ class RetryingPooledMySQLDatabase(PooledMySQLDatabase):
|
|||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
def execute_sql(self, sql, params=None, commit=True):
|
def execute_sql(self, sql, params=None, commit=True):
|
||||||
from peewee import OperationalError
|
|
||||||
|
|
||||||
for attempt in range(self.max_retries + 1):
|
for attempt in range(self.max_retries + 1):
|
||||||
try:
|
try:
|
||||||
return super().execute_sql(sql, params, commit)
|
return super().execute_sql(sql, params, commit)
|
||||||
except OperationalError as e:
|
except (OperationalError, InterfaceError) as e:
|
||||||
if e.args[0] in (2013, 2006) and attempt < self.max_retries:
|
error_codes = [2013, 2006]
|
||||||
logging.warning(f"Lost connection (attempt {attempt + 1}/{self.max_retries}): {e}")
|
error_messages = ['', 'Lost connection']
|
||||||
|
should_retry = (
|
||||||
|
(hasattr(e, 'args') and e.args and e.args[0] in error_codes) or
|
||||||
|
(str(e) in error_messages) or
|
||||||
|
(hasattr(e, '__class__') and e.__class__.__name__ == 'InterfaceError')
|
||||||
|
)
|
||||||
|
|
||||||
|
if should_retry and attempt < self.max_retries:
|
||||||
|
logging.warning(
|
||||||
|
f"Database connection issue (attempt {attempt+1}/{self.max_retries}): {e}"
|
||||||
|
)
|
||||||
self._handle_connection_loss()
|
self._handle_connection_loss()
|
||||||
time.sleep(self.retry_delay * (2**attempt))
|
time.sleep(self.retry_delay * (2 ** attempt))
|
||||||
else:
|
else:
|
||||||
logging.error(f"DB execution failure: {e}")
|
logging.error(f"DB execution failure: {e}")
|
||||||
raise
|
raise
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _handle_connection_loss(self):
|
def _handle_connection_loss(self):
|
||||||
self.close_all()
|
# self.close_all()
|
||||||
self.connect()
|
# self.connect()
|
||||||
|
try:
|
||||||
|
self.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
self.connect()
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Failed to reconnect: {e}")
|
||||||
|
time.sleep(0.1)
|
||||||
|
self.connect()
|
||||||
|
|
||||||
def begin(self):
|
def begin(self):
|
||||||
from peewee import OperationalError
|
|
||||||
|
|
||||||
for attempt in range(self.max_retries + 1):
|
for attempt in range(self.max_retries + 1):
|
||||||
try:
|
try:
|
||||||
return super().begin()
|
return super().begin()
|
||||||
except OperationalError as e:
|
except (OperationalError, InterfaceError) as e:
|
||||||
if e.args[0] in (2013, 2006) and attempt < self.max_retries:
|
error_codes = [2013, 2006]
|
||||||
logging.warning(f"Lost connection during transaction (attempt {attempt + 1}/{self.max_retries})")
|
error_messages = ['', 'Lost connection']
|
||||||
|
|
||||||
|
should_retry = (
|
||||||
|
(hasattr(e, 'args') and e.args and e.args[0] in error_codes) or
|
||||||
|
(str(e) in error_messages) or
|
||||||
|
(hasattr(e, '__class__') and e.__class__.__name__ == 'InterfaceError')
|
||||||
|
)
|
||||||
|
|
||||||
|
if should_retry and attempt < self.max_retries:
|
||||||
|
logging.warning(
|
||||||
|
f"Lost connection during transaction (attempt {attempt+1}/{self.max_retries})"
|
||||||
|
)
|
||||||
self._handle_connection_loss()
|
self._handle_connection_loss()
|
||||||
time.sleep(self.retry_delay * (2**attempt))
|
time.sleep(self.retry_delay * (2 ** attempt))
|
||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@ -299,7 +328,16 @@ class BaseDataBase:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
database_config = settings.DATABASE.copy()
|
database_config = settings.DATABASE.copy()
|
||||||
db_name = database_config.pop("name")
|
db_name = database_config.pop("name")
|
||||||
self.database_connection = PooledDatabase[settings.DATABASE_TYPE.upper()].value(db_name, **database_config)
|
|
||||||
|
pool_config = {
|
||||||
|
'max_retries': 5,
|
||||||
|
'retry_delay': 1,
|
||||||
|
}
|
||||||
|
database_config.update(pool_config)
|
||||||
|
self.database_connection = PooledDatabase[settings.DATABASE_TYPE.upper()].value(
|
||||||
|
db_name, **database_config
|
||||||
|
)
|
||||||
|
# self.database_connection = PooledDatabase[settings.DATABASE_TYPE.upper()].value(db_name, **database_config)
|
||||||
logging.info("init database on cluster mode successfully")
|
logging.info("init database on cluster mode successfully")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
0
api/db/joint_services/__init__.py
Normal file
0
api/db/joint_services/__init__.py
Normal file
120
api/db/joint_services/user_account_service.py
Normal file
120
api/db/joint_services/user_account_service.py
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from api import settings
|
||||||
|
from api.db import FileType, UserTenantRole
|
||||||
|
from api.db.db_models import TenantLLM
|
||||||
|
from api.db.services.llm_service import get_init_tenant_llm
|
||||||
|
from api.db.services.file_service import FileService
|
||||||
|
from api.db.services.tenant_llm_service import TenantLLMService
|
||||||
|
from api.db.services.user_service import TenantService, UserService, UserTenantService
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def create_new_user(user_info: dict) -> dict:
|
||||||
|
"""
|
||||||
|
Add a new user, and create tenant, tenant llm, file folder for new user.
|
||||||
|
:param user_info: {
|
||||||
|
"email": <example@example.com>,
|
||||||
|
"nickname": <str, "name">,
|
||||||
|
"password": <decrypted password>,
|
||||||
|
"login_channel": <enum, "password">,
|
||||||
|
"is_superuser": <bool, role == "admin">,
|
||||||
|
}
|
||||||
|
:return: {
|
||||||
|
"success": <bool>,
|
||||||
|
"user_info": <dict>, # if true, return user_info
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
# generate user_id and access_token for user
|
||||||
|
user_id = uuid.uuid1().hex
|
||||||
|
user_info['id'] = user_id
|
||||||
|
user_info['access_token'] = uuid.uuid1().hex
|
||||||
|
# construct tenant info
|
||||||
|
tenant = {
|
||||||
|
"id": user_id,
|
||||||
|
"name": user_info["nickname"] + "‘s Kingdom",
|
||||||
|
"llm_id": settings.CHAT_MDL,
|
||||||
|
"embd_id": settings.EMBEDDING_MDL,
|
||||||
|
"asr_id": settings.ASR_MDL,
|
||||||
|
"parser_ids": settings.PARSERS,
|
||||||
|
"img2txt_id": settings.IMAGE2TEXT_MDL,
|
||||||
|
"rerank_id": settings.RERANK_MDL,
|
||||||
|
}
|
||||||
|
usr_tenant = {
|
||||||
|
"tenant_id": user_id,
|
||||||
|
"user_id": user_id,
|
||||||
|
"invited_by": user_id,
|
||||||
|
"role": UserTenantRole.OWNER,
|
||||||
|
}
|
||||||
|
# construct file folder info
|
||||||
|
file_id = uuid.uuid1().hex
|
||||||
|
file = {
|
||||||
|
"id": file_id,
|
||||||
|
"parent_id": file_id,
|
||||||
|
"tenant_id": user_id,
|
||||||
|
"created_by": user_id,
|
||||||
|
"name": "/",
|
||||||
|
"type": FileType.FOLDER.value,
|
||||||
|
"size": 0,
|
||||||
|
"location": "",
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
tenant_llm = get_init_tenant_llm(user_id)
|
||||||
|
|
||||||
|
if not UserService.save(**user_info):
|
||||||
|
return {"success": False}
|
||||||
|
|
||||||
|
TenantService.insert(**tenant)
|
||||||
|
UserTenantService.insert(**usr_tenant)
|
||||||
|
TenantLLMService.insert_many(tenant_llm)
|
||||||
|
FileService.insert(file)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"user_info": user_info,
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as create_error:
|
||||||
|
logging.exception(create_error)
|
||||||
|
# rollback
|
||||||
|
try:
|
||||||
|
TenantService.delete_by_id(user_id)
|
||||||
|
except Exception as e:
|
||||||
|
logging.exception(e)
|
||||||
|
try:
|
||||||
|
u = UserTenantService.query(tenant_id=user_id)
|
||||||
|
if u:
|
||||||
|
UserTenantService.delete_by_id(u[0].id)
|
||||||
|
except Exception as e:
|
||||||
|
logging.exception(e)
|
||||||
|
try:
|
||||||
|
TenantLLM.delete().where(TenantLLM.tenant_id == user_id).execute()
|
||||||
|
except Exception as e:
|
||||||
|
logging.exception(e)
|
||||||
|
try:
|
||||||
|
FileService.delete_by_id(file["id"])
|
||||||
|
except Exception as e:
|
||||||
|
logging.exception(e)
|
||||||
|
# delete user row finally
|
||||||
|
try:
|
||||||
|
UserService.delete_by_id(user_id)
|
||||||
|
except Exception as e:
|
||||||
|
logging.exception(e)
|
||||||
|
# reraise
|
||||||
|
raise create_error
|
||||||
@ -61,6 +61,36 @@ class UserCanvasService(CommonService):
|
|||||||
|
|
||||||
return list(agents.dicts())
|
return list(agents.dicts())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def get_all_agents_by_tenant_ids(cls, tenant_ids, user_id):
|
||||||
|
# will get all permitted agents, be cautious
|
||||||
|
fields = [
|
||||||
|
cls.model.title,
|
||||||
|
cls.model.permission,
|
||||||
|
cls.model.canvas_type,
|
||||||
|
cls.model.canvas_category
|
||||||
|
]
|
||||||
|
# find team agents and owned agents
|
||||||
|
agents = cls.model.select(*fields).where(
|
||||||
|
(cls.model.user_id.in_(tenant_ids) & (cls.model.permission == TenantPermission.TEAM.value)) | (
|
||||||
|
cls.model.user_id == user_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# sort by create_time, asc
|
||||||
|
agents.order_by(cls.model.create_time.asc())
|
||||||
|
# maybe cause slow query by deep paginate, optimize later
|
||||||
|
offset, limit = 0, 50
|
||||||
|
res = []
|
||||||
|
while True:
|
||||||
|
ag_batch = agents.offset(offset).limit(limit)
|
||||||
|
_temp = list(ag_batch.dicts())
|
||||||
|
if not _temp:
|
||||||
|
break
|
||||||
|
res.extend(_temp)
|
||||||
|
offset += limit
|
||||||
|
return res
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_by_tenant_id(cls, pid):
|
def get_by_tenant_id(cls, pid):
|
||||||
|
|||||||
@ -14,12 +14,24 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
|
||||||
import peewee
|
import peewee
|
||||||
|
from peewee import InterfaceError, OperationalError
|
||||||
|
|
||||||
from api.db.db_models import DB
|
from api.db.db_models import DB
|
||||||
from api.utils import current_timestamp, datetime_format, get_uuid
|
from api.utils import current_timestamp, datetime_format, get_uuid
|
||||||
|
|
||||||
|
def retry_db_operation(func):
|
||||||
|
@retry(
|
||||||
|
stop=stop_after_attempt(3),
|
||||||
|
wait=wait_exponential(multiplier=1, min=1, max=5),
|
||||||
|
retry=retry_if_exception_type((InterfaceError, OperationalError)),
|
||||||
|
before_sleep=lambda retry_state: print(f"RETRY {retry_state.attempt_number} TIMES"),
|
||||||
|
reraise=True,
|
||||||
|
)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
return wrapper
|
||||||
|
|
||||||
class CommonService:
|
class CommonService:
|
||||||
"""Base service class that provides common database operations.
|
"""Base service class that provides common database operations.
|
||||||
@ -202,6 +214,7 @@ class CommonService:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
|
@retry_db_operation
|
||||||
def update_by_id(cls, pid, data):
|
def update_by_id(cls, pid, data):
|
||||||
# Update a single record by ID
|
# Update a single record by ID
|
||||||
# Args:
|
# Args:
|
||||||
|
|||||||
@ -190,6 +190,41 @@ class KnowledgebaseService(CommonService):
|
|||||||
|
|
||||||
return list(kbs.dicts()), count
|
return list(kbs.dicts()), count
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def get_all_kb_by_tenant_ids(cls, tenant_ids, user_id):
|
||||||
|
# will get all permitted kb, be cautious.
|
||||||
|
fields = [
|
||||||
|
cls.model.name,
|
||||||
|
cls.model.language,
|
||||||
|
cls.model.permission,
|
||||||
|
cls.model.doc_num,
|
||||||
|
cls.model.token_num,
|
||||||
|
cls.model.chunk_num,
|
||||||
|
cls.model.status,
|
||||||
|
cls.model.create_date,
|
||||||
|
cls.model.update_date
|
||||||
|
]
|
||||||
|
# find team kb and owned kb
|
||||||
|
kbs = cls.model.select(*fields).where(
|
||||||
|
(cls.model.tenant_id.in_(tenant_ids) & (cls.model.permission ==TenantPermission.TEAM.value)) | (
|
||||||
|
cls.model.tenant_id == user_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# sort by create_time asc
|
||||||
|
kbs.order_by(cls.model.create_time.asc())
|
||||||
|
# maybe cause slow query by deep paginate, optimize later.
|
||||||
|
offset, limit = 0, 50
|
||||||
|
res = []
|
||||||
|
while True:
|
||||||
|
kb_batch = kbs.offset(offset).limit(limit)
|
||||||
|
_temp = list(kb_batch.dicts())
|
||||||
|
if not _temp:
|
||||||
|
break
|
||||||
|
res.extend(_temp)
|
||||||
|
offset += limit
|
||||||
|
return res
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_kb_ids(cls, tenant_id):
|
def get_kb_ids(cls, tenant_id):
|
||||||
|
|||||||
@ -100,6 +100,12 @@ class UserService(CommonService):
|
|||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def query_user_by_email(cls, email):
|
||||||
|
users = cls.model.select().where((cls.model.email == email))
|
||||||
|
return list(users)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def save(cls, **kwargs):
|
def save(cls, **kwargs):
|
||||||
@ -133,6 +139,17 @@ class UserService(CommonService):
|
|||||||
cls.model.update(user_dict).where(
|
cls.model.update(user_dict).where(
|
||||||
cls.model.id == user_id).execute()
|
cls.model.id == user_id).execute()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def update_user_password(cls, user_id, new_password):
|
||||||
|
with DB.atomic():
|
||||||
|
update_dict = {
|
||||||
|
"password": generate_password_hash(str(new_password)),
|
||||||
|
"update_time": current_timestamp(),
|
||||||
|
"update_date": datetime_format(datetime.now())
|
||||||
|
}
|
||||||
|
cls.model.update(update_dict).where(cls.model.id == user_id).execute()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def is_admin(cls, user_id):
|
def is_admin(cls, user_id):
|
||||||
|
|||||||
@ -41,7 +41,7 @@ from api import utils
|
|||||||
from api.db.db_models import init_database_tables as init_web_db
|
from api.db.db_models import init_database_tables as init_web_db
|
||||||
from api.db.init_data import init_web_data
|
from api.db.init_data import init_web_data
|
||||||
from api.versions import get_ragflow_version
|
from api.versions import get_ragflow_version
|
||||||
from api.utils import show_configs
|
from api.utils.configs import show_configs
|
||||||
from rag.settings import print_rag_settings
|
from rag.settings import print_rag_settings
|
||||||
from rag.utils.mcp_tool_call_conn import shutdown_all_mcp_sessions
|
from rag.utils.mcp_tool_call_conn import shutdown_all_mcp_sessions
|
||||||
from rag.utils.redis_conn import RedisDistributedLock
|
from rag.utils.redis_conn import RedisDistributedLock
|
||||||
|
|||||||
@ -24,7 +24,7 @@ import rag.utils.es_conn
|
|||||||
import rag.utils.infinity_conn
|
import rag.utils.infinity_conn
|
||||||
import rag.utils.opensearch_conn
|
import rag.utils.opensearch_conn
|
||||||
from api.constants import RAG_FLOW_SERVICE_NAME
|
from api.constants import RAG_FLOW_SERVICE_NAME
|
||||||
from api.utils import decrypt_database_config, get_base_config
|
from api.utils.configs import decrypt_database_config, get_base_config
|
||||||
from api.utils.file_utils import get_project_base_directory
|
from api.utils.file_utils import get_project_base_directory
|
||||||
from rag.nlp import search
|
from rag.nlp import search
|
||||||
|
|
||||||
|
|||||||
@ -16,182 +16,15 @@
|
|||||||
import base64
|
import base64
|
||||||
import datetime
|
import datetime
|
||||||
import hashlib
|
import hashlib
|
||||||
import io
|
|
||||||
import json
|
|
||||||
import os
|
import os
|
||||||
import pickle
|
|
||||||
import socket
|
import socket
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
import requests
|
import requests
|
||||||
import logging
|
|
||||||
import copy
|
|
||||||
from enum import Enum, IntEnum
|
|
||||||
import importlib
|
import importlib
|
||||||
from filelock import FileLock
|
|
||||||
from api.constants import SERVICE_CONF
|
|
||||||
|
|
||||||
from . import file_utils
|
from .common import string_to_bytes
|
||||||
|
|
||||||
|
|
||||||
def conf_realpath(conf_name):
|
|
||||||
conf_path = f"conf/{conf_name}"
|
|
||||||
return os.path.join(file_utils.get_project_base_directory(), conf_path)
|
|
||||||
|
|
||||||
|
|
||||||
def read_config(conf_name=SERVICE_CONF):
|
|
||||||
local_config = {}
|
|
||||||
local_path = conf_realpath(f'local.{conf_name}')
|
|
||||||
|
|
||||||
# load local config file
|
|
||||||
if os.path.exists(local_path):
|
|
||||||
local_config = file_utils.load_yaml_conf(local_path)
|
|
||||||
if not isinstance(local_config, dict):
|
|
||||||
raise ValueError(f'Invalid config file: "{local_path}".')
|
|
||||||
|
|
||||||
global_config_path = conf_realpath(conf_name)
|
|
||||||
global_config = file_utils.load_yaml_conf(global_config_path)
|
|
||||||
|
|
||||||
if not isinstance(global_config, dict):
|
|
||||||
raise ValueError(f'Invalid config file: "{global_config_path}".')
|
|
||||||
|
|
||||||
global_config.update(local_config)
|
|
||||||
return global_config
|
|
||||||
|
|
||||||
|
|
||||||
CONFIGS = read_config()
|
|
||||||
|
|
||||||
|
|
||||||
def show_configs():
|
|
||||||
msg = f"Current configs, from {conf_realpath(SERVICE_CONF)}:"
|
|
||||||
for k, v in CONFIGS.items():
|
|
||||||
if isinstance(v, dict):
|
|
||||||
if "password" in v:
|
|
||||||
v = copy.deepcopy(v)
|
|
||||||
v["password"] = "*" * 8
|
|
||||||
if "access_key" in v:
|
|
||||||
v = copy.deepcopy(v)
|
|
||||||
v["access_key"] = "*" * 8
|
|
||||||
if "secret_key" in v:
|
|
||||||
v = copy.deepcopy(v)
|
|
||||||
v["secret_key"] = "*" * 8
|
|
||||||
if "secret" in v:
|
|
||||||
v = copy.deepcopy(v)
|
|
||||||
v["secret"] = "*" * 8
|
|
||||||
if "sas_token" in v:
|
|
||||||
v = copy.deepcopy(v)
|
|
||||||
v["sas_token"] = "*" * 8
|
|
||||||
if "oauth" in k:
|
|
||||||
v = copy.deepcopy(v)
|
|
||||||
for key, val in v.items():
|
|
||||||
if "client_secret" in val:
|
|
||||||
val["client_secret"] = "*" * 8
|
|
||||||
if "authentication" in k:
|
|
||||||
v = copy.deepcopy(v)
|
|
||||||
for key, val in v.items():
|
|
||||||
if "http_secret_key" in val:
|
|
||||||
val["http_secret_key"] = "*" * 8
|
|
||||||
msg += f"\n\t{k}: {v}"
|
|
||||||
logging.info(msg)
|
|
||||||
|
|
||||||
|
|
||||||
def get_base_config(key, default=None):
|
|
||||||
if key is None:
|
|
||||||
return None
|
|
||||||
if default is None:
|
|
||||||
default = os.environ.get(key.upper())
|
|
||||||
return CONFIGS.get(key, default)
|
|
||||||
|
|
||||||
|
|
||||||
use_deserialize_safe_module = get_base_config(
|
|
||||||
'use_deserialize_safe_module', False)
|
|
||||||
|
|
||||||
|
|
||||||
class BaseType:
|
|
||||||
def to_dict(self):
|
|
||||||
return dict([(k.lstrip("_"), v) for k, v in self.__dict__.items()])
|
|
||||||
|
|
||||||
def to_dict_with_type(self):
|
|
||||||
def _dict(obj):
|
|
||||||
module = None
|
|
||||||
if issubclass(obj.__class__, BaseType):
|
|
||||||
data = {}
|
|
||||||
for attr, v in obj.__dict__.items():
|
|
||||||
k = attr.lstrip("_")
|
|
||||||
data[k] = _dict(v)
|
|
||||||
module = obj.__module__
|
|
||||||
elif isinstance(obj, (list, tuple)):
|
|
||||||
data = []
|
|
||||||
for i, vv in enumerate(obj):
|
|
||||||
data.append(_dict(vv))
|
|
||||||
elif isinstance(obj, dict):
|
|
||||||
data = {}
|
|
||||||
for _k, vv in obj.items():
|
|
||||||
data[_k] = _dict(vv)
|
|
||||||
else:
|
|
||||||
data = obj
|
|
||||||
return {"type": obj.__class__.__name__,
|
|
||||||
"data": data, "module": module}
|
|
||||||
|
|
||||||
return _dict(self)
|
|
||||||
|
|
||||||
|
|
||||||
class CustomJSONEncoder(json.JSONEncoder):
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
self._with_type = kwargs.pop("with_type", False)
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
|
|
||||||
def default(self, obj):
|
|
||||||
if isinstance(obj, datetime.datetime):
|
|
||||||
return obj.strftime('%Y-%m-%d %H:%M:%S')
|
|
||||||
elif isinstance(obj, datetime.date):
|
|
||||||
return obj.strftime('%Y-%m-%d')
|
|
||||||
elif isinstance(obj, datetime.timedelta):
|
|
||||||
return str(obj)
|
|
||||||
elif issubclass(type(obj), Enum) or issubclass(type(obj), IntEnum):
|
|
||||||
return obj.value
|
|
||||||
elif isinstance(obj, set):
|
|
||||||
return list(obj)
|
|
||||||
elif issubclass(type(obj), BaseType):
|
|
||||||
if not self._with_type:
|
|
||||||
return obj.to_dict()
|
|
||||||
else:
|
|
||||||
return obj.to_dict_with_type()
|
|
||||||
elif isinstance(obj, type):
|
|
||||||
return obj.__name__
|
|
||||||
else:
|
|
||||||
return json.JSONEncoder.default(self, obj)
|
|
||||||
|
|
||||||
|
|
||||||
def rag_uuid():
|
|
||||||
return uuid.uuid1().hex
|
|
||||||
|
|
||||||
|
|
||||||
def string_to_bytes(string):
|
|
||||||
return string if isinstance(
|
|
||||||
string, bytes) else string.encode(encoding="utf-8")
|
|
||||||
|
|
||||||
|
|
||||||
def bytes_to_string(byte):
|
|
||||||
return byte.decode(encoding="utf-8")
|
|
||||||
|
|
||||||
|
|
||||||
def json_dumps(src, byte=False, indent=None, with_type=False):
|
|
||||||
dest = json.dumps(
|
|
||||||
src,
|
|
||||||
indent=indent,
|
|
||||||
cls=CustomJSONEncoder,
|
|
||||||
with_type=with_type)
|
|
||||||
if byte:
|
|
||||||
dest = string_to_bytes(dest)
|
|
||||||
return dest
|
|
||||||
|
|
||||||
|
|
||||||
def json_loads(src, object_hook=None, object_pairs_hook=None):
|
|
||||||
if isinstance(src, bytes):
|
|
||||||
src = bytes_to_string(src)
|
|
||||||
return json.loads(src, object_hook=object_hook,
|
|
||||||
object_pairs_hook=object_pairs_hook)
|
|
||||||
|
|
||||||
|
|
||||||
def current_timestamp():
|
def current_timestamp():
|
||||||
@ -213,45 +46,6 @@ def date_string_to_timestamp(time_str, format_string="%Y-%m-%d %H:%M:%S"):
|
|||||||
return time_stamp
|
return time_stamp
|
||||||
|
|
||||||
|
|
||||||
def serialize_b64(src, to_str=False):
|
|
||||||
dest = base64.b64encode(pickle.dumps(src))
|
|
||||||
if not to_str:
|
|
||||||
return dest
|
|
||||||
else:
|
|
||||||
return bytes_to_string(dest)
|
|
||||||
|
|
||||||
|
|
||||||
def deserialize_b64(src):
|
|
||||||
src = base64.b64decode(
|
|
||||||
string_to_bytes(src) if isinstance(
|
|
||||||
src, str) else src)
|
|
||||||
if use_deserialize_safe_module:
|
|
||||||
return restricted_loads(src)
|
|
||||||
return pickle.loads(src)
|
|
||||||
|
|
||||||
|
|
||||||
safe_module = {
|
|
||||||
'numpy',
|
|
||||||
'rag_flow'
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class RestrictedUnpickler(pickle.Unpickler):
|
|
||||||
def find_class(self, module, name):
|
|
||||||
import importlib
|
|
||||||
if module.split('.')[0] in safe_module:
|
|
||||||
_module = importlib.import_module(module)
|
|
||||||
return getattr(_module, name)
|
|
||||||
# Forbid everything else.
|
|
||||||
raise pickle.UnpicklingError("global '%s.%s' is forbidden" %
|
|
||||||
(module, name))
|
|
||||||
|
|
||||||
|
|
||||||
def restricted_loads(src):
|
|
||||||
"""Helper function analogous to pickle.loads()."""
|
|
||||||
return RestrictedUnpickler(io.BytesIO(src)).load()
|
|
||||||
|
|
||||||
|
|
||||||
def get_lan_ip():
|
def get_lan_ip():
|
||||||
if os.name != "nt":
|
if os.name != "nt":
|
||||||
import fcntl
|
import fcntl
|
||||||
@ -296,47 +90,6 @@ def from_dict_hook(in_dict: dict):
|
|||||||
return in_dict
|
return in_dict
|
||||||
|
|
||||||
|
|
||||||
def decrypt_database_password(password):
|
|
||||||
encrypt_password = get_base_config("encrypt_password", False)
|
|
||||||
encrypt_module = get_base_config("encrypt_module", False)
|
|
||||||
private_key = get_base_config("private_key", None)
|
|
||||||
|
|
||||||
if not password or not encrypt_password:
|
|
||||||
return password
|
|
||||||
|
|
||||||
if not private_key:
|
|
||||||
raise ValueError("No private key")
|
|
||||||
|
|
||||||
module_fun = encrypt_module.split("#")
|
|
||||||
pwdecrypt_fun = getattr(
|
|
||||||
importlib.import_module(
|
|
||||||
module_fun[0]),
|
|
||||||
module_fun[1])
|
|
||||||
|
|
||||||
return pwdecrypt_fun(private_key, password)
|
|
||||||
|
|
||||||
|
|
||||||
def decrypt_database_config(
|
|
||||||
database=None, passwd_key="password", name="database"):
|
|
||||||
if not database:
|
|
||||||
database = get_base_config(name, {})
|
|
||||||
|
|
||||||
database[passwd_key] = decrypt_database_password(database[passwd_key])
|
|
||||||
return database
|
|
||||||
|
|
||||||
|
|
||||||
def update_config(key, value, conf_name=SERVICE_CONF):
|
|
||||||
conf_path = conf_realpath(conf_name=conf_name)
|
|
||||||
if not os.path.isabs(conf_path):
|
|
||||||
conf_path = os.path.join(
|
|
||||||
file_utils.get_project_base_directory(), conf_path)
|
|
||||||
|
|
||||||
with FileLock(os.path.join(os.path.dirname(conf_path), ".lock")):
|
|
||||||
config = file_utils.load_yaml_conf(conf_path=conf_path) or {}
|
|
||||||
config[key] = value
|
|
||||||
file_utils.rewrite_yaml_conf(conf_path=conf_path, config=config)
|
|
||||||
|
|
||||||
|
|
||||||
def get_uuid():
|
def get_uuid():
|
||||||
return uuid.uuid1().hex
|
return uuid.uuid1().hex
|
||||||
|
|
||||||
@ -375,5 +128,5 @@ def delta_seconds(date_string: str):
|
|||||||
return (datetime.datetime.now() - dt).total_seconds()
|
return (datetime.datetime.now() - dt).total_seconds()
|
||||||
|
|
||||||
|
|
||||||
def hash_str2int(line:str, mod: int=10 ** 8) -> int:
|
def hash_str2int(line: str, mod: int = 10 ** 8) -> int:
|
||||||
return int(hashlib.sha1(line.encode("utf-8")).hexdigest(), 16) % mod
|
return int(hashlib.sha1(line.encode("utf-8")).hexdigest(), 16) % mod
|
||||||
|
|||||||
@ -39,6 +39,7 @@ from flask import (
|
|||||||
make_response,
|
make_response,
|
||||||
send_file,
|
send_file,
|
||||||
)
|
)
|
||||||
|
from flask_login import current_user
|
||||||
from flask import (
|
from flask import (
|
||||||
request as flask_request,
|
request as flask_request,
|
||||||
)
|
)
|
||||||
@ -48,10 +49,13 @@ from werkzeug.http import HTTP_STATUS_CODES
|
|||||||
|
|
||||||
from api import settings
|
from api import settings
|
||||||
from api.constants import REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC
|
from api.constants import REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC
|
||||||
|
from api.db import ActiveEnum
|
||||||
from api.db.db_models import APIToken
|
from api.db.db_models import APIToken
|
||||||
|
from api.db.services import UserService
|
||||||
from api.db.services.llm_service import LLMService
|
from api.db.services.llm_service import LLMService
|
||||||
from api.db.services.tenant_llm_service import TenantLLMService
|
from api.db.services.tenant_llm_service import TenantLLMService
|
||||||
from api.utils import CustomJSONEncoder, get_uuid, json_dumps
|
from api.utils.json import CustomJSONEncoder, json_dumps
|
||||||
|
from api.utils import get_uuid
|
||||||
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
|
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
|
||||||
|
|
||||||
requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSONEncoder)
|
requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSONEncoder)
|
||||||
@ -226,6 +230,18 @@ def not_allowed_parameters(*params):
|
|||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def active_required(f):
|
||||||
|
@wraps(f)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
user_id = current_user.id
|
||||||
|
usr = UserService.filter_by_id(user_id)
|
||||||
|
# check is_active
|
||||||
|
if not usr or not usr.is_active == ActiveEnum.ACTIVE.value:
|
||||||
|
return get_json_result(code=settings.RetCode.FORBIDDEN, message="User isn't active, please activate first.")
|
||||||
|
return f(*args, **kwargs)
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
def is_localhost(ip):
|
def is_localhost(ip):
|
||||||
return ip in {"127.0.0.1", "::1", "[::1]", "localhost"}
|
return ip in {"127.0.0.1", "::1", "[::1]", "localhost"}
|
||||||
|
|
||||||
|
|||||||
23
api/utils/common.py
Normal file
23
api/utils/common.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
def string_to_bytes(string):
|
||||||
|
return string if isinstance(
|
||||||
|
string, bytes) else string.encode(encoding="utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
def bytes_to_string(byte):
|
||||||
|
return byte.decode(encoding="utf-8")
|
||||||
179
api/utils/configs.py
Normal file
179
api/utils/configs.py
Normal file
@ -0,0 +1,179 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
import os
|
||||||
|
import io
|
||||||
|
import copy
|
||||||
|
import logging
|
||||||
|
import base64
|
||||||
|
import pickle
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
from api.utils import file_utils
|
||||||
|
from filelock import FileLock
|
||||||
|
from api.utils.common import bytes_to_string, string_to_bytes
|
||||||
|
from api.constants import SERVICE_CONF
|
||||||
|
|
||||||
|
|
||||||
|
def conf_realpath(conf_name):
|
||||||
|
conf_path = f"conf/{conf_name}"
|
||||||
|
return os.path.join(file_utils.get_project_base_directory(), conf_path)
|
||||||
|
|
||||||
|
|
||||||
|
def read_config(conf_name=SERVICE_CONF):
|
||||||
|
local_config = {}
|
||||||
|
local_path = conf_realpath(f'local.{conf_name}')
|
||||||
|
|
||||||
|
# load local config file
|
||||||
|
if os.path.exists(local_path):
|
||||||
|
local_config = file_utils.load_yaml_conf(local_path)
|
||||||
|
if not isinstance(local_config, dict):
|
||||||
|
raise ValueError(f'Invalid config file: "{local_path}".')
|
||||||
|
|
||||||
|
global_config_path = conf_realpath(conf_name)
|
||||||
|
global_config = file_utils.load_yaml_conf(global_config_path)
|
||||||
|
|
||||||
|
if not isinstance(global_config, dict):
|
||||||
|
raise ValueError(f'Invalid config file: "{global_config_path}".')
|
||||||
|
|
||||||
|
global_config.update(local_config)
|
||||||
|
return global_config
|
||||||
|
|
||||||
|
|
||||||
|
CONFIGS = read_config()
|
||||||
|
|
||||||
|
|
||||||
|
def show_configs():
|
||||||
|
msg = f"Current configs, from {conf_realpath(SERVICE_CONF)}:"
|
||||||
|
for k, v in CONFIGS.items():
|
||||||
|
if isinstance(v, dict):
|
||||||
|
if "password" in v:
|
||||||
|
v = copy.deepcopy(v)
|
||||||
|
v["password"] = "*" * 8
|
||||||
|
if "access_key" in v:
|
||||||
|
v = copy.deepcopy(v)
|
||||||
|
v["access_key"] = "*" * 8
|
||||||
|
if "secret_key" in v:
|
||||||
|
v = copy.deepcopy(v)
|
||||||
|
v["secret_key"] = "*" * 8
|
||||||
|
if "secret" in v:
|
||||||
|
v = copy.deepcopy(v)
|
||||||
|
v["secret"] = "*" * 8
|
||||||
|
if "sas_token" in v:
|
||||||
|
v = copy.deepcopy(v)
|
||||||
|
v["sas_token"] = "*" * 8
|
||||||
|
if "oauth" in k:
|
||||||
|
v = copy.deepcopy(v)
|
||||||
|
for key, val in v.items():
|
||||||
|
if "client_secret" in val:
|
||||||
|
val["client_secret"] = "*" * 8
|
||||||
|
if "authentication" in k:
|
||||||
|
v = copy.deepcopy(v)
|
||||||
|
for key, val in v.items():
|
||||||
|
if "http_secret_key" in val:
|
||||||
|
val["http_secret_key"] = "*" * 8
|
||||||
|
msg += f"\n\t{k}: {v}"
|
||||||
|
logging.info(msg)
|
||||||
|
|
||||||
|
|
||||||
|
def get_base_config(key, default=None):
|
||||||
|
if key is None:
|
||||||
|
return None
|
||||||
|
if default is None:
|
||||||
|
default = os.environ.get(key.upper())
|
||||||
|
return CONFIGS.get(key, default)
|
||||||
|
|
||||||
|
|
||||||
|
def decrypt_database_password(password):
|
||||||
|
encrypt_password = get_base_config("encrypt_password", False)
|
||||||
|
encrypt_module = get_base_config("encrypt_module", False)
|
||||||
|
private_key = get_base_config("private_key", None)
|
||||||
|
|
||||||
|
if not password or not encrypt_password:
|
||||||
|
return password
|
||||||
|
|
||||||
|
if not private_key:
|
||||||
|
raise ValueError("No private key")
|
||||||
|
|
||||||
|
module_fun = encrypt_module.split("#")
|
||||||
|
pwdecrypt_fun = getattr(
|
||||||
|
importlib.import_module(
|
||||||
|
module_fun[0]),
|
||||||
|
module_fun[1])
|
||||||
|
|
||||||
|
return pwdecrypt_fun(private_key, password)
|
||||||
|
|
||||||
|
|
||||||
|
def decrypt_database_config(
|
||||||
|
database=None, passwd_key="password", name="database"):
|
||||||
|
if not database:
|
||||||
|
database = get_base_config(name, {})
|
||||||
|
|
||||||
|
database[passwd_key] = decrypt_database_password(database[passwd_key])
|
||||||
|
return database
|
||||||
|
|
||||||
|
|
||||||
|
def update_config(key, value, conf_name=SERVICE_CONF):
|
||||||
|
conf_path = conf_realpath(conf_name=conf_name)
|
||||||
|
if not os.path.isabs(conf_path):
|
||||||
|
conf_path = os.path.join(
|
||||||
|
file_utils.get_project_base_directory(), conf_path)
|
||||||
|
|
||||||
|
with FileLock(os.path.join(os.path.dirname(conf_path), ".lock")):
|
||||||
|
config = file_utils.load_yaml_conf(conf_path=conf_path) or {}
|
||||||
|
config[key] = value
|
||||||
|
file_utils.rewrite_yaml_conf(conf_path=conf_path, config=config)
|
||||||
|
|
||||||
|
|
||||||
|
safe_module = {
|
||||||
|
'numpy',
|
||||||
|
'rag_flow'
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class RestrictedUnpickler(pickle.Unpickler):
|
||||||
|
def find_class(self, module, name):
|
||||||
|
import importlib
|
||||||
|
if module.split('.')[0] in safe_module:
|
||||||
|
_module = importlib.import_module(module)
|
||||||
|
return getattr(_module, name)
|
||||||
|
# Forbid everything else.
|
||||||
|
raise pickle.UnpicklingError("global '%s.%s' is forbidden" %
|
||||||
|
(module, name))
|
||||||
|
|
||||||
|
|
||||||
|
def restricted_loads(src):
|
||||||
|
"""Helper function analogous to pickle.loads()."""
|
||||||
|
return RestrictedUnpickler(io.BytesIO(src)).load()
|
||||||
|
|
||||||
|
|
||||||
|
def serialize_b64(src, to_str=False):
|
||||||
|
dest = base64.b64encode(pickle.dumps(src))
|
||||||
|
if not to_str:
|
||||||
|
return dest
|
||||||
|
else:
|
||||||
|
return bytes_to_string(dest)
|
||||||
|
|
||||||
|
|
||||||
|
def deserialize_b64(src):
|
||||||
|
src = base64.b64decode(
|
||||||
|
string_to_bytes(src) if isinstance(
|
||||||
|
src, str) else src)
|
||||||
|
use_deserialize_safe_module = get_base_config(
|
||||||
|
'use_deserialize_safe_module', False)
|
||||||
|
if use_deserialize_safe_module:
|
||||||
|
return restricted_loads(src)
|
||||||
|
return pickle.loads(src)
|
||||||
@ -23,6 +23,9 @@ from api.utils import file_utils
|
|||||||
|
|
||||||
|
|
||||||
def crypt(line):
|
def crypt(line):
|
||||||
|
"""
|
||||||
|
decrypt(crypt(input_string)) == base64(input_string), which frontend and admin_client use.
|
||||||
|
"""
|
||||||
file_path = os.path.join(file_utils.get_project_base_directory(), "conf", "public.pem")
|
file_path = os.path.join(file_utils.get_project_base_directory(), "conf", "public.pem")
|
||||||
rsa_key = RSA.importKey(open(file_path).read(), "Welcome")
|
rsa_key = RSA.importKey(open(file_path).read(), "Welcome")
|
||||||
cipher = Cipher_pkcs1_v1_5.new(rsa_key)
|
cipher = Cipher_pkcs1_v1_5.new(rsa_key)
|
||||||
|
|||||||
78
api/utils/json.py
Normal file
78
api/utils/json.py
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
import datetime
|
||||||
|
import json
|
||||||
|
from enum import Enum, IntEnum
|
||||||
|
from api.utils.common import string_to_bytes, bytes_to_string
|
||||||
|
|
||||||
|
|
||||||
|
class BaseType:
|
||||||
|
def to_dict(self):
|
||||||
|
return dict([(k.lstrip("_"), v) for k, v in self.__dict__.items()])
|
||||||
|
|
||||||
|
def to_dict_with_type(self):
|
||||||
|
def _dict(obj):
|
||||||
|
module = None
|
||||||
|
if issubclass(obj.__class__, BaseType):
|
||||||
|
data = {}
|
||||||
|
for attr, v in obj.__dict__.items():
|
||||||
|
k = attr.lstrip("_")
|
||||||
|
data[k] = _dict(v)
|
||||||
|
module = obj.__module__
|
||||||
|
elif isinstance(obj, (list, tuple)):
|
||||||
|
data = []
|
||||||
|
for i, vv in enumerate(obj):
|
||||||
|
data.append(_dict(vv))
|
||||||
|
elif isinstance(obj, dict):
|
||||||
|
data = {}
|
||||||
|
for _k, vv in obj.items():
|
||||||
|
data[_k] = _dict(vv)
|
||||||
|
else:
|
||||||
|
data = obj
|
||||||
|
return {"type": obj.__class__.__name__,
|
||||||
|
"data": data, "module": module}
|
||||||
|
|
||||||
|
return _dict(self)
|
||||||
|
|
||||||
|
|
||||||
|
class CustomJSONEncoder(json.JSONEncoder):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
self._with_type = kwargs.pop("with_type", False)
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
def default(self, obj):
|
||||||
|
if isinstance(obj, datetime.datetime):
|
||||||
|
return obj.strftime('%Y-%m-%d %H:%M:%S')
|
||||||
|
elif isinstance(obj, datetime.date):
|
||||||
|
return obj.strftime('%Y-%m-%d')
|
||||||
|
elif isinstance(obj, datetime.timedelta):
|
||||||
|
return str(obj)
|
||||||
|
elif issubclass(type(obj), Enum) or issubclass(type(obj), IntEnum):
|
||||||
|
return obj.value
|
||||||
|
elif isinstance(obj, set):
|
||||||
|
return list(obj)
|
||||||
|
elif issubclass(type(obj), BaseType):
|
||||||
|
if not self._with_type:
|
||||||
|
return obj.to_dict()
|
||||||
|
else:
|
||||||
|
return obj.to_dict_with_type()
|
||||||
|
elif isinstance(obj, type):
|
||||||
|
return obj.__name__
|
||||||
|
else:
|
||||||
|
return json.JSONEncoder.default(self, obj)
|
||||||
|
|
||||||
|
|
||||||
|
def json_dumps(src, byte=False, indent=None, with_type=False):
|
||||||
|
dest = json.dumps(
|
||||||
|
src,
|
||||||
|
indent=indent,
|
||||||
|
cls=CustomJSONEncoder,
|
||||||
|
with_type=with_type)
|
||||||
|
if byte:
|
||||||
|
dest = string_to_bytes(dest)
|
||||||
|
return dest
|
||||||
|
|
||||||
|
|
||||||
|
def json_loads(src, object_hook=None, object_pairs_hook=None):
|
||||||
|
if isinstance(src, bytes):
|
||||||
|
src = bytes_to_string(src)
|
||||||
|
return json.loads(src, object_hook=object_hook,
|
||||||
|
object_pairs_hook=object_pairs_hook)
|
||||||
@ -350,7 +350,7 @@ class TextRecognizer:
|
|||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
# close session and release manually
|
# close session and release manually
|
||||||
logging.info('Close TextRecognizer.')
|
logging.info('Close text recognizer.')
|
||||||
if hasattr(self, "predictor"):
|
if hasattr(self, "predictor"):
|
||||||
del self.predictor
|
del self.predictor
|
||||||
gc.collect()
|
gc.collect()
|
||||||
@ -490,7 +490,7 @@ class TextDetector:
|
|||||||
return dt_boxes
|
return dt_boxes
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
logging.info("Close TextDetector.")
|
logging.info("Close text detector.")
|
||||||
if hasattr(self, "predictor"):
|
if hasattr(self, "predictor"):
|
||||||
del self.predictor
|
del self.predictor
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|||||||
@ -143,7 +143,7 @@ class Base(ABC):
|
|||||||
logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2))
|
logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2))
|
||||||
if self.model_name.lower().find("qwen3") >= 0:
|
if self.model_name.lower().find("qwen3") >= 0:
|
||||||
kwargs["extra_body"] = {"enable_thinking": False}
|
kwargs["extra_body"] = {"enable_thinking": False}
|
||||||
|
|
||||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, **gen_conf, **kwargs)
|
response = self.client.chat.completions.create(model=self.model_name, messages=history, **gen_conf, **kwargs)
|
||||||
|
|
||||||
if (not response.choices or not response.choices[0].message or not response.choices[0].message.content):
|
if (not response.choices or not response.choices[0].message or not response.choices[0].message.content):
|
||||||
@ -156,12 +156,12 @@ class Base(ABC):
|
|||||||
def _chat_streamly(self, history, gen_conf, **kwargs):
|
def _chat_streamly(self, history, gen_conf, **kwargs):
|
||||||
logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4))
|
logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4))
|
||||||
reasoning_start = False
|
reasoning_start = False
|
||||||
|
|
||||||
if kwargs.get("stop") or "stop" in gen_conf:
|
if kwargs.get("stop") or "stop" in gen_conf:
|
||||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf, stop=kwargs.get("stop"))
|
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf, stop=kwargs.get("stop"))
|
||||||
else:
|
else:
|
||||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf)
|
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf)
|
||||||
|
|
||||||
for resp in response:
|
for resp in response:
|
||||||
if not resp.choices:
|
if not resp.choices:
|
||||||
continue
|
continue
|
||||||
@ -643,7 +643,7 @@ class ZhipuChat(Base):
|
|||||||
del gen_conf["max_tokens"]
|
del gen_conf["max_tokens"]
|
||||||
gen_conf = self._clean_conf_plealty(gen_conf)
|
gen_conf = self._clean_conf_plealty(gen_conf)
|
||||||
return gen_conf
|
return gen_conf
|
||||||
|
|
||||||
def _clean_conf_plealty(self, gen_conf):
|
def _clean_conf_plealty(self, gen_conf):
|
||||||
if "presence_penalty" in gen_conf:
|
if "presence_penalty" in gen_conf:
|
||||||
del gen_conf["presence_penalty"]
|
del gen_conf["presence_penalty"]
|
||||||
|
|||||||
@ -56,7 +56,7 @@ class FulltextQueryer:
|
|||||||
def rmWWW(txt):
|
def rmWWW(txt):
|
||||||
patts = [
|
patts = [
|
||||||
(
|
(
|
||||||
r"是*(什么样的|哪家|一下|那家|请问|啥样|咋样了|什么时候|何时|何地|何人|是否|是不是|多少|哪里|怎么|哪儿|怎么样|如何|哪些|是啥|啥是|啊|吗|呢|吧|咋|什么|有没有|呀|谁|哪位|哪个)是*",
|
r"是*(怎么办|什么样的|哪家|一下|那家|请问|啥样|咋样了|什么时候|何时|何地|何人|是否|是不是|多少|哪里|怎么|哪儿|怎么样|如何|哪些|是啥|啥是|啊|吗|呢|吧|咋|什么|有没有|呀|谁|哪位|哪个)是*",
|
||||||
"",
|
"",
|
||||||
),
|
),
|
||||||
(r"(^| )(what|who|how|which|where|why)('re|'s)? ", " "),
|
(r"(^| )(what|who|how|which|where|why)('re|'s)? ", " "),
|
||||||
|
|||||||
@ -15,7 +15,7 @@
|
|||||||
#
|
#
|
||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
from api.utils import get_base_config, decrypt_database_config
|
from api.utils.configs import get_base_config, decrypt_database_config
|
||||||
from api.utils.file_utils import get_project_base_directory
|
from api.utils.file_utils import get_project_base_directory
|
||||||
|
|
||||||
# Server
|
# Server
|
||||||
|
|||||||
@ -3,7 +3,7 @@ import logging
|
|||||||
import pymysql
|
import pymysql
|
||||||
from urllib.parse import quote_plus
|
from urllib.parse import quote_plus
|
||||||
|
|
||||||
from api.utils import get_base_config
|
from api.utils.configs import get_base_config
|
||||||
from rag.utils import singleton
|
from rag.utils import singleton
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -949,7 +949,8 @@ export default {
|
|||||||
multimodalModels: 'Мультимодальные модели',
|
multimodalModels: 'Мультимодальные модели',
|
||||||
textOnlyModels: 'Только текстовые модели',
|
textOnlyModels: 'Только текстовые модели',
|
||||||
allModels: 'Все модели',
|
allModels: 'Все модели',
|
||||||
codeExecDescription: 'Напишите свою пользовательскую логику на Python или Javascript.',
|
codeExecDescription:
|
||||||
|
'Напишите свою пользовательскую логику на Python или Javascript.',
|
||||||
stringTransformDescription:
|
stringTransformDescription:
|
||||||
'Изменяет текстовое содержимое. В настоящее время поддерживает: разделение или объединение текста.',
|
'Изменяет текстовое содержимое. В настоящее время поддерживает: разделение или объединение текста.',
|
||||||
foundation: 'Основа',
|
foundation: 'Основа',
|
||||||
|
|||||||
Reference in New Issue
Block a user