<feat> 安全限制

This commit is contained in:
tangyi
2025-03-27 15:44:49 +08:00
parent 370b2ac9da
commit b0463903f5
12 changed files with 551 additions and 12 deletions

View File

@ -7,4 +7,32 @@ MYSQL_DATABASE=test
# 服务器配置
PORT=3000
HOST=127.0.0.1
HOST=127.0.0.1
# 环境配置
ENV_TYPE=development # 环境类型development/production
# SQL风险控制配置
ALLOWED_RISK_LEVELS=LOW,MEDIUM,HIGH # 允许的风险等级LOW/MEDIUM/HIGH/CRITICAL
# 如需执行无WHERE条件的DELETE操作需要将CRITICAL添加到ALLOWED_RISK_LEVELS中
# 例如ALLOWED_RISK_LEVELS=LOW,MEDIUM,HIGH,CRITICAL
# 禁止的SQL模式正则表达式用逗号分隔
BLOCKED_PATTERNS=DROP\s+DATABASE,TRUNCATE\s+TABLE
# SQL安全检查配置
ENABLE_QUERY_CHECK=true # 是否启用SQL安全检查
# 配置示例说明
# -------------------
# 开发环境配置示例:
# ENV_TYPE=development
# ALLOWED_RISK_LEVELS=LOW,MEDIUM,HIGH # 允许除CRITICAL外的所有风险等级
# BLOCKED_PATTERNS=DROP\s+DATABASE,TRUNCATE\s+TABLE # 仅禁止最危险的操作
# ENABLE_QUERY_CHECK=true # 启用SQL安全检查
# 生产环境配置示例:
# ENV_TYPE=production
# ALLOWED_RISK_LEVELS=LOW # 只允许低风险操作SELECT
# BLOCKED_PATTERNS=DROP,TRUNCATE,DELETE,UPDATE # 禁止所有数据修改操作
# ENABLE_QUERY_CHECK=true # 强制启用SQL安全检查

8
.gitignore vendored
View File

@ -8,4 +8,10 @@ vendor/
.idea
.cursorrules
.cursor
**/__pycache__/
**/__pycache__/
tests/
.env
.coverage
docs/
.pytest_cache/
requirements-dev.txt

View File

