mirror of
https://github.com/mangooer/mysql-mcp-server-sse.git
synced 2025-12-08 09:42:27 +08:00
<feat> 安全限制
This commit is contained in:
28
.env.example
28
.env.example
@ -8,3 +8,31 @@ MYSQL_DATABASE=test
|
||||
# 服务器配置
|
||||
PORT=3000
|
||||
HOST=127.0.0.1
|
||||
|
||||
# 环境配置
|
||||
ENV_TYPE=development # 环境类型:development/production
|
||||
|
||||
# SQL风险控制配置
|
||||
ALLOWED_RISK_LEVELS=LOW,MEDIUM,HIGH # 允许的风险等级(LOW/MEDIUM/HIGH/CRITICAL)
|
||||
# 如需执行无WHERE条件的DELETE操作,需要将CRITICAL添加到ALLOWED_RISK_LEVELS中
|
||||
# 例如:ALLOWED_RISK_LEVELS=LOW,MEDIUM,HIGH,CRITICAL
|
||||
|
||||
# 禁止的SQL模式(正则表达式,用逗号分隔)
|
||||
BLOCKED_PATTERNS=DROP\s+DATABASE,TRUNCATE\s+TABLE
|
||||
|
||||
# SQL安全检查配置
|
||||
ENABLE_QUERY_CHECK=true # 是否启用SQL安全检查
|
||||
|
||||
# 配置示例说明
|
||||
# -------------------
|
||||
# 开发环境配置示例:
|
||||
# ENV_TYPE=development
|
||||
# ALLOWED_RISK_LEVELS=LOW,MEDIUM,HIGH # 允许除CRITICAL外的所有风险等级
|
||||
# BLOCKED_PATTERNS=DROP\s+DATABASE,TRUNCATE\s+TABLE # 仅禁止最危险的操作
|
||||
# ENABLE_QUERY_CHECK=true # 启用SQL安全检查
|
||||
|
||||
# 生产环境配置示例:
|
||||
# ENV_TYPE=production
|
||||
# ALLOWED_RISK_LEVELS=LOW # 只允许低风险操作(SELECT)
|
||||
# BLOCKED_PATTERNS=DROP,TRUNCATE,DELETE,UPDATE # 禁止所有数据修改操作
|
||||
# ENABLE_QUERY_CHECK=true # 强制启用SQL安全检查
|
||||
6
.gitignore
vendored
6
.gitignore
vendored
@ -9,3 +9,9 @@ vendor/
|
||||
.cursorrules
|
||||
.cursor
|
||||
**/__pycache__/
|
||||
tests/
|
||||
.env
|
||||
.coverage
|
||||
docs/
|
||||
.pytest_cache/
|
||||
requirements-dev.txt
|
||||
72
README.md
72
README.md
@ -8,7 +8,14 @@
|
||||
- 支持SSE(Server-Sent Events)实时数据传输
|
||||
- 提供MySQL数据库查询接口
|
||||
- 完整的日志记录系统
|
||||
- 自动事务管理(提交/回滚)
|
||||
- 环境变量配置支持
|
||||
- SQL安全检查机制
|
||||
- 风险等级控制
|
||||
- SQL注入防护
|
||||
- 危险操作拦截
|
||||
- WHERE子句强制检查
|
||||
- 自动返回修改操作影响的行数
|
||||
|
||||
## 系统要求
|
||||
|
||||
@ -36,6 +43,7 @@ pip install -r requirements.txt
|
||||
|
||||
在`.env`文件中配置以下参数:
|
||||
|
||||
### 基本配置
|
||||
- `HOST`: 服务器监听地址(默认:127.0.0.1)
|
||||
- `PORT`: 服务器监听端口(默认:3000)
|
||||
- `MYSQL_HOST`: MySQL服务器地址
|
||||
@ -44,6 +52,45 @@ pip install -r requirements.txt
|
||||
- `MYSQL_PASSWORD`: MySQL密码
|
||||
- `MYSQL_DATABASE`: MySQL数据库名
|
||||
|
||||
### SQL安全配置
|
||||
- `ENV_TYPE`: 环境类型(development/production)
|
||||
- `ALLOWED_RISK_LEVELS`: 允许的风险等级(LOW/MEDIUM/HIGH/CRITICAL)
|
||||
- `BLOCKED_PATTERNS`: 禁止的SQL模式(正则表达式,用逗号分隔)
|
||||
- `ENABLE_QUERY_CHECK`: 是否启用SQL安全检查(true/false)
|
||||
|
||||
## SQL安全机制
|
||||
|
||||
### 风险等级控制
|
||||
- LOW: 查询操作(SELECT)
|
||||
- MEDIUM: 基本数据修改(INSERT,有WHERE的UPDATE/DELETE)
|
||||
- HIGH: 结构变更(CREATE/ALTER)和无WHERE的UPDATE
|
||||
- CRITICAL: 危险操作(DROP/TRUNCATE)和无WHERE的DELETE操作
|
||||
|
||||
### 环境变量加载顺序
|
||||
项目使用python-dotenv加载环境变量,需要确保在导入其他模块前先加载环境变量,否则可能会导致配置未被正确应用。
|
||||
|
||||
### 安全检查机制
|
||||
- 强制要求UPDATE/DELETE操作包含WHERE子句
|
||||
- SQL语句语法检查
|
||||
- 危险操作模式检测
|
||||
- 自动识别受影响的表
|
||||
- 风险等级评估
|
||||
|
||||
### 环境特性
|
||||
- 开发环境:允许较高风险的操作,但仍需遵守基本安全规则
|
||||
- 生产环境:默认只允许LOW风险操作,严格限制数据修改
|
||||
|
||||
### 安全拦截
|
||||
- SQL语句长度限制
|
||||
- 危险操作模式检测
|
||||
- 自动识别受影响的表
|
||||
- SQL注入防护
|
||||
|
||||
### 事务管理
|
||||
- 对于修改操作(INSERT/UPDATE/DELETE)会自动提交事务
|
||||
- 执行错误时自动回滚事务
|
||||
- 返回操作影响的行数
|
||||
|
||||
## 启动服务器
|
||||
|
||||
运行以下命令启动服务器:
|
||||
@ -61,11 +108,34 @@ python src/server.py
|
||||
├── src/ # 源代码目录
|
||||
│ ├── server.py # 主服务器文件
|
||||
│ ├── db/ # 数据库相关代码
|
||||
│ ├── security/ # SQL安全相关代码
|
||||
│ │ ├── interceptor.py # SQL拦截器
|
||||
│ │ ├── query_limiter.py # SQL安全检查器
|
||||
│ │ └── sql_analyzer.py # SQL分析器
|
||||
│ └── tools/ # 工具类代码
|
||||
├── tests/ # 测试代码目录
|
||||
├── .env.example # 环境变量示例文件
|
||||
└── requirements.txt # 项目依赖文件
|
||||
```
|
||||
|
||||
## 常见问题解决
|
||||
|
||||
### DELETE操作未执行成功
|
||||
- 检查DELETE操作是否包含WHERE条件
|
||||
- 无WHERE条件的DELETE操作被标记为CRITICAL风险级别
|
||||
- 确保环境变量ALLOWED_RISK_LEVELS中包含CRITICAL(如果需要执行该操作)
|
||||
- 检查影响行数返回值,确认操作是否实际影响了数据库
|
||||
|
||||
### 环境变量未生效
|
||||
- 确保在server.py中的load_dotenv()调用发生在导入其他模块之前
|
||||
- 重启应用以确保环境变量被正确加载
|
||||
- 检查日志中"从环境变量读取到的风险等级设置"的输出
|
||||
|
||||
### 操作被安全机制拒绝
|
||||
- 检查操作的风险级别是否在允许的范围内
|
||||
- 如果需要执行高风险操作,相应地调整ALLOWED_RISK_LEVELS
|
||||
- 对于不带WHERE条件的UPDATE或DELETE,可以添加条件(即使是WHERE 1=1)降低风险级别
|
||||
|
||||
## 日志系统
|
||||
|
||||
服务器包含完整的日志记录系统,可以在控制台和日志文件中查看运行状态和错误信息。日志级别可以在`server.py`中配置。
|
||||
@ -75,7 +145,9 @@ python src/server.py
|
||||
服务器包含完善的错误处理机制:
|
||||
- MySQL连接器导入检查
|
||||
- 数据库配置验证
|
||||
- SQL安全检查
|
||||
- 运行时错误捕获和记录
|
||||
- 事务自动回滚
|
||||
|
||||
## 贡献指南
|
||||
|
||||
|
||||
3
src/__init__.py
Normal file
3
src/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
"""
|
||||
MySQL查询服务器包
|
||||
"""
|
||||
3
src/db/__init__.py
Normal file
3
src/db/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
"""
|
||||
MySQL数据库操作包
|
||||
"""
|
||||
@ -5,8 +5,17 @@ from mysql.connector import Error
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from ..security.sql_analyzer import SQLOperationType
|
||||
from ..security.query_limiter import QueryLimiter
|
||||
from ..security.interceptor import SQLInterceptor, SecurityException
|
||||
|
||||
logger = logging.getLogger("mysql_server")
|
||||
|
||||
# 初始化安全组件
|
||||
sql_analyzer = SQLOperationType()
|
||||
query_limiter = QueryLimiter()
|
||||
sql_interceptor = SQLInterceptor(sql_analyzer)
|
||||
|
||||
def get_db_config():
|
||||
"""动态获取数据库配置"""
|
||||
return {
|
||||
@ -15,7 +24,8 @@ def get_db_config():
|
||||
'password': os.getenv('MYSQL_PASSWORD', ''),
|
||||
'database': os.getenv('MYSQL_DATABASE', ''),
|
||||
'port': int(os.getenv('MYSQL_PORT', '3306')),
|
||||
'connection_timeout': 5
|
||||
'connection_timeout': 5,
|
||||
'auth_plugin': 'mysql_native_password' # 指定认证插件
|
||||
}
|
||||
|
||||
@contextmanager
|
||||
@ -54,7 +64,7 @@ def get_db_connection():
|
||||
connection.close()
|
||||
logger.debug("数据库连接已关闭")
|
||||
|
||||
def execute_query(connection, query: str, params: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]:
|
||||
async def execute_query(connection, query: str, params: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
在给定的数据库连接上执行查询
|
||||
|
||||
@ -64,10 +74,18 @@ def execute_query(connection, query: str, params: Optional[Dict[str, Any]] = Non
|
||||
params: 查询参数 (可选)
|
||||
|
||||
Returns:
|
||||
查询结果列表
|
||||
查询结果列表,如果是修改操作则返回影响的行数
|
||||
|
||||
Raises:
|
||||
SecurityException: 当操作被安全机制拒绝时
|
||||
ValueError: 当查询执行失败时
|
||||
"""
|
||||
cursor = None
|
||||
try:
|
||||
# 安全检查
|
||||
if not await sql_interceptor.check_operation(query):
|
||||
raise SecurityException("操作被安全机制拒绝")
|
||||
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
|
||||
# 执行查询
|
||||
@ -76,12 +94,33 @@ def execute_query(connection, query: str, params: Optional[Dict[str, Any]] = Non
|
||||
else:
|
||||
cursor.execute(query)
|
||||
|
||||
# 获取结果
|
||||
# 获取操作类型
|
||||
operation = query.strip().split()[0].upper()
|
||||
|
||||
# 对于修改操作,提交事务并返回影响的行数
|
||||
if operation in {'UPDATE', 'DELETE', 'INSERT'}:
|
||||
affected_rows = cursor.rowcount
|
||||
# 提交事务,确保更改被保存
|
||||
connection.commit()
|
||||
logger.debug(f"修改操作 {operation} 影响了 {affected_rows} 行数据")
|
||||
return [{'affected_rows': affected_rows}]
|
||||
|
||||
# 对于查询操作,返回结果集
|
||||
results = cursor.fetchall()
|
||||
logger.debug(f"查询返回 {len(results)} 条结果")
|
||||
return results
|
||||
|
||||
except SecurityException as security_err:
|
||||
logger.error(f"安全检查失败: {str(security_err)}")
|
||||
raise
|
||||
except mysql.connector.Error as query_err:
|
||||
# 如果发生错误,进行回滚
|
||||
if operation in {'UPDATE', 'DELETE', 'INSERT'}:
|
||||
try:
|
||||
connection.rollback()
|
||||
logger.debug("事务已回滚")
|
||||
except:
|
||||
pass
|
||||
logger.error(f"查询执行失败: {str(query_err)}")
|
||||
raise ValueError(f"查询执行失败: {str(query_err)}")
|
||||
finally:
|
||||
|
||||
84
src/security/interceptor.py
Normal file
84
src/security/interceptor.py
Normal file
@ -0,0 +1,84 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Dict
|
||||
|
||||
from .sql_analyzer import SQLOperationType, SQLRiskLevel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class SecurityException(Exception):
|
||||
"""安全相关异常"""
|
||||
pass
|
||||
|
||||
class SQLInterceptor:
|
||||
"""SQL操作拦截器"""
|
||||
|
||||
def __init__(self, analyzer: SQLOperationType):
|
||||
self.analyzer = analyzer
|
||||
# 设置最大SQL长度限制(默认1000个字符)
|
||||
self.max_sql_length = 1000
|
||||
|
||||
async def check_operation(self, sql_query: str) -> bool:
|
||||
"""
|
||||
检查SQL操作是否允许执行
|
||||
|
||||
Args:
|
||||
sql_query: SQL查询语句
|
||||
|
||||
Returns:
|
||||
bool: 是否允许执行
|
||||
|
||||
Raises:
|
||||
SecurityException: 当操作被拒绝时抛出
|
||||
"""
|
||||
try:
|
||||
# 检查SQL是否为空
|
||||
if not sql_query or not sql_query.strip():
|
||||
raise SecurityException("SQL语句不能为空")
|
||||
|
||||
# 检查SQL长度
|
||||
if len(sql_query) > self.max_sql_length:
|
||||
raise SecurityException(f"SQL语句长度({len(sql_query)})超出限制({self.max_sql_length})")
|
||||
|
||||
# 检查SQL是否有效
|
||||
sql_parts = sql_query.strip().split()
|
||||
if not sql_parts:
|
||||
raise SecurityException("SQL语句格式无效")
|
||||
|
||||
operation = sql_parts[0].upper()
|
||||
if operation not in {'SELECT', 'INSERT', 'UPDATE', 'DELETE', 'CREATE', 'ALTER', 'DROP', 'TRUNCATE', 'MERGE'}:
|
||||
raise SecurityException(f"不支持的SQL操作: {operation}")
|
||||
|
||||
# 分析SQL风险
|
||||
risk_analysis = self.analyzer.analyze_risk(sql_query)
|
||||
|
||||
# 检查是否是危险操作
|
||||
if risk_analysis['is_dangerous']:
|
||||
raise SecurityException(
|
||||
f"检测到危险操作: {risk_analysis['operation']}"
|
||||
)
|
||||
|
||||
# 检查操作是否被允许
|
||||
if not risk_analysis['is_allowed']:
|
||||
raise SecurityException(
|
||||
f"当前操作风险等级({risk_analysis['risk_level'].name})不被允许执行,"
|
||||
f"允许的风险等级: {[level.name for level in self.analyzer.allowed_risk_levels]}"
|
||||
)
|
||||
|
||||
# 记录详细日志
|
||||
logger.info(
|
||||
f"SQL操作检查通过 - "
|
||||
f"操作: {risk_analysis['operation']}, "
|
||||
f"风险等级: {risk_analysis['risk_level'].name}, "
|
||||
f"影响表: {', '.join(risk_analysis['affected_tables'])}"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
except SecurityException as e:
|
||||
logger.error(str(e))
|
||||
raise
|
||||
except Exception as e:
|
||||
error_msg = f"安全检查失败: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
raise SecurityException(error_msg)
|
||||
68
src/security/query_limiter.py
Normal file
68
src/security/query_limiter.py
Normal file
@ -0,0 +1,68 @@
|
||||
import os
|
||||
import logging
|
||||
from typing import Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class QueryLimiter:
|
||||
"""查询安全检查器"""
|
||||
|
||||
def __init__(self):
|
||||
# 解析启用状态(默认启用)
|
||||
enable_check = os.getenv('ENABLE_QUERY_CHECK', 'true')
|
||||
self.enable_check = str(enable_check).lower() not in {'false', '0', 'no', 'off'}
|
||||
|
||||
def check_query(self, sql_query: str) -> Tuple[bool, str]:
|
||||
"""
|
||||
检查SQL查询是否安全
|
||||
|
||||
Args:
|
||||
sql_query: SQL查询语句
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str]: (是否允许执行, 错误信息)
|
||||
"""
|
||||
if not self.enable_check:
|
||||
return True, ""
|
||||
|
||||
sql_query = sql_query.strip().upper()
|
||||
operation_type = self._get_operation_type(sql_query)
|
||||
|
||||
# 检查是否为无 WHERE 子句的更新/删除操作
|
||||
if operation_type in {'UPDATE', 'DELETE'} and 'WHERE' not in sql_query:
|
||||
error_msg = f"{operation_type}操作必须包含WHERE子句"
|
||||
logger.warning(f"查询被限制: {error_msg}")
|
||||
return False, error_msg
|
||||
|
||||
return True, ""
|
||||
|
||||
def _get_operation_type(self, sql_query: str) -> str:
|
||||
"""获取SQL操作类型"""
|
||||
if not sql_query:
|
||||
return ""
|
||||
words = sql_query.split()
|
||||
if not words:
|
||||
return ""
|
||||
return words[0].upper()
|
||||
|
||||
def _parse_int_env(self, env_name: str, default: int) -> int:
|
||||
"""解析整数类型的环境变量"""
|
||||
try:
|
||||
return int(os.getenv(env_name, str(default)))
|
||||
except (ValueError, TypeError):
|
||||
return default
|
||||
|
||||
def update_limits(self, new_limits: dict):
|
||||
"""
|
||||
更新限制阈值
|
||||
|
||||
Args:
|
||||
new_limits: 新的限制值字典
|
||||
"""
|
||||
for operation, limit in new_limits.items():
|
||||
if operation in self.max_limits:
|
||||
try:
|
||||
self.max_limits[operation] = int(limit)
|
||||
logger.info(f"更新{operation}操作的限制为: {limit}")
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(f"无效的限制值: {operation}={limit}")
|
||||
221
src/security/sql_analyzer.py
Normal file
221
src/security/sql_analyzer.py
Normal file
@ -0,0 +1,221 @@
|
||||
import re
|
||||
import os
|
||||
from enum import IntEnum, Enum
|
||||
import logging
|
||||
from typing import Set, List
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class SQLRiskLevel(IntEnum):
|
||||
"""SQL操作风险等级"""
|
||||
LOW = 1 # 查询操作(SELECT)
|
||||
MEDIUM = 2 # 基本数据修改(INSERT,有WHERE的UPDATE/DELETE)
|
||||
HIGH = 3 # 结构变更(CREATE/ALTER)和无WHERE的数据修改
|
||||
CRITICAL = 4 # 危险操作(DROP/TRUNCATE等)
|
||||
|
||||
class EnvironmentType(Enum):
|
||||
"""环境类型"""
|
||||
DEVELOPMENT = 'development'
|
||||
PRODUCTION = 'production'
|
||||
|
||||
class SQLOperationType:
|
||||
"""SQL操作类型分析器"""
|
||||
|
||||
def __init__(self):
|
||||
# 环境类型处理
|
||||
env_type_str = os.getenv('ENV_TYPE', 'development').lower()
|
||||
try:
|
||||
self.env_type = EnvironmentType(env_type_str)
|
||||
except ValueError:
|
||||
logger.warning(f"无效的环境类型: {env_type_str},使用默认值: development")
|
||||
self.env_type = EnvironmentType.DEVELOPMENT
|
||||
|
||||
# 基础操作集合
|
||||
self.ddl_operations = {
|
||||
'CREATE', 'ALTER', 'DROP', 'TRUNCATE', 'RENAME'
|
||||
}
|
||||
self.dml_operations = {
|
||||
'SELECT', 'INSERT', 'UPDATE', 'DELETE', 'MERGE'
|
||||
}
|
||||
|
||||
# 风险等级配置
|
||||
self.allowed_risk_levels = self._parse_risk_levels()
|
||||
self.blocked_patterns = self._parse_blocked_patterns('BLOCKED_PATTERNS')
|
||||
|
||||
# 生产环境特殊处理:如果没有明确配置风险等级,则只允许LOW风险操作
|
||||
if self.env_type == EnvironmentType.PRODUCTION and not os.getenv('ALLOWED_RISK_LEVELS'):
|
||||
self.allowed_risk_levels = {SQLRiskLevel.LOW}
|
||||
|
||||
logger.info(f"SQL分析器初始化 - 环境: {self.env_type.value}")
|
||||
logger.info(f"允许的风险等级: {[level.name for level in self.allowed_risk_levels]}")
|
||||
|
||||
def _parse_risk_levels(self) -> Set[SQLRiskLevel]:
|
||||
"""解析允许的风险等级"""
|
||||
allowed_levels_str = os.getenv('ALLOWED_RISK_LEVELS', 'LOW,MEDIUM')
|
||||
allowed_levels = set()
|
||||
|
||||
logger.info(f"从环境变量读取到的风险等级设置: '{allowed_levels_str}'")
|
||||
|
||||
for level_str in allowed_levels_str.upper().split(','):
|
||||
level_str = level_str.strip()
|
||||
try:
|
||||
allowed_levels.add(SQLRiskLevel[level_str])
|
||||
except KeyError:
|
||||
logger.warning(f"未知的风险等级配置: {level_str}")
|
||||
|
||||
return allowed_levels
|
||||
|
||||
def _parse_blocked_patterns(self, env_var: str) -> List[str]:
|
||||
"""解析禁止的操作模式"""
|
||||
patterns = os.getenv(env_var, '').split(',')
|
||||
return [p.strip() for p in patterns if p.strip()]
|
||||
|
||||
def analyze_risk(self, sql_query: str) -> dict:
|
||||
"""
|
||||
分析SQL查询的风险级别和影响范围
|
||||
|
||||
Args:
|
||||
sql_query: SQL查询语句
|
||||
|
||||
Returns:
|
||||
dict: 包含风险分析结果的字典
|
||||
"""
|
||||
sql_query = sql_query.strip()
|
||||
|
||||
# 处理空SQL
|
||||
if not sql_query:
|
||||
return {
|
||||
'operation': '',
|
||||
'operation_type': 'UNKNOWN',
|
||||
'is_dangerous': True,
|
||||
'affected_tables': [],
|
||||
'estimated_impact': {
|
||||
'operation': '',
|
||||
'estimated_rows': 0,
|
||||
'needs_where': False,
|
||||
'has_where': False
|
||||
},
|
||||
'risk_level': SQLRiskLevel.HIGH,
|
||||
'is_allowed': False
|
||||
}
|
||||
|
||||
operation = sql_query.split()[0].upper()
|
||||
|
||||
# 基本风险分析
|
||||
risk_analysis = {
|
||||
'operation': operation,
|
||||
'operation_type': 'DDL' if operation in self.ddl_operations else 'DML',
|
||||
'is_dangerous': self._check_dangerous_patterns(sql_query),
|
||||
'affected_tables': self._get_affected_tables(sql_query),
|
||||
'estimated_impact': self._estimate_impact(sql_query)
|
||||
}
|
||||
|
||||
# 计算风险等级
|
||||
risk_level = self._calculate_risk_level(sql_query, operation, risk_analysis['is_dangerous'])
|
||||
risk_analysis['risk_level'] = risk_level
|
||||
risk_analysis['is_allowed'] = risk_level in self.allowed_risk_levels
|
||||
|
||||
return risk_analysis
|
||||
|
||||
def _calculate_risk_level(self, sql_query: str, operation: str, is_dangerous: bool) -> SQLRiskLevel:
|
||||
"""
|
||||
计算操作风险等级
|
||||
|
||||
规则:
|
||||
1. 危险操作(匹配危险模式)=> CRITICAL
|
||||
2. DDL操作:
|
||||
- CREATE/ALTER => HIGH
|
||||
- DROP/TRUNCATE => CRITICAL
|
||||
3. DML操作:
|
||||
- SELECT => LOW
|
||||
- INSERT => MEDIUM
|
||||
- UPDATE/DELETE(有WHERE)=> MEDIUM
|
||||
- UPDATE(无WHERE)=> HIGH
|
||||
- DELETE(无WHERE)=> CRITICAL
|
||||
"""
|
||||
# 危险操作
|
||||
if is_dangerous:
|
||||
return SQLRiskLevel.CRITICAL
|
||||
|
||||
# 生产环境中非SELECT操作
|
||||
if self.env_type == EnvironmentType.PRODUCTION and operation != 'SELECT':
|
||||
return SQLRiskLevel.CRITICAL
|
||||
|
||||
# DDL操作
|
||||
if operation in self.ddl_operations:
|
||||
if operation in {'DROP', 'TRUNCATE'}:
|
||||
return SQLRiskLevel.CRITICAL
|
||||
return SQLRiskLevel.HIGH
|
||||
|
||||
# DML操作
|
||||
if operation == 'SELECT':
|
||||
return SQLRiskLevel.LOW
|
||||
elif operation == 'INSERT':
|
||||
return SQLRiskLevel.MEDIUM
|
||||
elif operation == 'UPDATE':
|
||||
return SQLRiskLevel.HIGH if 'WHERE' not in sql_query.upper() else SQLRiskLevel.MEDIUM
|
||||
elif operation == 'DELETE':
|
||||
# 无WHERE条件的DELETE操作视为CRITICAL风险
|
||||
return SQLRiskLevel.CRITICAL if 'WHERE' not in sql_query.upper() else SQLRiskLevel.MEDIUM
|
||||
|
||||
return SQLRiskLevel.HIGH
|
||||
|
||||
def _check_dangerous_patterns(self, sql_query: str) -> bool:
|
||||
"""检查是否匹配危险操作模式"""
|
||||
sql_upper = sql_query.upper()
|
||||
|
||||
# 生产环境额外的安全检查
|
||||
if self.env_type == EnvironmentType.PRODUCTION:
|
||||
# 生产环境中禁止所有非SELECT操作
|
||||
if sql_upper.split()[0] != 'SELECT':
|
||||
return True
|
||||
|
||||
for pattern in self.blocked_patterns:
|
||||
if re.search(pattern, sql_upper, re.IGNORECASE):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _get_affected_tables(self, sql_query: str) -> list:
|
||||
"""获取受影响的表名列表"""
|
||||
words = sql_query.upper().split()
|
||||
tables = []
|
||||
|
||||
for i, word in enumerate(words):
|
||||
if word in {'FROM', 'JOIN', 'UPDATE', 'INTO', 'TABLE'}:
|
||||
if i + 1 < len(words):
|
||||
table = words[i + 1].strip('`;')
|
||||
if table not in {'SELECT', 'WHERE', 'SET'}:
|
||||
tables.append(table)
|
||||
|
||||
return list(set(tables))
|
||||
|
||||
def _estimate_impact(self, sql_query: str) -> dict:
|
||||
"""
|
||||
估算查询影响范围
|
||||
|
||||
Returns:
|
||||
dict: 包含预估影响的字典
|
||||
"""
|
||||
operation = sql_query.split()[0].upper()
|
||||
|
||||
impact = {
|
||||
'operation': operation,
|
||||
'estimated_rows': 0,
|
||||
'needs_where': operation in {'UPDATE', 'DELETE'},
|
||||
'has_where': 'WHERE' in sql_query.upper()
|
||||
}
|
||||
|
||||
# 根据环境类型调整估算
|
||||
if self.env_type == EnvironmentType.PRODUCTION:
|
||||
if operation == 'SELECT':
|
||||
impact['estimated_rows'] = 100
|
||||
else:
|
||||
impact['estimated_rows'] = float('inf') # 生产环境中非SELECT操作视为影响无限行
|
||||
else:
|
||||
if operation == 'SELECT':
|
||||
impact['estimated_rows'] = 100
|
||||
elif operation in {'UPDATE', 'DELETE'}:
|
||||
impact['estimated_rows'] = 1000 if impact['has_where'] else float('inf')
|
||||
|
||||
return impact
|
||||
@ -2,6 +2,11 @@ from mcp.server.fastmcp import FastMCP
|
||||
import os
|
||||
import logging
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# 加载环境变量 - 移到最前面确保所有模块导入前环境变量已加载
|
||||
load_dotenv()
|
||||
|
||||
# 导入自定义模块 - 确保在load_dotenv之后导入
|
||||
from src.tools.mysql_tool import register_mysql_tool
|
||||
|
||||
# 配置日志
|
||||
@ -11,6 +16,10 @@ logging.basicConfig(
|
||||
)
|
||||
logger = logging.getLogger("mysql_server")
|
||||
|
||||
# 记录环境变量加载情况
|
||||
logger.debug("已加载环境变量")
|
||||
logger.debug(f"当前允许的风险等级: {os.getenv('ALLOWED_RISK_LEVELS', '未设置')}")
|
||||
|
||||
# 尝试导入MySQL连接器
|
||||
try:
|
||||
import mysql.connector
|
||||
@ -21,10 +30,6 @@ except ImportError as e:
|
||||
logger.critical("请确保已安装mysql-connector-python包: pip install mysql-connector-python")
|
||||
mysql_available = False
|
||||
|
||||
# 加载环境变量
|
||||
load_dotenv()
|
||||
logger.debug("已加载环境变量")
|
||||
|
||||
# 从环境变量获取服务器配置
|
||||
host = os.getenv('HOST', '127.0.0.1')
|
||||
port = int(os.getenv('PORT', '3000'))
|
||||
|
||||
3
src/tools/__init__.py
Normal file
3
src/tools/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
"""
|
||||
MySQL工具包
|
||||
"""
|
||||
@ -24,7 +24,7 @@ def register_mysql_tool(mcp: FastMCP):
|
||||
logger.debug("注册MySQL查询工具...")
|
||||
|
||||
@mcp.tool()
|
||||
def mysql_query(query: str, params: Optional[Dict[str, Any]] = None) -> str:
|
||||
async def mysql_query(query: str, params: Optional[Dict[str, Any]] = None) -> str:
|
||||
"""
|
||||
执行MySQL查询并返回结果
|
||||
|
||||
@ -39,7 +39,14 @@ def register_mysql_tool(mcp: FastMCP):
|
||||
|
||||
try:
|
||||
with get_db_connection() as connection:
|
||||
results = execute_query(connection, query, params)
|
||||
results = await execute_query(connection, query, params)
|
||||
|
||||
# 检查是否是修改操作返回的影响行数
|
||||
operation = query.strip().split()[0].upper()
|
||||
if operation in {'UPDATE', 'DELETE', 'INSERT'} and results and 'affected_rows' in results[0]:
|
||||
affected_rows = results[0]['affected_rows']
|
||||
logger.info(f"{operation}操作影响了{affected_rows}行数据")
|
||||
|
||||
return json.dumps(results, default=str)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
Reference in New Issue
Block a user