mirror of
https://github.com/mangooer/mysql-mcp-server-sse.git
synced 2025-12-08 17:52:28 +08:00
<feat> 增加元数据操作功能
This commit is contained in:
83
README.md
83
README.md
@ -16,11 +16,45 @@
|
|||||||
- 危险操作拦截
|
- 危险操作拦截
|
||||||
- WHERE子句强制检查
|
- WHERE子句强制检查
|
||||||
- 自动返回修改操作影响的行数
|
- 自动返回修改操作影响的行数
|
||||||
|
- 敏感信息保护机制
|
||||||
|
- 自动对元数据查询结果进行格式化和增强
|
||||||
|
|
||||||
|
## API接口功能
|
||||||
|
|
||||||
|
系统提供以下四大类工具:
|
||||||
|
|
||||||
|
### 基础查询工具
|
||||||
|
|
||||||
|
- `mysql_query`: 执行任意SQL查询,支持参数化查询
|
||||||
|
|
||||||
|
### 元数据查询工具
|
||||||
|
|
||||||
|
- `mysql_show_tables`: 获取数据库中的表列表,支持模式匹配和限制结果数量
|
||||||
|
- `mysql_show_columns`: 获取表的列信息
|
||||||
|
- `mysql_describe_table`: 描述表结构
|
||||||
|
- `mysql_show_create_table`: 获取表的创建语句
|
||||||
|
|
||||||
|
### 数据库信息查询工具
|
||||||
|
|
||||||
|
- `mysql_show_databases`: 获取所有数据库列表,支持过滤系统数据库
|
||||||
|
- `mysql_show_variables`: 获取MySQL服务器变量
|
||||||
|
- `mysql_show_status`: 获取MySQL服务器状态信息
|
||||||
|
|
||||||
|
### 表结构高级查询工具
|
||||||
|
|
||||||
|
- `mysql_show_indexes`: 获取表的索引信息
|
||||||
|
- `mysql_show_table_status`: 获取表状态信息
|
||||||
|
- `mysql_show_foreign_keys`: 获取表的外键约束信息
|
||||||
|
- `mysql_paginate_results`: 提供结果分页功能
|
||||||
|
|
||||||
## 系统要求
|
## 系统要求
|
||||||
|
|
||||||
- Python 3.6+
|
- Python 3.6+
|
||||||
- MySQL服务器
|
- MySQL服务器
|
||||||
|
- 依赖包:
|
||||||
|
- mysql-connector-python
|
||||||
|
- python-dotenv
|
||||||
|
- mcp (FastMCP框架)
|
||||||
|
|
||||||
## 安装步骤
|
## 安装步骤
|
||||||
|
|
||||||
@ -57,34 +91,33 @@ pip install -r requirements.txt
|
|||||||
- `ALLOWED_RISK_LEVELS`: 允许的风险等级(LOW/MEDIUM/HIGH/CRITICAL)
|
- `ALLOWED_RISK_LEVELS`: 允许的风险等级(LOW/MEDIUM/HIGH/CRITICAL)
|
||||||
- `BLOCKED_PATTERNS`: 禁止的SQL模式(正则表达式,用逗号分隔)
|
- `BLOCKED_PATTERNS`: 禁止的SQL模式(正则表达式,用逗号分隔)
|
||||||
- `ENABLE_QUERY_CHECK`: 是否启用SQL安全检查(true/false)
|
- `ENABLE_QUERY_CHECK`: 是否启用SQL安全检查(true/false)
|
||||||
|
- `ALLOW_SENSITIVE_INFO`: 是否允许查询敏感信息(true/false)
|
||||||
|
- `SENSITIVE_INFO_FIELDS`: 自定义敏感字段模式列表(逗号分隔)
|
||||||
|
|
||||||
## SQL安全机制
|
## 安全机制详解
|
||||||
|
|
||||||
### 风险等级控制
|
### 风险等级控制
|
||||||
- LOW: 查询操作(SELECT)
|
- LOW: 查询操作(SELECT)和元数据操作(SHOW, DESCRIBE等)
|
||||||
- MEDIUM: 基本数据修改(INSERT,有WHERE的UPDATE/DELETE)
|
- MEDIUM: 基本数据修改(INSERT,有WHERE的UPDATE/DELETE)
|
||||||
- HIGH: 结构变更(CREATE/ALTER)和无WHERE的UPDATE
|
- HIGH: 结构变更(CREATE/ALTER)和无WHERE的UPDATE
|
||||||
- CRITICAL: 危险操作(DROP/TRUNCATE)和无WHERE的DELETE操作
|
- CRITICAL: 危险操作(DROP/TRUNCATE)和无WHERE的DELETE操作
|
||||||
|
|
||||||
### 环境变量加载顺序
|
### 环境特性差异
|
||||||
项目使用python-dotenv加载环境变量,需要确保在导入其他模块前先加载环境变量,否则可能会导致配置未被正确应用。
|
- 开发环境:
|
||||||
|
- 允许较高风险的操作
|
||||||
|
- 不隐藏敏感信息
|
||||||
|
- 提供详细的错误信息
|
||||||
|
- 生产环境:
|
||||||
|
- 默认只允许LOW风险操作
|
||||||
|
- 严格限制数据修改
|
||||||
|
- 自动隐藏敏感信息
|
||||||
|
- 错误信息不暴露实现细节
|
||||||
|
|
||||||
### 安全检查机制
|
### 敏感信息保护
|
||||||
- 强制要求UPDATE/DELETE操作包含WHERE子句
|
系统会自动检测并隐藏包含以下关键词的变量/状态值:
|
||||||
- SQL语句语法检查
|
- password、auth、credential、key、secret、private
|
||||||
- 危险操作模式检测
|
- ssl、tls、cipher、certificate
|
||||||
- 自动识别受影响的表
|
- host、path、directory等系统路径信息
|
||||||
- 风险等级评估
|
|
||||||
|
|
||||||
### 环境特性
|
|
||||||
- 开发环境:允许较高风险的操作,但仍需遵守基本安全规则
|
|
||||||
- 生产环境:默认只允许LOW风险操作,严格限制数据修改
|
|
||||||
|
|
||||||
### 安全拦截
|
|
||||||
- SQL语句长度限制
|
|
||||||
- 危险操作模式检测
|
|
||||||
- 自动识别受影响的表
|
|
||||||
- SQL注入防护
|
|
||||||
|
|
||||||
### 事务管理
|
### 事务管理
|
||||||
- 对于修改操作(INSERT/UPDATE/DELETE)会自动提交事务
|
- 对于修改操作(INSERT/UPDATE/DELETE)会自动提交事务
|
||||||
@ -108,11 +141,17 @@ python src/server.py
|
|||||||
├── src/ # 源代码目录
|
├── src/ # 源代码目录
|
||||||
│ ├── server.py # 主服务器文件
|
│ ├── server.py # 主服务器文件
|
||||||
│ ├── db/ # 数据库相关代码
|
│ ├── db/ # 数据库相关代码
|
||||||
|
│ │ └── mysql_operations.py # MySQL操作实现
|
||||||
│ ├── security/ # SQL安全相关代码
|
│ ├── security/ # SQL安全相关代码
|
||||||
│ │ ├── interceptor.py # SQL拦截器
|
│ │ ├── interceptor.py # SQL拦截器
|
||||||
│ │ ├── query_limiter.py # SQL安全检查器
|
│ │ ├── query_limiter.py # SQL安全检查器
|
||||||
│ │ └── sql_analyzer.py # SQL分析器
|
│ │ └── sql_analyzer.py # SQL分析器
|
||||||
│ └── tools/ # 工具类代码
|
│ └── tools/ # 工具类代码
|
||||||
|
│ ├── mysql_tool.py # 基础查询工具
|
||||||
|
│ ├── mysql_metadata_tool.py # 元数据查询工具
|
||||||
|
│ ├── mysql_info_tool.py # 数据库信息查询工具
|
||||||
|
│ ├── mysql_schema_tool.py # 表结构高级查询工具
|
||||||
|
│ └── metadata_base_tool.py # 元数据工具基类
|
||||||
├── tests/ # 测试代码目录
|
├── tests/ # 测试代码目录
|
||||||
├── .env.example # 环境变量示例文件
|
├── .env.example # 环境变量示例文件
|
||||||
└── requirements.txt # 项目依赖文件
|
└── requirements.txt # 项目依赖文件
|
||||||
@ -136,6 +175,10 @@ python src/server.py
|
|||||||
- 如果需要执行高风险操作,相应地调整ALLOWED_RISK_LEVELS
|
- 如果需要执行高风险操作,相应地调整ALLOWED_RISK_LEVELS
|
||||||
- 对于不带WHERE条件的UPDATE或DELETE,可以添加条件(即使是WHERE 1=1)降低风险级别
|
- 对于不带WHERE条件的UPDATE或DELETE,可以添加条件(即使是WHERE 1=1)降低风险级别
|
||||||
|
|
||||||
|
### 无法查看敏感信息
|
||||||
|
- 在开发环境中,设置ALLOW_SENSITIVE_INFO=true
|
||||||
|
- 在生产环境中,敏感信息默认会被隐藏,这是安全特性
|
||||||
|
|
||||||
## 日志系统
|
## 日志系统
|
||||||
|
|
||||||
服务器包含完整的日志记录系统,可以在控制台和日志文件中查看运行状态和错误信息。日志级别可以在`server.py`中配置。
|
服务器包含完整的日志记录系统,可以在控制台和日志文件中查看运行状态和错误信息。日志级别可以在`server.py`中配置。
|
||||||
|
|||||||
@ -81,6 +81,7 @@ async def execute_query(connection, query: str, params: Optional[Dict[str, Any]]
|
|||||||
ValueError: 当查询执行失败时
|
ValueError: 当查询执行失败时
|
||||||
"""
|
"""
|
||||||
cursor = None
|
cursor = None
|
||||||
|
operation = None # 初始化操作类型变量
|
||||||
try:
|
try:
|
||||||
# 安全检查
|
# 安全检查
|
||||||
if not await sql_interceptor.check_operation(query):
|
if not await sql_interceptor.check_operation(query):
|
||||||
@ -105,6 +106,33 @@ async def execute_query(connection, query: str, params: Optional[Dict[str, Any]]
|
|||||||
logger.debug(f"修改操作 {operation} 影响了 {affected_rows} 行数据")
|
logger.debug(f"修改操作 {operation} 影响了 {affected_rows} 行数据")
|
||||||
return [{'affected_rows': affected_rows}]
|
return [{'affected_rows': affected_rows}]
|
||||||
|
|
||||||
|
# 处理元数据查询操作
|
||||||
|
if operation in sql_analyzer.metadata_operations:
|
||||||
|
# 获取结果集
|
||||||
|
results = cursor.fetchall()
|
||||||
|
|
||||||
|
# 没有结果时返回空列表但添加元信息
|
||||||
|
if not results:
|
||||||
|
logger.debug(f"元数据查询 {operation} 没有返回结果")
|
||||||
|
return [{'metadata_operation': operation, 'result_count': 0}]
|
||||||
|
|
||||||
|
# 优化结果格式 - 为元数据结果添加额外信息
|
||||||
|
metadata_results = []
|
||||||
|
for row in results:
|
||||||
|
# 对某些特定元数据查询进行特殊处理
|
||||||
|
if operation == 'SHOW' and 'Table' in row:
|
||||||
|
# SHOW TABLES 结果增强
|
||||||
|
row['table_name'] = row['Table']
|
||||||
|
elif operation in {'DESC', 'DESCRIBE'} and 'Field' in row:
|
||||||
|
# DESC/DESCRIBE 表结构结果增强
|
||||||
|
row['column_name'] = row['Field']
|
||||||
|
row['data_type'] = row['Type']
|
||||||
|
|
||||||
|
metadata_results.append(row)
|
||||||
|
|
||||||
|
logger.debug(f"元数据查询 {operation} 返回 {len(metadata_results)} 条结果")
|
||||||
|
return metadata_results
|
||||||
|
|
||||||
# 对于查询操作,返回结果集
|
# 对于查询操作,返回结果集
|
||||||
results = cursor.fetchall()
|
results = cursor.fetchall()
|
||||||
logger.debug(f"查询返回 {len(results)} 条结果")
|
logger.debug(f"查询返回 {len(results)} 条结果")
|
||||||
@ -115,7 +143,7 @@ async def execute_query(connection, query: str, params: Optional[Dict[str, Any]]
|
|||||||
raise
|
raise
|
||||||
except mysql.connector.Error as query_err:
|
except mysql.connector.Error as query_err:
|
||||||
# 如果发生错误,进行回滚
|
# 如果发生错误,进行回滚
|
||||||
if operation in {'UPDATE', 'DELETE', 'INSERT'}:
|
if operation and operation in {'UPDATE', 'DELETE', 'INSERT'}: # 确保operation已定义
|
||||||
try:
|
try:
|
||||||
connection.rollback()
|
connection.rollback()
|
||||||
logger.debug("事务已回滚")
|
logger.debug("事务已回滚")
|
||||||
|
|||||||
@ -46,7 +46,15 @@ class SQLInterceptor:
|
|||||||
raise SecurityException("SQL语句格式无效")
|
raise SecurityException("SQL语句格式无效")
|
||||||
|
|
||||||
operation = sql_parts[0].upper()
|
operation = sql_parts[0].upper()
|
||||||
if operation not in {'SELECT', 'INSERT', 'UPDATE', 'DELETE', 'CREATE', 'ALTER', 'DROP', 'TRUNCATE', 'MERGE'}:
|
# 更新支持的操作类型列表,包括元数据操作
|
||||||
|
supported_operations = {
|
||||||
|
'SELECT', 'INSERT', 'UPDATE', 'DELETE',
|
||||||
|
'CREATE', 'ALTER', 'DROP', 'TRUNCATE', 'MERGE',
|
||||||
|
'SHOW', 'DESC', 'DESCRIBE', 'EXPLAIN', 'HELP',
|
||||||
|
'ANALYZE', 'CHECK', 'CHECKSUM', 'OPTIMIZE'
|
||||||
|
}
|
||||||
|
|
||||||
|
if operation not in supported_operations:
|
||||||
raise SecurityException(f"不支持的SQL操作: {operation}")
|
raise SecurityException(f"不支持的SQL操作: {operation}")
|
||||||
|
|
||||||
# 分析SQL风险
|
# 分析SQL风险
|
||||||
@ -65,9 +73,14 @@ class SQLInterceptor:
|
|||||||
f"允许的风险等级: {[level.name for level in self.analyzer.allowed_risk_levels]}"
|
f"允许的风险等级: {[level.name for level in self.analyzer.allowed_risk_levels]}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 确定操作类型(DDL, DML 或 元数据)
|
||||||
|
operation_category = "元数据操作" if operation in self.analyzer.metadata_operations else (
|
||||||
|
"DDL操作" if operation in self.analyzer.ddl_operations else "DML操作"
|
||||||
|
)
|
||||||
|
|
||||||
# 记录详细日志
|
# 记录详细日志
|
||||||
logger.info(
|
logger.info(
|
||||||
f"SQL操作检查通过 - "
|
f"SQL{operation_category}检查通过 - "
|
||||||
f"操作: {risk_analysis['operation']}, "
|
f"操作: {risk_analysis['operation']}, "
|
||||||
f"风险等级: {risk_analysis['risk_level'].name}, "
|
f"风险等级: {risk_analysis['risk_level'].name}, "
|
||||||
f"影响表: {', '.join(risk_analysis['affected_tables'])}"
|
f"影响表: {', '.join(risk_analysis['affected_tables'])}"
|
||||||
|
|||||||
@ -38,6 +38,12 @@ class SQLOperationType:
|
|||||||
'SELECT', 'INSERT', 'UPDATE', 'DELETE', 'MERGE'
|
'SELECT', 'INSERT', 'UPDATE', 'DELETE', 'MERGE'
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# 添加元数据操作集合
|
||||||
|
self.metadata_operations = {
|
||||||
|
'SHOW', 'DESC', 'DESCRIBE', 'EXPLAIN', 'HELP',
|
||||||
|
'ANALYZE', 'CHECK', 'CHECKSUM', 'OPTIMIZE'
|
||||||
|
}
|
||||||
|
|
||||||
# 风险等级配置
|
# 风险等级配置
|
||||||
self.allowed_risk_levels = self._parse_risk_levels()
|
self.allowed_risk_levels = self._parse_risk_levels()
|
||||||
self.blocked_patterns = self._parse_blocked_patterns('BLOCKED_PATTERNS')
|
self.blocked_patterns = self._parse_blocked_patterns('BLOCKED_PATTERNS')
|
||||||
@ -132,11 +138,17 @@ class SQLOperationType:
|
|||||||
- UPDATE/DELETE(有WHERE)=> MEDIUM
|
- UPDATE/DELETE(有WHERE)=> MEDIUM
|
||||||
- UPDATE(无WHERE)=> HIGH
|
- UPDATE(无WHERE)=> HIGH
|
||||||
- DELETE(无WHERE)=> CRITICAL
|
- DELETE(无WHERE)=> CRITICAL
|
||||||
|
4. 元数据操作:
|
||||||
|
- SHOW/DESC/DESCRIBE等 => LOW
|
||||||
"""
|
"""
|
||||||
# 危险操作
|
# 危险操作
|
||||||
if is_dangerous:
|
if is_dangerous:
|
||||||
return SQLRiskLevel.CRITICAL
|
return SQLRiskLevel.CRITICAL
|
||||||
|
|
||||||
|
# 元数据操作
|
||||||
|
if operation in self.metadata_operations:
|
||||||
|
return SQLRiskLevel.LOW # 元数据查询视为低风险操作
|
||||||
|
|
||||||
# 生产环境中非SELECT操作
|
# 生产环境中非SELECT操作
|
||||||
if self.env_type == EnvironmentType.PRODUCTION and operation != 'SELECT':
|
if self.env_type == EnvironmentType.PRODUCTION and operation != 'SELECT':
|
||||||
return SQLRiskLevel.CRITICAL
|
return SQLRiskLevel.CRITICAL
|
||||||
|
|||||||
@ -8,6 +8,9 @@ load_dotenv()
|
|||||||
|
|
||||||
# 导入自定义模块 - 确保在load_dotenv之后导入
|
# 导入自定义模块 - 确保在load_dotenv之后导入
|
||||||
from src.tools.mysql_tool import register_mysql_tool
|
from src.tools.mysql_tool import register_mysql_tool
|
||||||
|
from src.tools.mysql_metadata_tool import register_metadata_tools
|
||||||
|
from src.tools.mysql_info_tool import register_info_tools
|
||||||
|
from src.tools.mysql_schema_tool import register_schema_tools
|
||||||
|
|
||||||
# 配置日志
|
# 配置日志
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
@ -19,6 +22,8 @@ logger = logging.getLogger("mysql_server")
|
|||||||
# 记录环境变量加载情况
|
# 记录环境变量加载情况
|
||||||
logger.debug("已加载环境变量")
|
logger.debug("已加载环境变量")
|
||||||
logger.debug(f"当前允许的风险等级: {os.getenv('ALLOWED_RISK_LEVELS', '未设置')}")
|
logger.debug(f"当前允许的风险等级: {os.getenv('ALLOWED_RISK_LEVELS', '未设置')}")
|
||||||
|
logger.debug(f"当前环境类型: {os.getenv('ENV_TYPE', 'development')}")
|
||||||
|
logger.debug(f"是否允许敏感信息查询: {os.getenv('ALLOW_SENSITIVE_INFO', 'false')}")
|
||||||
|
|
||||||
# 尝试导入MySQL连接器
|
# 尝试导入MySQL连接器
|
||||||
try:
|
try:
|
||||||
@ -40,9 +45,21 @@ logger.debug("正在创建MCP服务器实例...")
|
|||||||
mcp = FastMCP("MySQL Query Server", "cccccccccc", host=host, port=port, debug=True, endpoint='/sse')
|
mcp = FastMCP("MySQL Query Server", "cccccccccc", host=host, port=port, debug=True, endpoint='/sse')
|
||||||
logger.debug("MCP服务器实例创建完成")
|
logger.debug("MCP服务器实例创建完成")
|
||||||
|
|
||||||
# 注册MySQL工具
|
# 注册MySQL基础查询工具
|
||||||
register_mysql_tool(mcp)
|
register_mysql_tool(mcp)
|
||||||
|
|
||||||
|
# 注册MySQL元数据查询工具
|
||||||
|
register_metadata_tools(mcp)
|
||||||
|
logger.debug("已注册元数据查询工具")
|
||||||
|
|
||||||
|
# 注册MySQL数据库信息查询工具
|
||||||
|
register_info_tools(mcp)
|
||||||
|
logger.debug("已注册数据库信息查询工具")
|
||||||
|
|
||||||
|
# 注册MySQL表结构高级查询工具
|
||||||
|
register_schema_tools(mcp)
|
||||||
|
logger.debug("已注册表结构高级查询工具")
|
||||||
|
|
||||||
def start_server():
|
def start_server():
|
||||||
"""启动SSE服务器的同步包装器"""
|
"""启动SSE服务器的同步包装器"""
|
||||||
logger.debug("开始启动MySQL查询服务器...")
|
logger.debug("开始启动MySQL查询服务器...")
|
||||||
|
|||||||
123
src/tools/metadata_base_tool.py
Normal file
123
src/tools/metadata_base_tool.py
Normal file
@ -0,0 +1,123 @@
|
|||||||
|
"""
|
||||||
|
MySQL元数据工具基类
|
||||||
|
提供元数据查询工具的共享功能
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Any, Dict, List, Optional, Union, TypeVar, Generic, Callable
|
||||||
|
import functools
|
||||||
|
|
||||||
|
from src.db.mysql_operations import get_db_connection, execute_query
|
||||||
|
|
||||||
|
T = TypeVar('T')
|
||||||
|
logger = logging.getLogger("mysql_server")
|
||||||
|
|
||||||
|
class MySQLToolError(Exception):
|
||||||
|
"""MySQL工具异常基类"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
class ParameterValidationError(MySQLToolError):
|
||||||
|
"""参数验证错误"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
class QueryExecutionError(MySQLToolError):
|
||||||
|
"""查询执行错误"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
class MetadataToolBase:
|
||||||
|
"""
|
||||||
|
元数据查询工具基类
|
||||||
|
提供共享功能:
|
||||||
|
- 错误处理
|
||||||
|
- 结果格式化
|
||||||
|
- 参数验证
|
||||||
|
- 异常处理
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def validate_parameter(param_name: str, param_value: Any, validator: Callable[[Any], bool],
|
||||||
|
error_message: str) -> None:
|
||||||
|
"""
|
||||||
|
验证参数是否有效
|
||||||
|
|
||||||
|
Args:
|
||||||
|
param_name: 参数名称
|
||||||
|
param_value: 参数值
|
||||||
|
validator: 验证函数
|
||||||
|
error_message: 错误消息
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ParameterValidationError: 当参数验证失败时
|
||||||
|
"""
|
||||||
|
if param_value is not None and not validator(param_value):
|
||||||
|
raise ParameterValidationError(f"{param_name} - {error_message}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def format_results(results: List[Dict[str, Any]], operation_type: str = "元数据查询") -> str:
|
||||||
|
"""
|
||||||
|
格式化查询结果为JSON字符串
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results: 查询结果列表
|
||||||
|
operation_type: 操作类型描述
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
格式化后的JSON字符串
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 添加元数据头部信息
|
||||||
|
metadata_info = {
|
||||||
|
"metadata_info": {
|
||||||
|
"operation_type": operation_type,
|
||||||
|
"result_count": len(results)
|
||||||
|
},
|
||||||
|
"results": results
|
||||||
|
}
|
||||||
|
return json.dumps(metadata_info, default=str)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"结果格式化失败: {str(e)}")
|
||||||
|
# 如果格式化失败,尝试直接序列化结果
|
||||||
|
return json.dumps({"error": f"结果格式化失败: {str(e)}"})
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def handle_query_error(func):
|
||||||
|
"""
|
||||||
|
装饰器: 统一处理查询执行过程中的错误
|
||||||
|
"""
|
||||||
|
@functools.wraps(func)
|
||||||
|
async def wrapper(*args, **kwargs):
|
||||||
|
try:
|
||||||
|
return await func(*args, **kwargs)
|
||||||
|
except ParameterValidationError as e:
|
||||||
|
logger.error(f"参数验证错误: {str(e)}")
|
||||||
|
return json.dumps({"error": f"参数错误: {str(e)}"})
|
||||||
|
except QueryExecutionError as e:
|
||||||
|
logger.error(f"查询执行错误: {str(e)}")
|
||||||
|
return json.dumps({"error": f"查询执行失败: {str(e)}"})
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"未预期的错误: {str(e)}")
|
||||||
|
return json.dumps({"error": f"操作失败: {str(e)}"})
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def execute_metadata_query(query: str, params: Optional[Dict[str, Any]] = None,
|
||||||
|
operation_type: str = "元数据查询") -> str:
|
||||||
|
"""
|
||||||
|
执行元数据查询并返回格式化结果
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: SQL查询语句
|
||||||
|
params: 查询参数 (可选)
|
||||||
|
operation_type: 操作类型描述
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
查询结果的JSON字符串
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with get_db_connection() as connection:
|
||||||
|
results = await execute_query(connection, query, params)
|
||||||
|
return MetadataToolBase.format_results(results, operation_type)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"元数据查询执行失败: {str(e)}")
|
||||||
|
raise QueryExecutionError(str(e))
|
||||||
332
src/tools/mysql_info_tool.py
Normal file
332
src/tools/mysql_info_tool.py
Normal file
@ -0,0 +1,332 @@
|
|||||||
|
"""
|
||||||
|
MySQL数据库信息查询工具
|
||||||
|
提供数据库、变量和状态等系统信息查询功能
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
import os
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
from mcp.server.fastmcp import FastMCP
|
||||||
|
|
||||||
|
from .metadata_base_tool import MetadataToolBase, ParameterValidationError, QueryExecutionError
|
||||||
|
from src.security.sql_analyzer import EnvironmentType
|
||||||
|
from src.db.mysql_operations import get_db_connection, execute_query
|
||||||
|
|
||||||
|
logger = logging.getLogger("mysql_server")
|
||||||
|
|
||||||
|
# 自定义异常类
|
||||||
|
class SecurityError(QueryExecutionError):
|
||||||
|
"""安全限制错误"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 从环境变量读取敏感字段列表
|
||||||
|
def get_sensitive_patterns():
|
||||||
|
"""从环境变量获取敏感字段模式列表"""
|
||||||
|
default_patterns = [
|
||||||
|
r'password', r'auth', r'credential', r'key', r'secret', r'private',
|
||||||
|
r'host', r'path', r'directory', r'ssl', r'iptables', r'filter'
|
||||||
|
]
|
||||||
|
|
||||||
|
env_patterns = os.getenv('SENSITIVE_INFO_FIELDS', '')
|
||||||
|
if env_patterns:
|
||||||
|
# 合并自定义模式和默认模式
|
||||||
|
patterns = [pattern.strip() for pattern in env_patterns.split(',') if pattern.strip()]
|
||||||
|
return list(set(patterns + default_patterns))
|
||||||
|
|
||||||
|
return default_patterns
|
||||||
|
|
||||||
|
# 敏感变量和状态关键字列表
|
||||||
|
SENSITIVE_VARIABLE_PATTERNS = get_sensitive_patterns()
|
||||||
|
|
||||||
|
# 敏感变量名前缀,生产环境中这些变量的值会被隐藏
|
||||||
|
SENSITIVE_VARIABLE_PREFIXES = [
|
||||||
|
"password", "auth", "secret", "key", "certificate", "ssl", "tls", "cipher",
|
||||||
|
"authentication", "secure", "credential", "token"
|
||||||
|
]
|
||||||
|
|
||||||
|
def check_environment_permission(env_type: EnvironmentType, query_type: str) -> bool:
|
||||||
|
"""
|
||||||
|
检查当前环境是否允许执行特定类型的查询
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env_type: 环境类型(开发/生产)
|
||||||
|
query_type: 查询类型
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 是否允许执行
|
||||||
|
"""
|
||||||
|
# 开发环境不限制查询
|
||||||
|
if env_type == EnvironmentType.DEVELOPMENT:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# 生产环境限制敏感信息查询
|
||||||
|
sensitive_queries = ['variables', 'status', 'processlist']
|
||||||
|
|
||||||
|
if query_type in sensitive_queries:
|
||||||
|
# 检查是否在环境变量中明确允许
|
||||||
|
allow_sensitive = os.getenv('ALLOW_SENSITIVE_INFO', 'false').lower() == 'true'
|
||||||
|
if not allow_sensitive:
|
||||||
|
logger.warning(f"生产环境中禁止执行敏感查询: {query_type}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def filter_sensitive_info(results: List[Dict[str, Any]], filter_patterns: List[str] = None) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
过滤结果中的敏感信息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results: 查询结果
|
||||||
|
filter_patterns: 敏感信息的正则表达式模式列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
过滤后的结果列表
|
||||||
|
"""
|
||||||
|
if not filter_patterns:
|
||||||
|
filter_patterns = SENSITIVE_VARIABLE_PATTERNS
|
||||||
|
|
||||||
|
filtered_results = []
|
||||||
|
for item in results:
|
||||||
|
# 复制一份,避免修改原始数据
|
||||||
|
filtered_item = item.copy()
|
||||||
|
|
||||||
|
# 检查常见的变量名字段
|
||||||
|
for field in ['Variable_name', 'variable_name', 'name']:
|
||||||
|
if field in filtered_item:
|
||||||
|
var_name = filtered_item[field].lower()
|
||||||
|
# 检查是否匹配敏感模式
|
||||||
|
is_sensitive = any(re.search(pattern, var_name, re.IGNORECASE) for pattern in filter_patterns)
|
||||||
|
|
||||||
|
if is_sensitive:
|
||||||
|
# 敏感信息,隐藏具体的值
|
||||||
|
for value_field in ['Value', 'value', 'variable_value']:
|
||||||
|
if value_field in filtered_item:
|
||||||
|
filtered_item[value_field] = '*** HIDDEN ***'
|
||||||
|
|
||||||
|
filtered_results.append(filtered_item)
|
||||||
|
|
||||||
|
return filtered_results
|
||||||
|
|
||||||
|
def register_info_tools(mcp: FastMCP):
|
||||||
|
"""
|
||||||
|
注册MySQL数据库信息查询工具到MCP服务器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mcp: FastMCP服务器实例
|
||||||
|
"""
|
||||||
|
logger.debug("注册MySQL数据库信息查询工具...")
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
@MetadataToolBase.handle_query_error
|
||||||
|
async def mysql_show_databases(pattern: Optional[str] = None, limit: int = 100, exclude_system: bool = True) -> str:
|
||||||
|
"""
|
||||||
|
获取所有数据库列表,支持筛选和限制结果数量
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pattern: 数据库名称匹配模式 (可选, 例如 '%test%')
|
||||||
|
limit: 返回结果的最大数量 (默认100,设为0表示无限制)
|
||||||
|
exclude_system: 是否排除系统数据库 (默认为True)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
数据库列表的JSON字符串
|
||||||
|
"""
|
||||||
|
# 参数验证
|
||||||
|
if pattern:
|
||||||
|
MetadataToolBase.validate_parameter(
|
||||||
|
"pattern", pattern,
|
||||||
|
lambda x: re.match(r'^[a-zA-Z0-9_%]+$', x),
|
||||||
|
"模式只能包含字母、数字、下划线和通配符(%_)"
|
||||||
|
)
|
||||||
|
|
||||||
|
MetadataToolBase.validate_parameter(
|
||||||
|
"limit", limit,
|
||||||
|
lambda x: isinstance(x, int) and x >= 0,
|
||||||
|
"返回结果的最大数量必须是非负整数"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 构建基础查询
|
||||||
|
query = "SHOW DATABASES"
|
||||||
|
|
||||||
|
# 执行查询
|
||||||
|
with get_db_connection() as connection:
|
||||||
|
# 先获取所有数据库
|
||||||
|
results = await execute_query(connection, query)
|
||||||
|
|
||||||
|
# 通常结果中每个数据库名会在"Database"字段
|
||||||
|
db_field = next((k for k in results[0].keys() if k.lower() == 'database'), None) if results else None
|
||||||
|
|
||||||
|
if not db_field:
|
||||||
|
logger.warning("查询结果未找到数据库名称字段")
|
||||||
|
return MetadataToolBase.format_results(results, operation_type="数据库列表查询")
|
||||||
|
|
||||||
|
# 对结果进行过滤
|
||||||
|
filtered_results = []
|
||||||
|
system_dbs = ['information_schema', 'mysql', 'performance_schema', 'sys']
|
||||||
|
|
||||||
|
for item in results:
|
||||||
|
db_name = item[db_field]
|
||||||
|
|
||||||
|
# 排除系统数据库
|
||||||
|
if exclude_system and db_name.lower() in system_dbs:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 根据模式过滤
|
||||||
|
if pattern:
|
||||||
|
pattern_regex = pattern.replace('%', '.*').replace('_', '.')
|
||||||
|
if not re.search(pattern_regex, db_name, re.IGNORECASE):
|
||||||
|
continue
|
||||||
|
|
||||||
|
filtered_results.append(item)
|
||||||
|
|
||||||
|
# 限制返回数量
|
||||||
|
if limit > 0 and len(filtered_results) > limit:
|
||||||
|
filtered_results = filtered_results[:limit]
|
||||||
|
logger.debug(f"结果数量已限制为前{limit}个")
|
||||||
|
|
||||||
|
# 返回结果
|
||||||
|
metadata_info = {
|
||||||
|
"metadata_info": {
|
||||||
|
"operation_type": "数据库列表查询",
|
||||||
|
"result_count": len(filtered_results),
|
||||||
|
"total_count": len(results),
|
||||||
|
"filtered": {
|
||||||
|
"pattern": pattern,
|
||||||
|
"exclude_system": exclude_system,
|
||||||
|
"limited": len(filtered_results) < len(results)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"results": filtered_results
|
||||||
|
}
|
||||||
|
|
||||||
|
return json.dumps(metadata_info, default=str)
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
@MetadataToolBase.handle_query_error
|
||||||
|
async def mysql_show_variables(pattern: Optional[str] = None, global_scope: bool = False) -> str:
|
||||||
|
"""
|
||||||
|
获取MySQL系统变量
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pattern: 变量名称匹配模式 (可选, 例如 '%buffer%')
|
||||||
|
global_scope: 是否查询全局变量 (默认为会话变量)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
系统变量的JSON字符串
|
||||||
|
"""
|
||||||
|
# 获取当前环境类型
|
||||||
|
from src.security.sql_analyzer import sql_analyzer
|
||||||
|
env_type = sql_analyzer.env_type
|
||||||
|
|
||||||
|
# 检查环境权限
|
||||||
|
if not check_environment_permission(env_type, 'variables'):
|
||||||
|
raise SecurityError("当前环境不允许查询系统变量")
|
||||||
|
|
||||||
|
# 参数验证
|
||||||
|
if pattern:
|
||||||
|
MetadataToolBase.validate_parameter(
|
||||||
|
"pattern", pattern,
|
||||||
|
lambda x: re.match(r'^[a-zA-Z0-9_%]+$', x),
|
||||||
|
"变量模式只能包含字母、数字、下划线和通配符(%_)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 构建查询
|
||||||
|
scope = "GLOBAL" if global_scope else "SESSION"
|
||||||
|
query = f"SHOW {scope} VARIABLES"
|
||||||
|
if pattern:
|
||||||
|
query += f" LIKE '{pattern}'"
|
||||||
|
|
||||||
|
logger.debug(f"执行查询: {query}")
|
||||||
|
|
||||||
|
with get_db_connection() as connection:
|
||||||
|
results = await execute_query(connection, query)
|
||||||
|
|
||||||
|
# 生产环境中过滤敏感信息
|
||||||
|
if env_type == EnvironmentType.PRODUCTION:
|
||||||
|
results = filter_sensitive_info(results)
|
||||||
|
|
||||||
|
return MetadataToolBase.format_results(results, operation_type="系统变量查询")
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
@MetadataToolBase.handle_query_error
|
||||||
|
async def mysql_show_status(pattern: Optional[str] = None, global_scope: bool = False) -> str:
|
||||||
|
"""
|
||||||
|
获取MySQL服务器状态
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pattern: 状态名称匹配模式 (可选, 例如 '%conn%')
|
||||||
|
global_scope: 是否查询全局状态 (默认为会话状态)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
服务器状态的JSON字符串
|
||||||
|
"""
|
||||||
|
# 获取当前环境类型
|
||||||
|
from src.security.sql_analyzer import sql_analyzer
|
||||||
|
env_type = sql_analyzer.env_type
|
||||||
|
|
||||||
|
# 检查环境权限
|
||||||
|
if not check_environment_permission(env_type, 'status'):
|
||||||
|
raise SecurityError("当前环境不允许查询系统状态")
|
||||||
|
|
||||||
|
# 参数验证
|
||||||
|
if pattern:
|
||||||
|
MetadataToolBase.validate_parameter(
|
||||||
|
"pattern", pattern,
|
||||||
|
lambda x: re.match(r'^[a-zA-Z0-9_%]+$', x),
|
||||||
|
"状态模式只能包含字母、数字、下划线和通配符(%_)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 构建查询
|
||||||
|
scope = "GLOBAL" if global_scope else "SESSION"
|
||||||
|
query = f"SHOW {scope} STATUS"
|
||||||
|
if pattern:
|
||||||
|
query += f" LIKE '{pattern}'"
|
||||||
|
|
||||||
|
logger.debug(f"执行查询: {query}")
|
||||||
|
|
||||||
|
with get_db_connection() as connection:
|
||||||
|
results = await execute_query(connection, query)
|
||||||
|
|
||||||
|
# 生产环境中过滤敏感信息
|
||||||
|
if env_type == EnvironmentType.PRODUCTION:
|
||||||
|
results = filter_sensitive_info(results)
|
||||||
|
|
||||||
|
return MetadataToolBase.format_results(results, operation_type="服务器状态查询")
|
||||||
|
|
||||||
|
# 工具函数: 用于参数验证
|
||||||
|
def validate_pattern(pattern: str) -> bool:
|
||||||
|
"""
|
||||||
|
验证模式字符串是否安全 (防止SQL注入)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pattern: 要验证的模式字符串
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
如果模式安全返回True,否则抛出ValueError
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: 当模式包含不安全字符时
|
||||||
|
"""
|
||||||
|
# 仅允许字母、数字、下划线和通配符(% 和 _)
|
||||||
|
if not re.match(r'^[a-zA-Z0-9_%]+$', pattern):
|
||||||
|
raise ValueError("模式只能包含字母、数字、下划线和通配符(%_)")
|
||||||
|
return True
|
||||||
|
|
||||||
|
def validate_engine_name(name: str) -> bool:
|
||||||
|
"""
|
||||||
|
验证存储引擎名称是否合法安全
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: 要验证的引擎名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
如果引擎名称安全返回True,否则抛出ValueError
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: 当引擎名称包含不安全字符时
|
||||||
|
"""
|
||||||
|
# 仅允许字母、数字和下划线
|
||||||
|
if not re.match(r'^[a-zA-Z0-9_]+$', name):
|
||||||
|
raise ValueError(f"无效的引擎名称: {name}, 引擎名称只能包含字母、数字和下划线")
|
||||||
|
return True
|
||||||
272
src/tools/mysql_metadata_tool.py
Normal file
272
src/tools/mysql_metadata_tool.py
Normal file
@ -0,0 +1,272 @@
|
|||||||
|
"""
|
||||||
|
MySQL元数据查询工具
|
||||||
|
提供表结构等元数据信息查询功能
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
from mcp.server.fastmcp import FastMCP
|
||||||
|
|
||||||
|
from .metadata_base_tool import MetadataToolBase, ParameterValidationError, QueryExecutionError
|
||||||
|
from src.db.mysql_operations import get_db_connection, execute_query
|
||||||
|
|
||||||
|
logger = logging.getLogger("mysql_server")
|
||||||
|
|
||||||
|
# 工具函数: 用于参数验证
|
||||||
|
def validate_pattern(pattern: str) -> bool:
|
||||||
|
"""
|
||||||
|
验证模式字符串是否安全 (防止SQL注入)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pattern: 要验证的模式字符串
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
如果模式安全返回True,否则抛出ValueError
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: 当模式包含不安全字符时
|
||||||
|
"""
|
||||||
|
# 仅允许字母、数字、下划线和通配符(% 和 _)
|
||||||
|
if not re.match(r'^[a-zA-Z0-9_%]+$', pattern):
|
||||||
|
raise ValueError("模式只能包含字母、数字、下划线和通配符(%_)")
|
||||||
|
return True
|
||||||
|
|
||||||
|
def validate_table_name(name: str) -> bool:
|
||||||
|
"""
|
||||||
|
验证表名是否合法安全
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: 要验证的表名
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
如果表名安全返回True,否则抛出ValueError
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: 当表名包含不安全字符时
|
||||||
|
"""
|
||||||
|
# 仅允许字母、数字和下划线
|
||||||
|
if not re.match(r'^[a-zA-Z0-9_]+$', name):
|
||||||
|
raise ValueError(f"无效的表名: {name}, 表名只能包含字母、数字和下划线")
|
||||||
|
return True
|
||||||
|
|
||||||
|
def validate_database_name(name: str) -> bool:
|
||||||
|
"""
|
||||||
|
验证数据库名是否合法安全
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: 要验证的数据库名
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
如果数据库名安全返回True,否则抛出ValueError
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: 当数据库名包含不安全字符时
|
||||||
|
"""
|
||||||
|
# 仅允许字母、数字和下划线
|
||||||
|
if not re.match(r'^[a-zA-Z0-9_]+$', name):
|
||||||
|
raise ValueError(f"无效的数据库名: {name}, 数据库名只能包含字母、数字和下划线")
|
||||||
|
return True
|
||||||
|
|
||||||
|
def register_metadata_tools(mcp: FastMCP):
|
||||||
|
"""
|
||||||
|
注册MySQL元数据查询工具到MCP服务器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mcp: FastMCP服务器实例
|
||||||
|
"""
|
||||||
|
logger.debug("注册MySQL元数据查询工具...")
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
@MetadataToolBase.handle_query_error
|
||||||
|
async def mysql_show_tables(database: Optional[str] = None, pattern: Optional[str] = None,
|
||||||
|
limit: int = 100, exclude_views: bool = False) -> str:
|
||||||
|
"""
|
||||||
|
获取数据库中的表列表,支持筛选和限制结果数量
|
||||||
|
|
||||||
|
Args:
|
||||||
|
database: 数据库名称 (可选,默认使用当前连接的数据库)
|
||||||
|
pattern: 表名匹配模式 (可选, 例如 '%user%')
|
||||||
|
limit: 返回结果的最大数量 (默认100,设为0表示无限制)
|
||||||
|
exclude_views: 是否排除视图 (默认为False)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
表列表的JSON字符串
|
||||||
|
"""
|
||||||
|
# 参数验证
|
||||||
|
if database:
|
||||||
|
MetadataToolBase.validate_parameter(
|
||||||
|
"database", database,
|
||||||
|
lambda x: re.match(r'^[a-zA-Z0-9_]+$', x),
|
||||||
|
"数据库名称只能包含字母、数字和下划线"
|
||||||
|
)
|
||||||
|
|
||||||
|
if pattern:
|
||||||
|
MetadataToolBase.validate_parameter(
|
||||||
|
"pattern", pattern,
|
||||||
|
lambda x: re.match(r'^[a-zA-Z0-9_%]+$', x),
|
||||||
|
"模式只能包含字母、数字、下划线和通配符(%_)"
|
||||||
|
)
|
||||||
|
|
||||||
|
MetadataToolBase.validate_parameter(
|
||||||
|
"limit", limit,
|
||||||
|
lambda x: isinstance(x, int) and x >= 0,
|
||||||
|
"返回结果的最大数量必须是非负整数"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 基础查询
|
||||||
|
base_query = "SHOW FULL TABLES" if exclude_views else "SHOW TABLES"
|
||||||
|
if database:
|
||||||
|
base_query += f" FROM `{database}`"
|
||||||
|
if pattern:
|
||||||
|
base_query += f" LIKE '{pattern}'"
|
||||||
|
|
||||||
|
logger.debug(f"执行查询: {base_query}")
|
||||||
|
|
||||||
|
# 执行查询
|
||||||
|
with get_db_connection() as connection:
|
||||||
|
results = await execute_query(connection, base_query)
|
||||||
|
|
||||||
|
# 如果需要排除视图,且使用的是SHOW FULL TABLES
|
||||||
|
if exclude_views and "FULL" in base_query:
|
||||||
|
filtered_results = []
|
||||||
|
|
||||||
|
# 查找表名和表类型字段
|
||||||
|
fields = list(results[0].keys()) if results else []
|
||||||
|
table_field = fields[0] if fields else None
|
||||||
|
table_type_field = fields[1] if len(fields) > 1 else None
|
||||||
|
|
||||||
|
if table_field and table_type_field:
|
||||||
|
# 基表类型通常是"BASE TABLE"
|
||||||
|
for item in results:
|
||||||
|
if item[table_type_field] == 'BASE TABLE':
|
||||||
|
filtered_results.append(item)
|
||||||
|
else:
|
||||||
|
filtered_results = results
|
||||||
|
else:
|
||||||
|
filtered_results = results
|
||||||
|
|
||||||
|
# 限制返回数量
|
||||||
|
if limit > 0 and len(filtered_results) > limit:
|
||||||
|
limited_results = filtered_results[:limit]
|
||||||
|
is_limited = True
|
||||||
|
else:
|
||||||
|
limited_results = filtered_results
|
||||||
|
is_limited = False
|
||||||
|
|
||||||
|
# 构造元数据
|
||||||
|
metadata_info = {
|
||||||
|
"metadata_info": {
|
||||||
|
"operation_type": "表列表查询",
|
||||||
|
"result_count": len(limited_results),
|
||||||
|
"total_count": len(results),
|
||||||
|
"filtered": {
|
||||||
|
"database": database,
|
||||||
|
"pattern": pattern,
|
||||||
|
"exclude_views": exclude_views and "FULL" in base_query,
|
||||||
|
"limited": is_limited
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"results": limited_results
|
||||||
|
}
|
||||||
|
|
||||||
|
return json.dumps(metadata_info, default=str)
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
@MetadataToolBase.handle_query_error
|
||||||
|
async def mysql_show_columns(table: str, database: Optional[str] = None) -> str:
|
||||||
|
"""
|
||||||
|
获取表的列信息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
table: 表名
|
||||||
|
database: 数据库名称 (可选,默认使用当前连接的数据库)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
表列信息的JSON字符串
|
||||||
|
"""
|
||||||
|
# 参数验证
|
||||||
|
MetadataToolBase.validate_parameter(
|
||||||
|
"table", table,
|
||||||
|
lambda x: re.match(r'^[a-zA-Z0-9_]+$', x),
|
||||||
|
"表名只能包含字母、数字和下划线"
|
||||||
|
)
|
||||||
|
|
||||||
|
if database:
|
||||||
|
MetadataToolBase.validate_parameter(
|
||||||
|
"database", database,
|
||||||
|
lambda x: re.match(r'^[a-zA-Z0-9_]+$', x),
|
||||||
|
"数据库名称只能包含字母、数字和下划线"
|
||||||
|
)
|
||||||
|
|
||||||
|
query = f"SHOW COLUMNS FROM `{table}`" if not database else f"SHOW COLUMNS FROM `{database}`.`{table}`"
|
||||||
|
logger.debug(f"执行查询: {query}")
|
||||||
|
|
||||||
|
return await MetadataToolBase.execute_metadata_query(query, operation_type="表列信息查询")
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
@MetadataToolBase.handle_query_error
|
||||||
|
async def mysql_describe_table(table: str, database: Optional[str] = None) -> str:
|
||||||
|
"""
|
||||||
|
描述表结构
|
||||||
|
|
||||||
|
Args:
|
||||||
|
table: 表名
|
||||||
|
database: 数据库名称 (可选,默认使用当前连接的数据库)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
表结构描述的JSON字符串
|
||||||
|
"""
|
||||||
|
# 参数验证
|
||||||
|
MetadataToolBase.validate_parameter(
|
||||||
|
"table", table,
|
||||||
|
lambda x: re.match(r'^[a-zA-Z0-9_]+$', x),
|
||||||
|
"表名只能包含字母、数字和下划线"
|
||||||
|
)
|
||||||
|
|
||||||
|
if database:
|
||||||
|
MetadataToolBase.validate_parameter(
|
||||||
|
"database", database,
|
||||||
|
lambda x: re.match(r'^[a-zA-Z0-9_]+$', x),
|
||||||
|
"数据库名称只能包含字母、数字和下划线"
|
||||||
|
)
|
||||||
|
|
||||||
|
# DESCRIBE 语句与 SHOW COLUMNS 功能类似,但结果格式可能略有不同
|
||||||
|
query = f"DESCRIBE `{table}`" if not database else f"DESCRIBE `{database}`.`{table}`"
|
||||||
|
logger.debug(f"执行查询: {query}")
|
||||||
|
|
||||||
|
return await MetadataToolBase.execute_metadata_query(query, operation_type="表结构描述")
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
@MetadataToolBase.handle_query_error
|
||||||
|
async def mysql_show_create_table(table: str, database: Optional[str] = None) -> str:
|
||||||
|
"""
|
||||||
|
获取表的创建语句
|
||||||
|
|
||||||
|
Args:
|
||||||
|
table: 表名
|
||||||
|
database: 数据库名称 (可选,默认使用当前连接的数据库)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
表创建语句的JSON字符串
|
||||||
|
"""
|
||||||
|
# 参数验证
|
||||||
|
MetadataToolBase.validate_parameter(
|
||||||
|
"table", table,
|
||||||
|
lambda x: re.match(r'^[a-zA-Z0-9_]+$', x),
|
||||||
|
"表名只能包含字母、数字和下划线"
|
||||||
|
)
|
||||||
|
|
||||||
|
if database:
|
||||||
|
MetadataToolBase.validate_parameter(
|
||||||
|
"database", database,
|
||||||
|
lambda x: re.match(r'^[a-zA-Z0-9_]+$', x),
|
||||||
|
"数据库名称只能包含字母、数字和下划线"
|
||||||
|
)
|
||||||
|
|
||||||
|
table_ref = f"`{table}`" if not database else f"`{database}`.`{table}`"
|
||||||
|
query = f"SHOW CREATE TABLE {table_ref}"
|
||||||
|
logger.debug(f"执行查询: {query}")
|
||||||
|
|
||||||
|
return await MetadataToolBase.execute_metadata_query(query, operation_type="表创建语句查询")
|
||||||
294
src/tools/mysql_schema_tool.py
Normal file
294
src/tools/mysql_schema_tool.py
Normal file
@ -0,0 +1,294 @@
|
|||||||
|
"""
|
||||||
|
MySQL表结构高级查询工具
|
||||||
|
提供索引、约束、表状态等高级元数据查询功能
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
import os
|
||||||
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
from mcp.server.fastmcp import FastMCP
|
||||||
|
|
||||||
|
from .metadata_base_tool import MetadataToolBase, ParameterValidationError, QueryExecutionError
|
||||||
|
from src.db.mysql_operations import get_db_connection, execute_query
|
||||||
|
|
||||||
|
logger = logging.getLogger("mysql_server")
|
||||||
|
|
||||||
|
# 参数验证函数
|
||||||
|
def validate_table_name(name: str) -> bool:
|
||||||
|
"""
|
||||||
|
验证表名是否合法安全
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: 要验证的表名
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
如果表名安全返回True,否则抛出ValueError
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: 当表名包含不安全字符时
|
||||||
|
"""
|
||||||
|
# 仅允许字母、数字和下划线
|
||||||
|
if not re.match(r'^[a-zA-Z0-9_]+$', name):
|
||||||
|
raise ValueError(f"无效的表名: {name}, 表名只能包含字母、数字和下划线")
|
||||||
|
return True
|
||||||
|
|
||||||
|
def validate_database_name(name: str) -> bool:
|
||||||
|
"""
|
||||||
|
验证数据库名是否合法安全
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: 要验证的数据库名
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
如果数据库名安全返回True,否则抛出ValueError
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: 当数据库名包含不安全字符时
|
||||||
|
"""
|
||||||
|
# 仅允许字母、数字和下划线
|
||||||
|
if not re.match(r'^[a-zA-Z0-9_]+$', name):
|
||||||
|
raise ValueError(f"无效的数据库名: {name}, 数据库名只能包含字母、数字和下划线")
|
||||||
|
return True
|
||||||
|
|
||||||
|
def validate_column_name(name: str) -> bool:
|
||||||
|
"""
|
||||||
|
验证列名是否合法安全
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: 要验证的列名
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
如果列名安全返回True,否则抛出ValueError
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: 当列名包含不安全字符时
|
||||||
|
"""
|
||||||
|
# 仅允许字母、数字和下划线
|
||||||
|
if not re.match(r'^[a-zA-Z0-9_]+$', name):
|
||||||
|
raise ValueError(f"无效的列名: {name}, 列名只能包含字母、数字和下划线")
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def execute_schema_query(
|
||||||
|
query: str,
|
||||||
|
params: Optional[Dict[str, Any]] = None,
|
||||||
|
operation_type: str = "元数据查询"
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
执行表结构查询
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: SQL查询语句
|
||||||
|
params: 查询参数 (可选)
|
||||||
|
operation_type: 操作类型描述
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
查询结果的JSON字符串
|
||||||
|
"""
|
||||||
|
with get_db_connection() as connection:
|
||||||
|
results = await execute_query(connection, query, params)
|
||||||
|
return MetadataToolBase.format_results(results, operation_type)
|
||||||
|
|
||||||
|
def register_schema_tools(mcp: FastMCP):
|
||||||
|
"""
|
||||||
|
注册MySQL表结构高级查询工具到MCP服务器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mcp: FastMCP服务器实例
|
||||||
|
"""
|
||||||
|
logger.debug("注册MySQL表结构高级查询工具...")
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
@MetadataToolBase.handle_query_error
|
||||||
|
async def mysql_show_indexes(table: str, database: Optional[str] = None) -> str:
|
||||||
|
"""
|
||||||
|
获取表的索引信息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
table: 表名
|
||||||
|
database: 数据库名称 (可选,默认使用当前连接的数据库)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
表索引信息的JSON字符串
|
||||||
|
"""
|
||||||
|
# 参数验证
|
||||||
|
validate_table_name(table)
|
||||||
|
|
||||||
|
if database:
|
||||||
|
validate_database_name(database)
|
||||||
|
|
||||||
|
# 构建查询
|
||||||
|
table_ref = f"`{table}`" if not database else f"`{database}`.`{table}`"
|
||||||
|
query = f"SHOW INDEX FROM {table_ref}"
|
||||||
|
logger.debug(f"执行查询: {query}")
|
||||||
|
|
||||||
|
# 执行查询
|
||||||
|
return await execute_schema_query(query, operation_type="表索引查询")
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
@MetadataToolBase.handle_query_error
|
||||||
|
async def mysql_show_table_status(database: Optional[str] = None, like_pattern: Optional[str] = None) -> str:
|
||||||
|
"""
|
||||||
|
获取表状态信息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
database: 数据库名称 (可选,默认使用当前连接的数据库)
|
||||||
|
like_pattern: 表名匹配模式 (可选,例如 '%user%')
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
表状态信息的JSON字符串
|
||||||
|
"""
|
||||||
|
# 参数验证
|
||||||
|
if database:
|
||||||
|
validate_database_name(database)
|
||||||
|
|
||||||
|
if like_pattern:
|
||||||
|
validate_column_name(like_pattern)
|
||||||
|
|
||||||
|
# 构建查询
|
||||||
|
if database:
|
||||||
|
query = f"SHOW TABLE STATUS FROM `{database}`"
|
||||||
|
else:
|
||||||
|
query = "SHOW TABLE STATUS"
|
||||||
|
|
||||||
|
if like_pattern:
|
||||||
|
query += f" LIKE '{like_pattern}'"
|
||||||
|
|
||||||
|
logger.debug(f"执行查询: {query}")
|
||||||
|
|
||||||
|
# 执行查询
|
||||||
|
return await execute_schema_query(query, operation_type="表状态查询")
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
@MetadataToolBase.handle_query_error
|
||||||
|
async def mysql_show_foreign_keys(table: str, database: Optional[str] = None) -> str:
|
||||||
|
"""
|
||||||
|
获取表的外键约束信息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
table: 表名
|
||||||
|
database: 数据库名称 (可选,默认使用当前连接的数据库)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
表外键约束信息的JSON字符串
|
||||||
|
"""
|
||||||
|
# 参数验证
|
||||||
|
validate_table_name(table)
|
||||||
|
|
||||||
|
if database:
|
||||||
|
validate_database_name(database)
|
||||||
|
|
||||||
|
# 确定数据库名
|
||||||
|
db_name = database
|
||||||
|
if not db_name:
|
||||||
|
# 获取当前数据库
|
||||||
|
with get_db_connection() as connection:
|
||||||
|
current_db_results = await execute_query(connection, "SELECT DATABASE() as db")
|
||||||
|
if current_db_results and 'db' in current_db_results[0]:
|
||||||
|
db_name = current_db_results[0]['db']
|
||||||
|
|
||||||
|
if not db_name:
|
||||||
|
raise ValueError("无法确定数据库名称,请明确指定database参数")
|
||||||
|
|
||||||
|
# 使用INFORMATION_SCHEMA查询外键
|
||||||
|
query = """
|
||||||
|
SELECT
|
||||||
|
CONSTRAINT_NAME,
|
||||||
|
TABLE_NAME,
|
||||||
|
COLUMN_NAME,
|
||||||
|
REFERENCED_TABLE_NAME,
|
||||||
|
REFERENCED_COLUMN_NAME,
|
||||||
|
UPDATE_RULE,
|
||||||
|
DELETE_RULE
|
||||||
|
FROM
|
||||||
|
INFORMATION_SCHEMA.KEY_COLUMN_USAGE kcu
|
||||||
|
JOIN
|
||||||
|
INFORMATION_SCHEMA.REFERENTIAL_CONSTRAINTS rc
|
||||||
|
ON
|
||||||
|
kcu.CONSTRAINT_NAME = rc.CONSTRAINT_NAME
|
||||||
|
WHERE
|
||||||
|
kcu.TABLE_SCHEMA = %s
|
||||||
|
AND kcu.TABLE_NAME = %s
|
||||||
|
AND kcu.REFERENCED_TABLE_NAME IS NOT NULL
|
||||||
|
"""
|
||||||
|
params = {'TABLE_SCHEMA': db_name, 'TABLE_NAME': table}
|
||||||
|
|
||||||
|
logger.debug(f"执行查询: 获取表 {db_name}.{table} 的外键约束")
|
||||||
|
|
||||||
|
# 执行查询
|
||||||
|
return await execute_schema_query(query, params, operation_type="外键约束查询")
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
@MetadataToolBase.handle_query_error
|
||||||
|
async def mysql_paginate_results(query: str, page: int = 1, page_size: int = 50) -> str:
|
||||||
|
"""
|
||||||
|
分页执行查询以处理大型结果集
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: SQL查询语句
|
||||||
|
page: 页码 (从1开始)
|
||||||
|
page_size: 每页记录数 (默认50)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
分页结果的JSON字符串
|
||||||
|
"""
|
||||||
|
# 参数验证
|
||||||
|
MetadataToolBase.validate_parameter(
|
||||||
|
"page", page,
|
||||||
|
lambda x: isinstance(x, int) and x > 0,
|
||||||
|
"页码必须是正整数"
|
||||||
|
)
|
||||||
|
|
||||||
|
MetadataToolBase.validate_parameter(
|
||||||
|
"page_size", page_size,
|
||||||
|
lambda x: isinstance(x, int) and 1 <= x <= 1000,
|
||||||
|
"每页记录数必须在1-1000之间"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 检查查询语法
|
||||||
|
if not query.strip().upper().startswith('SELECT'):
|
||||||
|
raise ValueError("只支持SELECT查询的分页")
|
||||||
|
|
||||||
|
# 计算LIMIT和OFFSET
|
||||||
|
offset = (page - 1) * page_size
|
||||||
|
|
||||||
|
# 在查询末尾添加LIMIT子句
|
||||||
|
paginated_query = query.strip()
|
||||||
|
if 'LIMIT' in paginated_query.upper():
|
||||||
|
raise ValueError("查询已包含LIMIT子句,请移除后重试")
|
||||||
|
|
||||||
|
paginated_query += f" LIMIT {page_size} OFFSET {offset}"
|
||||||
|
|
||||||
|
logger.debug(f"执行分页查询: 页码={page}, 每页记录数={page_size}")
|
||||||
|
|
||||||
|
# 获取总记录数(用于计算总页数)
|
||||||
|
count_query = f"SELECT COUNT(*) as total FROM ({query}) as temp_count_table"
|
||||||
|
|
||||||
|
with get_db_connection() as connection:
|
||||||
|
# 执行分页查询
|
||||||
|
results = await execute_query(connection, paginated_query)
|
||||||
|
|
||||||
|
# 获取总记录数
|
||||||
|
count_results = await execute_query(connection, count_query)
|
||||||
|
total_records = count_results[0]['total'] if count_results else 0
|
||||||
|
|
||||||
|
# 计算总页数
|
||||||
|
total_pages = (total_records + page_size - 1) // page_size
|
||||||
|
|
||||||
|
# 构建分页元数据
|
||||||
|
pagination_info = {
|
||||||
|
"metadata_info": {
|
||||||
|
"operation_type": "分页查询",
|
||||||
|
"result_count": len(results),
|
||||||
|
"pagination": {
|
||||||
|
"page": page,
|
||||||
|
"page_size": page_size,
|
||||||
|
"total_records": total_records,
|
||||||
|
"total_pages": total_pages
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"results": results
|
||||||
|
}
|
||||||
|
|
||||||
|
return json.dumps(pagination_info, default=str)
|
||||||
Reference in New Issue
Block a user