@ -8,7 +8,14 @@
- 支持SSEServer-Sent Events实时数据传输
- 提供MySQL数据库查询接口
- 完整的日志记录系统
- 自动事务管理(提交/回滚)
- 环境变量配置支持
- SQL安全检查机制
- 风险等级控制
- SQL注入防护
- 危险操作拦截
- WHERE子句强制检查
- 自动返回修改操作影响的行数
## 系统要求
@ -36,6 +43,7 @@ pip install -r requirements.txt
`.env`文件中配置以下参数:
### 基本配置
- `HOST`: 服务器监听地址默认127.0.0.1
- `PORT`: 服务器监听端口默认3000
- `MYSQL_HOST`: MySQL服务器地址
@ -44,6 +52,45 @@ pip install -r requirements.txt
- `MYSQL_PASSWORD`: MySQL密码
- `MYSQL_DATABASE`: MySQL数据库名
### SQL安全配置
- `ENV_TYPE`: 环境类型development/production
- `ALLOWED_RISK_LEVELS`: 允许的风险等级LOW/MEDIUM/HIGH/CRITICAL
- `BLOCKED_PATTERNS`: 禁止的SQL模式正则表达式用逗号分隔
- `ENABLE_QUERY_CHECK`: 是否启用SQL安全检查true/false
## SQL安全机制
### 风险等级控制
- LOW: 查询操作SELECT
- MEDIUM: 基本数据修改INSERT有WHERE的UPDATE/DELETE
- HIGH: 结构变更CREATE/ALTER和无WHERE的UPDATE
- CRITICAL: 危险操作DROP/TRUNCATE和无WHERE的DELETE操作
### 环境变量加载顺序
项目使用python-dotenv加载环境变量需要确保在导入其他模块前先加载环境变量否则可能会导致配置未被正确应用。
### 安全检查机制
- 强制要求UPDATE/DELETE操作包含WHERE子句
- SQL语句语法检查
- 危险操作模式检测
- 自动识别受影响的表
- 风险等级评估
### 环境特性
- 开发环境:允许较高风险的操作,但仍需遵守基本安全规则
- 生产环境默认只允许LOW风险操作严格限制数据修改
### 安全拦截
- SQL语句长度限制
- 危险操作模式检测
- 自动识别受影响的表
- SQL注入防护
### 事务管理
- 对于修改操作INSERT/UPDATE/DELETE会自动提交事务
- 执行错误时自动回滚事务
- 返回操作影响的行数
## 启动服务器
运行以下命令启动服务器:
@ -61,11 +108,34 @@ python src/server.py
├── src/ # 源代码目录
│ ├── server.py # 主服务器文件
│ ├── db/ # 数据库相关代码
│ ├── security/ # SQL安全相关代码
│ │ ├── interceptor.py # SQL拦截器
│ │ ├── query_limiter.py # SQL安全检查器
│ │ └── sql_analyzer.py # SQL分析器
│ └── tools/ # 工具类代码
├── tests/ # 测试代码目录
├── .env.example # 环境变量示例文件
└── requirements.txt # 项目依赖文件
```
## 常见问题解决
### DELETE操作未执行成功
- 检查DELETE操作是否包含WHERE条件
- 无WHERE条件的DELETE操作被标记为CRITICAL风险级别
- 确保环境变量ALLOWED_RISK_LEVELS中包含CRITICAL如果需要执行该操作
- 检查影响行数返回值,确认操作是否实际影响了数据库
### 环境变量未生效
- 确保在server.py中的load_dotenv()调用发生在导入其他模块之前
- 重启应用以确保环境变量被正确加载
- 检查日志中"从环境变量读取到的风险等级设置"的输出
### 操作被安全机制拒绝
- 检查操作的风险级别是否在允许的范围内
- 如果需要执行高风险操作相应地调整ALLOWED_RISK_LEVELS
- 对于不带WHERE条件的UPDATE或DELETE可以添加条件即使是WHERE 1=1降低风险级别
## 日志系统
服务器包含完整的日志记录系统,可以在控制台和日志文件中查看运行状态和错误信息。日志级别可以在`server.py`中配置。
@ -75,7 +145,9 @@ python src/server.py
服务器包含完善的错误处理机制:
- MySQL连接器导入检查
- 数据库配置验证
- SQL安全检查
- 运行时错误捕获和记录
- 事务自动回滚
## 贡献指南

3
src/__init__.py Normal file
View File

@ -0,0 +1,3 @@
"""
MySQL查询服务器包
"""

3
src/db/__init__.py Normal file
View File

@ -0,0 +1,3 @@
"""
MySQL数据库操作包
"""

View File

@ -5,8 +5,17 @@ from mysql.connector import Error
from contextlib import contextmanager
from typing import Any, Dict, List, Optional
from ..security.sql_analyzer import SQLOperationType
from ..security.query_limiter import QueryLimiter
from ..security.interceptor import SQLInterceptor, SecurityException
logger = logging.getLogger("mysql_server")
# 初始化安全组件
sql_analyzer = SQLOperationType()
query_limiter = QueryLimiter()
sql_interceptor = SQLInterceptor(sql_analyzer)
def get_db_config():
"""动态获取数据库配置"""
return {
@ -15,7 +24,8 @@ def get_db_config():
'password': os.getenv('MYSQL_PASSWORD', ''),
'database': os.getenv('MYSQL_DATABASE', ''),
'port': int(os.getenv('MYSQL_PORT', '3306')),
'connection_timeout': 5
'connection_timeout': 5,
'auth_plugin': 'mysql_native_password' # 指定认证插件
}
@contextmanager
@ -54,7 +64,7 @@ def get_db_connection():
connection.close()
logger.debug("数据库连接已关闭")
def execute_query(connection, query: str, params: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]:
async def execute_query(connection, query: str, params: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]:
"""
在给定的数据库连接上执行查询
@ -64,10 +74,18 @@ def execute_query(connection, query: str, params: Optional[Dict[str, Any]] = Non
params: 查询参数 (可选)
Returns:
查询结果列表
查询结果列表,如果是修改操作则返回影响的行数
Raises:
SecurityException: 当操作被安全机制拒绝时
ValueError: 当查询执行失败时
"""
cursor = None
try:
# 安全检查
if not await sql_interceptor.check_operation(query):
raise SecurityException("操作被安全机制拒绝")
cursor = connection.cursor(dictionary=True)
# 执行查询
@ -76,12 +94,33 @@ def execute_query(connection, query: str, params: Optional[Dict[str, Any]] = Non
else:
cursor.execute(query)
# 获取结果
# 获取操作类型
operation = query.strip().split()[0].upper()
# 对于修改操作,提交事务并返回影响的行数
if operation in {'UPDATE', 'DELETE', 'INSERT'}:
affected_rows = cursor.rowcount
# 提交事务,确保更改被保存
connection.commit()
logger.debug(f"修改操作 {operation} 影响了 {affected_rows} 行数据")
return [{'affected_rows': affected_rows}]
# 对于查询操作,返回结果集
results = cursor.fetchall()
logger.debug(f"查询返回 {len(results)} 条结果")
return results
except SecurityException as security_err:
logger.error(f"安全检查失败: {str(security_err)}")
raise
except mysql.connector.Error as query_err:
# 如果发生错误,进行回滚
if operation in {'UPDATE', 'DELETE', 'INSERT'}:
try:
connection.rollback()
logger.debug("事务已回滚")
except:
pass
logger.error(f"查询执行失败: {str(query_err)}")
raise ValueError(f"查询执行失败: {str(query_err)}")
finally:

View File

@ -0,0 +1,84 @@
import logging
import os
from typing import List, Dict
from .sql_analyzer import SQLOperationType, SQLRiskLevel
logger = logging.getLogger(__name__)
class SecurityException(Exception):
"""安全相关异常"""
pass
class SQLInterceptor:
"""SQL操作拦截器"""
def __init__(self, analyzer: SQLOperationType):
self.analyzer = analyzer
# 设置最大SQL长度限制默认1000个字符
self.max_sql_length = 1000
async def check_operation(self, sql_query: str) -> bool:
"""
检查SQL操作是否允许执行
Args:
sql_query: SQL查询语句
Returns:
bool: 是否允许执行
Raises:
SecurityException: 当操作被拒绝时抛出
"""
try:
# 检查SQL是否为空
if not sql_query or not sql_query.strip():
raise SecurityException("SQL语句不能为空")
# 检查SQL长度
if len(sql_query) > self.max_sql_length:
raise SecurityException(f"SQL语句长度({len(sql_query)})超出限制({self.max_sql_length})")
# 检查SQL是否有效
sql_parts = sql_query.strip().split()
if not sql_parts:
raise SecurityException("SQL语句格式无效")
operation = sql_parts[0].upper()
if operation not in {'SELECT', 'INSERT', 'UPDATE', 'DELETE', 'CREATE', 'ALTER', 'DROP', 'TRUNCATE', 'MERGE'}:
raise SecurityException(f"不支持的SQL操作: {operation}")
# 分析SQL风险
risk_analysis = self.analyzer.analyze_risk(sql_query)
# 检查是否是危险操作
if risk_analysis['is_dangerous']:
raise SecurityException(
f"检测到危险操作: {risk_analysis['operation']}"
)
# 检查操作是否被允许
if not risk_analysis['is_allowed']:
raise SecurityException(
f"当前操作风险等级({risk_analysis['risk_level'].name})不被允许执行,"
f"允许的风险等级: {[level.name for level in self.analyzer.allowed_risk_levels]}"
)
# 记录详细日志
logger.info(
f"SQL操作检查通过 - "
f"操作: {risk_analysis['operation']}, "
f"风险等级: {risk_analysis['risk_level'].name}, "
f"影响表: {', '.join(risk_analysis['affected_tables'])}"
)
return True
except SecurityException as e:
logger.error(str(e))
raise
except Exception as e:
error_msg = f"安全检查失败: {str(e)}"
logger.error(error_msg)
raise SecurityException(error_msg)

View File

@ -0,0 +1,68 @@
import os
import logging
from typing import Tuple
logger = logging.getLogger(__name__)
class QueryLimiter:
"""查询安全检查器"""
def __init__(self):
# 解析启用状态(默认启用)
enable_check = os.getenv('ENABLE_QUERY_CHECK', 'true')
self.enable_check = str(enable_check).lower() not in {'false', '0', 'no', 'off'}
def check_query(self, sql_query: str) -> Tuple[bool, str]:
"""
检查SQL查询是否安全
Args:
sql_query: SQL查询语句
Returns:
Tuple[bool, str]: (是否允许执行, 错误信息)
"""
if not self.enable_check:
return True, ""
sql_query = sql_query.strip().upper()
operation_type = self._get_operation_type(sql_query)
# 检查是否为无 WHERE 子句的更新/删除操作
if operation_type in {'UPDATE', 'DELETE'} and 'WHERE' not in sql_query:
error_msg = f"{operation_type}操作必须包含WHERE子句"
logger.warning(f"查询被限制: {error_msg}")
return False, error_msg
return True, ""
def _get_operation_type(self, sql_query: str) -> str:
"""获取SQL操作类型"""
if not sql_query:
return ""
words = sql_query.split()
if not words:
return ""
return words[0].upper()
def _parse_int_env(self, env_name: str, default: int) -> int:
"""解析整数类型的环境变量"""
try:
return int(os.getenv(env_name, str(default)))
except (ValueError, TypeError):
return default
def update_limits(self, new_limits: dict):
"""
更新限制阈值
Args:
new_limits: 新的限制值字典
"""
for operation, limit in new_limits.items():
if operation in self.max_limits:
try:
self.max_limits[operation] = int(limit)
logger.info(f"更新{operation}操作的限制为: {limit}")
except (ValueError, TypeError):
logger.warning(f"无效的限制值: {operation}={limit}")

View File

@ -0,0 +1,221 @@
import re
import os
from enum import IntEnum, Enum
import logging
from typing import Set, List
logger = logging.getLogger(__name__)
class SQLRiskLevel(IntEnum):
"""SQL操作风险等级"""
LOW = 1 # 查询操作SELECT
MEDIUM = 2 # 基本数据修改INSERT有WHERE的UPDATE/DELETE
HIGH = 3 # 结构变更CREATE/ALTER和无WHERE的数据修改
CRITICAL = 4 # 危险操作DROP/TRUNCATE等
class EnvironmentType(Enum):
"""环境类型"""
DEVELOPMENT = 'development'
PRODUCTION = 'production'
class SQLOperationType:
"""SQL操作类型分析器"""
def __init__(self):
# 环境类型处理
env_type_str = os.getenv('ENV_TYPE', 'development').lower()
try:
self.env_type = EnvironmentType(env_type_str)
except ValueError:
logger.warning(f"无效的环境类型: {env_type_str},使用默认值: development")
self.env_type = EnvironmentType.DEVELOPMENT
# 基础操作集合
self.ddl_operations = {
'CREATE', 'ALTER', 'DROP', 'TRUNCATE', 'RENAME'
}
self.dml_operations = {
'SELECT', 'INSERT', 'UPDATE', 'DELETE', 'MERGE'
}
# 风险等级配置
self.allowed_risk_levels = self._parse_risk_levels()
self.blocked_patterns = self._parse_blocked_patterns('BLOCKED_PATTERNS')
# 生产环境特殊处理如果没有明确配置风险等级则只允许LOW风险操作
if self.env_type == EnvironmentType.PRODUCTION and not os.getenv('ALLOWED_RISK_LEVELS'):
self.allowed_risk_levels = {SQLRiskLevel.LOW}
logger.info(f"SQL分析器初始化 - 环境: {self.env_type.value}")
logger.info(f"允许的风险等级: {[level.name for level in self.allowed_risk_levels]}")
def _parse_risk_levels(self) -> Set[SQLRiskLevel]:
"""解析允许的风险等级"""
allowed_levels_str = os.getenv('ALLOWED_RISK_LEVELS', 'LOW,MEDIUM')
allowed_levels = set()
logger.info(f"从环境变量读取到的风险等级设置: '{allowed_levels_str}'")
for level_str in allowed_levels_str.upper().split(','):
level_str = level_str.strip()
try:
allowed_levels.add(SQLRiskLevel[level_str])
except KeyError:
logger.warning(f"未知的风险等级配置: {level_str}")
return allowed_levels
def _parse_blocked_patterns(self, env_var: str) -> List[str]:
"""解析禁止的操作模式"""
patterns = os.getenv(env_var, '').split(',')
return [p.strip() for p in patterns if p.strip()]
def analyze_risk(self, sql_query: str) -> dict:
"""
分析SQL查询的风险级别和影响范围
Args:
sql_query: SQL查询语句
Returns:
dict: 包含风险分析结果的字典
"""
sql_query = sql_query.strip()
# 处理空SQL
if not sql_query:
return {
'operation': '',
'operation_type': 'UNKNOWN',
'is_dangerous': True,
'affected_tables': [],
'estimated_impact': {
'operation': '',
'estimated_rows': 0,
'needs_where': False,
'has_where': False
},
'risk_level': SQLRiskLevel.HIGH,
'is_allowed': False
}
operation = sql_query.split()[0].upper()
# 基本风险分析
risk_analysis = {
'operation': operation,
'operation_type': 'DDL' if operation in self.ddl_operations else 'DML',
'is_dangerous': self._check_dangerous_patterns(sql_query),
'affected_tables': self._get_affected_tables(sql_query),
'estimated_impact': self._estimate_impact(sql_query)
}
# 计算风险等级
risk_level = self._calculate_risk_level(sql_query, operation, risk_analysis['is_dangerous'])
risk_analysis['risk_level'] = risk_level
risk_analysis['is_allowed'] = risk_level in self.allowed_risk_levels
return risk_analysis
def _calculate_risk_level(self, sql_query: str, operation: str, is_dangerous: bool) -> SQLRiskLevel:
"""
计算操作风险等级
规则:
1. 危险操作(匹配危险模式)=> CRITICAL
2. DDL操作
- CREATE/ALTER => HIGH
- DROP/TRUNCATE => CRITICAL
3. DML操作
- SELECT => LOW
- INSERT => MEDIUM
- UPDATE/DELETE有WHERE=> MEDIUM
- UPDATE无WHERE=> HIGH
- DELETE无WHERE=> CRITICAL
"""
# 危险操作
if is_dangerous:
return SQLRiskLevel.CRITICAL
# 生产环境中非SELECT操作
if self.env_type == EnvironmentType.PRODUCTION and operation != 'SELECT':
return SQLRiskLevel.CRITICAL
# DDL操作
if operation in self.ddl_operations:
if operation in {'DROP', 'TRUNCATE'}:
return SQLRiskLevel.CRITICAL
return SQLRiskLevel.HIGH
# DML操作
if operation == 'SELECT':
return SQLRiskLevel.LOW
elif operation == 'INSERT':
return SQLRiskLevel.MEDIUM
elif operation == 'UPDATE':
return SQLRiskLevel.HIGH if 'WHERE' not in sql_query.upper() else SQLRiskLevel.MEDIUM
elif operation == 'DELETE':
# 无WHERE条件的DELETE操作视为CRITICAL风险
return SQLRiskLevel.CRITICAL if 'WHERE' not in sql_query.upper() else SQLRiskLevel.MEDIUM
return SQLRiskLevel.HIGH
def _check_dangerous_patterns(self, sql_query: str) -> bool:
"""检查是否匹配危险操作模式"""
sql_upper = sql_query.upper()
# 生产环境额外的安全检查
if self.env_type == EnvironmentType.PRODUCTION:
# 生产环境中禁止所有非SELECT操作
if sql_upper.split()[0] != 'SELECT':
return True
for pattern in self.blocked_patterns:
if re.search(pattern, sql_upper, re.IGNORECASE):
return True
return False
def _get_affected_tables(self, sql_query: str) -> list:
"""获取受影响的表名列表"""
words = sql_query.upper().split()
tables = []
for i, word in enumerate(words):
if word in {'FROM', 'JOIN', 'UPDATE', 'INTO', 'TABLE'}:
if i + 1 < len(words):
table = words[i + 1].strip('`;')
if table not in {'SELECT', 'WHERE', 'SET'}:
tables.append(table)
return list(set(tables))
def _estimate_impact(self, sql_query: str) -> dict:
"""
估算查询影响范围
Returns:
dict: 包含预估影响的字典
"""
operation = sql_query.split()[0].upper()
impact = {
'operation': operation,
'estimated_rows': 0,
'needs_where': operation in {'UPDATE', 'DELETE'},
'has_where': 'WHERE' in sql_query.upper()
}
# 根据环境类型调整估算
if self.env_type == EnvironmentType.PRODUCTION:
if operation == 'SELECT':
impact['estimated_rows'] = 100
else:
impact['estimated_rows'] = float('inf') # 生产环境中非SELECT操作视为影响无限行
else:
if operation == 'SELECT':
impact['estimated_rows'] = 100
elif operation in {'UPDATE', 'DELETE'}:
impact['estimated_rows'] = 1000 if impact['has_where'] else float('inf')
return impact

