Compare commits

...

5 Commits

Author SHA1 Message Date
b0b866c8fd Refactor: move some functions out of api/utils/__init__.py (#10216)
### What problem does this PR solve?

Refactor import modules.

### Type of change

- [x] Refactoring

---------

Signed-off-by: jinhai <haijin.chn@gmail.com>
Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2025-09-25 18:04:49 +08:00
3a831d0c28 Fixed the issue where database connections were interrupted under high concurrency (#10126)
### What problem does this PR solve?

Fixed the issue where database connections were interrupted under high
concurrency

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

---------

Co-authored-by: lemsn <lemsn@126.com>
Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
2025-09-25 17:03:43 +08:00
9e323a9351 Feat(nlp): add "怎么办" pattern to question word removal (#10284)
### What problem does this PR solve?

Added "怎么办" to the regex pattern in rmWWW method to improve query
cleaning by removing this common question phrase along with other
question words.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-09-25 16:47:56 +08:00
7ac95b759b Feat/admin service (#10233)
### What problem does this PR solve?

- Admin client support show user and create user command.
- Admin client support alter user password and active status.
- Admin client support list user datasets.

issue: #10241

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-09-25 16:15:15 +08:00
daea357940 Fix: invalid COMPONENT_EXEC_TIMEOUT (#10278)
### What problem does this PR solve?

Fix invalid COMPONENT_EXEC_TIMEOUT. #10273

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-09-25 14:11:09 +08:00
45 changed files with 923 additions and 363 deletions

View File

@ -1,5 +1,7 @@
import argparse
import base64
from Cryptodome.PublicKey import RSA
from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5
from typing import Dict, List, Any
from lark import Lark, Transformer, Tree
import requests
@ -19,6 +21,8 @@ sql_command: list_services
| show_user
| drop_user
| alter_user
| create_user
| activate_user
| list_datasets
| list_agents
@ -35,6 +39,7 @@ meta_arg: /[^\\s"']+/ | quoted_string
LIST: "LIST"i
SERVICES: "SERVICES"i
SHOW: "SHOW"i
CREATE: "CREATE"i
SERVICE: "SERVICE"i
SHUTDOWN: "SHUTDOWN"i
STARTUP: "STARTUP"i
@ -43,6 +48,7 @@ USERS: "USERS"i
DROP: "DROP"i
USER: "USER"i
ALTER: "ALTER"i
ACTIVE: "ACTIVE"i
PASSWORD: "PASSWORD"i
DATASETS: "DATASETS"i
OF: "OF"i
@ -58,12 +64,15 @@ list_users: LIST USERS ";"
drop_user: DROP USER quoted_string ";"
alter_user: ALTER USER PASSWORD quoted_string quoted_string ";"
show_user: SHOW USER quoted_string ";"
create_user: CREATE USER quoted_string quoted_string ";"
activate_user: ALTER USER ACTIVE quoted_string status ";"
list_datasets: LIST DATASETS OF quoted_string ";"
list_agents: LIST AGENTS OF quoted_string ";"
identifier: WORD
quoted_string: QUOTED_STRING
status: WORD
QUOTED_STRING: /'[^']+'/ | /"[^"]+"/
WORD: /[a-zA-Z0-9_\-\.]+/
@ -118,6 +127,16 @@ class AdminTransformer(Transformer):
new_password = items[4]
return {"type": "alter_user", "username": user_name, "password": new_password}
def create_user(self, items):
user_name = items[2]
password = items[3]
return {"type": "create_user", "username": user_name, "password": password, "role": "user"}
def activate_user(self, items):
user_name = items[3]
activate_status = items[4]
return {"type": "activate_user", "activate_status": activate_status, "username": user_name}
def list_datasets(self, items):
user_name = items[3]
return {"type": "list_datasets", "username": user_name}
@ -152,6 +171,14 @@ def encode_to_base64(input_string):
return base64_encoded.decode('utf-8')
def encrypt(input_string):
pub = '-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEArq9XTUSeYr2+N1h3Afl/z8Dse/2yD0ZGrKwx+EEEcdsBLca9Ynmx3nIB5obmLlSfmskLpBo0UACBmB5rEjBp2Q2f3AG3Hjd4B+gNCG6BDaawuDlgANIhGnaTLrIqWrrcm4EMzJOnAOI1fgzJRsOOUEfaS318Eq9OVO3apEyCCt0lOQK6PuksduOjVxtltDav+guVAA068NrPYmRNabVKRNLJpL8w4D44sfth5RvZ3q9t+6RTArpEtc5sh5ChzvqPOzKGMXW83C95TxmXqpbK6olN4RevSfVjEAgCydH6HN6OhtOQEcnrU97r9H0iZOWwbw3pVrZiUkuRD1R56Wzs2wIDAQAB\n-----END PUBLIC KEY-----'
pub_key = RSA.importKey(pub)
cipher = Cipher_pkcs1_v1_5.new(pub_key)
cipher_text = cipher.encrypt(base64.b64encode(input_string.encode('utf-8')))
return base64.b64encode(cipher_text).decode("utf-8")
class AdminCommandParser:
def __init__(self):
self.parser = Lark(GRAMMAR, start='start', parser='lalr', transformer=AdminTransformer())
@ -220,6 +247,9 @@ class AdminCLI:
if not data:
print("No data to print")
return
if isinstance(data, dict):
# handle single row data
data = [data]
columns = list(data[0].keys())
col_widths = {}
@ -335,6 +365,10 @@ class AdminCLI:
self._handle_drop_user(command_dict)
case 'alter_user':
self._handle_alter_user(command_dict)
case 'create_user':
self._handle_create_user(command_dict)
case 'activate_user':
self._handle_activate_user(command_dict)
case 'list_datasets':
self._handle_list_datasets(command_dict)
case 'list_agents':
@ -349,9 +383,8 @@ class AdminCLI:
url = f'http://{self.host}:{self.port}/api/v1/admin/services'
response = requests.get(url, auth=HTTPBasicAuth(self.admin_account, self.admin_password))
res_json = dict
res_json = response.json()
if response.status_code == 200:
res_json = response.json()
self._print_table_simple(res_json['data'])
else:
print(f"Fail to get all users, code: {res_json['code']}, message: {res_json['message']}")
@ -377,9 +410,8 @@ class AdminCLI:
url = f'http://{self.host}:{self.port}/api/v1/admin/users'
response = requests.get(url, auth=HTTPBasicAuth(self.admin_account, self.admin_password))
res_json = dict
res_json = response.json()
if response.status_code == 200:
res_json = response.json()
self._print_table_simple(res_json['data'])
else:
print(f"Fail to get all users, code: {res_json['code']}, message: {res_json['message']}")
@ -388,6 +420,13 @@ class AdminCLI:
username_tree: Tree = command['username']
username: str = username_tree.children[0].strip("'\"")
print(f"Showing user: {username}")
url = f'http://{self.host}:{self.port}/api/v1/admin/users/{username}'
response = requests.get(url, auth=HTTPBasicAuth(self.admin_account, self.admin_password))
res_json = response.json()
if response.status_code == 200:
self._print_table_simple(res_json['data'])
else:
print(f"Fail to get user {username}, code: {res_json['code']}, message: {res_json['message']}")
def _handle_drop_user(self, command):
username_tree: Tree = command['username']
@ -400,16 +439,73 @@ class AdminCLI:
password_tree: Tree = command['password']
password: str = password_tree.children[0].strip("'\"")
print(f"Alter user: {username}, password: {password}")
url = f'http://{self.host}:{self.port}/api/v1/admin/users/{username}/password'
response = requests.put(url, auth=HTTPBasicAuth(self.admin_account, self.admin_password), json={'new_password': encrypt(password)})
res_json = response.json()
if response.status_code == 200:
print(res_json["message"])
else:
print(f"Fail to alter password, code: {res_json['code']}, message: {res_json['message']}")
def _handle_create_user(self, command):
username_tree: Tree = command['username']
username: str = username_tree.children[0].strip("'\"")
password_tree: Tree = command['password']
password: str = password_tree.children[0].strip("'\"")
role: str = command['role']
print(f"Create user: {username}, password: {password}, role: {role}")
url = f'http://{self.host}:{self.port}/api/v1/admin/users'
response = requests.post(
url,
auth=HTTPBasicAuth(self.admin_account, self.admin_password),
json={'username': username, 'password': encrypt(password), 'role': role}
)
res_json = response.json()
if response.status_code == 200:
self._print_table_simple(res_json['data'])
else:
print(f"Fail to create user {username}, code: {res_json['code']}, message: {res_json['message']}")
def _handle_activate_user(self, command):
username_tree: Tree = command['username']
username: str = username_tree.children[0].strip("'\"")
activate_tree: Tree = command['activate_status']
activate_status: str = activate_tree.children[0].strip("'\"")
if activate_status.lower() in ['on', 'off']:
print(f"Alter user {username} activate status, turn {activate_status.lower()}.")
url = f'http://{self.host}:{self.port}/api/v1/admin/users/{username}/activate'
response = requests.put(url, auth=HTTPBasicAuth(self.admin_account, self.admin_password), json={'activate_status': activate_status})
res_json = response.json()
if response.status_code == 200:
print(res_json["message"])
else:
print(f"Fail to alter activate status, code: {res_json['code']}, message: {res_json['message']}")
else:
print(f"Unknown activate status: {activate_status}.")
def _handle_list_datasets(self, command):
username_tree: Tree = command['username']
username: str = username_tree.children[0].strip("'\"")
print(f"Listing all datasets of user: {username}")
url = f'http://{self.host}:{self.port}/api/v1/admin/users/{username}/datasets'
response = requests.get(url, auth=HTTPBasicAuth(self.admin_account, self.admin_password))
res_json = response.json()
if response.status_code == 200:
self._print_table_simple(res_json['data'])
else:
print(f"Fail to get all datasets of {username}, code: {res_json['code']}, message: {res_json['message']}")
def _handle_list_agents(self, command):
username_tree: Tree = command['username']
username: str = username_tree.children[0].strip("'\"")
print(f"Listing all agents of user: {username}")
url = f'http://{self.host}:{self.port}/api/v1/admin/users/{username}/agents'
response = requests.get(url, auth=HTTPBasicAuth(self.admin_account, self.admin_password))
res_json = response.json()
if response.status_code == 200:
self._print_table_simple(res_json['data'])
else:
print(f"Fail to get all agents of {username}, code: {res_json['code']}, message: {res_json['message']}")
def _handle_meta_command(self, command):
meta_command = command['command']

View File

@ -10,6 +10,7 @@ from flask import Flask
from routes import admin_bp
from api.utils.log_utils import init_root_logger
from api.constants import SERVICE_CONF
from api import settings
from config import load_configurations, SERVICE_CONFIGS
stop_event = threading.Event()
@ -26,7 +27,7 @@ if __name__ == '__main__':
app = Flask(__name__)
app.register_blueprint(admin_bp)
settings.init_settings()
SERVICE_CONFIGS.configs = load_configurations(SERVICE_CONF)
try:

View File

@ -1,7 +1,8 @@
from flask import Blueprint, request
from auth import login_verify
from responses import success_response, error_response
from services import UserMgr, ServiceMgr
from services import UserMgr, ServiceMgr, UserServiceMgr
from exceptions import AdminException
admin_bp = Blueprint('admin', __name__, url_prefix='/api/v1/admin')
@ -38,13 +39,18 @@ def create_user():
password = data['password']
role = data.get('role', 'user')
user = UserMgr.create_user(username, password, role)
return success_response(user, "User created successfully", 201)
res = UserMgr.create_user(username, password, role)
if res["success"]:
user_info = res["user_info"]
user_info.pop("password") # do not return password
return success_response(user_info, "User created successfully")
else:
return error_response("create user failed")
except AdminException as e:
return error_response(e.message, e.code)
except Exception as e:
return error_response(str(e), 500)
return error_response(str(e))
@admin_bp.route('/users/<username>', methods=['DELETE'])
@ -69,8 +75,8 @@ def change_password(username):
return error_response("New password is required", 400)
new_password = data['new_password']
UserMgr.update_user_password(username, new_password)
return success_response(None, "Password updated successfully")
msg = UserMgr.update_user_password(username, new_password)
return success_response(None, msg)
except AdminException as e:
return error_response(e.message, e.code)
@ -78,6 +84,21 @@ def change_password(username):
return error_response(str(e), 500)
@admin_bp.route('/users/<username>/activate', methods=['PUT'])
@login_verify
def alter_user_activate_status(username):
try:
data = request.get_json()
if not data or 'activate_status' not in data:
return error_response("Activation status is required", 400)
activate_status = data['activate_status']
msg = UserMgr.update_user_activate_status(username, activate_status)
return success_response(None, msg)
except AdminException as e:
return error_response(e.message, e.code)
except Exception as e:
return error_response(str(e), 500)
@admin_bp.route('/users/<username>', methods=['GET'])
@login_verify
def get_user_details(username):
@ -90,6 +111,31 @@ def get_user_details(username):
except Exception as e:
return error_response(str(e), 500)
@admin_bp.route('/users/<username>/datasets', methods=['GET'])
@login_verify
def get_user_datasets(username):
try:
datasets_list = UserServiceMgr.get_user_datasets(username)
return success_response(datasets_list)
except AdminException as e:
return error_response(e.message, e.code)
except Exception as e:
return error_response(str(e), 500)
@admin_bp.route('/users/<username>/agents', methods=['GET'])
@login_verify
def get_user_agents(username):
try:
agents_list = UserServiceMgr.get_user_agents(username)
return success_response(agents_list)
except AdminException as e:
return error_response(e.message, e.code)
except Exception as e:
return error_response(str(e), 500)
@admin_bp.route('/services', methods=['GET'])
@login_verify

View File

@ -1,5 +1,13 @@
import re
from werkzeug.security import check_password_hash
from api.db import ActiveEnum
from api.db.services import UserService
from exceptions import AdminException
from api.db.joint_services.user_account_service import create_new_user
from api.db.services.canvas_service import UserCanvasService
from api.db.services.user_service import TenantService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.utils.crypt import decrypt
from exceptions import AdminException, UserAlreadyExistsError, UserNotFoundError
from config import SERVICE_CONFIGS
class UserMgr:
@ -13,19 +21,120 @@ class UserMgr:
@staticmethod
def get_user_details(username):
raise AdminException("get_user_details: not implemented")
# use email to query
users = UserService.query_user_by_email(username)
result = []
for user in users:
result.append({
'email': user.email,
'language': user.language,
'last_login_time': user.last_login_time,
'is_authenticated': user.is_authenticated,
'is_active': user.is_active,
'is_anonymous': user.is_anonymous,
'login_channel': user.login_channel,
'status': user.status,
'is_superuser': user.is_superuser,
'create_date': user.create_date,
'update_date': user.update_date
})
return result
@staticmethod
def create_user(username, password, role="user"):
raise AdminException("create_user: not implemented")
def create_user(username, password, role="user") -> dict:
# Validate the email address
if not re.match(r"^[\w\._-]+@([\w_-]+\.)+[\w-]{2,}$", username):
raise AdminException(f"Invalid email address: {username}!")
# Check if the email address is already used
if UserService.query(email=username):
raise UserAlreadyExistsError(username)
# Construct user info data
user_info_dict = {
"email": username,
"nickname": "", # ask user to edit it manually in settings.
"password": decrypt(password),
"login_channel": "password",
"is_superuser": role == "admin",
}
return create_new_user(user_info_dict)
@staticmethod
def delete_user(username):
# use email to delete
raise AdminException("delete_user: not implemented")
@staticmethod
def update_user_password(username, new_password):
raise AdminException("update_user_password: not implemented")
def update_user_password(username, new_password) -> str:
# use email to find user. check exist and unique.
user_list = UserService.query_user_by_email(username)
if not user_list:
raise UserNotFoundError(username)
elif len(user_list) > 1:
raise AdminException(f"Exist more than 1 user: {username}!")
# check new_password different from old.
usr = user_list[0]
psw = decrypt(new_password)
if check_password_hash(usr.password, psw):
return "Same password, no need to update!"
# update password
UserService.update_user_password(usr.id, psw)
return "Password updated successfully!"
@staticmethod
def update_user_activate_status(username, activate_status: str):
# use email to find user. check exist and unique.
user_list = UserService.query_user_by_email(username)
if not user_list:
raise UserNotFoundError(username)
elif len(user_list) > 1:
raise AdminException(f"Exist more than 1 user: {username}!")
# check activate status different from new
usr = user_list[0]
# format activate_status before handle
_activate_status = activate_status.lower()
target_status = {
'on': ActiveEnum.ACTIVE.value,
'off': ActiveEnum.INACTIVE.value,
}.get(_activate_status)
if not target_status:
raise AdminException(f"Invalid activate_status: {activate_status}")
if target_status == usr.is_active:
return f"User activate status is already {_activate_status}!"
# update is_active
UserService.update_user(usr.id, {"is_active": target_status})
return f"Turn {_activate_status} user activate status successfully!"
class UserServiceMgr:
@staticmethod
def get_user_datasets(username):
# use email to find user.
user_list = UserService.query_user_by_email(username)
if not user_list:
raise UserNotFoundError(username)
elif len(user_list) > 1:
raise AdminException(f"Exist more than 1 user: {username}!")
# find tenants
usr = user_list[0]
tenants = TenantService.get_joined_tenants_by_user_id(usr.id)
tenant_ids = [m["tenant_id"] for m in tenants]
# filter permitted kb and owned kb
return KnowledgebaseService.get_all_kb_by_tenant_ids(tenant_ids, usr.id)
@staticmethod
def get_user_agents(username):
# use email to find user.
user_list = UserService.query_user_by_email(username)
if not user_list:
raise UserNotFoundError(username)
elif len(user_list) > 1:
raise AdminException(f"Exist more than 1 user: {username}!")
# find tenants
usr = user_list[0]
tenants = TenantService.get_joined_tenants_by_user_id(usr.id)
tenant_ids = [m["tenant_id"] for m in tenants]
# filter permitted agents and owned agents
return UserCanvasService.get_all_agents_by_tenant_ids(tenant_ids, usr.id)
class ServiceMgr:

View File

@ -80,7 +80,7 @@ Here's description of each category:
- Prioritize the most specific applicable category
- Return only the category name without explanations
- Use "Other" only when no other category fits
""".format(
"\n - ".join(list(self.category_description.keys())),
"\n".join(descriptions)
@ -96,7 +96,7 @@ Here's description of each category:
class Categorize(LLM, ABC):
component_name = "Categorize"
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
def _invoke(self, **kwargs):
msg = self._canvas.get_history(self._param.message_history_window_size)
if not msg:
@ -112,7 +112,7 @@ class Categorize(LLM, ABC):
user_prompt = """
---- Real Data ----
{}
{}
""".format(" | ".join(["{}: \"{}\"".format(c["role"].upper(), re.sub(r"\n", "", c["content"], flags=re.DOTALL)) for c in msg]))
ans = chat_mdl.chat(self._param.sys_prompt, [{"role": "user", "content": user_prompt}], self._param.gen_conf())
logging.info(f"input: {user_prompt}, answer: {str(ans)}")
@ -134,4 +134,4 @@ class Categorize(LLM, ABC):
self.set_output("_next", cpn_ids)
def thoughts(self) -> str:
return "Which should it falls into {}? ...".format(",".join([f"`{c}`" for c, _ in self._param.category_description.items()]))
return "Which should it falls into {}? ...".format(",".join([f"`{c}`" for c, _ in self._param.category_description.items()]))

View File

@ -56,7 +56,7 @@ class StringTransform(Message, ABC):
"type": "line"
} for k, o in self.get_input_elements_from_text(self._param.script).items()}
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
def _invoke(self, **kwargs):
if self._param.method == "split":
self._split(kwargs.get("line"))

View File

@ -61,7 +61,7 @@ class ArXivParam(ToolParamBase):
class ArXiv(ToolBase, ABC):
component_name = "ArXiv"
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
def _invoke(self, **kwargs):
if not kwargs.get("query"):
self.set_output("formalized_content", "")
@ -97,6 +97,6 @@ class ArXiv(ToolBase, ABC):
def thoughts(self) -> str:
return """
Keywords: {}
Keywords: {}
Looking for the most relevant articles.
""".format(self.get_input().get("query", "-_-!"))
""".format(self.get_input().get("query", "-_-!"))

View File

@ -129,7 +129,7 @@ module.exports = { main };
class CodeExec(ToolBase, ABC):
component_name = "CodeExec"
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
def _invoke(self, **kwargs):
lang = kwargs.get("lang", self._param.lang)
script = kwargs.get("script", self._param.script)

View File

@ -73,7 +73,7 @@ class DuckDuckGoParam(ToolParamBase):
class DuckDuckGo(ToolBase, ABC):
component_name = "DuckDuckGo"
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
def _invoke(self, **kwargs):
if not kwargs.get("query"):
self.set_output("formalized_content", "")
@ -115,6 +115,6 @@ class DuckDuckGo(ToolBase, ABC):
def thoughts(self) -> str:
return """
Keywords: {}
Keywords: {}
Looking for the most relevant articles.
""".format(self.get_input().get("query", "-_-!"))
""".format(self.get_input().get("query", "-_-!"))

View File

@ -98,8 +98,8 @@ class EmailParam(ToolParamBase):
class Email(ToolBase, ABC):
component_name = "Email"
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60))
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60)))
def _invoke(self, **kwargs):
if not kwargs.get("to_email"):
self.set_output("success", False)
@ -212,4 +212,4 @@ class Email(ToolBase, ABC):
To: {}
Subject: {}
Your email is on its way—sit tight!
""".format(inputs.get("to_email", "-_-!"), inputs.get("subject", "-_-!"))
""".format(inputs.get("to_email", "-_-!"), inputs.get("subject", "-_-!"))

View File

@ -78,7 +78,7 @@ class ExeSQLParam(ToolParamBase):
class ExeSQL(ToolBase, ABC):
component_name = "ExeSQL"
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60))
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60)))
def _invoke(self, **kwargs):
def convert_decimals(obj):

View File

@ -57,7 +57,7 @@ class GitHubParam(ToolParamBase):
class GitHub(ToolBase, ABC):
component_name = "GitHub"
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
def _invoke(self, **kwargs):
if not kwargs.get("query"):
self.set_output("formalized_content", "")
@ -88,4 +88,4 @@ class GitHub(ToolBase, ABC):
assert False, self.output()
def thoughts(self) -> str:
return "Scanning GitHub repos related to `{}`.".format(self.get_input().get("query", "-_-!"))
return "Scanning GitHub repos related to `{}`.".format(self.get_input().get("query", "-_-!"))

View File

@ -116,7 +116,7 @@ class GoogleParam(ToolParamBase):
class Google(ToolBase, ABC):
component_name = "Google"
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
def _invoke(self, **kwargs):
if not kwargs.get("q"):
self.set_output("formalized_content", "")
@ -154,6 +154,6 @@ class Google(ToolBase, ABC):
def thoughts(self) -> str:
return """
Keywords: {}
Keywords: {}
Looking for the most relevant articles.
""".format(self.get_input().get("query", "-_-!"))
""".format(self.get_input().get("query", "-_-!"))

View File

@ -63,7 +63,7 @@ class GoogleScholarParam(ToolParamBase):
class GoogleScholar(ToolBase, ABC):
component_name = "GoogleScholar"
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
def _invoke(self, **kwargs):
if not kwargs.get("query"):
self.set_output("formalized_content", "")
@ -93,4 +93,4 @@ class GoogleScholar(ToolBase, ABC):
assert False, self.output()
def thoughts(self) -> str:
return "Looking for scholarly papers on `{}`,” prioritising reputable sources.".format(self.get_input().get("query", "-_-!"))
return "Looking for scholarly papers on `{}`,” prioritising reputable sources.".format(self.get_input().get("query", "-_-!"))

View File

@ -33,7 +33,7 @@ class PubMedParam(ToolParamBase):
self.meta:ToolMeta = {
"name": "pubmed_search",
"description": """
PubMed is an openly accessible, free database which includes primarily the MEDLINE database of references and abstracts on life sciences and biomedical topics.
PubMed is an openly accessible, free database which includes primarily the MEDLINE database of references and abstracts on life sciences and biomedical topics.
In addition to MEDLINE, PubMed provides access to:
- older references from the print version of Index Medicus, back to 1951 and earlier
- references to some journals before they were indexed in Index Medicus and MEDLINE, for instance Science, BMJ, and Annals of Surgery
@ -69,7 +69,7 @@ In addition to MEDLINE, PubMed provides access to:
class PubMed(ToolBase, ABC):
component_name = "PubMed"
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
def _invoke(self, **kwargs):
if not kwargs.get("query"):
self.set_output("formalized_content", "")
@ -105,4 +105,4 @@ class PubMed(ToolBase, ABC):
assert False, self.output()
def thoughts(self) -> str:
return "Looking for scholarly papers on `{}`,” prioritising reputable sources.".format(self.get_input().get("query", "-_-!"))
return "Looking for scholarly papers on `{}`,” prioritising reputable sources.".format(self.get_input().get("query", "-_-!"))

View File

@ -74,7 +74,7 @@ class RetrievalParam(ToolParamBase):
class Retrieval(ToolBase, ABC):
component_name = "Retrieval"
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
def _invoke(self, **kwargs):
if not kwargs.get("query"):
self.set_output("formalized_content", self._param.empty_response)
@ -164,18 +164,18 @@ class Retrieval(ToolBase, ABC):
# Format the chunks for JSON output (similar to how other tools do it)
json_output = kbinfos["chunks"].copy()
self._canvas.add_reference(kbinfos["chunks"], kbinfos["doc_aggs"])
form_cnt = "\n".join(kb_prompt(kbinfos, 200000, True))
# Set both formalized content and JSON output
self.set_output("formalized_content", form_cnt)
self.set_output("json", json_output)
return form_cnt
def thoughts(self) -> str:
return """
Keywords: {}
Keywords: {}
Looking for the most relevant articles.
""".format(self.get_input().get("query", "-_-!"))
""".format(self.get_input().get("query", "-_-!"))

View File

@ -77,7 +77,7 @@ class SearXNGParam(ToolParamBase):
class SearXNG(ToolBase, ABC):
component_name = "SearXNG"
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
def _invoke(self, **kwargs):
# Gracefully handle try-run without inputs
query = kwargs.get("query")
@ -94,7 +94,6 @@ class SearXNG(ToolBase, ABC):
last_e = ""
for _ in range(self._param.max_retries+1):
try:
# 构建搜索参数
search_params = {
'q': query,
'format': 'json',
@ -104,33 +103,29 @@ class SearXNG(ToolBase, ABC):
'pageno': 1
}
# 发送搜索请求
response = requests.get(
f"{searxng_url}/search",
params=search_params,
timeout=10
)
response.raise_for_status()
data = response.json()
# 验证响应数据
if not data or not isinstance(data, dict):
raise ValueError("Invalid response from SearXNG")
results = data.get("results", [])
if not isinstance(results, list):
raise ValueError("Invalid results format from SearXNG")
# 限制结果数量
results = results[:self._param.top_n]
# 处理搜索结果
self._retrieve_chunks(results,
get_title=lambda r: r.get("title", ""),
get_url=lambda r: r.get("url", ""),
get_content=lambda r: r.get("content", ""))
self.set_output("json", results)
return self.output("formalized_content")
@ -151,6 +146,6 @@ class SearXNG(ToolBase, ABC):
def thoughts(self) -> str:
return """
Keywords: {}
Keywords: {}
Searching with SearXNG for relevant results...
""".format(self.get_input().get("query", "-_-!"))

View File

@ -31,7 +31,7 @@ class TavilySearchParam(ToolParamBase):
self.meta:ToolMeta = {
"name": "tavily_search",
"description": """
Tavily is a search engine optimized for LLMs, aimed at efficient, quick and persistent search results.
Tavily is a search engine optimized for LLMs, aimed at efficient, quick and persistent search results.
When searching:
- Start with specific query which should focus on just a single aspect.
- Number of keywords in query should be less than 5.
@ -101,7 +101,7 @@ When searching:
class TavilySearch(ToolBase, ABC):
component_name = "TavilySearch"
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
def _invoke(self, **kwargs):
if not kwargs.get("query"):
self.set_output("formalized_content", "")
@ -136,7 +136,7 @@ class TavilySearch(ToolBase, ABC):
def thoughts(self) -> str:
return """
Keywords: {}
Keywords: {}
Looking for the most relevant articles.
""".format(self.get_input().get("query", "-_-!"))
@ -199,7 +199,7 @@ class TavilyExtractParam(ToolParamBase):
class TavilyExtract(ToolBase, ABC):
component_name = "TavilyExtract"
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
def _invoke(self, **kwargs):
self.tavily_client = TavilyClient(api_key=self._param.api_key)
last_e = None
@ -224,4 +224,4 @@ class TavilyExtract(ToolBase, ABC):
assert False, self.output()
def thoughts(self) -> str:
return "Opened {}—pulling out the main text…".format(self.get_input().get("urls", "-_-!"))
return "Opened {}—pulling out the main text…".format(self.get_input().get("urls", "-_-!"))

View File

@ -68,7 +68,7 @@ fund selection platform: through AI technology, is committed to providing excell
class WenCai(ToolBase, ABC):
component_name = "WenCai"
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
def _invoke(self, **kwargs):
if not kwargs.get("query"):
self.set_output("report", "")
@ -111,4 +111,4 @@ class WenCai(ToolBase, ABC):
assert False, self.output()
def thoughts(self) -> str:
return "Pulling live financial data for `{}`.".format(self.get_input().get("query", "-_-!"))
return "Pulling live financial data for `{}`.".format(self.get_input().get("query", "-_-!"))

View File

@ -64,7 +64,7 @@ class WikipediaParam(ToolParamBase):
class Wikipedia(ToolBase, ABC):
component_name = "Wikipedia"
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60))
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60)))
def _invoke(self, **kwargs):
if not kwargs.get("query"):
self.set_output("formalized_content", "")
@ -99,6 +99,6 @@ class Wikipedia(ToolBase, ABC):
def thoughts(self) -> str:
return """
Keywords: {}
Keywords: {}
Looking for the most relevant articles.
""".format(self.get_input().get("query", "-_-!"))
""".format(self.get_input().get("query", "-_-!"))

View File

@ -72,7 +72,7 @@ class YahooFinanceParam(ToolParamBase):
class YahooFinance(ToolBase, ABC):
component_name = "YahooFinance"
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60))
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60)))
def _invoke(self, **kwargs):
if not kwargs.get("stock_code"):
self.set_output("report", "")
@ -111,4 +111,4 @@ class YahooFinance(ToolBase, ABC):
assert False, self.output()
def thoughts(self) -> str:
return "Pulling live financial data for `{}`.".format(self.get_input().get("stock_code", "-_-!"))
return "Pulling live financial data for `{}`.".format(self.get_input().get("stock_code", "-_-!"))

View File

@ -27,7 +27,8 @@ from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
from api.db import StatusEnum
from api.db.db_models import close_connection
from api.db.services import UserService
from api.utils import CustomJSONEncoder, commands
from api.utils.json import CustomJSONEncoder
from api.utils import commands
from flask_mail import Mail
from flask_session import Session

View File

@ -23,7 +23,7 @@ from api.db.services.document_service import DocumentService
from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService
from api.db.services.user_service import TenantService, UserTenantService
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request, not_allowed_parameters
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request, not_allowed_parameters, active_required
from api.utils import get_uuid
from api.db import StatusEnum, FileSource
from api.db.services.knowledgebase_service import KnowledgebaseService
@ -38,6 +38,7 @@ from rag.utils.storage_factory import STORAGE_IMPL
@manager.route('/create', methods=['post']) # noqa: F821
@login_required
@active_required
@validate_request("name")
def create():
req = request.json

View File

@ -23,6 +23,11 @@ class StatusEnum(Enum):
INVALID = "0"
class ActiveEnum(Enum):
ACTIVE = "1"
INACTIVE = "0"
class UserTenantRole(StrEnum):
OWNER = 'owner'
ADMIN = 'admin'

View File

@ -26,12 +26,14 @@ from functools import wraps
from flask_login import UserMixin
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
from peewee import BigIntegerField, BooleanField, CharField, CompositeKey, DateTimeField, Field, FloatField, IntegerField, Metadata, Model, TextField
from peewee import InterfaceError, OperationalError, BigIntegerField, BooleanField, CharField, CompositeKey, DateTimeField, Field, FloatField, IntegerField, Metadata, Model, TextField
from playhouse.migrate import MySQLMigrator, PostgresqlMigrator, migrate
from playhouse.pool import PooledMySQLDatabase, PooledPostgresqlDatabase
from api import settings, utils
from api.db import ParserType, SerializedType
from api.utils.json import json_dumps, json_loads
from api.utils.configs import deserialize_b64, serialize_b64
def singleton(cls, *args, **kw):
@ -70,12 +72,12 @@ class JSONField(LongTextField):
def db_value(self, value):
if value is None:
value = self.default_value
return utils.json_dumps(value)
return json_dumps(value)
def python_value(self, value):
if not value:
return self.default_value
return utils.json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
return json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
class ListField(JSONField):
@ -91,21 +93,21 @@ class SerializedField(LongTextField):
def db_value(self, value):
if self._serialized_type == SerializedType.PICKLE:
return utils.serialize_b64(value, to_str=True)
return serialize_b64(value, to_str=True)
elif self._serialized_type == SerializedType.JSON:
if value is None:
return None
return utils.json_dumps(value, with_type=True)
return json_dumps(value, with_type=True)
else:
raise ValueError(f"the serialized type {self._serialized_type} is not supported")
def python_value(self, value):
if self._serialized_type == SerializedType.PICKLE:
return utils.deserialize_b64(value)
return deserialize_b64(value)
elif self._serialized_type == SerializedType.JSON:
if value is None:
return {}
return utils.json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
return json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
else:
raise ValueError(f"the serialized type {self._serialized_type} is not supported")
@ -250,36 +252,63 @@ class RetryingPooledMySQLDatabase(PooledMySQLDatabase):
super().__init__(*args, **kwargs)
def execute_sql(self, sql, params=None, commit=True):
from peewee import OperationalError
for attempt in range(self.max_retries + 1):
try:
return super().execute_sql(sql, params, commit)
except OperationalError as e:
if e.args[0] in (2013, 2006) and attempt < self.max_retries:
logging.warning(f"Lost connection (attempt {attempt + 1}/{self.max_retries}): {e}")
except (OperationalError, InterfaceError) as e:
error_codes = [2013, 2006]
error_messages = ['', 'Lost connection']
should_retry = (
(hasattr(e, 'args') and e.args and e.args[0] in error_codes) or
(str(e) in error_messages) or
(hasattr(e, '__class__') and e.__class__.__name__ == 'InterfaceError')
)
if should_retry and attempt < self.max_retries:
logging.warning(
f"Database connection issue (attempt {attempt+1}/{self.max_retries}): {e}"
)
self._handle_connection_loss()
time.sleep(self.retry_delay * (2**attempt))
time.sleep(self.retry_delay * (2 ** attempt))
else:
logging.error(f"DB execution failure: {e}")
raise
return None
def _handle_connection_loss(self):
self.close_all()
self.connect()
# self.close_all()
# self.connect()
try:
self.close()
except Exception:
pass
try:
self.connect()
except Exception as e:
logging.error(f"Failed to reconnect: {e}")
time.sleep(0.1)
self.connect()
def begin(self):
from peewee import OperationalError
for attempt in range(self.max_retries + 1):
try:
return super().begin()
except OperationalError as e:
if e.args[0] in (2013, 2006) and attempt < self.max_retries:
logging.warning(f"Lost connection during transaction (attempt {attempt + 1}/{self.max_retries})")
except (OperationalError, InterfaceError) as e:
error_codes = [2013, 2006]
error_messages = ['', 'Lost connection']
should_retry = (
(hasattr(e, 'args') and e.args and e.args[0] in error_codes) or
(str(e) in error_messages) or
(hasattr(e, '__class__') and e.__class__.__name__ == 'InterfaceError')
)
if should_retry and attempt < self.max_retries:
logging.warning(
f"Lost connection during transaction (attempt {attempt+1}/{self.max_retries})"
)
self._handle_connection_loss()
time.sleep(self.retry_delay * (2**attempt))
time.sleep(self.retry_delay * (2 ** attempt))
else:
raise
@ -299,7 +328,16 @@ class BaseDataBase:
def __init__(self):
database_config = settings.DATABASE.copy()
db_name = database_config.pop("name")
self.database_connection = PooledDatabase[settings.DATABASE_TYPE.upper()].value(db_name, **database_config)
pool_config = {
'max_retries': 5,
'retry_delay': 1,
}
database_config.update(pool_config)
self.database_connection = PooledDatabase[settings.DATABASE_TYPE.upper()].value(
db_name, **database_config
)
# self.database_connection = PooledDatabase[settings.DATABASE_TYPE.upper()].value(db_name, **database_config)
logging.info("init database on cluster mode successfully")

View File

View File

@ -0,0 +1,120 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import uuid
from api import settings
from api.db import FileType, UserTenantRole
from api.db.db_models import TenantLLM
from api.db.services.llm_service import get_init_tenant_llm
from api.db.services.file_service import FileService
from api.db.services.tenant_llm_service import TenantLLMService
from api.db.services.user_service import TenantService, UserService, UserTenantService
def create_new_user(user_info: dict) -> dict:
"""
Add a new user, and create tenant, tenant llm, file folder for new user.
:param user_info: {
"email": <example@example.com>,
"nickname": <str, "name">,
"password": <decrypted password>,
"login_channel": <enum, "password">,
"is_superuser": <bool, role == "admin">,
}
:return: {
"success": <bool>,
"user_info": <dict>, # if true, return user_info
}
"""
# generate user_id and access_token for user
user_id = uuid.uuid1().hex
user_info['id'] = user_id
user_info['access_token'] = uuid.uuid1().hex
# construct tenant info
tenant = {
"id": user_id,
"name": user_info["nickname"] + "s Kingdom",
"llm_id": settings.CHAT_MDL,
"embd_id": settings.EMBEDDING_MDL,
"asr_id": settings.ASR_MDL,
"parser_ids": settings.PARSERS,
"img2txt_id": settings.IMAGE2TEXT_MDL,
"rerank_id": settings.RERANK_MDL,
}
usr_tenant = {
"tenant_id": user_id,
"user_id": user_id,
"invited_by": user_id,
"role": UserTenantRole.OWNER,
}
# construct file folder info
file_id = uuid.uuid1().hex
file = {
"id": file_id,
"parent_id": file_id,
"tenant_id": user_id,
"created_by": user_id,
"name": "/",
"type": FileType.FOLDER.value,
"size": 0,
"location": "",
}
try:
tenant_llm = get_init_tenant_llm(user_id)
if not UserService.save(**user_info):
return {"success": False}
TenantService.insert(**tenant)
UserTenantService.insert(**usr_tenant)
TenantLLMService.insert_many(tenant_llm)
FileService.insert(file)
return {
"success": True,
"user_info": user_info,
}
except Exception as create_error:
logging.exception(create_error)
# rollback
try:
TenantService.delete_by_id(user_id)
except Exception as e:
logging.exception(e)
try:
u = UserTenantService.query(tenant_id=user_id)
if u:
UserTenantService.delete_by_id(u[0].id)
except Exception as e:
logging.exception(e)
try:
TenantLLM.delete().where(TenantLLM.tenant_id == user_id).execute()
except Exception as e:
logging.exception(e)
try:
FileService.delete_by_id(file["id"])
except Exception as e:
logging.exception(e)
# delete user row finally
try:
UserService.delete_by_id(user_id)
except Exception as e:
logging.exception(e)
# reraise
raise create_error

View File

@ -61,6 +61,36 @@ class UserCanvasService(CommonService):
return list(agents.dicts())
@classmethod
@DB.connection_context()
def get_all_agents_by_tenant_ids(cls, tenant_ids, user_id):
# will get all permitted agents, be cautious
fields = [
cls.model.title,
cls.model.permission,
cls.model.canvas_type,
cls.model.canvas_category
]
# find team agents and owned agents
agents = cls.model.select(*fields).where(
(cls.model.user_id.in_(tenant_ids) & (cls.model.permission == TenantPermission.TEAM.value)) | (
cls.model.user_id == user_id
)
)
# sort by create_time, asc
agents.order_by(cls.model.create_time.asc())
# maybe cause slow query by deep paginate, optimize later
offset, limit = 0, 50
res = []
while True:
ag_batch = agents.offset(offset).limit(limit)
_temp = list(ag_batch.dicts())
if not _temp:
break
res.extend(_temp)
offset += limit
return res
@classmethod
@DB.connection_context()
def get_by_tenant_id(cls, pid):

View File

@ -14,12 +14,24 @@
# limitations under the License.
#
from datetime import datetime
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
import peewee
from peewee import InterfaceError, OperationalError
from api.db.db_models import DB
from api.utils import current_timestamp, datetime_format, get_uuid
def retry_db_operation(func):
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=1, max=5),
retry=retry_if_exception_type((InterfaceError, OperationalError)),
before_sleep=lambda retry_state: print(f"RETRY {retry_state.attempt_number} TIMES"),
reraise=True,
)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return wrapper
class CommonService:
"""Base service class that provides common database operations.
@ -202,6 +214,7 @@ class CommonService:
@classmethod
@DB.connection_context()
@retry_db_operation
def update_by_id(cls, pid, data):
# Update a single record by ID
# Args:

View File

@ -190,6 +190,41 @@ class KnowledgebaseService(CommonService):
return list(kbs.dicts()), count
@classmethod
@DB.connection_context()
def get_all_kb_by_tenant_ids(cls, tenant_ids, user_id):
# will get all permitted kb, be cautious.
fields = [
cls.model.name,
cls.model.language,
cls.model.permission,
cls.model.doc_num,
cls.model.token_num,
cls.model.chunk_num,
cls.model.status,
cls.model.create_date,
cls.model.update_date
]
# find team kb and owned kb
kbs = cls.model.select(*fields).where(
(cls.model.tenant_id.in_(tenant_ids) & (cls.model.permission ==TenantPermission.TEAM.value)) | (
cls.model.tenant_id == user_id
)
)
# sort by create_time asc
kbs.order_by(cls.model.create_time.asc())
# maybe cause slow query by deep paginate, optimize later.
offset, limit = 0, 50
res = []
while True:
kb_batch = kbs.offset(offset).limit(limit)
_temp = list(kb_batch.dicts())
if not _temp:
break
res.extend(_temp)
offset += limit
return res
@classmethod
@DB.connection_context()
def get_kb_ids(cls, tenant_id):

View File

@ -100,6 +100,12 @@ class UserService(CommonService):
else:
return None
@classmethod
@DB.connection_context()
def query_user_by_email(cls, email):
users = cls.model.select().where((cls.model.email == email))
return list(users)
@classmethod
@DB.connection_context()
def save(cls, **kwargs):
@ -133,6 +139,17 @@ class UserService(CommonService):
cls.model.update(user_dict).where(
cls.model.id == user_id).execute()
@classmethod
@DB.connection_context()
def update_user_password(cls, user_id, new_password):
with DB.atomic():
update_dict = {
"password": generate_password_hash(str(new_password)),
"update_time": current_timestamp(),
"update_date": datetime_format(datetime.now())
}
cls.model.update(update_dict).where(cls.model.id == user_id).execute()
@classmethod
@DB.connection_context()
def is_admin(cls, user_id):

View File

@ -41,7 +41,7 @@ from api import utils
from api.db.db_models import init_database_tables as init_web_db
from api.db.init_data import init_web_data
from api.versions import get_ragflow_version
from api.utils import show_configs
from api.utils.configs import show_configs
from rag.settings import print_rag_settings
from rag.utils.mcp_tool_call_conn import shutdown_all_mcp_sessions
from rag.utils.redis_conn import RedisDistributedLock

View File

@ -24,7 +24,7 @@ import rag.utils.es_conn
import rag.utils.infinity_conn
import rag.utils.opensearch_conn
from api.constants import RAG_FLOW_SERVICE_NAME
from api.utils import decrypt_database_config, get_base_config
from api.utils.configs import decrypt_database_config, get_base_config
from api.utils.file_utils import get_project_base_directory
from rag.nlp import search

View File

@ -16,182 +16,15 @@
import base64
import datetime
import hashlib
import io
import json
import os
import pickle
import socket
import time
import uuid
import requests
import logging
import copy
from enum import Enum, IntEnum
import importlib
from filelock import FileLock
from api.constants import SERVICE_CONF
from . import file_utils
def conf_realpath(conf_name):
conf_path = f"conf/{conf_name}"
return os.path.join(file_utils.get_project_base_directory(), conf_path)
def read_config(conf_name=SERVICE_CONF):
local_config = {}
local_path = conf_realpath(f'local.{conf_name}')
# load local config file
if os.path.exists(local_path):
local_config = file_utils.load_yaml_conf(local_path)
if not isinstance(local_config, dict):
raise ValueError(f'Invalid config file: "{local_path}".')
global_config_path = conf_realpath(conf_name)
global_config = file_utils.load_yaml_conf(global_config_path)
if not isinstance(global_config, dict):
raise ValueError(f'Invalid config file: "{global_config_path}".')
global_config.update(local_config)
return global_config
CONFIGS = read_config()
def show_configs():
msg = f"Current configs, from {conf_realpath(SERVICE_CONF)}:"
for k, v in CONFIGS.items():
if isinstance(v, dict):
if "password" in v:
v = copy.deepcopy(v)
v["password"] = "*" * 8
if "access_key" in v:
v = copy.deepcopy(v)
v["access_key"] = "*" * 8
if "secret_key" in v:
v = copy.deepcopy(v)
v["secret_key"] = "*" * 8
if "secret" in v:
v = copy.deepcopy(v)
v["secret"] = "*" * 8
if "sas_token" in v:
v = copy.deepcopy(v)
v["sas_token"] = "*" * 8
if "oauth" in k:
v = copy.deepcopy(v)
for key, val in v.items():
if "client_secret" in val:
val["client_secret"] = "*" * 8
if "authentication" in k:
v = copy.deepcopy(v)
for key, val in v.items():
if "http_secret_key" in val:
val["http_secret_key"] = "*" * 8
msg += f"\n\t{k}: {v}"
logging.info(msg)
def get_base_config(key, default=None):
if key is None:
return None
if default is None:
default = os.environ.get(key.upper())
return CONFIGS.get(key, default)
use_deserialize_safe_module = get_base_config(
'use_deserialize_safe_module', False)
class BaseType:
def to_dict(self):
return dict([(k.lstrip("_"), v) for k, v in self.__dict__.items()])
def to_dict_with_type(self):
def _dict(obj):
module = None
if issubclass(obj.__class__, BaseType):
data = {}
for attr, v in obj.__dict__.items():
k = attr.lstrip("_")
data[k] = _dict(v)
module = obj.__module__
elif isinstance(obj, (list, tuple)):
data = []
for i, vv in enumerate(obj):
data.append(_dict(vv))
elif isinstance(obj, dict):
data = {}
for _k, vv in obj.items():
data[_k] = _dict(vv)
else:
data = obj
return {"type": obj.__class__.__name__,
"data": data, "module": module}
return _dict(self)
class CustomJSONEncoder(json.JSONEncoder):
def __init__(self, **kwargs):
self._with_type = kwargs.pop("with_type", False)
super().__init__(**kwargs)
def default(self, obj):
if isinstance(obj, datetime.datetime):
return obj.strftime('%Y-%m-%d %H:%M:%S')
elif isinstance(obj, datetime.date):
return obj.strftime('%Y-%m-%d')
elif isinstance(obj, datetime.timedelta):
return str(obj)
elif issubclass(type(obj), Enum) or issubclass(type(obj), IntEnum):
return obj.value
elif isinstance(obj, set):
return list(obj)
elif issubclass(type(obj), BaseType):
if not self._with_type:
return obj.to_dict()
else:
return obj.to_dict_with_type()
elif isinstance(obj, type):
return obj.__name__
else:
return json.JSONEncoder.default(self, obj)
def rag_uuid():
return uuid.uuid1().hex
def string_to_bytes(string):
return string if isinstance(
string, bytes) else string.encode(encoding="utf-8")
def bytes_to_string(byte):
return byte.decode(encoding="utf-8")
def json_dumps(src, byte=False, indent=None, with_type=False):
dest = json.dumps(
src,
indent=indent,
cls=CustomJSONEncoder,
with_type=with_type)
if byte:
dest = string_to_bytes(dest)
return dest
def json_loads(src, object_hook=None, object_pairs_hook=None):
if isinstance(src, bytes):
src = bytes_to_string(src)
return json.loads(src, object_hook=object_hook,
object_pairs_hook=object_pairs_hook)
from .common import string_to_bytes
def current_timestamp():
@ -213,45 +46,6 @@ def date_string_to_timestamp(time_str, format_string="%Y-%m-%d %H:%M:%S"):
return time_stamp
def serialize_b64(src, to_str=False):
dest = base64.b64encode(pickle.dumps(src))
if not to_str:
return dest
else:
return bytes_to_string(dest)
def deserialize_b64(src):
src = base64.b64decode(
string_to_bytes(src) if isinstance(
src, str) else src)
if use_deserialize_safe_module:
return restricted_loads(src)
return pickle.loads(src)
safe_module = {
'numpy',
'rag_flow'
}
class RestrictedUnpickler(pickle.Unpickler):
def find_class(self, module, name):
import importlib
if module.split('.')[0] in safe_module:
_module = importlib.import_module(module)
return getattr(_module, name)
# Forbid everything else.
raise pickle.UnpicklingError("global '%s.%s' is forbidden" %
(module, name))
def restricted_loads(src):
"""Helper function analogous to pickle.loads()."""
return RestrictedUnpickler(io.BytesIO(src)).load()
def get_lan_ip():
if os.name != "nt":
import fcntl
@ -296,47 +90,6 @@ def from_dict_hook(in_dict: dict):
return in_dict
def decrypt_database_password(password):
encrypt_password = get_base_config("encrypt_password", False)
encrypt_module = get_base_config("encrypt_module", False)
private_key = get_base_config("private_key", None)
if not password or not encrypt_password:
return password
if not private_key:
raise ValueError("No private key")
module_fun = encrypt_module.split("#")
pwdecrypt_fun = getattr(
importlib.import_module(
module_fun[0]),
module_fun[1])
return pwdecrypt_fun(private_key, password)
def decrypt_database_config(
database=None, passwd_key="password", name="database"):
if not database:
database = get_base_config(name, {})
database[passwd_key] = decrypt_database_password(database[passwd_key])
return database
def update_config(key, value, conf_name=SERVICE_CONF):
conf_path = conf_realpath(conf_name=conf_name)
if not os.path.isabs(conf_path):
conf_path = os.path.join(
file_utils.get_project_base_directory(), conf_path)
with FileLock(os.path.join(os.path.dirname(conf_path), ".lock")):
config = file_utils.load_yaml_conf(conf_path=conf_path) or {}
config[key] = value
file_utils.rewrite_yaml_conf(conf_path=conf_path, config=config)
def get_uuid():
return uuid.uuid1().hex
@ -375,5 +128,5 @@ def delta_seconds(date_string: str):
return (datetime.datetime.now() - dt).total_seconds()
def hash_str2int(line:str, mod: int=10 ** 8) -> int:
return int(hashlib.sha1(line.encode("utf-8")).hexdigest(), 16) % mod
def hash_str2int(line: str, mod: int = 10 ** 8) -> int:
return int(hashlib.sha1(line.encode("utf-8")).hexdigest(), 16) % mod

View File

@ -39,6 +39,7 @@ from flask import (
make_response,
send_file,
)
from flask_login import current_user
from flask import (
request as flask_request,
)
@ -48,10 +49,13 @@ from werkzeug.http import HTTP_STATUS_CODES
from api import settings
from api.constants import REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC
from api.db import ActiveEnum
from api.db.db_models import APIToken
from api.db.services import UserService
from api.db.services.llm_service import LLMService
from api.db.services.tenant_llm_service import TenantLLMService
from api.utils import CustomJSONEncoder, get_uuid, json_dumps
from api.utils.json import CustomJSONEncoder, json_dumps
from api.utils import get_uuid
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSONEncoder)
@ -226,6 +230,18 @@ def not_allowed_parameters(*params):
return decorator
def active_required(f):
@wraps(f)
def wrapper(*args, **kwargs):
user_id = current_user.id
usr = UserService.filter_by_id(user_id)
# check is_active
if not usr or not usr.is_active == ActiveEnum.ACTIVE.value:
return get_json_result(code=settings.RetCode.FORBIDDEN, message="User isn't active, please activate first.")
return f(*args, **kwargs)
return wrapper
def is_localhost(ip):
return ip in {"127.0.0.1", "::1", "[::1]", "localhost"}

23
api/utils/common.py Normal file
View File

@ -0,0 +1,23 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
def string_to_bytes(string):
return string if isinstance(
string, bytes) else string.encode(encoding="utf-8")
def bytes_to_string(byte):
return byte.decode(encoding="utf-8")

179
api/utils/configs.py Normal file
View File

@ -0,0 +1,179 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import io
import copy
import logging
import base64
import pickle
import importlib
from api.utils import file_utils
from filelock import FileLock
from api.utils.common import bytes_to_string, string_to_bytes
from api.constants import SERVICE_CONF
def conf_realpath(conf_name):
conf_path = f"conf/{conf_name}"
return os.path.join(file_utils.get_project_base_directory(), conf_path)
def read_config(conf_name=SERVICE_CONF):
local_config = {}
local_path = conf_realpath(f'local.{conf_name}')
# load local config file
if os.path.exists(local_path):
local_config = file_utils.load_yaml_conf(local_path)
if not isinstance(local_config, dict):
raise ValueError(f'Invalid config file: "{local_path}".')
global_config_path = conf_realpath(conf_name)
global_config = file_utils.load_yaml_conf(global_config_path)
if not isinstance(global_config, dict):
raise ValueError(f'Invalid config file: "{global_config_path}".')
global_config.update(local_config)
return global_config
CONFIGS = read_config()
def show_configs():
msg = f"Current configs, from {conf_realpath(SERVICE_CONF)}:"
for k, v in CONFIGS.items():
if isinstance(v, dict):
if "password" in v:
v = copy.deepcopy(v)
v["password"] = "*" * 8
if "access_key" in v:
v = copy.deepcopy(v)
v["access_key"] = "*" * 8
if "secret_key" in v:
v = copy.deepcopy(v)
v["secret_key"] = "*" * 8
if "secret" in v:
v = copy.deepcopy(v)
v["secret"] = "*" * 8
if "sas_token" in v:
v = copy.deepcopy(v)
v["sas_token"] = "*" * 8
if "oauth" in k:
v = copy.deepcopy(v)
for key, val in v.items():
if "client_secret" in val:
val["client_secret"] = "*" * 8
if "authentication" in k:
v = copy.deepcopy(v)
for key, val in v.items():
if "http_secret_key" in val:
val["http_secret_key"] = "*" * 8
msg += f"\n\t{k}: {v}"
logging.info(msg)
def get_base_config(key, default=None):
if key is None:
return None
if default is None:
default = os.environ.get(key.upper())
return CONFIGS.get(key, default)
def decrypt_database_password(password):
encrypt_password = get_base_config("encrypt_password", False)
encrypt_module = get_base_config("encrypt_module", False)
private_key = get_base_config("private_key", None)
if not password or not encrypt_password:
return password
if not private_key:
raise ValueError("No private key")
module_fun = encrypt_module.split("#")
pwdecrypt_fun = getattr(
importlib.import_module(
module_fun[0]),
module_fun[1])
return pwdecrypt_fun(private_key, password)
def decrypt_database_config(
database=None, passwd_key="password", name="database"):
if not database:
database = get_base_config(name, {})
database[passwd_key] = decrypt_database_password(database[passwd_key])
return database
def update_config(key, value, conf_name=SERVICE_CONF):
conf_path = conf_realpath(conf_name=conf_name)
if not os.path.isabs(conf_path):
conf_path = os.path.join(
file_utils.get_project_base_directory(), conf_path)
with FileLock(os.path.join(os.path.dirname(conf_path), ".lock")):
config = file_utils.load_yaml_conf(conf_path=conf_path) or {}
config[key] = value
file_utils.rewrite_yaml_conf(conf_path=conf_path, config=config)
safe_module = {
'numpy',
'rag_flow'
}
class RestrictedUnpickler(pickle.Unpickler):
def find_class(self, module, name):
import importlib
if module.split('.')[0] in safe_module:
_module = importlib.import_module(module)
return getattr(_module, name)
# Forbid everything else.
raise pickle.UnpicklingError("global '%s.%s' is forbidden" %
(module, name))
def restricted_loads(src):
"""Helper function analogous to pickle.loads()."""
return RestrictedUnpickler(io.BytesIO(src)).load()
def serialize_b64(src, to_str=False):
dest = base64.b64encode(pickle.dumps(src))
if not to_str:
return dest
else:
return bytes_to_string(dest)
def deserialize_b64(src):
src = base64.b64decode(
string_to_bytes(src) if isinstance(
src, str) else src)
use_deserialize_safe_module = get_base_config(
'use_deserialize_safe_module', False)
if use_deserialize_safe_module:
return restricted_loads(src)
return pickle.loads(src)

View File

@ -23,6 +23,9 @@ from api.utils import file_utils
def crypt(line):
"""
decrypt(crypt(input_string)) == base64(input_string), which frontend and admin_client use.
"""
file_path = os.path.join(file_utils.get_project_base_directory(), "conf", "public.pem")
rsa_key = RSA.importKey(open(file_path).read(), "Welcome")
cipher = Cipher_pkcs1_v1_5.new(rsa_key)

78
api/utils/json.py Normal file
View File

@ -0,0 +1,78 @@
import datetime
import json
from enum import Enum, IntEnum
from api.utils.common import string_to_bytes, bytes_to_string
class BaseType:
def to_dict(self):
return dict([(k.lstrip("_"), v) for k, v in self.__dict__.items()])
def to_dict_with_type(self):
def _dict(obj):
module = None
if issubclass(obj.__class__, BaseType):
data = {}
for attr, v in obj.__dict__.items():
k = attr.lstrip("_")
data[k] = _dict(v)
module = obj.__module__
elif isinstance(obj, (list, tuple)):
data = []
for i, vv in enumerate(obj):
data.append(_dict(vv))
elif isinstance(obj, dict):
data = {}
for _k, vv in obj.items():
data[_k] = _dict(vv)
else:
data = obj
return {"type": obj.__class__.__name__,
"data": data, "module": module}
return _dict(self)
class CustomJSONEncoder(json.JSONEncoder):
def __init__(self, **kwargs):
self._with_type = kwargs.pop("with_type", False)
super().__init__(**kwargs)
def default(self, obj):
if isinstance(obj, datetime.datetime):
return obj.strftime('%Y-%m-%d %H:%M:%S')
elif isinstance(obj, datetime.date):
return obj.strftime('%Y-%m-%d')
elif isinstance(obj, datetime.timedelta):
return str(obj)
elif issubclass(type(obj), Enum) or issubclass(type(obj), IntEnum):
return obj.value
elif isinstance(obj, set):
return list(obj)
elif issubclass(type(obj), BaseType):
if not self._with_type:
return obj.to_dict()
else:
return obj.to_dict_with_type()
elif isinstance(obj, type):
return obj.__name__
else:
return json.JSONEncoder.default(self, obj)
def json_dumps(src, byte=False, indent=None, with_type=False):
dest = json.dumps(
src,
indent=indent,
cls=CustomJSONEncoder,
with_type=with_type)
if byte:
dest = string_to_bytes(dest)
return dest
def json_loads(src, object_hook=None, object_pairs_hook=None):
if isinstance(src, bytes):
src = bytes_to_string(src)
return json.loads(src, object_hook=object_hook,
object_pairs_hook=object_pairs_hook)

View File

@ -350,7 +350,7 @@ class TextRecognizer:
def close(self):
# close session and release manually
logging.info('Close TextRecognizer.')
logging.info('Close text recognizer.')
if hasattr(self, "predictor"):
del self.predictor
gc.collect()
@ -490,7 +490,7 @@ class TextDetector:
return dt_boxes
def close(self):
logging.info("Close TextDetector.")
logging.info("Close text detector.")
if hasattr(self, "predictor"):
del self.predictor
gc.collect()

View File

@ -143,7 +143,7 @@ class Base(ABC):
logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2))
if self.model_name.lower().find("qwen3") >= 0:
kwargs["extra_body"] = {"enable_thinking": False}
response = self.client.chat.completions.create(model=self.model_name, messages=history, **gen_conf, **kwargs)
if (not response.choices or not response.choices[0].message or not response.choices[0].message.content):
@ -156,12 +156,12 @@ class Base(ABC):
def _chat_streamly(self, history, gen_conf, **kwargs):
logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4))
reasoning_start = False
if kwargs.get("stop") or "stop" in gen_conf:
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf, stop=kwargs.get("stop"))
else:
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf)
for resp in response:
if not resp.choices:
continue
@ -643,7 +643,7 @@ class ZhipuChat(Base):
del gen_conf["max_tokens"]
gen_conf = self._clean_conf_plealty(gen_conf)
return gen_conf
def _clean_conf_plealty(self, gen_conf):
if "presence_penalty" in gen_conf:
del gen_conf["presence_penalty"]

View File

@ -56,7 +56,7 @@ class FulltextQueryer:
def rmWWW(txt):
patts = [
(
r"是*(什么样的|哪家|一下|那家|请问|啥样|咋样了|什么时候|何时|何地|何人|是否|是不是|多少|哪里|怎么|哪儿|怎么样|如何|哪些|是啥|啥是|啊|吗|呢|吧|咋|什么|有没有|呀|谁|哪位|哪个)是*",
r"是*(怎么办|什么样的|哪家|一下|那家|请问|啥样|咋样了|什么时候|何时|何地|何人|是否|是不是|多少|哪里|怎么|哪儿|怎么样|如何|哪些|是啥|啥是|啊|吗|呢|吧|咋|什么|有没有|呀|谁|哪位|哪个)是*",
"",
),
(r"(^| )(what|who|how|which|where|why)('re|'s)? ", " "),

View File

@ -15,7 +15,7 @@
#
import os
import logging
from api.utils import get_base_config, decrypt_database_config
from api.utils.configs import get_base_config, decrypt_database_config
from api.utils.file_utils import get_project_base_directory
# Server

View File

@ -3,7 +3,7 @@ import logging
import pymysql
from urllib.parse import quote_plus
from api.utils import get_base_config
from api.utils.configs import get_base_config
from rag.utils import singleton

View File

@ -949,7 +949,8 @@ export default {
multimodalModels: 'Мультимодальные модели',
textOnlyModels: 'Только текстовые модели',
allModels: 'Все модели',
codeExecDescription: 'Напишите свою пользовательскую логику на Python или Javascript.',
codeExecDescription:
'Напишите свою пользовательскую логику на Python или Javascript.',
stringTransformDescription:
'Изменяет текстовое содержимое. В настоящее время поддерживает: разделение или объединение текста.',
foundation: 'Основа',