mirror of
https://github.com/mangooer/mysql-mcp-server-sse.git
synced 2025-12-08 17:52:28 +08:00
<feat> 安全限制
This commit is contained in:
28
.env.example
28
.env.example
@ -8,3 +8,31 @@ MYSQL_DATABASE=test
|
|||||||
# 服务器配置
|
# 服务器配置
|
||||||
PORT=3000
|
PORT=3000
|
||||||
HOST=127.0.0.1
|
HOST=127.0.0.1
|
||||||
|
|
||||||
|
# 环境配置
|
||||||
|
ENV_TYPE=development # 环境类型:development/production
|
||||||
|
|
||||||
|
# SQL风险控制配置
|
||||||
|
ALLOWED_RISK_LEVELS=LOW,MEDIUM,HIGH # 允许的风险等级(LOW/MEDIUM/HIGH/CRITICAL)
|
||||||
|
# 如需执行无WHERE条件的DELETE操作,需要将CRITICAL添加到ALLOWED_RISK_LEVELS中
|
||||||
|
# 例如:ALLOWED_RISK_LEVELS=LOW,MEDIUM,HIGH,CRITICAL
|
||||||
|
|
||||||
|
# 禁止的SQL模式(正则表达式,用逗号分隔)
|
||||||
|
BLOCKED_PATTERNS=DROP\s+DATABASE,TRUNCATE\s+TABLE
|
||||||
|
|
||||||
|
# SQL安全检查配置
|
||||||
|
ENABLE_QUERY_CHECK=true # 是否启用SQL安全检查
|
||||||
|
|
||||||
|
# 配置示例说明
|
||||||
|
# -------------------
|
||||||
|
# 开发环境配置示例:
|
||||||
|
# ENV_TYPE=development
|
||||||
|
# ALLOWED_RISK_LEVELS=LOW,MEDIUM,HIGH # 允许除CRITICAL外的所有风险等级
|
||||||
|
# BLOCKED_PATTERNS=DROP\s+DATABASE,TRUNCATE\s+TABLE # 仅禁止最危险的操作
|
||||||
|
# ENABLE_QUERY_CHECK=true # 启用SQL安全检查
|
||||||
|
|
||||||
|
# 生产环境配置示例:
|
||||||
|
# ENV_TYPE=production
|
||||||
|
# ALLOWED_RISK_LEVELS=LOW # 只允许低风险操作(SELECT)
|
||||||
|
# BLOCKED_PATTERNS=DROP,TRUNCATE,DELETE,UPDATE # 禁止所有数据修改操作
|
||||||
|
# ENABLE_QUERY_CHECK=true # 强制启用SQL安全检查
|
||||||
6
.gitignore
vendored
6
.gitignore
vendored
@ -9,3 +9,9 @@ vendor/
|
|||||||
.cursorrules
|
.cursorrules
|
||||||
.cursor
|
.cursor
|
||||||
**/__pycache__/
|
**/__pycache__/
|
||||||
|
tests/
|
||||||
|
.env
|
||||||
|
.coverage
|
||||||
|
docs/
|
||||||
|
.pytest_cache/
|
||||||
|
requirements-dev.txt
|
||||||
72
README.md
72
README.md
@ -8,7 +8,14 @@
|
|||||||
- 支持SSE(Server-Sent Events)实时数据传输
|
- 支持SSE(Server-Sent Events)实时数据传输
|
||||||
- 提供MySQL数据库查询接口
|
- 提供MySQL数据库查询接口
|
||||||
- 完整的日志记录系统
|
- 完整的日志记录系统
|
||||||
|
- 自动事务管理(提交/回滚)
|
||||||
- 环境变量配置支持
|
- 环境变量配置支持
|
||||||
|
- SQL安全检查机制
|
||||||
|
- 风险等级控制
|
||||||
|
- SQL注入防护
|
||||||
|
- 危险操作拦截
|
||||||
|
- WHERE子句强制检查
|
||||||
|
- 自动返回修改操作影响的行数
|
||||||
|
|
||||||
## 系统要求
|
## 系统要求
|
||||||
|
|
||||||
@ -36,6 +43,7 @@ pip install -r requirements.txt
|
|||||||
|
|
||||||
在`.env`文件中配置以下参数:
|
在`.env`文件中配置以下参数:
|
||||||
|
|
||||||
|
### 基本配置
|
||||||
- `HOST`: 服务器监听地址(默认:127.0.0.1)
|
- `HOST`: 服务器监听地址(默认:127.0.0.1)
|
||||||
- `PORT`: 服务器监听端口(默认:3000)
|
- `PORT`: 服务器监听端口(默认:3000)
|
||||||
- `MYSQL_HOST`: MySQL服务器地址
|
- `MYSQL_HOST`: MySQL服务器地址
|
||||||
@ -44,6 +52,45 @@ pip install -r requirements.txt
|
|||||||
- `MYSQL_PASSWORD`: MySQL密码
|
- `MYSQL_PASSWORD`: MySQL密码
|
||||||
- `MYSQL_DATABASE`: MySQL数据库名
|
- `MYSQL_DATABASE`: MySQL数据库名
|
||||||
|
|
||||||
|
### SQL安全配置
|
||||||
|
- `ENV_TYPE`: 环境类型(development/production)
|
||||||
|
- `ALLOWED_RISK_LEVELS`: 允许的风险等级(LOW/MEDIUM/HIGH/CRITICAL)
|
||||||
|
- `BLOCKED_PATTERNS`: 禁止的SQL模式(正则表达式,用逗号分隔)
|
||||||
|
- `ENABLE_QUERY_CHECK`: 是否启用SQL安全检查(true/false)
|
||||||
|
|
||||||
|
## SQL安全机制
|
||||||
|
|
||||||
|
### 风险等级控制
|
||||||
|
- LOW: 查询操作(SELECT)
|
||||||
|
- MEDIUM: 基本数据修改(INSERT,有WHERE的UPDATE/DELETE)
|
||||||
|
- HIGH: 结构变更(CREATE/ALTER)和无WHERE的UPDATE
|
||||||
|
- CRITICAL: 危险操作(DROP/TRUNCATE)和无WHERE的DELETE操作
|
||||||
|
|
||||||
|
### 环境变量加载顺序
|
||||||
|
项目使用python-dotenv加载环境变量,需要确保在导入其他模块前先加载环境变量,否则可能会导致配置未被正确应用。
|
||||||
|
|
||||||
|
### 安全检查机制
|
||||||
|
- 强制要求UPDATE/DELETE操作包含WHERE子句
|
||||||
|
- SQL语句语法检查
|
||||||
|
- 危险操作模式检测
|
||||||
|
- 自动识别受影响的表
|
||||||
|
- 风险等级评估
|
||||||
|
|
||||||
|
### 环境特性
|
||||||
|
- 开发环境:允许较高风险的操作,但仍需遵守基本安全规则
|
||||||
|
- 生产环境:默认只允许LOW风险操作,严格限制数据修改
|
||||||
|
|
||||||
|
### 安全拦截
|
||||||
|
- SQL语句长度限制
|
||||||
|
- 危险操作模式检测
|
||||||
|
- 自动识别受影响的表
|
||||||
|
- SQL注入防护
|
||||||
|
|
||||||
|
### 事务管理
|
||||||
|
- 对于修改操作(INSERT/UPDATE/DELETE)会自动提交事务
|
||||||
|
- 执行错误时自动回滚事务
|
||||||
|
- 返回操作影响的行数
|
||||||
|
|
||||||
## 启动服务器
|
## 启动服务器
|
||||||
|
|
||||||
运行以下命令启动服务器:
|
运行以下命令启动服务器:
|
||||||
@ -61,11 +108,34 @@ python src/server.py
|
|||||||
├── src/ # 源代码目录
|
├── src/ # 源代码目录
|
||||||
│ ├── server.py # 主服务器文件
|
│ ├── server.py # 主服务器文件
|
||||||
│ ├── db/ # 数据库相关代码
|
│ ├── db/ # 数据库相关代码
|
||||||
|
│ ├── security/ # SQL安全相关代码
|
||||||
|
│ │ ├── interceptor.py # SQL拦截器
|
||||||
|
│ │ ├── query_limiter.py # SQL安全检查器
|
||||||
|
│ │ └── sql_analyzer.py # SQL分析器
|
||||||
│ └── tools/ # 工具类代码
|
│ └── tools/ # 工具类代码
|
||||||
|
├── tests/ # 测试代码目录
|
||||||
├── .env.example # 环境变量示例文件
|
├── .env.example # 环境变量示例文件
|
||||||
└── requirements.txt # 项目依赖文件
|
└── requirements.txt # 项目依赖文件
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## 常见问题解决
|
||||||
|
|
||||||
|
### DELETE操作未执行成功
|
||||||
|
- 检查DELETE操作是否包含WHERE条件
|
||||||
|
- 无WHERE条件的DELETE操作被标记为CRITICAL风险级别
|
||||||
|
- 确保环境变量ALLOWED_RISK_LEVELS中包含CRITICAL(如果需要执行该操作)
|
||||||
|
- 检查影响行数返回值,确认操作是否实际影响了数据库
|
||||||
|
|
||||||
|
### 环境变量未生效
|
||||||
|
- 确保在server.py中的load_dotenv()调用发生在导入其他模块之前
|
||||||
|
- 重启应用以确保环境变量被正确加载
|
||||||
|
- 检查日志中"从环境变量读取到的风险等级设置"的输出
|
||||||
|
|
||||||
|
### 操作被安全机制拒绝
|
||||||
|
- 检查操作的风险级别是否在允许的范围内
|
||||||
|
- 如果需要执行高风险操作,相应地调整ALLOWED_RISK_LEVELS
|
||||||
|
- 对于不带WHERE条件的UPDATE或DELETE,可以添加条件(即使是WHERE 1=1)降低风险级别
|
||||||
|
|
||||||
## 日志系统
|
## 日志系统
|
||||||
|
|
||||||
服务器包含完整的日志记录系统,可以在控制台和日志文件中查看运行状态和错误信息。日志级别可以在`server.py`中配置。
|
服务器包含完整的日志记录系统,可以在控制台和日志文件中查看运行状态和错误信息。日志级别可以在`server.py`中配置。
|
||||||
@ -75,7 +145,9 @@ python src/server.py
|
|||||||
服务器包含完善的错误处理机制:
|
服务器包含完善的错误处理机制:
|
||||||
- MySQL连接器导入检查
|
- MySQL连接器导入检查
|
||||||
- 数据库配置验证
|
- 数据库配置验证
|
||||||
|
- SQL安全检查
|
||||||
- 运行时错误捕获和记录
|
- 运行时错误捕获和记录
|
||||||
|
- 事务自动回滚
|
||||||
|
|
||||||
## 贡献指南
|
## 贡献指南
|
||||||
|
|
||||||
|
|||||||
3
src/__init__.py
Normal file
3
src/__init__.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
"""
|
||||||
|
MySQL查询服务器包
|
||||||
|
"""
|
||||||
3
src/db/__init__.py
Normal file
3
src/db/__init__.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
"""
|
||||||
|
MySQL数据库操作包
|
||||||
|
"""
|
||||||
@ -5,8 +5,17 @@ from mysql.connector import Error
|
|||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from ..security.sql_analyzer import SQLOperationType
|
||||||
|
from ..security.query_limiter import QueryLimiter
|
||||||
|
from ..security.interceptor import SQLInterceptor, SecurityException
|
||||||
|
|
||||||
logger = logging.getLogger("mysql_server")
|
logger = logging.getLogger("mysql_server")
|
||||||
|
|
||||||
|
# 初始化安全组件
|
||||||
|
sql_analyzer = SQLOperationType()
|
||||||
|
query_limiter = QueryLimiter()
|
||||||
|
sql_interceptor = SQLInterceptor(sql_analyzer)
|
||||||
|
|
||||||
def get_db_config():
|
def get_db_config():
|
||||||
"""动态获取数据库配置"""
|
"""动态获取数据库配置"""
|
||||||
return {
|
return {
|
||||||
@ -15,7 +24,8 @@ def get_db_config():
|
|||||||
'password': os.getenv('MYSQL_PASSWORD', ''),
|
'password': os.getenv('MYSQL_PASSWORD', ''),
|
||||||
'database': os.getenv('MYSQL_DATABASE', ''),
|
'database': os.getenv('MYSQL_DATABASE', ''),
|
||||||
'port': int(os.getenv('MYSQL_PORT', '3306')),
|
'port': int(os.getenv('MYSQL_PORT', '3306')),
|
||||||
'connection_timeout': 5
|
'connection_timeout': 5,
|
||||||
|
'auth_plugin': 'mysql_native_password' # 指定认证插件
|
||||||
}
|
}
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
@ -54,7 +64,7 @@ def get_db_connection():
|
|||||||
connection.close()
|
connection.close()
|
||||||
logger.debug("数据库连接已关闭")
|
logger.debug("数据库连接已关闭")
|
||||||
|
|
||||||
def execute_query(connection, query: str, params: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]:
|
async def execute_query(connection, query: str, params: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
在给定的数据库连接上执行查询
|
在给定的数据库连接上执行查询
|
||||||
|
|
||||||
@ -64,10 +74,18 @@ def execute_query(connection, query: str, params: Optional[Dict[str, Any]] = Non
|
|||||||
params: 查询参数 (可选)
|
params: 查询参数 (可选)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
查询结果列表
|
查询结果列表,如果是修改操作则返回影响的行数
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
SecurityException: 当操作被安全机制拒绝时
|
||||||
|
ValueError: 当查询执行失败时
|
||||||
"""
|
"""
|
||||||
cursor = None
|
cursor = None
|
||||||
try:
|
try:
|
||||||
|
# 安全检查
|
||||||
|
if not await sql_interceptor.check_operation(query):
|
||||||
|
raise SecurityException("操作被安全机制拒绝")
|
||||||
|
|
||||||
cursor = connection.cursor(dictionary=True)
|
cursor = connection.cursor(dictionary=True)
|
||||||
|
|
||||||
# 执行查询
|
# 执行查询
|
||||||
@ -76,12 +94,33 @@ def execute_query(connection, query: str, params: Optional[Dict[str, Any]] = Non
|
|||||||
else:
|
else:
|
||||||
cursor.execute(query)
|
cursor.execute(query)
|
||||||
|
|
||||||
# 获取结果
|
# 获取操作类型
|
||||||
|
operation = query.strip().split()[0].upper()
|
||||||
|
|
||||||
|
# 对于修改操作,提交事务并返回影响的行数
|
||||||
|
if operation in {'UPDATE', 'DELETE', 'INSERT'}:
|
||||||
|
affected_rows = cursor.rowcount
|
||||||
|
# 提交事务,确保更改被保存
|
||||||
|
connection.commit()
|
||||||
|
logger.debug(f"修改操作 {operation} 影响了 {affected_rows} 行数据")
|
||||||
|
return [{'affected_rows': affected_rows}]
|
||||||
|
|
||||||
|
# 对于查询操作,返回结果集
|
||||||
results = cursor.fetchall()
|
results = cursor.fetchall()
|
||||||
logger.debug(f"查询返回 {len(results)} 条结果")
|
logger.debug(f"查询返回 {len(results)} 条结果")
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
except SecurityException as security_err:
|
||||||
|
logger.error(f"安全检查失败: {str(security_err)}")
|
||||||
|
raise
|
||||||
except mysql.connector.Error as query_err:
|
except mysql.connector.Error as query_err:
|
||||||
|
# 如果发生错误,进行回滚
|
||||||
|
if operation in {'UPDATE', 'DELETE', 'INSERT'}:
|
||||||
|
try:
|
||||||
|
connection.rollback()
|
||||||
|
logger.debug("事务已回滚")
|
||||||
|
except:
|
||||||
|
pass
|
||||||
logger.error(f"查询执行失败: {str(query_err)}")
|
logger.error(f"查询执行失败: {str(query_err)}")
|
||||||
raise ValueError(f"查询执行失败: {str(query_err)}")
|
raise ValueError(f"查询执行失败: {str(query_err)}")
|
||||||
finally:
|
finally:
|
||||||
|
|||||||
84
src/security/interceptor.py
Normal file
84
src/security/interceptor.py
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import List, Dict
|
||||||
|
|
||||||
|
from .sql_analyzer import SQLOperationType, SQLRiskLevel
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class SecurityException(Exception):
|
||||||
|
"""安全相关异常"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
class SQLInterceptor:
|
||||||
|
"""SQL操作拦截器"""
|
||||||
|
|
||||||
|
def __init__(self, analyzer: SQLOperationType):
|
||||||
|
self.analyzer = analyzer
|
||||||
|
# 设置最大SQL长度限制(默认1000个字符)
|
||||||
|
self.max_sql_length = 1000
|
||||||
|
|
||||||
|
async def check_operation(self, sql_query: str) -> bool:
|
||||||
|
"""
|
||||||
|
检查SQL操作是否允许执行
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sql_query: SQL查询语句
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 是否允许执行
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
SecurityException: 当操作被拒绝时抛出
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 检查SQL是否为空
|
||||||
|
if not sql_query or not sql_query.strip():
|
||||||
|
raise SecurityException("SQL语句不能为空")
|
||||||
|
|
||||||
|
# 检查SQL长度
|
||||||
|
if len(sql_query) > self.max_sql_length:
|
||||||
|
raise SecurityException(f"SQL语句长度({len(sql_query)})超出限制({self.max_sql_length})")
|
||||||
|
|
||||||
|
# 检查SQL是否有效
|
||||||
|
sql_parts = sql_query.strip().split()
|
||||||
|
if not sql_parts:
|
||||||
|
raise SecurityException("SQL语句格式无效")
|
||||||
|
|
||||||
|
operation = sql_parts[0].upper()
|
||||||
|
if operation not in {'SELECT', 'INSERT', 'UPDATE', 'DELETE', 'CREATE', 'ALTER', 'DROP', 'TRUNCATE', 'MERGE'}:
|
||||||
|
raise SecurityException(f"不支持的SQL操作: {operation}")
|
||||||
|
|
||||||
|
# 分析SQL风险
|
||||||
|
risk_analysis = self.analyzer.analyze_risk(sql_query)
|
||||||
|
|
||||||
|
# 检查是否是危险操作
|
||||||
|
if risk_analysis['is_dangerous']:
|
||||||
|
raise SecurityException(
|
||||||
|
f"检测到危险操作: {risk_analysis['operation']}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 检查操作是否被允许
|
||||||
|
if not risk_analysis['is_allowed']:
|
||||||
|
raise SecurityException(
|
||||||
|
f"当前操作风险等级({risk_analysis['risk_level'].name})不被允许执行,"
|
||||||
|
f"允许的风险等级: {[level.name for level in self.analyzer.allowed_risk_levels]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 记录详细日志
|
||||||
|
logger.info(
|
||||||
|
f"SQL操作检查通过 - "
|
||||||
|
f"操作: {risk_analysis['operation']}, "
|
||||||
|
f"风险等级: {risk_analysis['risk_level'].name}, "
|
||||||
|
f"影响表: {', '.join(risk_analysis['affected_tables'])}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
except SecurityException as e:
|
||||||
|
logger.error(str(e))
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"安全检查失败: {str(e)}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
raise SecurityException(error_msg)
|
||||||
68
src/security/query_limiter.py
Normal file
68
src/security/query_limiter.py
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
import os
|
||||||
|
import logging
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class QueryLimiter:
|
||||||
|
"""查询安全检查器"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
# 解析启用状态(默认启用)
|
||||||
|
enable_check = os.getenv('ENABLE_QUERY_CHECK', 'true')
|
||||||
|
self.enable_check = str(enable_check).lower() not in {'false', '0', 'no', 'off'}
|
||||||
|
|
||||||
|
def check_query(self, sql_query: str) -> Tuple[bool, str]:
|
||||||
|
"""
|
||||||
|
检查SQL查询是否安全
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sql_query: SQL查询语句
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[bool, str]: (是否允许执行, 错误信息)
|
||||||
|
"""
|
||||||
|
if not self.enable_check:
|
||||||
|
return True, ""
|
||||||
|
|
||||||
|
sql_query = sql_query.strip().upper()
|
||||||
|
operation_type = self._get_operation_type(sql_query)
|
||||||
|
|
||||||
|
# 检查是否为无 WHERE 子句的更新/删除操作
|
||||||
|
if operation_type in {'UPDATE', 'DELETE'} and 'WHERE' not in sql_query:
|
||||||
|
error_msg = f"{operation_type}操作必须包含WHERE子句"
|
||||||
|
logger.warning(f"查询被限制: {error_msg}")
|
||||||
|
return False, error_msg
|
||||||
|
|
||||||
|
return True, ""
|
||||||
|
|
||||||
|
def _get_operation_type(self, sql_query: str) -> str:
|
||||||
|
"""获取SQL操作类型"""
|
||||||
|
if not sql_query:
|
||||||
|
return ""
|
||||||
|
words = sql_query.split()
|
||||||
|
if not words:
|
||||||
|
return ""
|
||||||
|
return words[0].upper()
|
||||||
|
|
||||||
|
def _parse_int_env(self, env_name: str, default: int) -> int:
|
||||||
|
"""解析整数类型的环境变量"""
|
||||||
|
try:
|
||||||
|
return int(os.getenv(env_name, str(default)))
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
return default
|
||||||
|
|
||||||
|
def update_limits(self, new_limits: dict):
|
||||||
|
"""
|
||||||
|
更新限制阈值
|
||||||
|
|
||||||
|
Args:
|
||||||
|
new_limits: 新的限制值字典
|
||||||
|
"""
|
||||||
|
for operation, limit in new_limits.items():
|
||||||
|
if operation in self.max_limits:
|
||||||
|
try:
|
||||||
|
self.max_limits[operation] = int(limit)
|
||||||
|
logger.info(f"更新{operation}操作的限制为: {limit}")
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
logger.warning(f"无效的限制值: {operation}={limit}")
|
||||||
221
src/security/sql_analyzer.py
Normal file
221
src/security/sql_analyzer.py
Normal file
@ -0,0 +1,221 @@
|
|||||||
|
import re
|
||||||
|
import os
|
||||||
|
from enum import IntEnum, Enum
|
||||||
|
import logging
|
||||||
|
from typing import Set, List
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class SQLRiskLevel(IntEnum):
|
||||||
|
"""SQL操作风险等级"""
|
||||||
|
LOW = 1 # 查询操作(SELECT)
|
||||||
|
MEDIUM = 2 # 基本数据修改(INSERT,有WHERE的UPDATE/DELETE)
|
||||||
|
HIGH = 3 # 结构变更(CREATE/ALTER)和无WHERE的数据修改
|
||||||
|
CRITICAL = 4 # 危险操作(DROP/TRUNCATE等)
|
||||||
|
|
||||||
|
class EnvironmentType(Enum):
|
||||||
|
"""环境类型"""
|
||||||
|
DEVELOPMENT = 'development'
|
||||||
|
PRODUCTION = 'production'
|
||||||
|
|
||||||
|
class SQLOperationType:
|
||||||
|
"""SQL操作类型分析器"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
# 环境类型处理
|
||||||
|
env_type_str = os.getenv('ENV_TYPE', 'development').lower()
|
||||||
|
try:
|
||||||
|
self.env_type = EnvironmentType(env_type_str)
|
||||||
|
except ValueError:
|
||||||
|
logger.warning(f"无效的环境类型: {env_type_str},使用默认值: development")
|
||||||
|
self.env_type = EnvironmentType.DEVELOPMENT
|
||||||
|
|
||||||
|
# 基础操作集合
|
||||||
|
self.ddl_operations = {
|
||||||
|
'CREATE', 'ALTER', 'DROP', 'TRUNCATE', 'RENAME'
|
||||||
|
}
|
||||||
|
self.dml_operations = {
|
||||||
|
'SELECT', 'INSERT', 'UPDATE', 'DELETE', 'MERGE'
|
||||||
|
}
|
||||||
|
|
||||||
|
# 风险等级配置
|
||||||
|
self.allowed_risk_levels = self._parse_risk_levels()
|
||||||
|
self.blocked_patterns = self._parse_blocked_patterns('BLOCKED_PATTERNS')
|
||||||
|
|
||||||
|
# 生产环境特殊处理:如果没有明确配置风险等级,则只允许LOW风险操作
|
||||||
|
if self.env_type == EnvironmentType.PRODUCTION and not os.getenv('ALLOWED_RISK_LEVELS'):
|
||||||
|
self.allowed_risk_levels = {SQLRiskLevel.LOW}
|
||||||
|
|
||||||
|
logger.info(f"SQL分析器初始化 - 环境: {self.env_type.value}")
|
||||||
|
logger.info(f"允许的风险等级: {[level.name for level in self.allowed_risk_levels]}")
|
||||||
|
|
||||||
|
def _parse_risk_levels(self) -> Set[SQLRiskLevel]:
|
||||||
|
"""解析允许的风险等级"""
|
||||||
|
allowed_levels_str = os.getenv('ALLOWED_RISK_LEVELS', 'LOW,MEDIUM')
|
||||||
|
allowed_levels = set()
|
||||||
|
|
||||||
|
logger.info(f"从环境变量读取到的风险等级设置: '{allowed_levels_str}'")
|
||||||
|
|
||||||
|
for level_str in allowed_levels_str.upper().split(','):
|
||||||
|
level_str = level_str.strip()
|
||||||
|
try:
|
||||||
|
allowed_levels.add(SQLRiskLevel[level_str])
|
||||||
|
except KeyError:
|
||||||
|
logger.warning(f"未知的风险等级配置: {level_str}")
|
||||||
|
|
||||||
|
return allowed_levels
|
||||||
|
|
||||||
|
def _parse_blocked_patterns(self, env_var: str) -> List[str]:
|
||||||
|
"""解析禁止的操作模式"""
|
||||||
|
patterns = os.getenv(env_var, '').split(',')
|
||||||
|
return [p.strip() for p in patterns if p.strip()]
|
||||||
|
|
||||||
|
def analyze_risk(self, sql_query: str) -> dict:
|
||||||
|
"""
|
||||||
|
分析SQL查询的风险级别和影响范围
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sql_query: SQL查询语句
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: 包含风险分析结果的字典
|
||||||
|
"""
|
||||||
|
sql_query = sql_query.strip()
|
||||||
|
|
||||||
|
# 处理空SQL
|
||||||
|
if not sql_query:
|
||||||
|
return {
|
||||||
|
'operation': '',
|
||||||
|
'operation_type': 'UNKNOWN',
|
||||||
|
'is_dangerous': True,
|
||||||
|
'affected_tables': [],
|
||||||
|
'estimated_impact': {
|
||||||
|
'operation': '',
|
||||||
|
'estimated_rows': 0,
|
||||||
|
'needs_where': False,
|
||||||
|
'has_where': False
|
||||||
|
},
|
||||||
|
'risk_level': SQLRiskLevel.HIGH,
|
||||||
|
'is_allowed': False
|
||||||
|
}
|
||||||
|
|
||||||
|
operation = sql_query.split()[0].upper()
|
||||||
|
|
||||||
|
# 基本风险分析
|
||||||
|
risk_analysis = {
|
||||||
|
'operation': operation,
|
||||||
|
'operation_type': 'DDL' if operation in self.ddl_operations else 'DML',
|
||||||
|
'is_dangerous': self._check_dangerous_patterns(sql_query),
|
||||||
|
'affected_tables': self._get_affected_tables(sql_query),
|
||||||
|
'estimated_impact': self._estimate_impact(sql_query)
|
||||||
|
}
|
||||||
|
|
||||||
|
# 计算风险等级
|
||||||
|
risk_level = self._calculate_risk_level(sql_query, operation, risk_analysis['is_dangerous'])
|
||||||
|
risk_analysis['risk_level'] = risk_level
|
||||||
|
risk_analysis['is_allowed'] = risk_level in self.allowed_risk_levels
|
||||||
|
|
||||||
|
return risk_analysis
|
||||||
|
|
||||||
|
def _calculate_risk_level(self, sql_query: str, operation: str, is_dangerous: bool) -> SQLRiskLevel:
|
||||||
|
"""
|
||||||
|
计算操作风险等级
|
||||||
|
|
||||||
|
规则:
|
||||||
|
1. 危险操作(匹配危险模式)=> CRITICAL
|
||||||
|
2. DDL操作:
|
||||||
|
- CREATE/ALTER => HIGH
|
||||||
|
- DROP/TRUNCATE => CRITICAL
|
||||||
|
3. DML操作:
|
||||||
|
- SELECT => LOW
|
||||||
|
- INSERT => MEDIUM
|
||||||
|
- UPDATE/DELETE(有WHERE)=> MEDIUM
|
||||||
|
- UPDATE(无WHERE)=> HIGH
|
||||||
|
- DELETE(无WHERE)=> CRITICAL
|
||||||
|
"""
|
||||||
|
# 危险操作
|
||||||
|
if is_dangerous:
|
||||||
|
return SQLRiskLevel.CRITICAL
|
||||||
|
|
||||||
|
# 生产环境中非SELECT操作
|
||||||
|
if self.env_type == EnvironmentType.PRODUCTION and operation != 'SELECT':
|
||||||
|
return SQLRiskLevel.CRITICAL
|
||||||
|
|
||||||
|
# DDL操作
|
||||||
|
if operation in self.ddl_operations:
|
||||||
|
if operation in {'DROP', 'TRUNCATE'}:
|
||||||
|
return SQLRiskLevel.CRITICAL
|
||||||
|
return SQLRiskLevel.HIGH
|
||||||
|
|
||||||
|
# DML操作
|
||||||
|
if operation == 'SELECT':
|
||||||
|
return SQLRiskLevel.LOW
|
||||||
|
elif operation == 'INSERT':
|
||||||
|
return SQLRiskLevel.MEDIUM
|
||||||
|
elif operation == 'UPDATE':
|
||||||
|
return SQLRiskLevel.HIGH if 'WHERE' not in sql_query.upper() else SQLRiskLevel.MEDIUM
|
||||||
|
elif operation == 'DELETE':
|
||||||
|
# 无WHERE条件的DELETE操作视为CRITICAL风险
|
||||||
|
return SQLRiskLevel.CRITICAL if 'WHERE' not in sql_query.upper() else SQLRiskLevel.MEDIUM
|
||||||
|
|
||||||
|
return SQLRiskLevel.HIGH
|
||||||
|
|
||||||
|
def _check_dangerous_patterns(self, sql_query: str) -> bool:
|
||||||
|
"""检查是否匹配危险操作模式"""
|
||||||
|
sql_upper = sql_query.upper()
|
||||||
|
|
||||||
|
# 生产环境额外的安全检查
|
||||||
|
if self.env_type == EnvironmentType.PRODUCTION:
|
||||||
|
# 生产环境中禁止所有非SELECT操作
|
||||||
|
if sql_upper.split()[0] != 'SELECT':
|
||||||
|
return True
|
||||||
|
|
||||||
|
for pattern in self.blocked_patterns:
|
||||||
|
if re.search(pattern, sql_upper, re.IGNORECASE):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _get_affected_tables(self, sql_query: str) -> list:
|
||||||
|
"""获取受影响的表名列表"""
|
||||||
|
words = sql_query.upper().split()
|
||||||
|
tables = []
|
||||||
|
|
||||||
|
for i, word in enumerate(words):
|
||||||
|
if word in {'FROM', 'JOIN', 'UPDATE', 'INTO', 'TABLE'}:
|
||||||
|
if i + 1 < len(words):
|
||||||
|
table = words[i + 1].strip('`;')
|
||||||
|
if table not in {'SELECT', 'WHERE', 'SET'}:
|
||||||
|
tables.append(table)
|
||||||
|
|
||||||
|
return list(set(tables))
|
||||||
|
|
||||||
|
def _estimate_impact(self, sql_query: str) -> dict:
|
||||||
|
"""
|
||||||
|
估算查询影响范围
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: 包含预估影响的字典
|
||||||
|
"""
|
||||||
|
operation = sql_query.split()[0].upper()
|
||||||
|
|
||||||
|
impact = {
|
||||||
|
'operation': operation,
|
||||||
|
'estimated_rows': 0,
|
||||||
|
'needs_where': operation in {'UPDATE', 'DELETE'},
|
||||||
|
'has_where': 'WHERE' in sql_query.upper()
|
||||||
|
}
|
||||||
|
|
||||||
|
# 根据环境类型调整估算
|
||||||
|
if self.env_type == EnvironmentType.PRODUCTION:
|
||||||
|
if operation == 'SELECT':
|
||||||
|
impact['estimated_rows'] = 100
|
||||||
|
else:
|
||||||
|
impact['estimated_rows'] = float('inf') # 生产环境中非SELECT操作视为影响无限行
|
||||||
|
else:
|
||||||
|
if operation == 'SELECT':
|
||||||
|
impact['estimated_rows'] = 100
|
||||||
|
elif operation in {'UPDATE', 'DELETE'}:
|
||||||
|
impact['estimated_rows'] = 1000 if impact['has_where'] else float('inf')
|
||||||
|
|
||||||
|
return impact
|
||||||
@ -2,6 +2,11 @@ from mcp.server.fastmcp import FastMCP
|
|||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
# 加载环境变量 - 移到最前面确保所有模块导入前环境变量已加载
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
# 导入自定义模块 - 确保在load_dotenv之后导入
|
||||||
from src.tools.mysql_tool import register_mysql_tool
|
from src.tools.mysql_tool import register_mysql_tool
|
||||||
|
|
||||||
# 配置日志
|
# 配置日志
|
||||||
@ -11,6 +16,10 @@ logging.basicConfig(
|
|||||||
)
|
)
|
||||||
logger = logging.getLogger("mysql_server")
|
logger = logging.getLogger("mysql_server")
|
||||||
|
|
||||||
|
# 记录环境变量加载情况
|
||||||
|
logger.debug("已加载环境变量")
|
||||||
|
logger.debug(f"当前允许的风险等级: {os.getenv('ALLOWED_RISK_LEVELS', '未设置')}")
|
||||||
|
|
||||||
# 尝试导入MySQL连接器
|
# 尝试导入MySQL连接器
|
||||||
try:
|
try:
|
||||||
import mysql.connector
|
import mysql.connector
|
||||||
@ -21,10 +30,6 @@ except ImportError as e:
|
|||||||
logger.critical("请确保已安装mysql-connector-python包: pip install mysql-connector-python")
|
logger.critical("请确保已安装mysql-connector-python包: pip install mysql-connector-python")
|
||||||
mysql_available = False
|
mysql_available = False
|
||||||
|
|
||||||
# 加载环境变量
|
|
||||||
load_dotenv()
|
|
||||||
logger.debug("已加载环境变量")
|
|
||||||
|
|
||||||
# 从环境变量获取服务器配置
|
# 从环境变量获取服务器配置
|
||||||
host = os.getenv('HOST', '127.0.0.1')
|
host = os.getenv('HOST', '127.0.0.1')
|
||||||
port = int(os.getenv('PORT', '3000'))
|
port = int(os.getenv('PORT', '3000'))
|
||||||
|
|||||||
3
src/tools/__init__.py
Normal file
3
src/tools/__init__.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
"""
|
||||||
|
MySQL工具包
|
||||||
|
"""
|
||||||
@ -24,7 +24,7 @@ def register_mysql_tool(mcp: FastMCP):
|
|||||||
logger.debug("注册MySQL查询工具...")
|
logger.debug("注册MySQL查询工具...")
|
||||||
|
|
||||||
@mcp.tool()
|
@mcp.tool()
|
||||||
def mysql_query(query: str, params: Optional[Dict[str, Any]] = None) -> str:
|
async def mysql_query(query: str, params: Optional[Dict[str, Any]] = None) -> str:
|
||||||
"""
|
"""
|
||||||
执行MySQL查询并返回结果
|
执行MySQL查询并返回结果
|
||||||
|
|
||||||
@ -39,7 +39,14 @@ def register_mysql_tool(mcp: FastMCP):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
with get_db_connection() as connection:
|
with get_db_connection() as connection:
|
||||||
results = execute_query(connection, query, params)
|
results = await execute_query(connection, query, params)
|
||||||
|
|
||||||
|
# 检查是否是修改操作返回的影响行数
|
||||||
|
operation = query.strip().split()[0].upper()
|
||||||
|
if operation in {'UPDATE', 'DELETE', 'INSERT'} and results and 'affected_rows' in results[0]:
|
||||||
|
affected_rows = results[0]['affected_rows']
|
||||||
|
logger.info(f"{operation}操作影响了{affected_rows}行数据")
|
||||||
|
|
||||||
return json.dumps(results, default=str)
|
return json.dumps(results, default=str)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
Reference in New Issue
Block a user