View File

@ -2,6 +2,11 @@ from mcp.server.fastmcp import FastMCP
import os
import logging
from dotenv import load_dotenv
# 加载环境变量 - 移到最前面确保所有模块导入前环境变量已加载
load_dotenv()
# 导入自定义模块 - 确保在load_dotenv之后导入
from src.tools.mysql_tool import register_mysql_tool
# 配置日志
@ -11,6 +16,10 @@ logging.basicConfig(
)
logger = logging.getLogger("mysql_server")
# 记录环境变量加载情况
logger.debug("已加载环境变量")
logger.debug(f"当前允许的风险等级: {os.getenv('ALLOWED_RISK_LEVELS', '未设置')}")
# 尝试导入MySQL连接器
try:
import mysql.connector
@ -21,10 +30,6 @@ except ImportError as e:
logger.critical("请确保已安装mysql-connector-python包: pip install mysql-connector-python")
mysql_available = False
# 加载环境变量
load_dotenv()
logger.debug("已加载环境变量")
# 从环境变量获取服务器配置
host = os.getenv('HOST', '127.0.0.1')
port = int(os.getenv('PORT', '3000'))

3
src/tools/__init__.py Normal file
View File

@ -0,0 +1,3 @@
"""
MySQL工具包
"""

View File

@ -24,7 +24,7 @@ def register_mysql_tool(mcp: FastMCP):
logger.debug("注册MySQL查询工具...")
@mcp.tool()
def mysql_query(query: str, params: Optional[Dict[str, Any]] = None) -> str:
async def mysql_query(query: str, params: Optional[Dict[str, Any]] = None) -> str:
"""
执行MySQL查询并返回结果
@ -39,7 +39,14 @@ def register_mysql_tool(mcp: FastMCP):
try:
with get_db_connection() as connection:
results = execute_query(connection, query, params)
results = await execute_query(connection, query, params)
# 检查是否是修改操作返回的影响行数
operation = query.strip().split()[0].upper()
if operation in {'UPDATE', 'DELETE', 'INSERT'} and results and 'affected_rows' in results[0]:
affected_rows = results[0]['affected_rows']
logger.info(f"{operation}操作影响了{affected_rows}行数据")
return json.dumps(results, default=str)
except Exception as e: