mirror of
https://github.com/mangooer/mysql-mcp-server-sse.git
synced 2025-12-08 09:42:27 +08:00
<feat> 重构配置管理与安全检查机制,新增SQL解析器,优化数据库连接池管理
This commit is contained in:
38
.env.example
38
.env.example
@ -1,38 +0,0 @@
|
||||
# MySQL数据库连接配置
|
||||
MYSQL_HOST=127.0.0.1
|
||||
MYSQL_PORT=3306
|
||||
MYSQL_USER=root
|
||||
MYSQL_PASSWORD=root
|
||||
MYSQL_DATABASE=test
|
||||
|
||||
# 服务器配置
|
||||
PORT=3000
|
||||
HOST=127.0.0.1
|
||||
|
||||
# 环境配置
|
||||
ENV_TYPE=development # 环境类型:development/production
|
||||
|
||||
# SQL风险控制配置
|
||||
ALLOWED_RISK_LEVELS=LOW,MEDIUM,HIGH # 允许的风险等级(LOW/MEDIUM/HIGH/CRITICAL)
|
||||
# 如需执行无WHERE条件的DELETE操作,需要将CRITICAL添加到ALLOWED_RISK_LEVELS中
|
||||
# 例如:ALLOWED_RISK_LEVELS=LOW,MEDIUM,HIGH,CRITICAL
|
||||
|
||||
# 禁止的SQL模式(正则表达式,用逗号分隔)
|
||||
BLOCKED_PATTERNS=DROP\s+DATABASE,TRUNCATE\s+TABLE
|
||||
|
||||
# SQL安全检查配置
|
||||
ENABLE_QUERY_CHECK=true # 是否启用SQL安全检查
|
||||
|
||||
# 配置示例说明
|
||||
# -------------------
|
||||
# 开发环境配置示例:
|
||||
# ENV_TYPE=development
|
||||
# ALLOWED_RISK_LEVELS=LOW,MEDIUM,HIGH # 允许除CRITICAL外的所有风险等级
|
||||
# BLOCKED_PATTERNS=DROP\s+DATABASE,TRUNCATE\s+TABLE # 仅禁止最危险的操作
|
||||
# ENABLE_QUERY_CHECK=true # 启用SQL安全检查
|
||||
|
||||
# 生产环境配置示例:
|
||||
# ENV_TYPE=production
|
||||
# ALLOWED_RISK_LEVELS=LOW # 只允许低风险操作(SELECT)
|
||||
# BLOCKED_PATTERNS=DROP,TRUNCATE,DELETE,UPDATE # 禁止所有数据修改操作
|
||||
# ENABLE_QUERY_CHECK=true # 强制启用SQL安全检查
|
||||
335
README.md
335
README.md
@ -1,209 +1,206 @@
|
||||
# MySQL查询服务器
|
||||
# MySQL查询服务器 / MySQL Query Server
|
||||
|
||||
这是一个基于MCP(Model-Controller-Provider)框架的MySQL查询服务器,提供了通过SSE(Server-Sent Events)进行MySQL数据库操作的功能。
|
||||
---
|
||||
|
||||
## 功能特点
|
||||
## 1. 项目简介 / Project Introduction
|
||||
|
||||
- 基于FastMCP框架构建
|
||||
- 支持SSE(Server-Sent Events)实时数据传输
|
||||
- 提供MySQL数据库查询接口
|
||||
- 完整的日志记录系统
|
||||
- 自动事务管理(提交/回滚)
|
||||
- 环境变量配置支持
|
||||
- SQL安全检查机制
|
||||
- 风险等级控制
|
||||
- SQL注入防护
|
||||
- 危险操作拦截
|
||||
- WHERE子句强制检查
|
||||
- 自动返回修改操作影响的行数
|
||||
- 敏感信息保护机制
|
||||
- 自动对元数据查询结果进行格式化和增强
|
||||
本项目是基于MCP框架的MySQL查询服务器,支持通过SSE协议进行实时数据库操作,具备完善的安全、日志、配置和敏感信息保护机制,适用于开发、测试和生产环境下的安全MySQL数据访问。
|
||||
|
||||
## API接口功能
|
||||
This project is a MySQL query server based on the MCP framework, supporting real-time database operations via SSE protocol. It features comprehensive security, logging, configuration, and sensitive information protection mechanisms, suitable for secure MySQL data access in development, testing, and production environments.
|
||||
|
||||
系统提供以下四大类工具:
|
||||
---
|
||||
|
||||
### 基础查询工具
|
||||
## 2. 主要特性 / Key Features
|
||||
|
||||
- `mysql_query`: 执行任意SQL查询,支持参数化查询
|
||||
- 基于FastMCP框架,异步高性能
|
||||
- 支持高并发的数据库连接池,参数灵活可调
|
||||
- 支持SSE实时推送
|
||||
- 丰富的MySQL元数据与结构查询API
|
||||
- 自动事务管理与回滚
|
||||
- 多级SQL风险控制与注入防护
|
||||
- 敏感信息自动隐藏与自定义
|
||||
- 灵活的环境变量配置
|
||||
- 完善的日志与错误处理
|
||||
|
||||
### 元数据查询工具
|
||||
- Built on FastMCP framework, high-performance async
|
||||
- Connection pool for high concurrency, with flexible parameter tuning
|
||||
- SSE real-time push support
|
||||
- Rich MySQL metadata & schema query APIs
|
||||
- Automatic transaction management & rollback
|
||||
- Multi-level SQL risk control & injection protection
|
||||
- Automatic and customizable sensitive info masking
|
||||
- Flexible environment variable configuration
|
||||
- Robust logging & error handling
|
||||
|
||||
- `mysql_show_tables`: 获取数据库中的表列表,支持模式匹配和限制结果数量
|
||||
- `mysql_show_columns`: 获取表的列信息
|
||||
- `mysql_describe_table`: 描述表结构
|
||||
- `mysql_show_create_table`: 获取表的创建语句
|
||||
---
|
||||
|
||||
### 数据库信息查询工具
|
||||
## 3. 快速开始 / Quick Start
|
||||
|
||||
- `mysql_show_databases`: 获取所有数据库列表,支持过滤系统数据库
|
||||
- `mysql_show_variables`: 获取MySQL服务器变量
|
||||
- `mysql_show_status`: 获取MySQL服务器状态信息
|
||||
|
||||
### 表结构高级查询工具
|
||||
|
||||
- `mysql_show_indexes`: 获取表的索引信息
|
||||
- `mysql_show_table_status`: 获取表状态信息
|
||||
- `mysql_show_foreign_keys`: 获取表的外键约束信息
|
||||
- `mysql_paginate_results`: 提供结果分页功能
|
||||
|
||||
## 系统要求
|
||||
|
||||
- Python 3.6+
|
||||
- MySQL服务器
|
||||
- 依赖包:
|
||||
- mysql-connector-python
|
||||
- python-dotenv
|
||||
- mcp (FastMCP框架)
|
||||
|
||||
## 安装步骤
|
||||
|
||||
1. 克隆项目到本地:
|
||||
```bash
|
||||
git clone [项目地址]
|
||||
cd mysql-query-server
|
||||
```
|
||||
|
||||
2. 安装依赖包:
|
||||
### 安装依赖 / Install Dependencies
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
3. 配置环境变量:
|
||||
- 复制`.env.example`文件并重命名为`.env`
|
||||
- 根据实际情况修改`.env`文件中的配置
|
||||
|
||||
## 环境变量配置
|
||||
|
||||
在`.env`文件中配置以下参数:
|
||||
|
||||
### 基本配置
|
||||
- `HOST`: 服务器监听地址(默认:127.0.0.1)
|
||||
- `PORT`: 服务器监听端口(默认:3000)
|
||||
- `MYSQL_HOST`: MySQL服务器地址
|
||||
- `MYSQL_PORT`: MySQL服务器端口
|
||||
- `MYSQL_USER`: MySQL用户名
|
||||
- `MYSQL_PASSWORD`: MySQL密码
|
||||
- `MYSQL_DATABASE`: MySQL数据库名
|
||||
|
||||
### SQL安全配置
|
||||
- `ENV_TYPE`: 环境类型(development/production)
|
||||
- `ALLOWED_RISK_LEVELS`: 允许的风险等级(LOW/MEDIUM/HIGH/CRITICAL)
|
||||
- `BLOCKED_PATTERNS`: 禁止的SQL模式(正则表达式,用逗号分隔)
|
||||
- `ENABLE_QUERY_CHECK`: 是否启用SQL安全检查(true/false)
|
||||
- `ALLOW_SENSITIVE_INFO`: 是否允许查询敏感信息(true/false)
|
||||
- `SENSITIVE_INFO_FIELDS`: 自定义敏感字段模式列表(逗号分隔)
|
||||
|
||||
## 安全机制详解
|
||||
|
||||
### 风险等级控制
|
||||
- LOW: 查询操作(SELECT)和元数据操作(SHOW, DESCRIBE等)
|
||||
- MEDIUM: 基本数据修改(INSERT,有WHERE的UPDATE/DELETE)
|
||||
- HIGH: 结构变更(CREATE/ALTER)和无WHERE的UPDATE
|
||||
- CRITICAL: 危险操作(DROP/TRUNCATE)和无WHERE的DELETE操作
|
||||
|
||||
### 环境特性差异
|
||||
- 开发环境:
|
||||
- 允许较高风险的操作
|
||||
- 不隐藏敏感信息
|
||||
- 提供详细的错误信息
|
||||
- 生产环境:
|
||||
- 默认只允许LOW风险操作
|
||||
- 严格限制数据修改
|
||||
- 自动隐藏敏感信息
|
||||
- 错误信息不暴露实现细节
|
||||
|
||||
### 敏感信息保护
|
||||
系统会自动检测并隐藏包含以下关键词的变量/状态值:
|
||||
- password、auth、credential、key、secret、private
|
||||
- ssl、tls、cipher、certificate
|
||||
- host、path、directory等系统路径信息
|
||||
|
||||
### 事务管理
|
||||
- 对于修改操作(INSERT/UPDATE/DELETE)会自动提交事务
|
||||
- 执行错误时自动回滚事务
|
||||
- 返回操作影响的行数
|
||||
|
||||
## 启动服务器
|
||||
|
||||
运行以下命令启动服务器:
|
||||
### 配置环境变量 / Configure Environment Variables
|
||||
复制`.env.example`为`.env`,并根据实际情况修改。
|
||||
Copy `.env.example` to `.env` and modify as needed.
|
||||
|
||||
### 启动服务 / Start the Server
|
||||
```bash
|
||||
python src/server.py
|
||||
python -m src.server
|
||||
```
|
||||
默认监听:http://127.0.0.1:3000/sse
|
||||
Default endpoint: http://127.0.0.1:3000/sse
|
||||
|
||||
服务器将在配置的地址和端口上启动,默认为 `http://127.0.0.1:3000/sse`
|
||||
---
|
||||
|
||||
## 项目结构
|
||||
## 4. 目录结构 / Project Structure
|
||||
|
||||
```
|
||||
.
|
||||
├── src/ # 源代码目录
|
||||
│ ├── server.py # 主服务器文件
|
||||
│ ├── db/ # 数据库相关代码
|
||||
│ │ └── mysql_operations.py # MySQL操作实现
|
||||
│ ├── security/ # SQL安全相关代码
|
||||
│ │ ├── interceptor.py # SQL拦截器
|
||||
│ │ ├── query_limiter.py # SQL安全检查器
|
||||
│ │ └── sql_analyzer.py # SQL分析器
|
||||
│ └── tools/ # 工具类代码
|
||||
│ ├── mysql_tool.py # 基础查询工具
|
||||
│ ├── mysql_metadata_tool.py # 元数据查询工具
|
||||
│ ├── mysql_info_tool.py # 数据库信息查询工具
|
||||
│ ├── mysql_schema_tool.py # 表结构高级查询工具
|
||||
│ └── metadata_base_tool.py # 元数据工具基类
|
||||
├── tests/ # 测试代码目录
|
||||
├── .env.example # 环境变量示例文件
|
||||
└── requirements.txt # 项目依赖文件
|
||||
├── src/
|
||||
│ ├── server.py # 主服务器入口 / Main server entry
|
||||
│ ├── config.py # 配置项定义 / Config definitions
|
||||
│ ├── validators.py # 参数校验 / Parameter validation
|
||||
│ ├── db/
|
||||
│ │ └── mysql_operations.py # 数据库操作 / DB operations
|
||||
│ ├── security/
|
||||
│ │ ├── interceptor.py # SQL拦截 / SQL interception
|
||||
│ │ ├── query_limiter.py # 风险控制 / Risk control
|
||||
│ │ └── sql_analyzer.py # SQL分析 / SQL analysis
|
||||
│ └── tools/
|
||||
│ ├── mysql_tool.py # 基础查询 / Basic query
|
||||
│ ├── mysql_metadata_tool.py # 元数据查询 / Metadata query
|
||||
│ ├── mysql_info_tool.py # 信息查询 / Info query
|
||||
│ ├── mysql_schema_tool.py # 结构查询 / Schema query
|
||||
│ └── metadata_base_tool.py # 工具基类 / Tool base class
|
||||
├── tests/ # 测试 / Tests
|
||||
├── .env.example # 环境变量示例 / Env example
|
||||
└── requirements.txt # 依赖 / Requirements
|
||||
```
|
||||
|
||||
## 常见问题解决
|
||||
---
|
||||
|
||||
### DELETE操作未执行成功
|
||||
- 检查DELETE操作是否包含WHERE条件
|
||||
- 无WHERE条件的DELETE操作被标记为CRITICAL风险级别
|
||||
- 确保环境变量ALLOWED_RISK_LEVELS中包含CRITICAL(如果需要执行该操作)
|
||||
- 检查影响行数返回值,确认操作是否实际影响了数据库
|
||||
## 5. 环境变量与配置 / Environment Variables & Configuration
|
||||
|
||||
### 环境变量未生效
|
||||
- 确保在server.py中的load_dotenv()调用发生在导入其他模块之前
|
||||
- 重启应用以确保环境变量被正确加载
|
||||
- 检查日志中"从环境变量读取到的风险等级设置"的输出
|
||||
| 变量名 / Variable | 说明 / Description | 默认值 / Default |
|
||||
|--------------------------|------------------------------------------------------|------------------|
|
||||
| HOST | 服务器监听地址 / Server listen address | 127.0.0.1 |
|
||||
| PORT | 服务器监听端口 / Server listen port | 3000 |
|
||||
| MYSQL_HOST | MySQL服务器地址 / MySQL server host | localhost |
|
||||
| MYSQL_PORT | MySQL服务器端口 / MySQL server port | 3306 |
|
||||
| MYSQL_USER | MySQL用户名 / MySQL username | root |
|
||||
| MYSQL_PASSWORD | MySQL密码 / MySQL password | (空/empty) |
|
||||
| MYSQL_DATABASE | 要连接的数据库名 / Database name | (空/empty) |
|
||||
| DB_CONNECTION_TIMEOUT | 连接超时时间(秒) / Connection timeout (seconds) | 5 |
|
||||
| DB_AUTH_PLUGIN | 认证插件类型 / Auth plugin type | mysql_native_password |
|
||||
| DB_POOL_ENABLED | 是否启用连接池 / Enable connection pool (true/false) | true |
|
||||
| DB_POOL_MIN_SIZE | 连接池最小连接数 / Pool min size | 5 |
|
||||
| DB_POOL_MAX_SIZE | 连接池最大连接数 / Pool max size | 20 |
|
||||
| DB_POOL_RECYCLE | 连接回收时间(秒) / Pool recycle time (seconds) | 300 |
|
||||
| DB_POOL_MAX_LIFETIME | 连接最大存活时间(秒, 0=不限制) / Max lifetime (sec) | 0 |
|
||||
| DB_POOL_ACQUIRE_TIMEOUT | 获取连接超时时间(秒) / Acquire timeout (seconds) | 10.0 |
|
||||
| ENV_TYPE | 环境类型(development/production) / Env type | development |
|
||||
| ALLOWED_RISK_LEVELS | 允许的风险等级(逗号分隔) / Allowed risk levels | LOW,MEDIUM |
|
||||
| ALLOW_SENSITIVE_INFO | 允许查询敏感字段 / Allow sensitive info (true/false) | false |
|
||||
| SENSITIVE_INFO_FIELDS | 自定义敏感字段模式(逗号分隔) / Custom sensitive fields | (空/empty) |
|
||||
| MAX_SQL_LENGTH | 最大SQL语句长度 / Max SQL length | 5000 |
|
||||
| BLOCKED_PATTERNS | 阻止的SQL模式(逗号分隔) / Blocked SQL patterns | (空/empty) |
|
||||
| ENABLE_QUERY_CHECK | 启用查询安全检查 / Enable query check (true/false) | true |
|
||||
| LOG_LEVEL | 日志级别(DEBUG/INFO/...) / Log level | DEBUG |
|
||||
|
||||
### 操作被安全机制拒绝
|
||||
- 检查操作的风险级别是否在允许的范围内
|
||||
- 如果需要执行高风险操作,相应地调整ALLOWED_RISK_LEVELS
|
||||
- 对于不带WHERE条件的UPDATE或DELETE,可以添加条件(即使是WHERE 1=1)降低风险级别
|
||||
> 注/Note: 部分云MySQL需指定`DB_AUTH_PLUGIN`为`mysql_native_password`。
|
||||
|
||||
### 无法查看敏感信息
|
||||
- 在开发环境中,设置ALLOW_SENSITIVE_INFO=true
|
||||
- 在生产环境中,敏感信息默认会被隐藏,这是安全特性
|
||||
---
|
||||
|
||||
## 日志系统
|
||||
## 6. 自动化与资源管理优化 / Automation & Resource Management Enhancements
|
||||
|
||||
服务器包含完整的日志记录系统,可以在控制台和日志文件中查看运行状态和错误信息。日志级别可以在`server.py`中配置。
|
||||
### 自动化工具注册 / Automated Tool Registration
|
||||
- 所有MySQL相关API工具均采用自动注册机制:
|
||||
- 无需手动在主入口维护注册代码,新增/删除工具只需在`src/tools/`目录下实现`register_xxx_tool(s)`函数即可。
|
||||
- 系统启动时自动扫描并注册,极大提升可维护性和扩展性。
|
||||
- All MySQL-related API tools are registered automatically:
|
||||
- No need to manually maintain registration code in the main entry. To add or remove a tool, simply implement a `register_xxx_tool(s)` function in the `src/tools/` directory.
|
||||
- The system scans and registers tools automatically at startup, greatly improving maintainability and extensibility.
|
||||
|
||||
## 错误处理
|
||||
### 连接池自动回收与资源管理 / Connection Pool Auto-Recycling & Resource Management
|
||||
- 连接池采用事件循环隔离与自动回收机制:
|
||||
- 每个事件循环独立池,支持高并发与多环境。
|
||||
- 定期(默认每5分钟)自动回收无效或失效的连接池,防止资源泄漏。
|
||||
- 事件循环关闭时自动关闭对应连接池,确保资源彻底释放。
|
||||
- 支持多数据库/多租户场景扩展。
|
||||
- 所有资源管理操作均有详细日志,便于追踪和排查。
|
||||
- The connection pool uses event loop isolation and auto-recycling:
|
||||
- Each event loop has its own pool, supporting high concurrency and multi-environment deployment.
|
||||
- Unused or invalid pools are automatically recycled every 5 minutes (by default), preventing resource leaks.
|
||||
- When an event loop is closed, its pool is automatically closed to ensure complete resource release.
|
||||
- Ready for multi-database/multi-tenant scenarios.
|
||||
- All resource management operations are logged in detail for easy tracking and troubleshooting.
|
||||
|
||||
服务器包含完善的错误处理机制:
|
||||
- MySQL连接器导入检查
|
||||
- 数据库配置验证
|
||||
- SQL安全检查
|
||||
- 运行时错误捕获和记录
|
||||
- 事务自动回滚
|
||||
---
|
||||
|
||||
## 贡献指南
|
||||
## 7. 安全机制 / Security Mechanisms
|
||||
|
||||
欢迎提交Issue和Pull Request来改进项目。
|
||||
- 多级SQL风险等级(LOW/MEDIUM/HIGH/CRITICAL)
|
||||
- SQL注入与危险操作拦截
|
||||
- WHERE子句强制检查
|
||||
- 敏感信息自动隐藏(支持自定义字段)
|
||||
- 生产环境默认只允许低风险操作
|
||||
|
||||
## 许可证
|
||||
- Multi-level SQL risk levels (LOW/MEDIUM/HIGH/CRITICAL)
|
||||
- SQL injection & dangerous operation interception
|
||||
- Mandatory WHERE clause check
|
||||
- Automatic sensitive info masking (customizable fields)
|
||||
- Production allows only low-risk operations by default
|
||||
|
||||
---
|
||||
|
||||
## 8. 日志与错误处理 / Logging & Error Handling
|
||||
|
||||
- 日志级别可配置(LOG_LEVEL)
|
||||
- 控制台与文件日志输出
|
||||
- 详细记录运行状态与错误
|
||||
- 完善的异常捕获与事务回滚
|
||||
|
||||
- Configurable log level (LOG_LEVEL)
|
||||
- Console & file log output
|
||||
- Detailed running status & error logs
|
||||
- Robust exception capture & transaction rollback
|
||||
|
||||
---
|
||||
|
||||
## 9. 常见问题 / FAQ
|
||||
|
||||
### Q: DELETE操作未执行成功?
|
||||
A: 检查是否有WHERE条件,无WHERE为高风险,需在ALLOWED_RISK_LEVELS中允许CRITICAL。
|
||||
|
||||
Q: Why does DELETE not work?
|
||||
A: Check for WHERE clause. DELETE without WHERE is high risk (CRITICAL), must be allowed in ALLOWED_RISK_LEVELS.
|
||||
|
||||
### Q: 如何自定义敏感字段?
|
||||
A: 设置SENSITIVE_INFO_FIELDS,如SENSITIVE_INFO_FIELDS=password,token
|
||||
|
||||
Q: How to customize sensitive fields?
|
||||
A: Set SENSITIVE_INFO_FIELDS, e.g. SENSITIVE_INFO_FIELDS=password,token
|
||||
|
||||
### Q: limit参数报错?
|
||||
A: limit必须为非负整数。
|
||||
|
||||
Q: limit parameter error?
|
||||
A: limit must be a non-negative integer.
|
||||
|
||||
---
|
||||
|
||||
## 10. 贡献指南 / Contribution Guide
|
||||
|
||||
欢迎通过Issue和Pull Request参与改进。
|
||||
Contributions via Issue and Pull Request are welcome.
|
||||
|
||||
---
|
||||
|
||||
## 11. 许可证 / License
|
||||
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2024 MCP MySQL Query Server
|
||||
|
||||
特此免费授予任何获得本软件副本和相关文档文件("软件")的人不受限制地处理本软件的权利,包括不受限制地使用、复制、修改、合并、发布、分发、再许可和/或出售本软件副本,以及允许本软件的使用者这样做,但须符合以下条件:
|
||||
|
||||
上述版权声明和本许可声明应包含在本软件的所有副本或重要部分中。
|
||||
|
||||
本软件按"原样"提供,不提供任何形式的明示或暗示的保证,包括但不限于对适销性、特定用途的适用性和非侵权性的保证。在任何情况下,作者或版权持有人均不对任何索赔、损害或其他责任负责,无论是在合同诉讼、侵权行为还是其他方面,产生于、源于或与本软件有关,或与本软件的使用或其他交易有关。
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
46
example.env
Normal file
46
example.env
Normal file
@ -0,0 +1,46 @@
|
||||
# 服务器配置
|
||||
HOST=127.0.0.1 # 服务器监听地址
|
||||
PORT=3000 # 服务器监听端口
|
||||
|
||||
# 数据库连接配置
|
||||
MYSQL_HOST=localhost # MySQL服务器地址
|
||||
MYSQL_PORT=3306 # MySQL服务器端口
|
||||
MYSQL_USER=root # MySQL用户名
|
||||
MYSQL_PASSWORD= # MySQL密码(留空表示无密码)
|
||||
MYSQL_DATABASE=testdb # 要连接的数据库名
|
||||
DB_CONNECTION_TIMEOUT=5 # 连接超时时间(秒)
|
||||
DB_AUTH_PLUGIN=mysql_native_password # 认证插件类型
|
||||
|
||||
# 数据库连接池配置
|
||||
DB_POOL_ENABLED=true # 是否启用连接池 (true/false)
|
||||
DB_POOL_MIN_SIZE=5 # 连接池最小连接数
|
||||
DB_POOL_MAX_SIZE=20 # 连接池最大连接数
|
||||
DB_POOL_RECYCLE=300 # 连接回收时间(秒)
|
||||
DB_POOL_MAX_LIFETIME=0 # 连接最大存活时间(秒,0表示不限制)
|
||||
DB_POOL_ACQUIRE_TIMEOUT=10.0 # 获取连接超时时间(秒)
|
||||
|
||||
# 环境类型
|
||||
# development: 开发环境,较少限制
|
||||
# production: 生产环境,严格安全限制
|
||||
ENV_TYPE=development
|
||||
|
||||
# 安全配置
|
||||
# 允许的风险等级: LOW(查询), MEDIUM(安全修改), HIGH(结构变更), CRITICAL(危险操作)
|
||||
ALLOWED_RISK_LEVELS=LOW,MEDIUM
|
||||
|
||||
# 是否允许查询敏感字段信息(密码,凭证等)
|
||||
ALLOW_SENSITIVE_INFO=false
|
||||
|
||||
# 最大SQL语句长度限制
|
||||
MAX_SQL_LENGTH=5000
|
||||
|
||||
# 阻止的SQL模式,多个模式用逗号分隔
|
||||
# 例如: 'DROP TABLE,DROP DATABASE,DELETE FROM' 将阻止包含这些字符串的查询
|
||||
BLOCKED_PATTERNS=
|
||||
|
||||
# 是否启用查询安全检查
|
||||
ENABLE_QUERY_CHECK=true
|
||||
|
||||
# 日志配置
|
||||
# DEBUG, INFO, WARNING, ERROR, CRITICAL
|
||||
LOG_LEVEL=DEBUG
|
||||
@ -1,3 +1,4 @@
|
||||
mcp>=0.1.0
|
||||
mysql-connector-python==8.0.33
|
||||
aiomysql==0.2.0
|
||||
python-dotenv>=0.19.0
|
||||
sqlparse>=0.4.2
|
||||
129
src/config.py
Normal file
129
src/config.py
Normal file
@ -0,0 +1,129 @@
|
||||
import os
|
||||
from typing import Set, List
|
||||
from enum import IntEnum, Enum
|
||||
|
||||
# 环境变量
|
||||
class EnvironmentType(Enum):
|
||||
"""环境类型"""
|
||||
DEVELOPMENT = 'development'
|
||||
PRODUCTION = 'production'
|
||||
|
||||
class SQLRiskLevel(IntEnum):
|
||||
"""SQL操作风险等级"""
|
||||
LOW = 1 # 查询操作(SELECT)
|
||||
MEDIUM = 2 # 基本数据修改(INSERT,有WHERE的UPDATE/DELETE)
|
||||
HIGH = 3 # 结构变更(CREATE/ALTER)和无WHERE的数据修改
|
||||
CRITICAL = 4 # 危险操作(DROP/TRUNCATE等)
|
||||
|
||||
# 服务器配置
|
||||
class ServerConfig:
|
||||
"""服务器配置"""
|
||||
HOST = os.getenv('HOST', '127.0.0.1')
|
||||
PORT = int(os.getenv('PORT', '3000'))
|
||||
|
||||
# 数据库配置
|
||||
class DatabaseConfig:
|
||||
"""数据库连接配置"""
|
||||
HOST = os.getenv('MYSQL_HOST', 'localhost')
|
||||
USER = os.getenv('MYSQL_USER', 'root')
|
||||
PASSWORD = os.getenv('MYSQL_PASSWORD', '')
|
||||
DATABASE = os.getenv('MYSQL_DATABASE', '')
|
||||
PORT = int(os.getenv('MYSQL_PORT', '3306'))
|
||||
CONNECTION_TIMEOUT = int(os.getenv('DB_CONNECTION_TIMEOUT', '5'))
|
||||
AUTH_PLUGIN = os.getenv('DB_AUTH_PLUGIN', 'mysql_native_password')
|
||||
|
||||
@staticmethod
|
||||
def get_config():
|
||||
"""获取数据库配置字典"""
|
||||
return {
|
||||
'host': DatabaseConfig.HOST,
|
||||
'user': DatabaseConfig.USER,
|
||||
'password': DatabaseConfig.PASSWORD,
|
||||
'database': DatabaseConfig.DATABASE,
|
||||
'port': DatabaseConfig.PORT,
|
||||
'connection_timeout': DatabaseConfig.CONNECTION_TIMEOUT,
|
||||
'auth_plugin': DatabaseConfig.AUTH_PLUGIN
|
||||
}
|
||||
|
||||
# 数据库连接池配置
|
||||
class ConnectionPoolConfig:
|
||||
"""数据库连接池配置"""
|
||||
# 连接池最小连接数
|
||||
MIN_SIZE = int(os.getenv('DB_POOL_MIN_SIZE', '5'))
|
||||
# 连接池最大连接数
|
||||
MAX_SIZE = int(os.getenv('DB_POOL_MAX_SIZE', '20'))
|
||||
# 连接池回收时间(秒)
|
||||
POOL_RECYCLE = int(os.getenv('DB_POOL_RECYCLE', '300'))
|
||||
# 连接最大存活时间(秒,0表示不限制)
|
||||
MAX_LIFETIME = int(os.getenv('DB_POOL_MAX_LIFETIME', '0'))
|
||||
# 连接获取超时时间(秒)
|
||||
ACQUIRE_TIMEOUT = float(os.getenv('DB_POOL_ACQUIRE_TIMEOUT', '10.0'))
|
||||
# 是否启用连接池
|
||||
ENABLED = os.getenv('DB_POOL_ENABLED', 'true').lower() in ('true', 'yes', '1')
|
||||
|
||||
@staticmethod
|
||||
def get_config():
|
||||
"""获取连接池配置字典"""
|
||||
return {
|
||||
'minsize': ConnectionPoolConfig.MIN_SIZE,
|
||||
'maxsize': ConnectionPoolConfig.MAX_SIZE,
|
||||
'pool_recycle': ConnectionPoolConfig.POOL_RECYCLE,
|
||||
'max_lifetime': ConnectionPoolConfig.MAX_LIFETIME,
|
||||
'acquire_timeout': ConnectionPoolConfig.ACQUIRE_TIMEOUT,
|
||||
'enabled': ConnectionPoolConfig.ENABLED
|
||||
}
|
||||
|
||||
# 安全配置
|
||||
class SecurityConfig:
|
||||
"""安全相关配置"""
|
||||
# 环境类型
|
||||
ENV_TYPE_STR = os.getenv('ENV_TYPE', 'development').lower()
|
||||
try:
|
||||
ENV_TYPE = EnvironmentType(ENV_TYPE_STR)
|
||||
except ValueError:
|
||||
ENV_TYPE = EnvironmentType.DEVELOPMENT
|
||||
|
||||
# 允许的风险等级
|
||||
ALLOWED_RISK_LEVELS_STR = os.getenv('ALLOWED_RISK_LEVELS', 'LOW,MEDIUM')
|
||||
ALLOWED_RISK_LEVELS = set()
|
||||
for level_str in ALLOWED_RISK_LEVELS_STR.upper().split(','):
|
||||
level_str = level_str.strip()
|
||||
try:
|
||||
ALLOWED_RISK_LEVELS.add(SQLRiskLevel[level_str])
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
# 如果是生产环境且没有明确配置风险等级,则只允许LOW风险操作
|
||||
if ENV_TYPE == EnvironmentType.PRODUCTION and not os.getenv('ALLOWED_RISK_LEVELS'):
|
||||
ALLOWED_RISK_LEVELS = {SQLRiskLevel.LOW}
|
||||
|
||||
# 最大SQL长度
|
||||
MAX_SQL_LENGTH = int(os.getenv('MAX_SQL_LENGTH', '1000'))
|
||||
|
||||
# 敏感信息查询
|
||||
ALLOW_SENSITIVE_INFO = os.getenv('ALLOW_SENSITIVE_INFO', 'false').lower() in ('true', 'yes', '1')
|
||||
|
||||
# 阻止的模式
|
||||
BLOCKED_PATTERNS_STR = os.getenv('BLOCKED_PATTERNS', '')
|
||||
BLOCKED_PATTERNS = [p.strip() for p in BLOCKED_PATTERNS_STR.split(',') if p.strip()]
|
||||
|
||||
# 查询检查
|
||||
ENABLE_QUERY_CHECK = os.getenv('ENABLE_QUERY_CHECK', 'true').lower() in ('true', 'yes', '1')
|
||||
|
||||
# SQL操作配置
|
||||
class SQLConfig:
|
||||
"""SQL操作相关配置"""
|
||||
# 基础操作集合
|
||||
DDL_OPERATIONS = {
|
||||
'CREATE', 'ALTER', 'DROP', 'TRUNCATE', 'RENAME'
|
||||
}
|
||||
|
||||
DML_OPERATIONS = {
|
||||
'SELECT', 'INSERT', 'UPDATE', 'DELETE', 'MERGE'
|
||||
}
|
||||
|
||||
# 元数据操作集合
|
||||
METADATA_OPERATIONS = {
|
||||
'SHOW', 'DESC', 'DESCRIBE', 'EXPLAIN', 'HELP',
|
||||
'ANALYZE', 'CHECK', 'CHECKSUM', 'OPTIMIZE'
|
||||
}
|
||||
@ -1,70 +1,309 @@
|
||||
import os
|
||||
import logging
|
||||
import mysql.connector
|
||||
from mysql.connector import Error
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, List, Optional
|
||||
import aiomysql
|
||||
import asyncio
|
||||
import time
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
import threading
|
||||
import weakref
|
||||
|
||||
from ..config import DatabaseConfig, SecurityConfig, SQLConfig, ConnectionPoolConfig
|
||||
from ..security.sql_analyzer import SQLOperationType
|
||||
from ..security.query_limiter import QueryLimiter
|
||||
from ..security.interceptor import SQLInterceptor, SecurityException
|
||||
from ..security.sql_parser import SQLParser
|
||||
|
||||
logger = logging.getLogger("mysql_server")
|
||||
|
||||
# 初始化安全组件
|
||||
sql_analyzer = SQLOperationType()
|
||||
query_limiter = QueryLimiter()
|
||||
sql_interceptor = SQLInterceptor(sql_analyzer)
|
||||
|
||||
# 全局连接池 - 使用线程本地存储
|
||||
_pools = threading.local()
|
||||
|
||||
# 定期回收无效连接池
|
||||
_cleanup_interval = 300 # 秒,可根据需要调整
|
||||
_last_cleanup = 0
|
||||
|
||||
def _cleanup_unused_pools():
|
||||
"""回收无效或已关闭的连接池,释放资源"""
|
||||
global _last_cleanup
|
||||
now = time.time()
|
||||
if now - _last_cleanup < _cleanup_interval:
|
||||
return
|
||||
_last_cleanup = now
|
||||
if hasattr(_pools, 'pools'):
|
||||
to_remove = []
|
||||
for loop_id, pool in list(_pools.pools.items()):
|
||||
# 检查事件循环是否还活着
|
||||
if pool.closed:
|
||||
to_remove.append(loop_id)
|
||||
continue
|
||||
# 尝试获取事件循环对象
|
||||
for loop in asyncio.all_tasks():
|
||||
if id(loop.get_loop()) == loop_id:
|
||||
break
|
||||
else:
|
||||
# 没找到对应事件循环,关闭池
|
||||
pool.close()
|
||||
to_remove.append(loop_id)
|
||||
logger.info(f"检测到无主事件循环,已关闭连接池 (事件循环ID: {loop_id})")
|
||||
for loop_id in to_remove:
|
||||
del _pools.pools[loop_id]
|
||||
|
||||
def get_db_config():
|
||||
"""动态获取数据库配置"""
|
||||
return {
|
||||
'host': os.getenv('MYSQL_HOST', 'localhost'),
|
||||
'user': os.getenv('MYSQL_USER', 'root'),
|
||||
'password': os.getenv('MYSQL_PASSWORD', ''),
|
||||
'database': os.getenv('MYSQL_DATABASE', ''),
|
||||
'port': int(os.getenv('MYSQL_PORT', '3306')),
|
||||
'connection_timeout': 5,
|
||||
'auth_plugin': 'mysql_native_password' # 指定认证插件
|
||||
# 获取基础配置
|
||||
config = DatabaseConfig.get_config()
|
||||
|
||||
# aiomysql使用不同的配置键名,进行映射
|
||||
aiomysql_config = {
|
||||
'host': config['host'],
|
||||
'user': config['user'],
|
||||
'password': config['password'],
|
||||
'db': config['database'], # 'database' -> 'db'
|
||||
'port': config['port'],
|
||||
'connect_timeout': config.get('connection_timeout', 5), # 'connection_timeout' -> 'connect_timeout'
|
||||
# auth_plugin在aiomysql中不直接支持,忽略此参数
|
||||
}
|
||||
|
||||
@contextmanager
|
||||
def get_db_connection():
|
||||
return aiomysql_config
|
||||
|
||||
# 自定义异常类,细化错误处理
|
||||
class MySQLConnectionError(Exception):
|
||||
"""数据库连接错误基类"""
|
||||
pass
|
||||
|
||||
class MySQLAuthError(MySQLConnectionError):
|
||||
"""认证错误"""
|
||||
pass
|
||||
|
||||
class MySQLDatabaseNotFoundError(MySQLConnectionError):
|
||||
"""数据库不存在错误"""
|
||||
pass
|
||||
|
||||
class MySQLServerError(MySQLConnectionError):
|
||||
"""服务器连接错误"""
|
||||
pass
|
||||
|
||||
class MySQLAuthPluginError(MySQLConnectionError):
|
||||
"""认证插件错误"""
|
||||
pass
|
||||
|
||||
async def init_db_pool(min_size: Optional[int] = None, max_size: Optional[int] = None, require_database: bool = True):
|
||||
"""
|
||||
创建数据库连接的上下文管理器
|
||||
初始化数据库连接池
|
||||
|
||||
Args:
|
||||
min_size: 连接池最小连接数 (可选,默认从配置读取)
|
||||
max_size: 连接池最大连接数 (可选,默认从配置读取)
|
||||
require_database: 是否要求指定数据库
|
||||
|
||||
Returns:
|
||||
连接池对象
|
||||
|
||||
Raises:
|
||||
MySQLConnectionError: 连接池初始化失败时
|
||||
"""
|
||||
try:
|
||||
# 获取数据库配置
|
||||
db_config = get_db_config()
|
||||
|
||||
# 检查是否需要数据库名
|
||||
if require_database and not db_config.get('db'):
|
||||
raise MySQLDatabaseNotFoundError("数据库名称未设置,请检查环境变量MYSQL_DATABASE")
|
||||
|
||||
# 如果不需要指定数据库,且db为空,则移除db参数
|
||||
if not require_database and not db_config.get('db'):
|
||||
db_config.pop('db', None)
|
||||
|
||||
# 获取当前事件循环
|
||||
current_loop = asyncio.get_event_loop()
|
||||
loop_id = id(current_loop)
|
||||
|
||||
# 获取连接池配置
|
||||
pool_config = ConnectionPoolConfig.get_config()
|
||||
|
||||
# 使用传入的参数或者配置值
|
||||
min_size = min_size if min_size is not None else pool_config['minsize']
|
||||
max_size = max_size if max_size is not None else pool_config['maxsize']
|
||||
pool_recycle = pool_config['pool_recycle']
|
||||
|
||||
# 检查是否启用连接池
|
||||
if not pool_config['enabled']:
|
||||
logger.warning("连接池功能已被禁用,使用直接连接")
|
||||
# 创建单连接的池
|
||||
min_size = 1
|
||||
max_size = 1
|
||||
|
||||
# 创建连接池
|
||||
logger.info(f"初始化连接池: 最小连接数={min_size}, 最大连接数={max_size}, 回收时间={pool_recycle}秒")
|
||||
pool = await aiomysql.create_pool(
|
||||
minsize=min_size,
|
||||
maxsize=max_size,
|
||||
pool_recycle=pool_recycle,
|
||||
echo=False, # 不记录SQL执行日志,由我们自己的日志系统处理
|
||||
loop=current_loop, # 显式指定事件循环
|
||||
**db_config
|
||||
)
|
||||
|
||||
# 将池存储在线程本地存储中,键是事件循环ID
|
||||
if not hasattr(_pools, 'pools'):
|
||||
_pools.pools = {}
|
||||
_pools.pools[loop_id] = pool
|
||||
|
||||
# 注册事件循环关闭时自动清理
|
||||
def _finalizer(p=pool, lid=loop_id):
|
||||
if not p.closed:
|
||||
p.close()
|
||||
logger.info(f"事件循环关闭时自动关闭连接池 (事件循环ID: {lid})")
|
||||
try:
|
||||
weakref.finalize(current_loop, _finalizer)
|
||||
except Exception as e:
|
||||
logger.warning(f"注册事件循环关闭回调失败: {e}")
|
||||
|
||||
logger.info(f"MySQL连接池初始化成功,最小连接数: {min_size},最大连接数: {max_size},事件循环ID: {loop_id}")
|
||||
return pool
|
||||
except aiomysql.Error as err:
|
||||
error_msg = str(err)
|
||||
logger.error(f"数据库连接池初始化失败: {error_msg}")
|
||||
|
||||
# 细化错误类型
|
||||
if "Access denied" in error_msg:
|
||||
raise MySQLAuthError("访问被拒绝,请检查用户名和密码")
|
||||
elif "Unknown database" in error_msg:
|
||||
raise MySQLDatabaseNotFoundError(f"数据库'{db_config.get('db', '')}'不存在")
|
||||
elif "Can't connect" in error_msg or "Connection refused" in error_msg:
|
||||
raise MySQLServerError("无法连接到MySQL服务器,请检查服务是否启动")
|
||||
elif "Authentication plugin" in error_msg:
|
||||
raise MySQLAuthPluginError(f"认证插件问题: {error_msg},请尝试修改用户认证方式为mysql_native_password")
|
||||
else:
|
||||
raise MySQLConnectionError(f"数据库连接失败: {error_msg}")
|
||||
except Exception as e:
|
||||
logger.error(f"连接池初始化发生未预期错误: {str(e)}")
|
||||
raise MySQLConnectionError(f"连接池初始化失败: {str(e)}")
|
||||
|
||||
def get_pool_for_current_loop():
|
||||
"""获取当前事件循环对应的连接池"""
|
||||
_cleanup_unused_pools() # 每次获取时尝试回收
|
||||
try:
|
||||
# 获取当前事件循环ID
|
||||
current_loop = asyncio.get_event_loop()
|
||||
loop_id = id(current_loop)
|
||||
|
||||
# 检查是否有此循环的连接池
|
||||
if hasattr(_pools, 'pools') and loop_id in _pools.pools:
|
||||
pool = _pools.pools[loop_id]
|
||||
# 检查连接池是否已关闭
|
||||
if pool.closed:
|
||||
logger.debug(f"连接池已关闭,将重新创建 (事件循环ID: {loop_id})")
|
||||
return None
|
||||
return pool
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"获取当前事件循环的连接池失败: {str(e)}")
|
||||
return None
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_db_connection(require_database: bool = True):
|
||||
"""
|
||||
从连接池获取数据库连接的异步上下文管理器
|
||||
|
||||
Args:
|
||||
require_database: 是否要求必须指定数据库。设置为False时可以执行如SHOW DATABASES等不需要
|
||||
指定具体数据库的操作。
|
||||
|
||||
Yields:
|
||||
mysql.connector.connection.MySQLConnection: 数据库连接对象
|
||||
aiomysql.Connection: 数据库连接对象
|
||||
"""
|
||||
connection = None
|
||||
try:
|
||||
db_config = get_db_config()
|
||||
if not db_config['database']:
|
||||
raise ValueError("数据库名称未设置,请检查环境变量MYSQL_DATABASE")
|
||||
# 获取当前事件循环的连接池
|
||||
pool = get_pool_for_current_loop()
|
||||
|
||||
connection = mysql.connector.connect(**db_config)
|
||||
yield connection
|
||||
except mysql.connector.Error as err:
|
||||
# 如果没有连接池,则初始化一个
|
||||
if pool is None:
|
||||
pool = await init_db_pool(require_database=require_database)
|
||||
|
||||
try:
|
||||
# 从连接池获取连接
|
||||
async with pool.acquire() as connection:
|
||||
yield connection
|
||||
except aiomysql.Error as err:
|
||||
error_msg = str(err)
|
||||
logger.error(f"数据库连接失败: {error_msg}")
|
||||
logger.error(f"获取数据库连接失败: {error_msg}")
|
||||
|
||||
if "Access denied" in error_msg:
|
||||
raise ValueError("访问被拒绝,请检查用户名和密码")
|
||||
raise MySQLAuthError("访问被拒绝,请检查用户名和密码")
|
||||
elif "Unknown database" in error_msg:
|
||||
db_config = get_db_config()
|
||||
raise ValueError(f"数据库'{db_config['database']}'不存在")
|
||||
raise MySQLDatabaseNotFoundError(f"数据库'{db_config.get('db', '')}'不存在")
|
||||
elif "Can't connect" in error_msg or "Connection refused" in error_msg:
|
||||
raise ConnectionError("无法连接到MySQL服务器,请检查服务是否启动")
|
||||
raise MySQLServerError("无法连接到MySQL服务器,请检查服务是否启动")
|
||||
elif "Authentication plugin" in error_msg:
|
||||
raise ValueError(f"认证插件问题: {error_msg},请尝试修改用户认证方式为mysql_native_password")
|
||||
raise MySQLAuthPluginError(f"认证插件问题: {error_msg},请尝试修改用户认证方式为mysql_native_password")
|
||||
else:
|
||||
raise ConnectionError(f"数据库连接失败: {error_msg}")
|
||||
finally:
|
||||
if connection and connection.is_connected():
|
||||
connection.close()
|
||||
logger.debug("数据库连接已关闭")
|
||||
raise MySQLConnectionError(f"数据库连接失败: {error_msg}")
|
||||
except Exception as e:
|
||||
logger.error(f"获取数据库连接时发生未预期错误: {str(e)}")
|
||||
raise MySQLConnectionError(f"获取数据库连接失败: {str(e)}")
|
||||
|
||||
async def execute_query(connection, query: str, params: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]:
|
||||
async def close_all_pools():
|
||||
"""关闭所有连接池"""
|
||||
if hasattr(_pools, 'pools'):
|
||||
for loop_id, pool in list(_pools.pools.items()):
|
||||
if not pool.closed:
|
||||
pool.close()
|
||||
await pool.wait_closed()
|
||||
logger.info(f"连接池已关闭 (事件循环ID: {loop_id})")
|
||||
_pools.pools = {}
|
||||
|
||||
@asynccontextmanager
|
||||
async def transaction(connection):
|
||||
"""
|
||||
事务上下文管理器
|
||||
|
||||
用法示例:
|
||||
async with get_db_connection() as conn:
|
||||
async with transaction(conn):
|
||||
await execute_query(conn, "INSERT INTO...")
|
||||
await execute_query(conn, "UPDATE...")
|
||||
|
||||
Args:
|
||||
connection: 数据库连接
|
||||
|
||||
Yields:
|
||||
connection: 事务中的数据库连接
|
||||
"""
|
||||
try:
|
||||
# 开始事务
|
||||
await connection.begin()
|
||||
logger.debug("事务已开始")
|
||||
yield connection
|
||||
# 提交事务
|
||||
await connection.commit()
|
||||
logger.debug("事务已提交")
|
||||
except Exception as e:
|
||||
# 回滚事务
|
||||
await connection.rollback()
|
||||
logger.error(f"事务执行失败,已回滚: {str(e)}")
|
||||
raise
|
||||
|
||||
def normalize_result(result_rows):
|
||||
"""
|
||||
将 DictRow 对象转换为普通字典
|
||||
|
||||
Args:
|
||||
result_rows: 查询结果行列表
|
||||
|
||||
Returns:
|
||||
包含普通字典的列表
|
||||
"""
|
||||
if not result_rows:
|
||||
return []
|
||||
|
||||
return [dict(row) for row in result_rows]
|
||||
|
||||
async def execute_query(connection, query: str, params: Optional[Dict[str, Any]] = None,
|
||||
batch_size: int = 1000, stream_results: bool = False) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
在给定的数据库连接上执行查询
|
||||
|
||||
@ -72,6 +311,8 @@ async def execute_query(connection, query: str, params: Optional[Dict[str, Any]]
|
||||
connection: 数据库连接
|
||||
query: SQL查询语句
|
||||
params: 查询参数 (可选)
|
||||
batch_size: 批处理大小,控制每次从游标获取的记录数量 (仅当stream_results=True时有效)
|
||||
stream_results: 是否使用流式处理获取大型结果集
|
||||
|
||||
Returns:
|
||||
查询结果列表,如果是修改操作则返回影响的行数
|
||||
@ -81,78 +322,211 @@ async def execute_query(connection, query: str, params: Optional[Dict[str, Any]]
|
||||
ValueError: 当查询执行失败时
|
||||
"""
|
||||
cursor = None
|
||||
operation = None # 初始化操作类型变量
|
||||
parsed_sql = None # 初始化SQL解析结果
|
||||
start_time = time.time() # 记录查询开始时间
|
||||
|
||||
try:
|
||||
# 安全检查
|
||||
if not await sql_interceptor.check_operation(query):
|
||||
raise SecurityException("操作被安全机制拒绝")
|
||||
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
# 创建异步游标,支持字典结果
|
||||
cursor = await connection.cursor(aiomysql.DictCursor)
|
||||
|
||||
# 执行查询
|
||||
# 执行查询 - 异步执行
|
||||
if params:
|
||||
cursor.execute(query, params)
|
||||
# 检查参数类型并转换为适合aiomysql的格式
|
||||
if isinstance(params, dict):
|
||||
# 构建使用%(key)s格式的查询
|
||||
await cursor.execute(query, params)
|
||||
else:
|
||||
await cursor.execute(query, params)
|
||||
else:
|
||||
cursor.execute(query)
|
||||
await cursor.execute(query)
|
||||
|
||||
# 获取操作类型
|
||||
operation = query.strip().split()[0].upper()
|
||||
# 解析SQL语句获取操作类型
|
||||
parsed_sql = SQLParser.parse_query(query)
|
||||
operation = parsed_sql['operation_type']
|
||||
|
||||
# 对于修改操作,提交事务并返回影响的行数
|
||||
if operation in {'UPDATE', 'DELETE', 'INSERT'}:
|
||||
if parsed_sql['category'] == 'DML' and operation in {'UPDATE', 'DELETE', 'INSERT'}:
|
||||
affected_rows = cursor.rowcount
|
||||
# 提交事务,确保更改被保存
|
||||
connection.commit()
|
||||
await connection.commit()
|
||||
logger.debug(f"修改操作 {operation} 影响了 {affected_rows} 行数据")
|
||||
|
||||
# 记录查询执行时间
|
||||
execution_time = time.time() - start_time
|
||||
_log_query_performance(query, execution_time, operation)
|
||||
|
||||
return [{'affected_rows': affected_rows}]
|
||||
|
||||
# 处理元数据查询操作
|
||||
if operation in sql_analyzer.metadata_operations:
|
||||
# 获取结果集
|
||||
results = cursor.fetchall()
|
||||
if parsed_sql['category'] == 'METADATA':
|
||||
# 元数据查询通常结果较小,直接获取所有结果
|
||||
results = await cursor.fetchall()
|
||||
|
||||
# 没有结果时返回空列表但添加元信息
|
||||
if not results:
|
||||
logger.debug(f"元数据查询 {operation} 没有返回结果")
|
||||
# 记录查询执行时间
|
||||
execution_time = time.time() - start_time
|
||||
_log_query_performance(query, execution_time, operation)
|
||||
return [{'metadata_operation': operation, 'result_count': 0}]
|
||||
|
||||
# 优化结果格式 - 为元数据结果添加额外信息
|
||||
metadata_results = []
|
||||
for row in results:
|
||||
# 对某些特定元数据查询进行特殊处理
|
||||
if operation == 'SHOW' and 'Table' in row:
|
||||
# SHOW TABLES 结果增强
|
||||
row['table_name'] = row['Table']
|
||||
elif operation in {'DESC', 'DESCRIBE'} and 'Field' in row:
|
||||
# DESC/DESCRIBE 表结构结果增强
|
||||
row['column_name'] = row['Field']
|
||||
row['data_type'] = row['Type']
|
||||
# 将行结果转为普通字典,而不是DictCursor的特殊对象
|
||||
row_dict = dict(row)
|
||||
|
||||
metadata_results.append(row)
|
||||
# 对某些特定元数据查询进行特殊处理
|
||||
if operation == 'SHOW' and 'Table' in row_dict:
|
||||
# SHOW TABLES 结果增强
|
||||
row_dict['table_name'] = row_dict['Table']
|
||||
elif operation in {'DESC', 'DESCRIBE'} and 'Field' in row_dict:
|
||||
# DESC/DESCRIBE 表结构结果增强
|
||||
row_dict['column_name'] = row_dict['Field']
|
||||
row_dict['data_type'] = row_dict['Type']
|
||||
|
||||
metadata_results.append(row_dict)
|
||||
|
||||
logger.debug(f"元数据查询 {operation} 返回 {len(metadata_results)} 条结果")
|
||||
|
||||
# 记录查询执行时间
|
||||
execution_time = time.time() - start_time
|
||||
_log_query_performance(query, execution_time, operation)
|
||||
|
||||
return metadata_results
|
||||
|
||||
# 对于查询操作,返回结果集
|
||||
results = cursor.fetchall()
|
||||
logger.debug(f"查询返回 {len(results)} 条结果")
|
||||
return results
|
||||
# 对于普通查询操作,根据stream_results参数决定结果获取方式
|
||||
if stream_results:
|
||||
# 流式处理大型结果集 - 分批获取
|
||||
all_results = []
|
||||
total_fetched = 0
|
||||
|
||||
# 分批次获取结果
|
||||
while True:
|
||||
batch = await cursor.fetchmany(batch_size)
|
||||
if not batch:
|
||||
break
|
||||
|
||||
# 使用工具函数将DictRow对象转换为普通字典
|
||||
dict_batch = normalize_result(batch)
|
||||
all_results.extend(dict_batch)
|
||||
|
||||
total_fetched += len(batch)
|
||||
logger.debug(f"已获取 {total_fetched} 条记录")
|
||||
|
||||
# 检查是否还有剩余结果
|
||||
if len(batch) < batch_size:
|
||||
break
|
||||
|
||||
logger.debug(f"流式查询总共返回 {len(all_results)} 条结果")
|
||||
|
||||
# 记录查询执行时间
|
||||
execution_time = time.time() - start_time
|
||||
_log_query_performance(query, execution_time, operation)
|
||||
|
||||
return all_results
|
||||
else:
|
||||
# 传统方式 - 一次性获取所有结果
|
||||
results = await cursor.fetchall()
|
||||
|
||||
# 使用工具函数将DictRow对象转换为普通字典
|
||||
dict_results = normalize_result(results)
|
||||
|
||||
logger.debug(f"查询返回 {len(dict_results)} 条结果")
|
||||
|
||||
# 记录查询执行时间
|
||||
execution_time = time.time() - start_time
|
||||
_log_query_performance(query, execution_time, operation)
|
||||
|
||||
return dict_results
|
||||
|
||||
except SecurityException as security_err:
|
||||
logger.error(f"安全检查失败: {str(security_err)}")
|
||||
raise
|
||||
except mysql.connector.Error as query_err:
|
||||
except aiomysql.Error as query_err:
|
||||
# 如果发生错误,进行回滚
|
||||
if operation and operation in {'UPDATE', 'DELETE', 'INSERT'}: # 确保operation已定义
|
||||
if parsed_sql and parsed_sql['operation_type'] in {'UPDATE', 'DELETE', 'INSERT'}:
|
||||
try:
|
||||
connection.rollback()
|
||||
await connection.rollback()
|
||||
logger.debug("事务已回滚")
|
||||
except:
|
||||
pass
|
||||
except Exception as rollback_err:
|
||||
logger.error(f"回滚事务失败: {str(rollback_err)}")
|
||||
logger.error(f"查询执行失败: {str(query_err)}")
|
||||
raise ValueError(f"查询执行失败: {str(query_err)}")
|
||||
finally:
|
||||
# 确保游标正确关闭
|
||||
if cursor:
|
||||
cursor.close()
|
||||
await cursor.close()
|
||||
logger.debug("数据库游标已关闭")
|
||||
|
||||
def _log_query_performance(query: str, execution_time: float, operation_type: str = ""):
|
||||
"""
|
||||
记录查询性能日志
|
||||
|
||||
Args:
|
||||
query: SQL查询语句
|
||||
execution_time: 执行时间(秒)
|
||||
operation_type: 操作类型
|
||||
"""
|
||||
# 截断长查询以避免日志过大
|
||||
truncated_query = query[:150] + '...' if len(query) > 150 else query
|
||||
|
||||
# 根据执行时间确定日志级别
|
||||
if execution_time >= 1.0: # 超过1秒的查询记录为警告
|
||||
logger.warning(f"慢查询 [{operation_type}]: {truncated_query} 执行时间: {execution_time:.4f}秒")
|
||||
elif execution_time >= 0.5: # 超过0.5秒的查询记录为提醒
|
||||
logger.info(f"较慢查询 [{operation_type}]: {truncated_query} 执行时间: {execution_time:.4f}秒")
|
||||
else:
|
||||
logger.debug(f"查询 [{operation_type}] 执行时间: {execution_time:.4f}秒")
|
||||
|
||||
async def execute_transaction_queries(connection, queries: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
在单个事务中执行多个查询
|
||||
|
||||
Args:
|
||||
connection: 数据库连接
|
||||
queries: 查询列表,每个查询是一个包含 'query' 和可选 'params' 的字典
|
||||
|
||||
Returns:
|
||||
所有查询的结果列表
|
||||
|
||||
Raises:
|
||||
Exception: 当任何查询执行失败时,整个事务将回滚
|
||||
"""
|
||||
results = []
|
||||
|
||||
async with transaction(connection):
|
||||
for query_item in queries:
|
||||
query = query_item['query']
|
||||
params = query_item.get('params')
|
||||
|
||||
# 执行单个查询
|
||||
result = await execute_query(connection, query, params)
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
|
||||
async def get_current_database() -> str:
|
||||
"""
|
||||
获取当前连接的数据库名称
|
||||
|
||||
Returns:
|
||||
当前数据库名称,如果未设置则返回空字符串
|
||||
"""
|
||||
async with get_db_connection(require_database=False) as connection:
|
||||
try:
|
||||
cursor = await connection.cursor(aiomysql.DictCursor)
|
||||
await cursor.execute("SELECT DATABASE() as db")
|
||||
result = await cursor.fetchone()
|
||||
await cursor.close()
|
||||
|
||||
if result and 'db' in result:
|
||||
return result['db'] or ""
|
||||
return ""
|
||||
except Exception as e:
|
||||
logger.error(f"获取当前数据库名称失败: {str(e)}")
|
||||
return ""
|
||||
@ -1,8 +1,9 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Dict
|
||||
|
||||
from ..config import SecurityConfig, SQLConfig
|
||||
from .sql_analyzer import SQLOperationType, SQLRiskLevel
|
||||
from .sql_parser import SQLParser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -15,8 +16,8 @@ class SQLInterceptor:
|
||||
|
||||
def __init__(self, analyzer: SQLOperationType):
|
||||
self.analyzer = analyzer
|
||||
# 设置最大SQL长度限制(默认1000个字符)
|
||||
self.max_sql_length = 1000
|
||||
# 设置最大SQL长度限制
|
||||
self.max_sql_length = SecurityConfig.MAX_SQL_LENGTH
|
||||
|
||||
async def check_operation(self, sql_query: str) -> bool:
|
||||
"""
|
||||
@ -40,19 +41,16 @@ class SQLInterceptor:
|
||||
if len(sql_query) > self.max_sql_length:
|
||||
raise SecurityException(f"SQL语句长度({len(sql_query)})超出限制({self.max_sql_length})")
|
||||
|
||||
# 使用SQLParser解析SQL
|
||||
parsed_sql = SQLParser.parse_query(sql_query)
|
||||
|
||||
# 检查SQL是否有效
|
||||
sql_parts = sql_query.strip().split()
|
||||
if not sql_parts:
|
||||
if not parsed_sql['is_valid']:
|
||||
raise SecurityException("SQL语句格式无效")
|
||||
|
||||
operation = sql_parts[0].upper()
|
||||
operation = parsed_sql['operation_type']
|
||||
# 更新支持的操作类型列表,包括元数据操作
|
||||
supported_operations = {
|
||||
'SELECT', 'INSERT', 'UPDATE', 'DELETE',
|
||||
'CREATE', 'ALTER', 'DROP', 'TRUNCATE', 'MERGE',
|
||||
'SHOW', 'DESC', 'DESCRIBE', 'EXPLAIN', 'HELP',
|
||||
'ANALYZE', 'CHECK', 'CHECKSUM', 'OPTIMIZE'
|
||||
}
|
||||
supported_operations = SQLConfig.DDL_OPERATIONS | SQLConfig.DML_OPERATIONS | SQLConfig.METADATA_OPERATIONS
|
||||
|
||||
if operation not in supported_operations:
|
||||
raise SecurityException(f"不支持的SQL操作: {operation}")
|
||||
@ -74,13 +72,11 @@ class SQLInterceptor:
|
||||
)
|
||||
|
||||
# 确定操作类型(DDL, DML 或 元数据)
|
||||
operation_category = "元数据操作" if operation in self.analyzer.metadata_operations else (
|
||||
"DDL操作" if operation in self.analyzer.ddl_operations else "DML操作"
|
||||
)
|
||||
operation_category = parsed_sql['category']
|
||||
|
||||
# 记录详细日志
|
||||
logger.info(
|
||||
f"SQL{operation_category}检查通过 - "
|
||||
f"SQL{operation_category}操作检查通过 - "
|
||||
f"操作: {risk_analysis['operation']}, "
|
||||
f"风险等级: {risk_analysis['risk_level'].name}, "
|
||||
f"影响表: {', '.join(risk_analysis['affected_tables'])}"
|
||||
|
||||
@ -1,16 +1,17 @@
|
||||
import os
|
||||
import logging
|
||||
from typing import Tuple
|
||||
|
||||
from ..config import SecurityConfig, SQLConfig
|
||||
from .sql_parser import SQLParser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class QueryLimiter:
|
||||
"""查询安全检查器"""
|
||||
|
||||
def __init__(self):
|
||||
# 解析启用状态(默认启用)
|
||||
enable_check = os.getenv('ENABLE_QUERY_CHECK', 'true')
|
||||
self.enable_check = str(enable_check).lower() not in {'false', '0', 'no', 'off'}
|
||||
# 从配置中获取启用状态
|
||||
self.enable_check = SecurityConfig.ENABLE_QUERY_CHECK
|
||||
|
||||
def check_query(self, sql_query: str) -> Tuple[bool, str]:
|
||||
"""
|
||||
@ -25,44 +26,14 @@ class QueryLimiter:
|
||||
if not self.enable_check:
|
||||
return True, ""
|
||||
|
||||
sql_query = sql_query.strip().upper()
|
||||
operation_type = self._get_operation_type(sql_query)
|
||||
# 使用SQLParser解析SQL
|
||||
parsed_sql = SQLParser.parse_query(sql_query)
|
||||
operation_type = parsed_sql['operation_type']
|
||||
|
||||
# 检查是否为无 WHERE 子句的更新/删除操作
|
||||
if operation_type in {'UPDATE', 'DELETE'} and 'WHERE' not in sql_query:
|
||||
if operation_type in {'UPDATE', 'DELETE'} and not parsed_sql['has_where']:
|
||||
error_msg = f"{operation_type}操作必须包含WHERE子句"
|
||||
logger.warning(f"查询被限制: {error_msg}")
|
||||
return False, error_msg
|
||||
|
||||
return True, ""
|
||||
|
||||
def _get_operation_type(self, sql_query: str) -> str:
|
||||
"""获取SQL操作类型"""
|
||||
if not sql_query:
|
||||
return ""
|
||||
words = sql_query.split()
|
||||
if not words:
|
||||
return ""
|
||||
return words[0].upper()
|
||||
|
||||
def _parse_int_env(self, env_name: str, default: int) -> int:
|
||||
"""解析整数类型的环境变量"""
|
||||
try:
|
||||
return int(os.getenv(env_name, str(default)))
|
||||
except (ValueError, TypeError):
|
||||
return default
|
||||
|
||||
def update_limits(self, new_limits: dict):
|
||||
"""
|
||||
更新限制阈值
|
||||
|
||||
Args:
|
||||
new_limits: 新的限制值字典
|
||||
"""
|
||||
for operation, limit in new_limits.items():
|
||||
if operation in self.max_limits:
|
||||
try:
|
||||
self.max_limits[operation] = int(limit)
|
||||
logger.info(f"更新{operation}操作的限制为: {limit}")
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(f"无效的限制值: {operation}={limit}")
|
||||
@ -1,81 +1,31 @@
|
||||
import re
|
||||
import os
|
||||
from enum import IntEnum, Enum
|
||||
import logging
|
||||
from typing import Set, List
|
||||
from typing import List, Set
|
||||
|
||||
from ..config import SQLRiskLevel, EnvironmentType, SecurityConfig, SQLConfig
|
||||
from .sql_parser import SQLParser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class SQLRiskLevel(IntEnum):
|
||||
"""SQL操作风险等级"""
|
||||
LOW = 1 # 查询操作(SELECT)
|
||||
MEDIUM = 2 # 基本数据修改(INSERT,有WHERE的UPDATE/DELETE)
|
||||
HIGH = 3 # 结构变更(CREATE/ALTER)和无WHERE的数据修改
|
||||
CRITICAL = 4 # 危险操作(DROP/TRUNCATE等)
|
||||
|
||||
class EnvironmentType(Enum):
|
||||
"""环境类型"""
|
||||
DEVELOPMENT = 'development'
|
||||
PRODUCTION = 'production'
|
||||
|
||||
class SQLOperationType:
|
||||
"""SQL操作类型分析器"""
|
||||
|
||||
def __init__(self):
|
||||
# 环境类型处理
|
||||
env_type_str = os.getenv('ENV_TYPE', 'development').lower()
|
||||
try:
|
||||
self.env_type = EnvironmentType(env_type_str)
|
||||
except ValueError:
|
||||
logger.warning(f"无效的环境类型: {env_type_str},使用默认值: development")
|
||||
self.env_type = EnvironmentType.DEVELOPMENT
|
||||
# 环境类型从配置读取
|
||||
self.env_type = SecurityConfig.ENV_TYPE
|
||||
|
||||
# 基础操作集合
|
||||
self.ddl_operations = {
|
||||
'CREATE', 'ALTER', 'DROP', 'TRUNCATE', 'RENAME'
|
||||
}
|
||||
self.dml_operations = {
|
||||
'SELECT', 'INSERT', 'UPDATE', 'DELETE', 'MERGE'
|
||||
}
|
||||
# 操作类型集合从配置读取
|
||||
self.ddl_operations = SQLConfig.DDL_OPERATIONS
|
||||
self.dml_operations = SQLConfig.DML_OPERATIONS
|
||||
self.metadata_operations = SQLConfig.METADATA_OPERATIONS
|
||||
|
||||
# 添加元数据操作集合
|
||||
self.metadata_operations = {
|
||||
'SHOW', 'DESC', 'DESCRIBE', 'EXPLAIN', 'HELP',
|
||||
'ANALYZE', 'CHECK', 'CHECKSUM', 'OPTIMIZE'
|
||||
}
|
||||
|
||||
# 风险等级配置
|
||||
self.allowed_risk_levels = self._parse_risk_levels()
|
||||
self.blocked_patterns = self._parse_blocked_patterns('BLOCKED_PATTERNS')
|
||||
|
||||
# 生产环境特殊处理:如果没有明确配置风险等级,则只允许LOW风险操作
|
||||
if self.env_type == EnvironmentType.PRODUCTION and not os.getenv('ALLOWED_RISK_LEVELS'):
|
||||
self.allowed_risk_levels = {SQLRiskLevel.LOW}
|
||||
# 风险等级配置从配置读取
|
||||
self.allowed_risk_levels = SecurityConfig.ALLOWED_RISK_LEVELS
|
||||
self.blocked_patterns = SecurityConfig.BLOCKED_PATTERNS
|
||||
|
||||
logger.info(f"SQL分析器初始化 - 环境: {self.env_type.value}")
|
||||
logger.info(f"允许的风险等级: {[level.name for level in self.allowed_risk_levels]}")
|
||||
|
||||
def _parse_risk_levels(self) -> Set[SQLRiskLevel]:
|
||||
"""解析允许的风险等级"""
|
||||
allowed_levels_str = os.getenv('ALLOWED_RISK_LEVELS', 'LOW,MEDIUM')
|
||||
allowed_levels = set()
|
||||
|
||||
logger.info(f"从环境变量读取到的风险等级设置: '{allowed_levels_str}'")
|
||||
|
||||
for level_str in allowed_levels_str.upper().split(','):
|
||||
level_str = level_str.strip()
|
||||
try:
|
||||
allowed_levels.add(SQLRiskLevel[level_str])
|
||||
except KeyError:
|
||||
logger.warning(f"未知的风险等级配置: {level_str}")
|
||||
|
||||
return allowed_levels
|
||||
|
||||
def _parse_blocked_patterns(self, env_var: str) -> List[str]:
|
||||
"""解析禁止的操作模式"""
|
||||
patterns = os.getenv(env_var, '').split(',')
|
||||
return [p.strip() for p in patterns if p.strip()]
|
||||
|
||||
def analyze_risk(self, sql_query: str) -> dict:
|
||||
"""
|
||||
分析SQL查询的风险级别和影响范围
|
||||
@ -105,54 +55,76 @@ class SQLOperationType:
|
||||
'is_allowed': False
|
||||
}
|
||||
|
||||
operation = sql_query.split()[0].upper()
|
||||
# 使用SQLParser解析SQL
|
||||
parsed_sql = SQLParser.parse_query(sql_query)
|
||||
operation = parsed_sql['operation_type']
|
||||
|
||||
# 基本风险分析
|
||||
risk_analysis = {
|
||||
'operation': operation,
|
||||
'operation_type': 'DDL' if operation in self.ddl_operations else 'DML',
|
||||
'operation_type': parsed_sql['category'],
|
||||
'is_dangerous': self._check_dangerous_patterns(sql_query),
|
||||
'affected_tables': self._get_affected_tables(sql_query),
|
||||
'estimated_impact': self._estimate_impact(sql_query)
|
||||
'affected_tables': parsed_sql['tables'],
|
||||
'estimated_impact': self._estimate_impact(sql_query, parsed_sql)
|
||||
}
|
||||
|
||||
# 计算风险等级
|
||||
risk_level = self._calculate_risk_level(sql_query, operation, risk_analysis['is_dangerous'])
|
||||
risk_level = self._calculate_risk_level(sql_query, operation, risk_analysis['is_dangerous'], parsed_sql['has_where'])
|
||||
risk_analysis['risk_level'] = risk_level
|
||||
risk_analysis['is_allowed'] = risk_level in self.allowed_risk_levels
|
||||
|
||||
return risk_analysis
|
||||
|
||||
def _calculate_risk_level(self, sql_query: str, operation: str, is_dangerous: bool) -> SQLRiskLevel:
|
||||
def _calculate_risk_level(self, sql_query: str, operation: str, is_dangerous: bool, has_where: bool) -> SQLRiskLevel:
|
||||
"""
|
||||
计算操作风险等级
|
||||
|
||||
规则:
|
||||
1. 危险操作(匹配危险模式)=> CRITICAL
|
||||
2. DDL操作:
|
||||
2. 生产环境非SELECT操作 => CRITICAL
|
||||
3. DDL操作:
|
||||
- CREATE/ALTER => HIGH
|
||||
- DROP/TRUNCATE => CRITICAL
|
||||
3. DML操作:
|
||||
4. DML操作:
|
||||
- SELECT => LOW
|
||||
- INSERT => MEDIUM
|
||||
- UPDATE/DELETE(有WHERE)=> MEDIUM
|
||||
- UPDATE(无WHERE)=> HIGH
|
||||
- DELETE(无WHERE)=> CRITICAL
|
||||
4. 元数据操作:
|
||||
5. 元数据操作:
|
||||
- SHOW/DESC/DESCRIBE等 => LOW
|
||||
6. 多语句SQL通常被认为是高风险的
|
||||
"""
|
||||
# 解析SQL获取额外信息
|
||||
parsed_sql = SQLParser.parse_query(sql_query)
|
||||
|
||||
# 危险操作
|
||||
if is_dangerous:
|
||||
return SQLRiskLevel.CRITICAL
|
||||
|
||||
# 生产环境特别规则
|
||||
if self.env_type == EnvironmentType.PRODUCTION:
|
||||
# 生产环境中只允许SELECT和元数据操作
|
||||
if operation != 'SELECT' and parsed_sql['category'] != 'METADATA':
|
||||
return SQLRiskLevel.CRITICAL
|
||||
|
||||
# 生产环境中的多语句SQL视为高风险
|
||||
if parsed_sql.get('multi_statement', False):
|
||||
return SQLRiskLevel.HIGH
|
||||
|
||||
# 多语句SQL在任何环境中都是更高风险的
|
||||
if parsed_sql.get('multi_statement', False):
|
||||
# 至少中等风险,如果包含DDL则为高风险或严重风险
|
||||
if parsed_sql['category'] == 'DDL':
|
||||
return SQLRiskLevel.HIGH
|
||||
elif parsed_sql['category'] == 'DML' and operation not in {'SELECT'}:
|
||||
return SQLRiskLevel.HIGH
|
||||
return SQLRiskLevel.MEDIUM
|
||||
|
||||
# 元数据操作
|
||||
if operation in self.metadata_operations:
|
||||
return SQLRiskLevel.LOW # 元数据查询视为低风险操作
|
||||
|
||||
# 生产环境中非SELECT操作
|
||||
if self.env_type == EnvironmentType.PRODUCTION and operation != 'SELECT':
|
||||
return SQLRiskLevel.CRITICAL
|
||||
|
||||
# DDL操作
|
||||
if operation in self.ddl_operations:
|
||||
if operation in {'DROP', 'TRUNCATE'}:
|
||||
@ -161,61 +133,58 @@ class SQLOperationType:
|
||||
|
||||
# DML操作
|
||||
if operation == 'SELECT':
|
||||
# 对于不带LIMIT的大型SELECT, 风险可能提高
|
||||
if not parsed_sql['has_limit'] and self.env_type == EnvironmentType.PRODUCTION:
|
||||
return SQLRiskLevel.MEDIUM
|
||||
return SQLRiskLevel.LOW
|
||||
elif operation == 'INSERT':
|
||||
return SQLRiskLevel.MEDIUM
|
||||
elif operation == 'UPDATE':
|
||||
return SQLRiskLevel.HIGH if 'WHERE' not in sql_query.upper() else SQLRiskLevel.MEDIUM
|
||||
return SQLRiskLevel.HIGH if not has_where else SQLRiskLevel.MEDIUM
|
||||
elif operation == 'DELETE':
|
||||
# 无WHERE条件的DELETE操作视为CRITICAL风险
|
||||
return SQLRiskLevel.CRITICAL if 'WHERE' not in sql_query.upper() else SQLRiskLevel.MEDIUM
|
||||
return SQLRiskLevel.CRITICAL if not has_where else SQLRiskLevel.MEDIUM
|
||||
|
||||
# 默认情况
|
||||
return SQLRiskLevel.HIGH
|
||||
|
||||
def _check_dangerous_patterns(self, sql_query: str) -> bool:
|
||||
"""检查是否匹配危险操作模式"""
|
||||
sql_upper = sql_query.upper()
|
||||
|
||||
# 生产环境额外的安全检查
|
||||
if self.env_type == EnvironmentType.PRODUCTION:
|
||||
# 生产环境中禁止所有非SELECT操作
|
||||
if sql_upper.split()[0] != 'SELECT':
|
||||
return True
|
||||
# 解析SQL以获取更多信息
|
||||
parsed_sql = SQLParser.parse_query(sql_query)
|
||||
|
||||
# 检查是否为多语句SQL - 大多数情况下使用多语句SQL可能是危险的
|
||||
if parsed_sql.get('multi_statement', False) and self.env_type == EnvironmentType.PRODUCTION:
|
||||
# 生产环境中的多语句SQL视为危险
|
||||
return True
|
||||
|
||||
# 对敏感关键字的检查
|
||||
for pattern in self.blocked_patterns:
|
||||
if re.search(pattern, sql_upper, re.IGNORECASE):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _get_affected_tables(self, sql_query: str) -> list:
|
||||
"""获取受影响的表名列表"""
|
||||
words = sql_query.upper().split()
|
||||
tables = []
|
||||
|
||||
for i, word in enumerate(words):
|
||||
if word in {'FROM', 'JOIN', 'UPDATE', 'INTO', 'TABLE'}:
|
||||
if i + 1 < len(words):
|
||||
table = words[i + 1].strip('`;')
|
||||
if table not in {'SELECT', 'WHERE', 'SET'}:
|
||||
tables.append(table)
|
||||
|
||||
return list(set(tables))
|
||||
|
||||
def _estimate_impact(self, sql_query: str) -> dict:
|
||||
def _estimate_impact(self, sql_query: str, parsed_sql: dict) -> dict:
|
||||
"""
|
||||
估算查询影响范围
|
||||
|
||||
Args:
|
||||
sql_query: 原始SQL查询
|
||||
parsed_sql: 解析后的SQL信息
|
||||
|
||||
Returns:
|
||||
dict: 包含预估影响的字典
|
||||
"""
|
||||
operation = sql_query.split()[0].upper()
|
||||
operation = parsed_sql['operation_type']
|
||||
|
||||
impact = {
|
||||
'operation': operation,
|
||||
'estimated_rows': 0,
|
||||
'needs_where': operation in {'UPDATE', 'DELETE'},
|
||||
'has_where': 'WHERE' in sql_query.upper()
|
||||
'has_where': parsed_sql['has_where']
|
||||
}
|
||||
|
||||
# 根据环境类型调整估算
|
||||
|
||||
358
src/security/sql_parser.py
Normal file
358
src/security/sql_parser.py
Normal file
@ -0,0 +1,358 @@
|
||||
import sqlparse
|
||||
import re
|
||||
import logging
|
||||
from typing import List, Set, Tuple, Optional, Dict
|
||||
|
||||
from ..config import SQLConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class SQLParser:
|
||||
"""
|
||||
SQL解析器 - 使用sqlparse库提供更精确的SQL解析功能
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def parse_query(sql_query: str) -> Dict:
|
||||
"""
|
||||
解析SQL查询,返回解析结果
|
||||
|
||||
Args:
|
||||
sql_query: SQL查询语句
|
||||
|
||||
Returns:
|
||||
Dict: 包含解析结果的字典
|
||||
"""
|
||||
if not sql_query or not sql_query.strip():
|
||||
return {
|
||||
'operation_type': '',
|
||||
'tables': [],
|
||||
'has_where': False,
|
||||
'has_limit': False,
|
||||
'is_valid': False,
|
||||
'normalized_query': '',
|
||||
'category': 'UNKNOWN',
|
||||
'multi_statement': False,
|
||||
'statement_count': 0
|
||||
}
|
||||
|
||||
try:
|
||||
# 标准化和格式化SQL
|
||||
formatted_sql = SQLParser._format_sql(sql_query)
|
||||
# 解析SQL语句 - 可能有多个语句
|
||||
parsed = sqlparse.parse(formatted_sql)
|
||||
|
||||
# 检查是否有多个语句
|
||||
is_multi_statement = len(parsed) > 1
|
||||
statement_count = len(parsed)
|
||||
|
||||
if not parsed:
|
||||
return {
|
||||
'operation_type': '',
|
||||
'tables': [],
|
||||
'has_where': False,
|
||||
'has_limit': False,
|
||||
'is_valid': False,
|
||||
'normalized_query': formatted_sql,
|
||||
'category': 'UNKNOWN',
|
||||
'multi_statement': False,
|
||||
'statement_count': 0
|
||||
}
|
||||
|
||||
# 默认分析第一个语句,但记录多语句信息
|
||||
stmt = parsed[0]
|
||||
|
||||
# 获取操作类型
|
||||
operation_type = SQLParser._get_operation_type(stmt)
|
||||
|
||||
# 确定操作类别
|
||||
category = SQLParser._get_operation_category(operation_type)
|
||||
|
||||
# 提取表名 - 汇总所有语句中的表名
|
||||
tables = set()
|
||||
has_where = False
|
||||
has_limit = False
|
||||
|
||||
for statement in parsed:
|
||||
# 将各语句涉及的表合并
|
||||
tables.update(SQLParser._extract_tables(statement))
|
||||
|
||||
# 检查任一语句是否有WHERE子句
|
||||
if SQLParser._has_where_clause(statement):
|
||||
has_where = True
|
||||
|
||||
# 检查任一语句是否有LIMIT子句
|
||||
if SQLParser._has_limit_clause(statement):
|
||||
has_limit = True
|
||||
|
||||
# 对于多语句,获取最高风险的操作类型
|
||||
if is_multi_statement and len(parsed) > 1:
|
||||
operations = []
|
||||
categories = []
|
||||
for statement in parsed:
|
||||
op = SQLParser._get_operation_type(statement)
|
||||
operations.append(op)
|
||||
categories.append(SQLParser._get_operation_category(op))
|
||||
|
||||
# 风险优先级: DDL > DML > METADATA
|
||||
if 'DDL' in categories:
|
||||
category = 'DDL'
|
||||
# 在DDL操作中找出优先级最高的
|
||||
# DROP/TRUNCATE > ALTER > CREATE
|
||||
if 'DROP' in operations or 'TRUNCATE' in operations:
|
||||
operation_type = 'DROP' if 'DROP' in operations else 'TRUNCATE'
|
||||
elif 'ALTER' in operations:
|
||||
operation_type = 'ALTER'
|
||||
elif 'CREATE' in operations:
|
||||
operation_type = 'CREATE'
|
||||
elif 'DML' in categories:
|
||||
category = 'DML'
|
||||
# 在DML操作中找出优先级最高的
|
||||
# DELETE > UPDATE > INSERT > SELECT
|
||||
if 'DELETE' in operations:
|
||||
operation_type = 'DELETE'
|
||||
elif 'UPDATE' in operations:
|
||||
operation_type = 'UPDATE'
|
||||
elif 'INSERT' in operations:
|
||||
operation_type = 'INSERT'
|
||||
elif 'SELECT' in operations:
|
||||
operation_type = 'SELECT'
|
||||
|
||||
return {
|
||||
'operation_type': operation_type,
|
||||
'tables': list(tables),
|
||||
'has_where': has_where,
|
||||
'has_limit': has_limit,
|
||||
'is_valid': True,
|
||||
'normalized_query': formatted_sql,
|
||||
'category': category,
|
||||
'multi_statement': is_multi_statement,
|
||||
'statement_count': statement_count
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"SQL解析错误: {str(e)}")
|
||||
# 回退到简单的字符串解析
|
||||
result = SQLParser._fallback_parse(sql_query)
|
||||
# 添加多语句检测,简单检测分号
|
||||
result['multi_statement'] = ';' in sql_query.strip()
|
||||
result['statement_count'] = sql_query.count(';') + 1 if sql_query.strip() else 0
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _format_sql(sql_query: str) -> str:
|
||||
"""标准化SQL查询格式"""
|
||||
# 去除多余空白和注释
|
||||
return sqlparse.format(
|
||||
sql_query,
|
||||
strip_comments=True,
|
||||
reindent=True,
|
||||
keyword_case='upper'
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_operation_type(stmt: sqlparse.sql.Statement) -> str:
|
||||
"""获取SQL操作类型"""
|
||||
# 获取第一个token
|
||||
if stmt.tokens and stmt.tokens[0].ttype is sqlparse.tokens.DML:
|
||||
return stmt.tokens[0].value.upper()
|
||||
elif stmt.tokens and stmt.tokens[0].ttype is sqlparse.tokens.DDL:
|
||||
return stmt.tokens[0].value.upper()
|
||||
elif stmt.tokens and stmt.tokens[0].ttype is sqlparse.tokens.Keyword:
|
||||
return stmt.tokens[0].value.upper()
|
||||
|
||||
# 如果无法确定,返回空字符串
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _get_operation_category(operation_type: str) -> str:
|
||||
"""确定操作类别(DDL、DML或元数据)"""
|
||||
if operation_type in SQLConfig.DDL_OPERATIONS:
|
||||
return 'DDL'
|
||||
elif operation_type in SQLConfig.DML_OPERATIONS:
|
||||
return 'DML'
|
||||
elif operation_type in SQLConfig.METADATA_OPERATIONS:
|
||||
return 'METADATA'
|
||||
else:
|
||||
return 'UNKNOWN'
|
||||
|
||||
@staticmethod
|
||||
def _extract_tables(stmt: sqlparse.sql.Statement) -> List[str]:
|
||||
"""从SQL语句中提取所有表名"""
|
||||
tables = []
|
||||
|
||||
# 根据操作类型处理表名提取
|
||||
operation_type = SQLParser._get_operation_type(stmt)
|
||||
|
||||
# 递归函数用于深入处理复杂的SQL结构
|
||||
def extract_from_token_list(token_list):
|
||||
local_tables = []
|
||||
in_from_clause = False
|
||||
in_join_clause = False
|
||||
|
||||
for token in token_list.tokens:
|
||||
# 检测FROM子句
|
||||
if token.ttype is sqlparse.tokens.Keyword and token.value.upper() == 'FROM':
|
||||
in_from_clause = True
|
||||
continue
|
||||
|
||||
# 检测JOIN子句
|
||||
if token.ttype is sqlparse.tokens.Keyword and 'JOIN' in token.value.upper():
|
||||
in_join_clause = True
|
||||
continue
|
||||
|
||||
# 在FROM或JOIN子句后提取表名
|
||||
if in_from_clause or in_join_clause:
|
||||
if isinstance(token, sqlparse.sql.Identifier):
|
||||
# 直接引用的表名
|
||||
if token.get_real_name():
|
||||
local_tables.append(token.get_real_name())
|
||||
elif isinstance(token, sqlparse.sql.IdentifierList):
|
||||
# 多个表,如FROM table1, table2
|
||||
for identifier in token.get_identifiers():
|
||||
if identifier.get_real_name():
|
||||
local_tables.append(identifier.get_real_name())
|
||||
elif isinstance(token, sqlparse.sql.Function):
|
||||
# 处理子查询中的函数,可能包含表
|
||||
local_tables.extend(extract_from_token_list(token))
|
||||
elif isinstance(token, sqlparse.sql.Parenthesis):
|
||||
# 可能是子查询
|
||||
if token.tokens and isinstance(token.tokens[1], sqlparse.sql.Statement):
|
||||
# 是子查询,递归解析
|
||||
local_tables.extend(SQLParser._extract_tables(token.tokens[1]))
|
||||
else:
|
||||
# 其他括号结构,递归处理
|
||||
local_tables.extend(extract_from_token_list(token))
|
||||
|
||||
# 重置标志以避免收集其他部分的标识符
|
||||
if token.ttype in (sqlparse.tokens.Keyword, sqlparse.tokens.Punctuation):
|
||||
in_from_clause = False
|
||||
in_join_clause = False
|
||||
|
||||
# 递归处理其他TokenList
|
||||
if isinstance(token, sqlparse.sql.TokenList) and not isinstance(token, sqlparse.sql.Identifier):
|
||||
local_tables.extend(extract_from_token_list(token))
|
||||
|
||||
return local_tables
|
||||
|
||||
# 特殊处理DML语句
|
||||
if operation_type == 'UPDATE':
|
||||
# UPDATE语句通常在第一个标识符中包含表名
|
||||
for i, token in enumerate(stmt.tokens):
|
||||
if token.ttype is sqlparse.tokens.DML and token.value.upper() == 'UPDATE':
|
||||
if i+1 < len(stmt.tokens):
|
||||
if isinstance(stmt.tokens[i+1], sqlparse.sql.Identifier):
|
||||
tables.append(stmt.tokens[i+1].get_real_name())
|
||||
elif isinstance(stmt.tokens[i+1], sqlparse.sql.IdentifierList):
|
||||
# 多表更新
|
||||
for identifier in stmt.tokens[i+1].get_identifiers():
|
||||
if identifier.get_real_name():
|
||||
tables.append(identifier.get_real_name())
|
||||
break
|
||||
elif operation_type == 'INSERT':
|
||||
# INSERT语句
|
||||
into_found = False
|
||||
for i, token in enumerate(stmt.tokens):
|
||||
if token.ttype is sqlparse.tokens.Keyword and token.value.upper() == 'INTO':
|
||||
into_found = True
|
||||
elif into_found and isinstance(token, sqlparse.sql.Identifier):
|
||||
tables.append(token.get_real_name())
|
||||
break
|
||||
elif into_found and isinstance(token, sqlparse.sql.Function):
|
||||
# 处理INSERT INTO table(...)
|
||||
if token.get_name():
|
||||
tables.append(token.get_name())
|
||||
break
|
||||
elif operation_type == 'DELETE':
|
||||
# DELETE FROM table
|
||||
from_found = False
|
||||
for i, token in enumerate(stmt.tokens):
|
||||
if token.ttype is sqlparse.tokens.Keyword and token.value.upper() == 'FROM':
|
||||
from_found = True
|
||||
elif from_found and isinstance(token, sqlparse.sql.Identifier):
|
||||
tables.append(token.get_real_name())
|
||||
break
|
||||
elif from_found and isinstance(token, sqlparse.sql.IdentifierList):
|
||||
for identifier in token.get_identifiers():
|
||||
if identifier.get_real_name():
|
||||
tables.append(identifier.get_real_name())
|
||||
break
|
||||
elif operation_type in {'CREATE', 'ALTER', 'DROP', 'TRUNCATE'}:
|
||||
# DDL语句
|
||||
table_found = False
|
||||
for i, token in enumerate(stmt.tokens):
|
||||
if token.ttype is sqlparse.tokens.Keyword and token.value.upper() == 'TABLE':
|
||||
table_found = True
|
||||
elif table_found and isinstance(token, sqlparse.sql.Identifier):
|
||||
tables.append(token.get_real_name())
|
||||
break
|
||||
else:
|
||||
# 对于其他语句,通过递归处理提取表名
|
||||
tables.extend(extract_from_token_list(stmt))
|
||||
|
||||
# 移除可能的重复项
|
||||
return list(set([table for table in tables if table]))
|
||||
|
||||
@staticmethod
|
||||
def _has_where_clause(stmt: sqlparse.sql.Statement) -> bool:
|
||||
"""检查SQL语句是否包含WHERE子句"""
|
||||
for token in stmt.tokens:
|
||||
if isinstance(token, sqlparse.sql.Where):
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _has_limit_clause(stmt: sqlparse.sql.Statement) -> bool:
|
||||
"""检查SQL语句是否包含LIMIT子句"""
|
||||
# LIMIT通常作为一个关键字出现
|
||||
for token in stmt.tokens:
|
||||
if token.ttype is sqlparse.tokens.Keyword and token.value.upper() == 'LIMIT':
|
||||
return True
|
||||
# 处理更复杂的语句结构
|
||||
elif isinstance(token, sqlparse.sql.TokenList):
|
||||
for subtoken in token.tokens:
|
||||
if subtoken.ttype is sqlparse.tokens.Keyword and subtoken.value.upper() == 'LIMIT':
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _fallback_parse(sql_query: str) -> Dict:
|
||||
"""当高级解析失败时,回退到基本字符串解析"""
|
||||
sql_upper = sql_query.strip().upper()
|
||||
parts = sql_upper.split()
|
||||
|
||||
operation_type = parts[0] if parts else ""
|
||||
|
||||
# 确定操作类别
|
||||
category = 'UNKNOWN'
|
||||
if operation_type in SQLConfig.DDL_OPERATIONS:
|
||||
category = 'DDL'
|
||||
elif operation_type in SQLConfig.DML_OPERATIONS:
|
||||
category = 'DML'
|
||||
elif operation_type in SQLConfig.METADATA_OPERATIONS:
|
||||
category = 'METADATA'
|
||||
|
||||
# 基本的表名提取
|
||||
tables = []
|
||||
for i, word in enumerate(parts):
|
||||
if word in {'FROM', 'JOIN', 'UPDATE', 'INTO', 'TABLE'}:
|
||||
if i + 1 < len(parts):
|
||||
table = parts[i + 1].strip('`;')
|
||||
if table not in {'SELECT', 'WHERE', 'SET'}:
|
||||
tables.append(table)
|
||||
|
||||
# 简单检查WHERE子句
|
||||
has_where = 'WHERE' in sql_upper
|
||||
|
||||
# 简单检查LIMIT子句
|
||||
has_limit = 'LIMIT' in sql_upper
|
||||
|
||||
return {
|
||||
'operation_type': operation_type,
|
||||
'tables': list(set(tables)),
|
||||
'has_where': has_where,
|
||||
'has_limit': has_limit,
|
||||
'is_valid': bool(operation_type),
|
||||
'normalized_query': sql_query,
|
||||
'category': category
|
||||
}
|
||||
163
src/server.py
163
src/server.py
@ -1,12 +1,19 @@
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
import os
|
||||
import logging
|
||||
import asyncio
|
||||
from dotenv import load_dotenv
|
||||
import atexit
|
||||
import signal
|
||||
import importlib
|
||||
import pkgutil
|
||||
import inspect
|
||||
import threading
|
||||
|
||||
# 加载环境变量 - 移到最前面确保所有模块导入前环境变量已加载
|
||||
load_dotenv()
|
||||
|
||||
# 导入自定义模块 - 确保在load_dotenv之后导入
|
||||
from src.config import ServerConfig, SecurityConfig, DatabaseConfig, ConnectionPoolConfig
|
||||
from src.tools.mysql_tool import register_mysql_tool
|
||||
from src.tools.mysql_metadata_tool import register_metadata_tools
|
||||
from src.tools.mysql_info_tool import register_info_tools
|
||||
@ -21,23 +28,23 @@ logger = logging.getLogger("mysql_server")
|
||||
|
||||
# 记录环境变量加载情况
|
||||
logger.debug("已加载环境变量")
|
||||
logger.debug(f"当前允许的风险等级: {os.getenv('ALLOWED_RISK_LEVELS', '未设置')}")
|
||||
logger.debug(f"当前环境类型: {os.getenv('ENV_TYPE', 'development')}")
|
||||
logger.debug(f"是否允许敏感信息查询: {os.getenv('ALLOW_SENSITIVE_INFO', 'false')}")
|
||||
logger.debug(f"当前允许的风险等级: {SecurityConfig.ALLOWED_RISK_LEVELS_STR}")
|
||||
logger.debug(f"当前环境类型: {SecurityConfig.ENV_TYPE.value}")
|
||||
logger.debug(f"是否允许敏感信息查询: {SecurityConfig.ALLOW_SENSITIVE_INFO}")
|
||||
|
||||
# 尝试导入MySQL连接器
|
||||
try:
|
||||
import mysql.connector
|
||||
logger.debug("MySQL连接器导入成功")
|
||||
import aiomysql
|
||||
logger.debug("aiomysql连接器导入成功")
|
||||
mysql_available = True
|
||||
except ImportError as e:
|
||||
logger.critical(f"无法导入MySQL连接器: {str(e)}")
|
||||
logger.critical("请确保已安装mysql-connector-python包: pip install mysql-connector-python")
|
||||
logger.critical(f"无法导入aiomysql连接器: {str(e)}")
|
||||
logger.critical("请确保已安装aiomysql包: pip install aiomysql")
|
||||
mysql_available = False
|
||||
|
||||
# 从环境变量获取服务器配置
|
||||
host = os.getenv('HOST', '127.0.0.1')
|
||||
port = int(os.getenv('PORT', '3000'))
|
||||
# 从配置获取服务器配置
|
||||
host = ServerConfig.HOST
|
||||
port = ServerConfig.PORT
|
||||
logger.debug(f"服务器配置: host={host}, port={port}")
|
||||
|
||||
# 创建MCP服务器实例
|
||||
@ -45,20 +52,120 @@ logger.debug("正在创建MCP服务器实例...")
|
||||
mcp = FastMCP("MySQL Query Server", "cccccccccc", host=host, port=port, debug=True, endpoint='/sse')
|
||||
logger.debug("MCP服务器实例创建完成")
|
||||
|
||||
def auto_register_tools(mcp):
|
||||
"""
|
||||
自动扫描src.tools目录下所有register_开头的函数并注册到mcp
|
||||
"""
|
||||
import src.tools
|
||||
package = src.tools
|
||||
for finder, name, ispkg in pkgutil.iter_modules(package.__path__, package.__name__ + "."):
|
||||
if ispkg:
|
||||
continue
|
||||
module = importlib.import_module(name)
|
||||
for func_name, func in inspect.getmembers(module, inspect.isfunction):
|
||||
if func_name.startswith("register_") and func_name.endswith("tool") or func_name.endswith("tools"):
|
||||
try:
|
||||
func(mcp)
|
||||
logger.info(f"自动注册工具: {name}.{func_name}")
|
||||
except Exception as e:
|
||||
logger.error(f"自动注册工具失败: {name}.{func_name} - {e}")
|
||||
|
||||
# 注册MySQL基础查询工具
|
||||
register_mysql_tool(mcp)
|
||||
auto_register_tools(mcp)
|
||||
logger.debug("已自动注册所有MySQL工具")
|
||||
|
||||
# 注册MySQL元数据查询工具
|
||||
register_metadata_tools(mcp)
|
||||
logger.debug("已注册元数据查询工具")
|
||||
# 启动连接池定时回收任务
|
||||
def _start_pool_cleanup_task():
|
||||
"""启动后台线程定期回收连接池资源"""
|
||||
import time
|
||||
from src.db.mysql_operations import _cleanup_unused_pools
|
||||
def _loop():
|
||||
while True:
|
||||
try:
|
||||
_cleanup_unused_pools()
|
||||
except Exception as e:
|
||||
logger.warning(f"定时回收连接池异常: {e}")
|
||||
time.sleep(300) # 每5分钟回收一次
|
||||
t = threading.Thread(target=_loop, daemon=True)
|
||||
t.start()
|
||||
|
||||
# 注册MySQL数据库信息查询工具
|
||||
register_info_tools(mcp)
|
||||
logger.debug("已注册数据库信息查询工具")
|
||||
# 用于保存事件循环和初始化状态
|
||||
_server_data = {
|
||||
'loop': None,
|
||||
'db_initialized': False
|
||||
}
|
||||
|
||||
# 注册MySQL表结构高级查询工具
|
||||
register_schema_tools(mcp)
|
||||
logger.debug("已注册表结构高级查询工具")
|
||||
def cleanup_resources():
|
||||
"""清理资源,关闭连接池"""
|
||||
if _server_data['loop'] and _server_data['db_initialized']:
|
||||
try:
|
||||
# 导入连接池关闭函数
|
||||
from src.db.mysql_operations import close_all_pools
|
||||
|
||||
# 创建关闭任务并运行
|
||||
logger.info("正在关闭所有数据库连接池...")
|
||||
close_task = close_all_pools()
|
||||
|
||||
# 在当前事件循环中运行
|
||||
if _server_data['loop'].is_running():
|
||||
future = asyncio.run_coroutine_threadsafe(close_task, _server_data['loop'])
|
||||
future.result(timeout=5) # 等待最多5秒
|
||||
else:
|
||||
# 如果循环已经停止,创建新的循环运行清理任务
|
||||
temp_loop = asyncio.new_event_loop()
|
||||
temp_loop.run_until_complete(close_task)
|
||||
temp_loop.close()
|
||||
|
||||
logger.info("数据库连接池已关闭")
|
||||
except Exception as e:
|
||||
logger.error(f"关闭数据库连接池时出错: {str(e)}")
|
||||
|
||||
# 注册退出处理函数
|
||||
atexit.register(cleanup_resources)
|
||||
|
||||
# 注册信号处理
|
||||
def signal_handler(sig, frame):
|
||||
"""处理终止信号"""
|
||||
logger.info(f"收到信号 {sig},正在清理资源...")
|
||||
cleanup_resources()
|
||||
# 正常退出
|
||||
exit(0)
|
||||
|
||||
# 注册常见的终止信号
|
||||
signal.signal(signal.SIGINT, signal_handler) # Ctrl+C
|
||||
signal.signal(signal.SIGTERM, signal_handler) # kill
|
||||
|
||||
async def init_database():
|
||||
"""初始化数据库连接池"""
|
||||
try:
|
||||
from src.db.mysql_operations import init_db_pool, get_db_config
|
||||
|
||||
# 获取数据库配置
|
||||
db_config = get_db_config()
|
||||
|
||||
# 获取连接池配置
|
||||
pool_config = ConnectionPoolConfig.get_config()
|
||||
min_size = pool_config['minsize']
|
||||
max_size = pool_config['maxsize']
|
||||
|
||||
# 记录连接池配置
|
||||
logger.info(f"连接池配置: 最小连接数={min_size}, 最大连接数={max_size}, 回收时间={pool_config['pool_recycle']}秒")
|
||||
logger.info(f"连接池功能状态: {'启用' if pool_config['enabled'] else '禁用'}")
|
||||
|
||||
if not db_config.get('db'):
|
||||
logger.warning("未设置数据库名称,请检查环境变量MYSQL_DATABASE")
|
||||
print("警告: 未设置数据库名称,请检查环境变量MYSQL_DATABASE")
|
||||
# 初始化连接池但不要求指定数据库
|
||||
await init_db_pool(require_database=False)
|
||||
else:
|
||||
# 正常初始化连接池
|
||||
await init_db_pool()
|
||||
|
||||
logger.info("数据库连接池初始化完成")
|
||||
_server_data['db_initialized'] = True
|
||||
except Exception as e:
|
||||
logger.error(f"数据库连接池初始化失败: {str(e)}")
|
||||
print(f"警告: 数据库连接池初始化失败: {str(e)}")
|
||||
|
||||
def start_server():
|
||||
"""启动SSE服务器的同步包装器"""
|
||||
@ -68,12 +175,11 @@ def start_server():
|
||||
print(f"服务器监听在 {host}:{port}/sse")
|
||||
|
||||
try:
|
||||
# 检查MySQL配置是否有效
|
||||
from src.db.mysql_operations import get_db_config
|
||||
db_config = get_db_config()
|
||||
if mysql_available and not db_config['database']:
|
||||
logger.warning("未设置数据库名称,请检查环境变量MYSQL_DATABASE")
|
||||
print("警告: 未设置数据库名称,请检查环境变量MYSQL_DATABASE")
|
||||
# 检查MySQL配置是否有效并初始化连接池
|
||||
if mysql_available:
|
||||
# 使用事件循环执行异步初始化函数
|
||||
_server_data['loop'] = asyncio.get_event_loop()
|
||||
_server_data['loop'].run_until_complete(init_database())
|
||||
|
||||
# 使用run_app函数启动服务器
|
||||
logger.debug("调用mcp.run('sse')启动服务器...")
|
||||
@ -81,6 +187,9 @@ def start_server():
|
||||
except Exception as e:
|
||||
logger.exception(f"服务器运行时发生错误: {str(e)}")
|
||||
print(f"服务器运行时发生错误: {str(e)}")
|
||||
finally:
|
||||
# 确保资源被清理
|
||||
cleanup_resources()
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 确保初始化后工具才被注册
|
||||
|
||||
@ -5,12 +5,12 @@ MySQL元数据工具基类
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Union, TypeVar, Generic, Callable
|
||||
from typing import Any, Dict, List, Optional, Union, Callable
|
||||
import functools
|
||||
|
||||
from src.db.mysql_operations import get_db_connection, execute_query
|
||||
from src.validators import SQLValidators, ValidationError
|
||||
|
||||
T = TypeVar('T')
|
||||
logger = logging.getLogger("mysql_server")
|
||||
|
||||
class MySQLToolError(Exception):
|
||||
@ -50,8 +50,10 @@ class MetadataToolBase:
|
||||
Raises:
|
||||
ParameterValidationError: 当参数验证失败时
|
||||
"""
|
||||
if param_value is not None and not validator(param_value):
|
||||
raise ParameterValidationError(f"{param_name} - {error_message}")
|
||||
try:
|
||||
SQLValidators.validate_parameter(param_name, param_value, validator, "参数验证")
|
||||
except ValidationError as e:
|
||||
raise ParameterValidationError(str(e))
|
||||
|
||||
@staticmethod
|
||||
def format_results(results: List[Dict[str, Any]], operation_type: str = "元数据查询") -> str:
|
||||
@ -91,13 +93,22 @@ class MetadataToolBase:
|
||||
return await func(*args, **kwargs)
|
||||
except ParameterValidationError as e:
|
||||
logger.error(f"参数验证错误: {str(e)}")
|
||||
return json.dumps({"error": f"参数错误: {str(e)}"})
|
||||
return json.dumps({
|
||||
"error": f"参数错误: {str(e)}",
|
||||
"error_type": "ParameterValidationError"
|
||||
})
|
||||
except QueryExecutionError as e:
|
||||
logger.error(f"查询执行错误: {str(e)}")
|
||||
return json.dumps({"error": f"查询执行失败: {str(e)}"})
|
||||
return json.dumps({
|
||||
"error": f"查询执行失败: {str(e)}",
|
||||
"error_type": "QueryExecutionError"
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"未预期的错误: {str(e)}")
|
||||
return json.dumps({"error": f"操作失败: {str(e)}"})
|
||||
return json.dumps({
|
||||
"error": f"操作失败: {str(e)}",
|
||||
"error_type": "UnexpectedError"
|
||||
})
|
||||
return wrapper
|
||||
|
||||
@staticmethod
|
||||
@ -115,9 +126,9 @@ class MetadataToolBase:
|
||||
查询结果的JSON字符串
|
||||
"""
|
||||
try:
|
||||
with get_db_connection() as connection:
|
||||
async with get_db_connection() as connection:
|
||||
results = await execute_query(connection, query, params)
|
||||
return MetadataToolBase.format_results(results, operation_type)
|
||||
except Exception as e:
|
||||
logger.error(f"元数据查询执行失败: {str(e)}")
|
||||
raise QueryExecutionError(str(e))
|
||||
raise QueryExecutionError(str(e)) from e # 保留原始异常链
|
||||
@ -13,6 +13,7 @@ from mcp.server.fastmcp import FastMCP
|
||||
from .metadata_base_tool import MetadataToolBase, ParameterValidationError, QueryExecutionError
|
||||
from src.security.sql_analyzer import EnvironmentType
|
||||
from src.db.mysql_operations import get_db_connection, execute_query
|
||||
from src.validators import SQLValidators
|
||||
|
||||
logger = logging.getLogger("mysql_server")
|
||||
|
||||
@ -46,6 +47,10 @@ SENSITIVE_VARIABLE_PREFIXES = [
|
||||
"authentication", "secure", "credential", "token"
|
||||
]
|
||||
|
||||
# 变量名和值字段映射
|
||||
VARIABLE_NAME_FIELDS = ['Variable_name', 'variable_name', 'name', 'Name', 'key', 'Key', 'Setting']
|
||||
VALUE_FIELDS = ['Value', 'value', 'variable_value', 'val', 'setting', 'Setting_Value']
|
||||
|
||||
def check_environment_permission(env_type: EnvironmentType, query_type: str) -> bool:
|
||||
"""
|
||||
检查当前环境是否允许执行特定类型的查询
|
||||
@ -92,18 +97,26 @@ def filter_sensitive_info(results: List[Dict[str, Any]], filter_patterns: List[s
|
||||
# 复制一份,避免修改原始数据
|
||||
filtered_item = item.copy()
|
||||
|
||||
# 检查常见的变量名字段
|
||||
for field in ['Variable_name', 'variable_name', 'name']:
|
||||
# 确定哪个字段包含变量名
|
||||
name_field = None
|
||||
for field in VARIABLE_NAME_FIELDS:
|
||||
if field in filtered_item:
|
||||
var_name = filtered_item[field].lower()
|
||||
# 检查是否匹配敏感模式
|
||||
is_sensitive = any(re.search(pattern, var_name, re.IGNORECASE) for pattern in filter_patterns)
|
||||
name_field = field
|
||||
break
|
||||
|
||||
if is_sensitive:
|
||||
# 敏感信息,隐藏具体的值
|
||||
for value_field in ['Value', 'value', 'variable_value']:
|
||||
if value_field in filtered_item:
|
||||
filtered_item[value_field] = '*** HIDDEN ***'
|
||||
# 如果找到变量名字段,检查是否敏感
|
||||
if name_field:
|
||||
var_name = str(filtered_item[name_field]).lower()
|
||||
# 检查是否匹配敏感模式
|
||||
is_sensitive = any(re.search(pattern, var_name, re.IGNORECASE) for pattern in filter_patterns)
|
||||
|
||||
if is_sensitive:
|
||||
# 找出所有可能的值字段
|
||||
for value_field in VALUE_FIELDS:
|
||||
if value_field in filtered_item:
|
||||
# 敏感信息,隐藏具体的值
|
||||
filtered_item[value_field] = '*** HIDDEN ***'
|
||||
logger.debug(f"已隐藏敏感变量 '{var_name}' 的值")
|
||||
|
||||
filtered_results.append(filtered_item)
|
||||
|
||||
@ -136,21 +149,21 @@ def register_info_tools(mcp: FastMCP):
|
||||
if pattern:
|
||||
MetadataToolBase.validate_parameter(
|
||||
"pattern", pattern,
|
||||
lambda x: re.match(r'^[a-zA-Z0-9_%]+$', x),
|
||||
SQLValidators.validate_like_pattern,
|
||||
"模式只能包含字母、数字、下划线和通配符(%_)"
|
||||
)
|
||||
|
||||
MetadataToolBase.validate_parameter(
|
||||
"limit", limit,
|
||||
lambda x: isinstance(x, int) and x >= 0,
|
||||
lambda x: SQLValidators.validate_integer(x, min_value=0),
|
||||
"返回结果的最大数量必须是非负整数"
|
||||
)
|
||||
|
||||
# 构建基础查询
|
||||
query = "SHOW DATABASES"
|
||||
|
||||
# 执行查询
|
||||
with get_db_connection() as connection:
|
||||
# 执行查询 - 使用异步上下文管理器,不要求预先指定数据库
|
||||
async with get_db_connection(require_database=False) as connection:
|
||||
# 先获取所有数据库
|
||||
results = await execute_query(connection, query)
|
||||
|
||||
@ -227,7 +240,7 @@ def register_info_tools(mcp: FastMCP):
|
||||
if pattern:
|
||||
MetadataToolBase.validate_parameter(
|
||||
"pattern", pattern,
|
||||
lambda x: re.match(r'^[a-zA-Z0-9_%]+$', x),
|
||||
SQLValidators.validate_like_pattern,
|
||||
"变量模式只能包含字母、数字、下划线和通配符(%_)"
|
||||
)
|
||||
|
||||
@ -239,7 +252,7 @@ def register_info_tools(mcp: FastMCP):
|
||||
|
||||
logger.debug(f"执行查询: {query}")
|
||||
|
||||
with get_db_connection() as connection:
|
||||
async with get_db_connection() as connection:
|
||||
results = await execute_query(connection, query)
|
||||
|
||||
# 生产环境中过滤敏感信息
|
||||
@ -273,7 +286,7 @@ def register_info_tools(mcp: FastMCP):
|
||||
if pattern:
|
||||
MetadataToolBase.validate_parameter(
|
||||
"pattern", pattern,
|
||||
lambda x: re.match(r'^[a-zA-Z0-9_%]+$', x),
|
||||
SQLValidators.validate_like_pattern,
|
||||
"状态模式只能包含字母、数字、下划线和通配符(%_)"
|
||||
)
|
||||
|
||||
@ -285,7 +298,7 @@ def register_info_tools(mcp: FastMCP):
|
||||
|
||||
logger.debug(f"执行查询: {query}")
|
||||
|
||||
with get_db_connection() as connection:
|
||||
async with get_db_connection() as connection:
|
||||
results = await execute_query(connection, query)
|
||||
|
||||
# 生产环境中过滤敏感信息
|
||||
@ -293,40 +306,3 @@ def register_info_tools(mcp: FastMCP):
|
||||
results = filter_sensitive_info(results)
|
||||
|
||||
return MetadataToolBase.format_results(results, operation_type="服务器状态查询")
|
||||
|
||||
# 工具函数: 用于参数验证
|
||||
def validate_pattern(pattern: str) -> bool:
|
||||
"""
|
||||
验证模式字符串是否安全 (防止SQL注入)
|
||||
|
||||
Args:
|
||||
pattern: 要验证的模式字符串
|
||||
|
||||
Returns:
|
||||
如果模式安全返回True,否则抛出ValueError
|
||||
|
||||
Raises:
|
||||
ValueError: 当模式包含不安全字符时
|
||||
"""
|
||||
# 仅允许字母、数字、下划线和通配符(% 和 _)
|
||||
if not re.match(r'^[a-zA-Z0-9_%]+$', pattern):
|
||||
raise ValueError("模式只能包含字母、数字、下划线和通配符(%_)")
|
||||
return True
|
||||
|
||||
def validate_engine_name(name: str) -> bool:
|
||||
"""
|
||||
验证存储引擎名称是否合法安全
|
||||
|
||||
Args:
|
||||
name: 要验证的引擎名称
|
||||
|
||||
Returns:
|
||||
如果引擎名称安全返回True,否则抛出ValueError
|
||||
|
||||
Raises:
|
||||
ValueError: 当引擎名称包含不安全字符时
|
||||
"""
|
||||
# 仅允许字母、数字和下划线
|
||||
if not re.match(r'^[a-zA-Z0-9_]+$', name):
|
||||
raise ValueError(f"无效的引擎名称: {name}, 引擎名称只能包含字母、数字和下划线")
|
||||
return True
|
||||
@ -11,64 +11,10 @@ from mcp.server.fastmcp import FastMCP
|
||||
|
||||
from .metadata_base_tool import MetadataToolBase, ParameterValidationError, QueryExecutionError
|
||||
from src.db.mysql_operations import get_db_connection, execute_query
|
||||
from src.validators import SQLValidators
|
||||
|
||||
logger = logging.getLogger("mysql_server")
|
||||
|
||||
# 工具函数: 用于参数验证
|
||||
def validate_pattern(pattern: str) -> bool:
|
||||
"""
|
||||
验证模式字符串是否安全 (防止SQL注入)
|
||||
|
||||
Args:
|
||||
pattern: 要验证的模式字符串
|
||||
|
||||
Returns:
|
||||
如果模式安全返回True,否则抛出ValueError
|
||||
|
||||
Raises:
|
||||
ValueError: 当模式包含不安全字符时
|
||||
"""
|
||||
# 仅允许字母、数字、下划线和通配符(% 和 _)
|
||||
if not re.match(r'^[a-zA-Z0-9_%]+$', pattern):
|
||||
raise ValueError("模式只能包含字母、数字、下划线和通配符(%_)")
|
||||
return True
|
||||
|
||||
def validate_table_name(name: str) -> bool:
|
||||
"""
|
||||
验证表名是否合法安全
|
||||
|
||||
Args:
|
||||
name: 要验证的表名
|
||||
|
||||
Returns:
|
||||
如果表名安全返回True,否则抛出ValueError
|
||||
|
||||
Raises:
|
||||
ValueError: 当表名包含不安全字符时
|
||||
"""
|
||||
# 仅允许字母、数字和下划线
|
||||
if not re.match(r'^[a-zA-Z0-9_]+$', name):
|
||||
raise ValueError(f"无效的表名: {name}, 表名只能包含字母、数字和下划线")
|
||||
return True
|
||||
|
||||
def validate_database_name(name: str) -> bool:
|
||||
"""
|
||||
验证数据库名是否合法安全
|
||||
|
||||
Args:
|
||||
name: 要验证的数据库名
|
||||
|
||||
Returns:
|
||||
如果数据库名安全返回True,否则抛出ValueError
|
||||
|
||||
Raises:
|
||||
ValueError: 当数据库名包含不安全字符时
|
||||
"""
|
||||
# 仅允许字母、数字和下划线
|
||||
if not re.match(r'^[a-zA-Z0-9_]+$', name):
|
||||
raise ValueError(f"无效的数据库名: {name}, 数据库名只能包含字母、数字和下划线")
|
||||
return True
|
||||
|
||||
def register_metadata_tools(mcp: FastMCP):
|
||||
"""
|
||||
注册MySQL元数据查询工具到MCP服务器
|
||||
@ -98,20 +44,20 @@ def register_metadata_tools(mcp: FastMCP):
|
||||
if database:
|
||||
MetadataToolBase.validate_parameter(
|
||||
"database", database,
|
||||
lambda x: re.match(r'^[a-zA-Z0-9_]+$', x),
|
||||
SQLValidators.validate_database_name,
|
||||
"数据库名称只能包含字母、数字和下划线"
|
||||
)
|
||||
|
||||
if pattern:
|
||||
MetadataToolBase.validate_parameter(
|
||||
"pattern", pattern,
|
||||
lambda x: re.match(r'^[a-zA-Z0-9_%]+$', x),
|
||||
SQLValidators.validate_like_pattern,
|
||||
"模式只能包含字母、数字、下划线和通配符(%_)"
|
||||
)
|
||||
|
||||
MetadataToolBase.validate_parameter(
|
||||
"limit", limit,
|
||||
lambda x: isinstance(x, int) and x >= 0,
|
||||
lambda x: SQLValidators.validate_integer(x, min_value=0),
|
||||
"返回结果的最大数量必须是非负整数"
|
||||
)
|
||||
|
||||
@ -124,8 +70,8 @@ def register_metadata_tools(mcp: FastMCP):
|
||||
|
||||
logger.debug(f"执行查询: {base_query}")
|
||||
|
||||
# 执行查询
|
||||
with get_db_connection() as connection:
|
||||
# 执行查询 - 使用异步上下文管理器
|
||||
async with get_db_connection() as connection:
|
||||
results = await execute_query(connection, base_query)
|
||||
|
||||
# 如果需要排除视图,且使用的是SHOW FULL TABLES
|
||||
@ -133,17 +79,46 @@ def register_metadata_tools(mcp: FastMCP):
|
||||
filtered_results = []
|
||||
|
||||
# 查找表名和表类型字段
|
||||
fields = list(results[0].keys()) if results else []
|
||||
table_field = fields[0] if fields else None
|
||||
table_type_field = fields[1] if len(fields) > 1 else None
|
||||
if results:
|
||||
# 确定表名和表类型字段名
|
||||
field_names = list(results[0].keys())
|
||||
table_type_field = None
|
||||
table_field = None
|
||||
|
||||
if table_field and table_type_field:
|
||||
# 基表类型通常是"BASE TABLE"
|
||||
for item in results:
|
||||
if item[table_type_field] == 'BASE TABLE':
|
||||
filtered_results.append(item)
|
||||
# 查找表类型字段 - 这通常是'Table_type',但也检查其他可能的名称
|
||||
possible_type_fields = ['Table_type', 'table_type', 'type']
|
||||
for field in possible_type_fields:
|
||||
if field in field_names:
|
||||
table_type_field = field
|
||||
break
|
||||
|
||||
# 查找表名字段 - 这可能是结果中的第一个字段
|
||||
for field in field_names:
|
||||
if field != table_type_field: # 表名不会是类型字段
|
||||
if field.lower() in ['table', 'name', 'table_name', 'tables_in_']:
|
||||
table_field = field
|
||||
break
|
||||
|
||||
# 如果没找到明确的表名字段,使用第一个非类型字段
|
||||
if not table_field and len(field_names) > 0:
|
||||
for field in field_names:
|
||||
if field != table_type_field:
|
||||
table_field = field
|
||||
break
|
||||
|
||||
# 只有当我们能确定表名和类型字段时才进行过滤
|
||||
if table_field and table_type_field:
|
||||
logger.debug(f"表名字段: {table_field}, 表类型字段: {table_type_field}")
|
||||
# 只保留基表 (BASE TABLE),排除视图和其他对象
|
||||
for item in results:
|
||||
if item[table_type_field] == 'BASE TABLE':
|
||||
filtered_results.append(item)
|
||||
else:
|
||||
# 如果无法确定字段,保留所有结果并记录警告
|
||||
logger.warning("无法确定表类型字段,无法排除视图")
|
||||
filtered_results = results
|
||||
else:
|
||||
filtered_results = results
|
||||
filtered_results = []
|
||||
else:
|
||||
filtered_results = results
|
||||
|
||||
@ -189,14 +164,14 @@ def register_metadata_tools(mcp: FastMCP):
|
||||
# 参数验证
|
||||
MetadataToolBase.validate_parameter(
|
||||
"table", table,
|
||||
lambda x: re.match(r'^[a-zA-Z0-9_]+$', x),
|
||||
SQLValidators.validate_table_name,
|
||||
"表名只能包含字母、数字和下划线"
|
||||
)
|
||||
|
||||
if database:
|
||||
MetadataToolBase.validate_parameter(
|
||||
"database", database,
|
||||
lambda x: re.match(r'^[a-zA-Z0-9_]+$', x),
|
||||
SQLValidators.validate_database_name,
|
||||
"数据库名称只能包含字母、数字和下划线"
|
||||
)
|
||||
|
||||
@ -221,22 +196,21 @@ def register_metadata_tools(mcp: FastMCP):
|
||||
# 参数验证
|
||||
MetadataToolBase.validate_parameter(
|
||||
"table", table,
|
||||
lambda x: re.match(r'^[a-zA-Z0-9_]+$', x),
|
||||
SQLValidators.validate_table_name,
|
||||
"表名只能包含字母、数字和下划线"
|
||||
)
|
||||
|
||||
if database:
|
||||
MetadataToolBase.validate_parameter(
|
||||
"database", database,
|
||||
lambda x: re.match(r'^[a-zA-Z0-9_]+$', x),
|
||||
SQLValidators.validate_database_name,
|
||||
"数据库名称只能包含字母、数字和下划线"
|
||||
)
|
||||
|
||||
# DESCRIBE 语句与 SHOW COLUMNS 功能类似,但结果格式可能略有不同
|
||||
query = f"DESCRIBE `{table}`" if not database else f"DESCRIBE `{database}`.`{table}`"
|
||||
logger.debug(f"执行查询: {query}")
|
||||
|
||||
return await MetadataToolBase.execute_metadata_query(query, operation_type="表结构描述")
|
||||
return await MetadataToolBase.execute_metadata_query(query, operation_type="表结构描述查询")
|
||||
|
||||
@mcp.tool()
|
||||
@MetadataToolBase.handle_query_error
|
||||
@ -254,14 +228,14 @@ def register_metadata_tools(mcp: FastMCP):
|
||||
# 参数验证
|
||||
MetadataToolBase.validate_parameter(
|
||||
"table", table,
|
||||
lambda x: re.match(r'^[a-zA-Z0-9_]+$', x),
|
||||
SQLValidators.validate_table_name,
|
||||
"表名只能包含字母、数字和下划线"
|
||||
)
|
||||
|
||||
if database:
|
||||
MetadataToolBase.validate_parameter(
|
||||
"database", database,
|
||||
lambda x: re.match(r'^[a-zA-Z0-9_]+$', x),
|
||||
SQLValidators.validate_database_name,
|
||||
"数据库名称只能包含字母、数字和下划线"
|
||||
)
|
||||
|
||||
|
||||
@ -6,74 +6,21 @@ MySQL表结构高级查询工具
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
|
||||
from .metadata_base_tool import MetadataToolBase, ParameterValidationError, QueryExecutionError
|
||||
from src.db.mysql_operations import get_db_connection, execute_query
|
||||
from src.validators import SQLValidators
|
||||
|
||||
logger = logging.getLogger("mysql_server")
|
||||
|
||||
# 参数验证函数
|
||||
def validate_table_name(name: str) -> bool:
|
||||
"""
|
||||
验证表名是否合法安全
|
||||
|
||||
Args:
|
||||
name: 要验证的表名
|
||||
|
||||
Returns:
|
||||
如果表名安全返回True,否则抛出ValueError
|
||||
|
||||
Raises:
|
||||
ValueError: 当表名包含不安全字符时
|
||||
"""
|
||||
# 仅允许字母、数字和下划线
|
||||
if not re.match(r'^[a-zA-Z0-9_]+$', name):
|
||||
raise ValueError(f"无效的表名: {name}, 表名只能包含字母、数字和下划线")
|
||||
return True
|
||||
|
||||
def validate_database_name(name: str) -> bool:
|
||||
"""
|
||||
验证数据库名是否合法安全
|
||||
|
||||
Args:
|
||||
name: 要验证的数据库名
|
||||
|
||||
Returns:
|
||||
如果数据库名安全返回True,否则抛出ValueError
|
||||
|
||||
Raises:
|
||||
ValueError: 当数据库名包含不安全字符时
|
||||
"""
|
||||
# 仅允许字母、数字和下划线
|
||||
if not re.match(r'^[a-zA-Z0-9_]+$', name):
|
||||
raise ValueError(f"无效的数据库名: {name}, 数据库名只能包含字母、数字和下划线")
|
||||
return True
|
||||
|
||||
def validate_column_name(name: str) -> bool:
|
||||
"""
|
||||
验证列名是否合法安全
|
||||
|
||||
Args:
|
||||
name: 要验证的列名
|
||||
|
||||
Returns:
|
||||
如果列名安全返回True,否则抛出ValueError
|
||||
|
||||
Raises:
|
||||
ValueError: 当列名包含不安全字符时
|
||||
"""
|
||||
# 仅允许字母、数字和下划线
|
||||
if not re.match(r'^[a-zA-Z0-9_]+$', name):
|
||||
raise ValueError(f"无效的列名: {name}, 列名只能包含字母、数字和下划线")
|
||||
return True
|
||||
|
||||
async def execute_schema_query(
|
||||
query: str,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
operation_type: str = "元数据查询"
|
||||
operation_type: str = "元数据查询",
|
||||
stream_results: bool = False,
|
||||
batch_size: int = 1000
|
||||
) -> str:
|
||||
"""
|
||||
执行表结构查询
|
||||
@ -82,12 +29,20 @@ async def execute_schema_query(
|
||||
query: SQL查询语句
|
||||
params: 查询参数 (可选)
|
||||
operation_type: 操作类型描述
|
||||
stream_results: 是否使用流式处理获取大型结果集
|
||||
batch_size: 批处理大小,分批获取结果时的每批记录数量
|
||||
|
||||
Returns:
|
||||
查询结果的JSON字符串
|
||||
"""
|
||||
with get_db_connection() as connection:
|
||||
results = await execute_query(connection, query, params)
|
||||
async with get_db_connection() as connection:
|
||||
results = await execute_query(
|
||||
connection,
|
||||
query,
|
||||
params,
|
||||
batch_size=batch_size,
|
||||
stream_results=stream_results
|
||||
)
|
||||
return MetadataToolBase.format_results(results, operation_type)
|
||||
|
||||
def register_schema_tools(mcp: FastMCP):
|
||||
@ -113,10 +68,18 @@ def register_schema_tools(mcp: FastMCP):
|
||||
表索引信息的JSON字符串
|
||||
"""
|
||||
# 参数验证
|
||||
validate_table_name(table)
|
||||
MetadataToolBase.validate_parameter(
|
||||
"table", table,
|
||||
SQLValidators.validate_table_name,
|
||||
"表名只能包含字母、数字和下划线"
|
||||
)
|
||||
|
||||
if database:
|
||||
validate_database_name(database)
|
||||
MetadataToolBase.validate_parameter(
|
||||
"database", database,
|
||||
SQLValidators.validate_database_name,
|
||||
"数据库名称只能包含字母、数字和下划线"
|
||||
)
|
||||
|
||||
# 构建查询
|
||||
table_ref = f"`{table}`" if not database else f"`{database}`.`{table}`"
|
||||
@ -141,10 +104,18 @@ def register_schema_tools(mcp: FastMCP):
|
||||
"""
|
||||
# 参数验证
|
||||
if database:
|
||||
validate_database_name(database)
|
||||
MetadataToolBase.validate_parameter(
|
||||
"database", database,
|
||||
SQLValidators.validate_database_name,
|
||||
"数据库名称只能包含字母、数字和下划线"
|
||||
)
|
||||
|
||||
if like_pattern:
|
||||
validate_column_name(like_pattern)
|
||||
MetadataToolBase.validate_parameter(
|
||||
"like_pattern", like_pattern,
|
||||
SQLValidators.validate_like_pattern,
|
||||
"模式只能包含字母、数字、下划线和通配符(%_)"
|
||||
)
|
||||
|
||||
# 构建查询
|
||||
if database:
|
||||
@ -174,16 +145,24 @@ def register_schema_tools(mcp: FastMCP):
|
||||
表外键约束信息的JSON字符串
|
||||
"""
|
||||
# 参数验证
|
||||
validate_table_name(table)
|
||||
MetadataToolBase.validate_parameter(
|
||||
"table", table,
|
||||
SQLValidators.validate_table_name,
|
||||
"表名只能包含字母、数字和下划线"
|
||||
)
|
||||
|
||||
if database:
|
||||
validate_database_name(database)
|
||||
MetadataToolBase.validate_parameter(
|
||||
"database", database,
|
||||
SQLValidators.validate_database_name,
|
||||
"数据库名称只能包含字母、数字和下划线"
|
||||
)
|
||||
|
||||
# 确定数据库名
|
||||
db_name = database
|
||||
if not db_name:
|
||||
# 获取当前数据库
|
||||
with get_db_connection() as connection:
|
||||
# 获取当前数据库 - 使用异步上下文管理器
|
||||
async with get_db_connection() as connection:
|
||||
current_db_results = await execute_query(connection, "SELECT DATABASE() as db")
|
||||
if current_db_results and 'db' in current_db_results[0]:
|
||||
db_name = current_db_results[0]['db']
|
||||
@ -191,7 +170,7 @@ def register_schema_tools(mcp: FastMCP):
|
||||
if not db_name:
|
||||
raise ValueError("无法确定数据库名称,请明确指定database参数")
|
||||
|
||||
# 使用INFORMATION_SCHEMA查询外键
|
||||
# 使用INFORMATION_SCHEMA查询外键 - 修改为使用命名参数
|
||||
query = """
|
||||
SELECT
|
||||
CONSTRAINT_NAME,
|
||||
@ -208,16 +187,19 @@ def register_schema_tools(mcp: FastMCP):
|
||||
ON
|
||||
kcu.CONSTRAINT_NAME = rc.CONSTRAINT_NAME
|
||||
WHERE
|
||||
kcu.TABLE_SCHEMA = %s
|
||||
AND kcu.TABLE_NAME = %s
|
||||
kcu.TABLE_SCHEMA = %(table_schema)s
|
||||
AND kcu.TABLE_NAME = %(table_name)s
|
||||
AND kcu.REFERENCED_TABLE_NAME IS NOT NULL
|
||||
"""
|
||||
params = {'TABLE_SCHEMA': db_name, 'TABLE_NAME': table}
|
||||
|
||||
logger.debug(f"执行查询: 获取表 {db_name}.{table} 的外键约束")
|
||||
# 使用命名参数,键名与SQL中的占位符对应
|
||||
params = {"table_schema": db_name, "table_name": table}
|
||||
|
||||
logger.debug(f"执行外键查询: {query}")
|
||||
logger.debug(f"参数: {params}")
|
||||
|
||||
# 执行查询
|
||||
return await execute_schema_query(query, params, operation_type="外键约束查询")
|
||||
return await execute_schema_query(query, params, operation_type="表外键查询")
|
||||
|
||||
@mcp.tool()
|
||||
@MetadataToolBase.handle_query_error
|
||||
@ -236,47 +218,62 @@ def register_schema_tools(mcp: FastMCP):
|
||||
# 参数验证
|
||||
MetadataToolBase.validate_parameter(
|
||||
"page", page,
|
||||
lambda x: isinstance(x, int) and x > 0,
|
||||
lambda x: SQLValidators.validate_integer(x, min_value=1),
|
||||
"页码必须是正整数"
|
||||
)
|
||||
|
||||
MetadataToolBase.validate_parameter(
|
||||
"page_size", page_size,
|
||||
lambda x: isinstance(x, int) and 1 <= x <= 1000,
|
||||
"每页记录数必须在1-1000之间"
|
||||
lambda x: SQLValidators.validate_integer(x, min_value=1, max_value=1000),
|
||||
"每页记录数必须是正整数且不超过1000"
|
||||
)
|
||||
|
||||
# 检查查询语法
|
||||
if not query.strip().upper().startswith('SELECT'):
|
||||
raise ValueError("只支持SELECT查询的分页")
|
||||
|
||||
# 计算LIMIT和OFFSET
|
||||
# 计算偏移量
|
||||
offset = (page - 1) * page_size
|
||||
|
||||
# 在查询末尾添加LIMIT子句
|
||||
paginated_query = query.strip()
|
||||
if 'LIMIT' in paginated_query.upper():
|
||||
raise ValueError("查询已包含LIMIT子句,请移除后重试")
|
||||
# 分离基础查询和LIMIT/OFFSET部分
|
||||
base_query = query.strip()
|
||||
if re.search(r'\bLIMIT\b', base_query, re.IGNORECASE):
|
||||
raise ValueError("查询语句已包含LIMIT子句,不能与分页功能一起使用")
|
||||
|
||||
paginated_query += f" LIMIT {page_size} OFFSET {offset}"
|
||||
# 添加LIMIT和OFFSET
|
||||
paginated_query = f"{base_query} LIMIT {page_size} OFFSET {offset}"
|
||||
|
||||
logger.debug(f"执行分页查询: 页码={page}, 每页记录数={page_size}")
|
||||
logger.debug(f"执行分页查询: {paginated_query}")
|
||||
logger.debug(f"页码: {page}, 每页记录数: {page_size}, 偏移量: {offset}")
|
||||
|
||||
# 获取总记录数(用于计算总页数)
|
||||
count_query = f"SELECT COUNT(*) as total FROM ({query}) as temp_count_table"
|
||||
# 执行查询 - 使用异步上下文管理器
|
||||
async with get_db_connection() as connection:
|
||||
# 首先检查并验证查询
|
||||
# 确认查询安全性 - 限制查询类型,只允许SELECT查询
|
||||
if not base_query.strip().upper().startswith('SELECT'):
|
||||
raise ValueError("只支持SELECT查询进行分页")
|
||||
|
||||
with get_db_connection() as connection:
|
||||
# 执行分页查询
|
||||
# 使用普通查询获取当前页结果(不需要流式处理,因为已经有LIMIT限制)
|
||||
results = await execute_query(connection, paginated_query)
|
||||
|
||||
# 获取总记录数
|
||||
count_results = await execute_query(connection, count_query)
|
||||
total_records = count_results[0]['total'] if count_results else 0
|
||||
# 尝试获取总记录数 - 对于大型结果集使用流式处理
|
||||
try:
|
||||
# 由于无法参数化子查询,我们改为构建一个只返回计数的查询
|
||||
# 这仍有SQL注入风险,但我们已经验证查询只能是SELECT
|
||||
count_query = f"SELECT COUNT(*) as total FROM ({base_query}) as subquery"
|
||||
# 计数查询通常只返回一行,不需要流式处理
|
||||
count_results = await execute_query(connection, count_query)
|
||||
total = count_results[0]['total'] if count_results else 0
|
||||
|
||||
# 计算总页数
|
||||
total_pages = (total_records + page_size - 1) // page_size
|
||||
# 根据总记录数计算是否是大型结果集
|
||||
is_large_resultset = total > 1000
|
||||
|
||||
# 构建分页元数据
|
||||
# 提示用户结果集大小
|
||||
if is_large_resultset:
|
||||
logger.info(f"检测到大型结果集,共 {total} 条记录,建议使用较小的 page_size 值")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"无法执行总数查询: {str(e)}")
|
||||
total = None
|
||||
is_large_resultset = False
|
||||
|
||||
# 构造分页元数据
|
||||
pagination_info = {
|
||||
"metadata_info": {
|
||||
"operation_type": "分页查询",
|
||||
@ -284,8 +281,11 @@ def register_schema_tools(mcp: FastMCP):
|
||||
"pagination": {
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"total_records": total_records,
|
||||
"total_pages": total_pages
|
||||
"total_records": total,
|
||||
"total_pages": (total + page_size - 1) // page_size if total else None,
|
||||
"has_next": (page * page_size < total) if total is not None else len(results) == page_size,
|
||||
"has_previous": page > 1,
|
||||
"is_large_resultset": is_large_resultset if total is not None else None
|
||||
}
|
||||
},
|
||||
"results": results
|
||||
|
||||
@ -3,16 +3,14 @@ import logging
|
||||
from typing import Any, Dict, Optional
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
from src.db.mysql_operations import get_db_connection, execute_query
|
||||
import mysql.connector
|
||||
import aiomysql
|
||||
|
||||
from .metadata_base_tool import MetadataToolBase
|
||||
|
||||
logger = logging.getLogger("mysql_server")
|
||||
|
||||
# 尝试导入MySQL连接器
|
||||
try:
|
||||
mysql.connector
|
||||
mysql_available = True
|
||||
except ImportError:
|
||||
mysql_available = False
|
||||
# MySQL可用性检查变量,默认认为aiomysql已可用
|
||||
mysql_available = True
|
||||
|
||||
def register_mysql_tool(mcp: FastMCP):
|
||||
"""
|
||||
@ -24,6 +22,7 @@ def register_mysql_tool(mcp: FastMCP):
|
||||
logger.debug("注册MySQL查询工具...")
|
||||
|
||||
@mcp.tool()
|
||||
@MetadataToolBase.handle_query_error
|
||||
async def mysql_query(query: str, params: Optional[Dict[str, Any]] = None) -> str:
|
||||
"""
|
||||
执行MySQL查询并返回结果
|
||||
@ -37,18 +36,22 @@ def register_mysql_tool(mcp: FastMCP):
|
||||
"""
|
||||
logger.debug(f"执行MySQL查询: {query}, 参数: {params}")
|
||||
|
||||
try:
|
||||
with get_db_connection() as connection:
|
||||
results = await execute_query(connection, query, params)
|
||||
async with get_db_connection() as connection:
|
||||
results = await execute_query(connection, query, params)
|
||||
|
||||
# 检查是否是修改操作返回的影响行数
|
||||
operation = query.strip().split()[0].upper()
|
||||
if operation in {'UPDATE', 'DELETE', 'INSERT'} and results and 'affected_rows' in results[0]:
|
||||
affected_rows = results[0]['affected_rows']
|
||||
logger.info(f"{operation}操作影响了{affected_rows}行数据")
|
||||
# 检查是否是修改操作返回的影响行数
|
||||
operation = query.strip().split()[0].upper()
|
||||
if operation in {'UPDATE', 'DELETE', 'INSERT'} and results and 'affected_rows' in results[0]:
|
||||
affected_rows = results[0]['affected_rows']
|
||||
logger.info(f"{operation}操作影响了{affected_rows}行数据")
|
||||
|
||||
return json.dumps(results, default=str)
|
||||
# 添加元数据信息
|
||||
metadata_info = {
|
||||
"metadata_info": {
|
||||
"operation_type": operation,
|
||||
"result_count": len(results)
|
||||
},
|
||||
"results": results
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"执行查询时发生异常: {str(e)}")
|
||||
return json.dumps({"error": str(e)})
|
||||
return json.dumps(metadata_info, default=str)
|
||||
127
src/validators.py
Normal file
127
src/validators.py
Normal file
@ -0,0 +1,127 @@
|
||||
import re
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
class ValidationError(Exception):
|
||||
"""验证错误异常"""
|
||||
pass
|
||||
|
||||
class SQLValidators:
|
||||
"""SQL相关验证器集合"""
|
||||
|
||||
# 正则表达式常量
|
||||
IDENTIFIER_PATTERN = r'^[a-zA-Z0-9_]+$'
|
||||
PATTERN_PATTERN = r'^[a-zA-Z0-9_%]+$'
|
||||
|
||||
@staticmethod
|
||||
def validate_identifier(name: str, entity_type: str = "标识符") -> bool:
|
||||
"""
|
||||
验证SQL标识符是否合法安全(表名、数据库名、列名等)
|
||||
|
||||
Args:
|
||||
name: 要验证的标识符
|
||||
entity_type: 实体类型名称,用于错误信息
|
||||
|
||||
Returns:
|
||||
如果标识符安全返回True
|
||||
|
||||
Raises:
|
||||
ValidationError: 当标识符包含不安全字符时
|
||||
"""
|
||||
if not name:
|
||||
raise ValidationError(f"{entity_type}不能为空")
|
||||
|
||||
if not re.match(SQLValidators.IDENTIFIER_PATTERN, name):
|
||||
raise ValidationError(f"无效的{entity_type}: {name}, {entity_type}只能包含字母、数字和下划线")
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def validate_table_name(name: str) -> bool:
|
||||
"""验证表名是否合法安全"""
|
||||
return SQLValidators.validate_identifier(name, "表名")
|
||||
|
||||
@staticmethod
|
||||
def validate_database_name(name: str) -> bool:
|
||||
"""验证数据库名是否合法安全"""
|
||||
return SQLValidators.validate_identifier(name, "数据库名")
|
||||
|
||||
@staticmethod
|
||||
def validate_column_name(name: str) -> bool:
|
||||
"""验证列名是否合法安全"""
|
||||
return SQLValidators.validate_identifier(name, "列名")
|
||||
|
||||
@staticmethod
|
||||
def validate_like_pattern(pattern: str) -> bool:
|
||||
"""
|
||||
验证LIKE查询模式是否安全
|
||||
|
||||
Args:
|
||||
pattern: 要验证的模式字符串
|
||||
|
||||
Returns:
|
||||
如果模式安全返回True
|
||||
|
||||
Raises:
|
||||
ValidationError: 当模式包含不安全字符时
|
||||
"""
|
||||
if not pattern:
|
||||
raise ValidationError("模式不能为空")
|
||||
|
||||
if not re.match(SQLValidators.PATTERN_PATTERN, pattern):
|
||||
raise ValidationError(f"无效的模式: {pattern}, 模式只能包含字母、数字、下划线和通配符(%_)")
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def validate_integer(value: int, min_value: Optional[int] = None, max_value: Optional[int] = None) -> bool:
|
||||
"""
|
||||
验证整数值是否在允许范围内
|
||||
|
||||
Args:
|
||||
value: 要验证的整数值
|
||||
min_value: 最小允许值(可选)
|
||||
max_value: 最大允许值(可选)
|
||||
|
||||
Returns:
|
||||
如果值合法返回True
|
||||
|
||||
Raises:
|
||||
ValidationError: 当值不合法时
|
||||
"""
|
||||
if not isinstance(value, int):
|
||||
raise ValidationError(f"值必须是整数,当前类型: {type(value).__name__}")
|
||||
|
||||
if min_value is not None and value < min_value:
|
||||
raise ValidationError(f"值必须大于或等于 {min_value}")
|
||||
|
||||
if max_value is not None and value > max_value:
|
||||
raise ValidationError(f"值必须小于或等于 {max_value}")
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def validate_parameter(param_name: str, param_value: Any, validator: Callable, error_prefix: str = "") -> bool:
|
||||
"""
|
||||
通用参数验证函数
|
||||
|
||||
Args:
|
||||
param_name: 参数名称
|
||||
param_value: 参数值
|
||||
validator: 验证函数
|
||||
error_prefix: 错误信息前缀
|
||||
|
||||
Returns:
|
||||
如果验证通过返回True
|
||||
|
||||
Raises:
|
||||
ValidationError: 当验证失败时
|
||||
"""
|
||||
if param_value is None:
|
||||
return True # 允许None值
|
||||
|
||||
try:
|
||||
return validator(param_value)
|
||||
except ValidationError as e:
|
||||
prefix = f"{error_prefix}: " if error_prefix else ""
|
||||
raise ValidationError(f"{prefix}{param_name} - {str(e)}")
|
||||
except Exception as e:
|
||||
prefix = f"{error_prefix}: " if error_prefix else ""
|
||||
raise ValidationError(f"{prefix}{param_name} - 验证失败: {str(e)}")
|
||||
Reference in New Issue
Block a user