<feat> 安全限制

This commit is contained in:
tangyi
2025-03-27 15:44:49 +08:00
parent 370b2ac9da
commit b0463903f5
12 changed files with 551 additions and 12 deletions

3
src/db/__init__.py Normal file
View File

@ -0,0 +1,3 @@
"""
MySQL数据库操作包
"""

View File

@ -5,8 +5,17 @@ from mysql.connector import Error
from contextlib import contextmanager
from typing import Any, Dict, List, Optional
from ..security.sql_analyzer import SQLOperationType
from ..security.query_limiter import QueryLimiter
from ..security.interceptor import SQLInterceptor, SecurityException
logger = logging.getLogger("mysql_server")
# 初始化安全组件
sql_analyzer = SQLOperationType()
query_limiter = QueryLimiter()
sql_interceptor = SQLInterceptor(sql_analyzer)
def get_db_config():
"""动态获取数据库配置"""
return {
@ -15,7 +24,8 @@ def get_db_config():
'password': os.getenv('MYSQL_PASSWORD', ''),
'database': os.getenv('MYSQL_DATABASE', ''),
'port': int(os.getenv('MYSQL_PORT', '3306')),
'connection_timeout': 5
'connection_timeout': 5,
'auth_plugin': 'mysql_native_password' # 指定认证插件
}
@contextmanager
@ -54,7 +64,7 @@ def get_db_connection():
connection.close()
logger.debug("数据库连接已关闭")
def execute_query(connection, query: str, params: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]:
async def execute_query(connection, query: str, params: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]:
"""
在给定的数据库连接上执行查询
@ -64,10 +74,18 @@ def execute_query(connection, query: str, params: Optional[Dict[str, Any]] = Non
params: 查询参数 (可选)
Returns:
查询结果列表
查询结果列表,如果是修改操作则返回影响的行数
Raises:
SecurityException: 当操作被安全机制拒绝时
ValueError: 当查询执行失败时
"""
cursor = None
try:
# 安全检查
if not await sql_interceptor.check_operation(query):
raise SecurityException("操作被安全机制拒绝")
cursor = connection.cursor(dictionary=True)
# 执行查询
@ -76,12 +94,33 @@ def execute_query(connection, query: str, params: Optional[Dict[str, Any]] = Non
else:
cursor.execute(query)
# 获取结果
# 获取操作类型
operation = query.strip().split()[0].upper()
# 对于修改操作,提交事务并返回影响的行数
if operation in {'UPDATE', 'DELETE', 'INSERT'}:
affected_rows = cursor.rowcount
# 提交事务,确保更改被保存
connection.commit()
logger.debug(f"修改操作 {operation} 影响了 {affected_rows} 行数据")
return [{'affected_rows': affected_rows}]
# 对于查询操作,返回结果集
results = cursor.fetchall()
logger.debug(f"查询返回 {len(results)} 条结果")
return results
except SecurityException as security_err:
logger.error(f"安全检查失败: {str(security_err)}")
raise
except mysql.connector.Error as query_err:
# 如果发生错误,进行回滚
if operation in {'UPDATE', 'DELETE', 'INSERT'}:
try:
connection.rollback()
logger.debug("事务已回滚")
except:
pass
logger.error(f"查询执行失败: {str(query_err)}")
raise ValueError(f"查询执行失败: {str(query_err)}")
finally: