mirror of
https://github.com/mangooer/mysql-mcp-server-sse.git
synced 2025-12-08 09:42:27 +08:00
<feat> 增加元数据操作功能
This commit is contained in:
101
README.md
101
README.md
@ -16,11 +16,45 @@
|
||||
- 危险操作拦截
|
||||
- WHERE子句强制检查
|
||||
- 自动返回修改操作影响的行数
|
||||
- 敏感信息保护机制
|
||||
- 自动对元数据查询结果进行格式化和增强
|
||||
|
||||
## API接口功能
|
||||
|
||||
系统提供以下四大类工具:
|
||||
|
||||
### 基础查询工具
|
||||
|
||||
- `mysql_query`: 执行任意SQL查询,支持参数化查询
|
||||
|
||||
### 元数据查询工具
|
||||
|
||||
- `mysql_show_tables`: 获取数据库中的表列表,支持模式匹配和限制结果数量
|
||||
- `mysql_show_columns`: 获取表的列信息
|
||||
- `mysql_describe_table`: 描述表结构
|
||||
- `mysql_show_create_table`: 获取表的创建语句
|
||||
|
||||
### 数据库信息查询工具
|
||||
|
||||
- `mysql_show_databases`: 获取所有数据库列表,支持过滤系统数据库
|
||||
- `mysql_show_variables`: 获取MySQL服务器变量
|
||||
- `mysql_show_status`: 获取MySQL服务器状态信息
|
||||
|
||||
### 表结构高级查询工具
|
||||
|
||||
- `mysql_show_indexes`: 获取表的索引信息
|
||||
- `mysql_show_table_status`: 获取表状态信息
|
||||
- `mysql_show_foreign_keys`: 获取表的外键约束信息
|
||||
- `mysql_paginate_results`: 提供结果分页功能
|
||||
|
||||
## 系统要求
|
||||
|
||||
- Python 3.6+
|
||||
- MySQL服务器
|
||||
- 依赖包:
|
||||
- mysql-connector-python
|
||||
- python-dotenv
|
||||
- mcp (FastMCP框架)
|
||||
|
||||
## 安装步骤
|
||||
|
||||
@ -57,34 +91,33 @@ pip install -r requirements.txt
|
||||
- `ALLOWED_RISK_LEVELS`: 允许的风险等级(LOW/MEDIUM/HIGH/CRITICAL)
|
||||
- `BLOCKED_PATTERNS`: 禁止的SQL模式(正则表达式,用逗号分隔)
|
||||
- `ENABLE_QUERY_CHECK`: 是否启用SQL安全检查(true/false)
|
||||
- `ALLOW_SENSITIVE_INFO`: 是否允许查询敏感信息(true/false)
|
||||
- `SENSITIVE_INFO_FIELDS`: 自定义敏感字段模式列表(逗号分隔)
|
||||
|
||||
## SQL安全机制
|
||||
## 安全机制详解
|
||||
|
||||
### 风险等级控制
|
||||
- LOW: 查询操作(SELECT)
|
||||
- LOW: 查询操作(SELECT)和元数据操作(SHOW, DESCRIBE等)
|
||||
- MEDIUM: 基本数据修改(INSERT,有WHERE的UPDATE/DELETE)
|
||||
- HIGH: 结构变更(CREATE/ALTER)和无WHERE的UPDATE
|
||||
- CRITICAL: 危险操作(DROP/TRUNCATE)和无WHERE的DELETE操作
|
||||
|
||||
### 环境变量加载顺序
|
||||
项目使用python-dotenv加载环境变量,需要确保在导入其他模块前先加载环境变量,否则可能会导致配置未被正确应用。
|
||||
### 环境特性差异
|
||||
- 开发环境:
|
||||
- 允许较高风险的操作
|
||||
- 不隐藏敏感信息
|
||||
- 提供详细的错误信息
|
||||
- 生产环境:
|
||||
- 默认只允许LOW风险操作
|
||||
- 严格限制数据修改
|
||||
- 自动隐藏敏感信息
|
||||
- 错误信息不暴露实现细节
|
||||
|
||||
### 安全检查机制
|
||||
- 强制要求UPDATE/DELETE操作包含WHERE子句
|
||||
- SQL语句语法检查
|
||||
- 危险操作模式检测
|
||||
- 自动识别受影响的表
|
||||
- 风险等级评估
|
||||
|
||||
### 环境特性
|
||||
- 开发环境:允许较高风险的操作,但仍需遵守基本安全规则
|
||||
- 生产环境:默认只允许LOW风险操作,严格限制数据修改
|
||||
|
||||
### 安全拦截
|
||||
- SQL语句长度限制
|
||||
- 危险操作模式检测
|
||||
- 自动识别受影响的表
|
||||
- SQL注入防护
|
||||
### 敏感信息保护
|
||||
系统会自动检测并隐藏包含以下关键词的变量/状态值:
|
||||
- password、auth、credential、key、secret、private
|
||||
- ssl、tls、cipher、certificate
|
||||
- host、path、directory等系统路径信息
|
||||
|
||||
### 事务管理
|
||||
- 对于修改操作(INSERT/UPDATE/DELETE)会自动提交事务
|
||||
@ -105,17 +138,23 @@ python src/server.py
|
||||
|
||||
```
|
||||
.
|
||||
├── src/ # 源代码目录
|
||||
│ ├── server.py # 主服务器文件
|
||||
│ ├── db/ # 数据库相关代码
|
||||
│ ├── security/ # SQL安全相关代码
|
||||
│ │ ├── interceptor.py # SQL拦截器
|
||||
├── src/ # 源代码目录
|
||||
│ ├── server.py # 主服务器文件
|
||||
│ ├── db/ # 数据库相关代码
|
||||
│ │ └── mysql_operations.py # MySQL操作实现
|
||||
│ ├── security/ # SQL安全相关代码
|
||||
│ │ ├── interceptor.py # SQL拦截器
|
||||
│ │ ├── query_limiter.py # SQL安全检查器
|
||||
│ │ └── sql_analyzer.py # SQL分析器
|
||||
│ └── tools/ # 工具类代码
|
||||
├── tests/ # 测试代码目录
|
||||
├── .env.example # 环境变量示例文件
|
||||
└── requirements.txt # 项目依赖文件
|
||||
│ └── tools/ # 工具类代码
|
||||
│ ├── mysql_tool.py # 基础查询工具
|
||||
│ ├── mysql_metadata_tool.py # 元数据查询工具
|
||||
│ ├── mysql_info_tool.py # 数据库信息查询工具
|
||||
│ ├── mysql_schema_tool.py # 表结构高级查询工具
|
||||
│ └── metadata_base_tool.py # 元数据工具基类
|
||||
├── tests/ # 测试代码目录
|
||||
├── .env.example # 环境变量示例文件
|
||||
└── requirements.txt # 项目依赖文件
|
||||
```
|
||||
|
||||
## 常见问题解决
|
||||
@ -136,6 +175,10 @@ python src/server.py
|
||||
- 如果需要执行高风险操作,相应地调整ALLOWED_RISK_LEVELS
|
||||
- 对于不带WHERE条件的UPDATE或DELETE,可以添加条件(即使是WHERE 1=1)降低风险级别
|
||||
|
||||
### 无法查看敏感信息
|
||||
- 在开发环境中,设置ALLOW_SENSITIVE_INFO=true
|
||||
- 在生产环境中,敏感信息默认会被隐藏,这是安全特性
|
||||
|
||||
## 日志系统
|
||||
|
||||
服务器包含完整的日志记录系统,可以在控制台和日志文件中查看运行状态和错误信息。日志级别可以在`server.py`中配置。
|
||||
|
||||
@ -81,6 +81,7 @@ async def execute_query(connection, query: str, params: Optional[Dict[str, Any]]
|
||||
ValueError: 当查询执行失败时
|
||||
"""
|
||||
cursor = None
|
||||
operation = None # 初始化操作类型变量
|
||||
try:
|
||||
# 安全检查
|
||||
if not await sql_interceptor.check_operation(query):
|
||||
@ -105,6 +106,33 @@ async def execute_query(connection, query: str, params: Optional[Dict[str, Any]]
|
||||
logger.debug(f"修改操作 {operation} 影响了 {affected_rows} 行数据")
|
||||
return [{'affected_rows': affected_rows}]
|
||||
|
||||
# 处理元数据查询操作
|
||||
if operation in sql_analyzer.metadata_operations:
|
||||
# 获取结果集
|
||||
results = cursor.fetchall()
|
||||
|
||||
# 没有结果时返回空列表但添加元信息
|
||||
if not results:
|
||||
logger.debug(f"元数据查询 {operation} 没有返回结果")
|
||||
return [{'metadata_operation': operation, 'result_count': 0}]
|
||||
|
||||
# 优化结果格式 - 为元数据结果添加额外信息
|
||||
metadata_results = []
|
||||
for row in results:
|
||||
# 对某些特定元数据查询进行特殊处理
|
||||
if operation == 'SHOW' and 'Table' in row:
|
||||
# SHOW TABLES 结果增强
|
||||
row['table_name'] = row['Table']
|
||||
elif operation in {'DESC', 'DESCRIBE'} and 'Field' in row:
|
||||
# DESC/DESCRIBE 表结构结果增强
|
||||
row['column_name'] = row['Field']
|
||||
row['data_type'] = row['Type']
|
||||
|
||||
metadata_results.append(row)
|
||||
|
||||
logger.debug(f"元数据查询 {operation} 返回 {len(metadata_results)} 条结果")
|
||||
return metadata_results
|
||||
|
||||
# 对于查询操作,返回结果集
|
||||
results = cursor.fetchall()
|
||||
logger.debug(f"查询返回 {len(results)} 条结果")
|
||||
@ -115,7 +143,7 @@ async def execute_query(connection, query: str, params: Optional[Dict[str, Any]]
|
||||
raise
|
||||
except mysql.connector.Error as query_err:
|
||||
# 如果发生错误,进行回滚
|
||||
if operation in {'UPDATE', 'DELETE', 'INSERT'}:
|
||||
if operation and operation in {'UPDATE', 'DELETE', 'INSERT'}: # 确保operation已定义
|
||||
try:
|
||||
connection.rollback()
|
||||
logger.debug("事务已回滚")
|
||||
|
||||
@ -46,7 +46,15 @@ class SQLInterceptor:
|
||||
raise SecurityException("SQL语句格式无效")
|
||||
|
||||
operation = sql_parts[0].upper()
|
||||
if operation not in {'SELECT', 'INSERT', 'UPDATE', 'DELETE', 'CREATE', 'ALTER', 'DROP', 'TRUNCATE', 'MERGE'}:
|
||||
# 更新支持的操作类型列表,包括元数据操作
|
||||
supported_operations = {
|
||||
'SELECT', 'INSERT', 'UPDATE', 'DELETE',
|
||||
'CREATE', 'ALTER', 'DROP', 'TRUNCATE', 'MERGE',
|
||||
'SHOW', 'DESC', 'DESCRIBE', 'EXPLAIN', 'HELP',
|
||||
'ANALYZE', 'CHECK', 'CHECKSUM', 'OPTIMIZE'
|
||||
}
|
||||
|
||||
if operation not in supported_operations:
|
||||
raise SecurityException(f"不支持的SQL操作: {operation}")
|
||||
|
||||
# 分析SQL风险
|
||||
@ -65,9 +73,14 @@ class SQLInterceptor:
|
||||
f"允许的风险等级: {[level.name for level in self.analyzer.allowed_risk_levels]}"
|
||||
)
|
||||
|
||||
# 确定操作类型(DDL, DML 或 元数据)
|
||||
operation_category = "元数据操作" if operation in self.analyzer.metadata_operations else (
|
||||
"DDL操作" if operation in self.analyzer.ddl_operations else "DML操作"
|
||||
)
|
||||
|
||||
# 记录详细日志
|
||||
logger.info(
|
||||
f"SQL操作检查通过 - "
|
||||
f"SQL{operation_category}检查通过 - "
|
||||
f"操作: {risk_analysis['operation']}, "
|
||||
f"风险等级: {risk_analysis['risk_level'].name}, "
|
||||
f"影响表: {', '.join(risk_analysis['affected_tables'])}"
|
||||
|
||||
@ -38,6 +38,12 @@ class SQLOperationType:
|
||||
'SELECT', 'INSERT', 'UPDATE', 'DELETE', 'MERGE'
|
||||
}
|
||||
|
||||
# 添加元数据操作集合
|
||||
self.metadata_operations = {
|
||||
'SHOW', 'DESC', 'DESCRIBE', 'EXPLAIN', 'HELP',
|
||||
'ANALYZE', 'CHECK', 'CHECKSUM', 'OPTIMIZE'
|
||||
}
|
||||
|
||||
# 风险等级配置
|
||||
self.allowed_risk_levels = self._parse_risk_levels()
|
||||
self.blocked_patterns = self._parse_blocked_patterns('BLOCKED_PATTERNS')
|
||||
@ -132,11 +138,17 @@ class SQLOperationType:
|
||||
- UPDATE/DELETE(有WHERE)=> MEDIUM
|
||||
- UPDATE(无WHERE)=> HIGH
|
||||
- DELETE(无WHERE)=> CRITICAL
|
||||
4. 元数据操作:
|
||||
- SHOW/DESC/DESCRIBE等 => LOW
|
||||
"""
|
||||
# 危险操作
|
||||
if is_dangerous:
|
||||
return SQLRiskLevel.CRITICAL
|
||||
|
||||
# 元数据操作
|
||||
if operation in self.metadata_operations:
|
||||
return SQLRiskLevel.LOW # 元数据查询视为低风险操作
|
||||
|
||||
# 生产环境中非SELECT操作
|
||||
if self.env_type == EnvironmentType.PRODUCTION and operation != 'SELECT':
|
||||
return SQLRiskLevel.CRITICAL
|
||||
|
||||
@ -8,6 +8,9 @@ load_dotenv()
|
||||
|
||||
# 导入自定义模块 - 确保在load_dotenv之后导入
|
||||
from src.tools.mysql_tool import register_mysql_tool
|
||||
from src.tools.mysql_metadata_tool import register_metadata_tools
|
||||
from src.tools.mysql_info_tool import register_info_tools
|
||||
from src.tools.mysql_schema_tool import register_schema_tools
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(
|
||||
@ -19,6 +22,8 @@ logger = logging.getLogger("mysql_server")
|
||||
# 记录环境变量加载情况
|
||||
logger.debug("已加载环境变量")
|
||||
logger.debug(f"当前允许的风险等级: {os.getenv('ALLOWED_RISK_LEVELS', '未设置')}")
|
||||
logger.debug(f"当前环境类型: {os.getenv('ENV_TYPE', 'development')}")
|
||||
logger.debug(f"是否允许敏感信息查询: {os.getenv('ALLOW_SENSITIVE_INFO', 'false')}")
|
||||
|
||||
# 尝试导入MySQL连接器
|
||||
try:
|
||||
@ -40,9 +45,21 @@ logger.debug("正在创建MCP服务器实例...")
|
||||
mcp = FastMCP("MySQL Query Server", "cccccccccc", host=host, port=port, debug=True, endpoint='/sse')
|
||||
logger.debug("MCP服务器实例创建完成")
|
||||
|
||||
# 注册MySQL工具
|
||||
# 注册MySQL基础查询工具
|
||||
register_mysql_tool(mcp)
|
||||
|
||||
# 注册MySQL元数据查询工具
|
||||
register_metadata_tools(mcp)
|
||||
logger.debug("已注册元数据查询工具")
|
||||
|
||||
# 注册MySQL数据库信息查询工具
|
||||
register_info_tools(mcp)
|
||||
logger.debug("已注册数据库信息查询工具")
|
||||
|
||||
# 注册MySQL表结构高级查询工具
|
||||
register_schema_tools(mcp)
|
||||
logger.debug("已注册表结构高级查询工具")
|
||||
|
||||
def start_server():
|
||||
"""启动SSE服务器的同步包装器"""
|
||||
logger.debug("开始启动MySQL查询服务器...")
|
||||
|
||||
123
src/tools/metadata_base_tool.py
Normal file
123
src/tools/metadata_base_tool.py
Normal file
@ -0,0 +1,123 @@
|
||||
"""
|
||||
MySQL元数据工具基类
|
||||
提供元数据查询工具的共享功能
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Union, TypeVar, Generic, Callable
|
||||
import functools
|
||||
|
||||
from src.db.mysql_operations import get_db_connection, execute_query
|
||||
|
||||
T = TypeVar('T')
|
||||
logger = logging.getLogger("mysql_server")
|
||||
|
||||
class MySQLToolError(Exception):
|
||||
"""MySQL工具异常基类"""
|
||||
pass
|
||||
|
||||
class ParameterValidationError(MySQLToolError):
|
||||
"""参数验证错误"""
|
||||
pass
|
||||
|
||||
class QueryExecutionError(MySQLToolError):
|
||||
"""查询执行错误"""
|
||||
pass
|
||||
|
||||
class MetadataToolBase:
|
||||
"""
|
||||
元数据查询工具基类
|
||||
提供共享功能:
|
||||
- 错误处理
|
||||
- 结果格式化
|
||||
- 参数验证
|
||||
- 异常处理
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def validate_parameter(param_name: str, param_value: Any, validator: Callable[[Any], bool],
|
||||
error_message: str) -> None:
|
||||
"""
|
||||
验证参数是否有效
|
||||
|
||||
Args:
|
||||
param_name: 参数名称
|
||||
param_value: 参数值
|
||||
validator: 验证函数
|
||||
error_message: 错误消息
|
||||
|
||||
Raises:
|
||||
ParameterValidationError: 当参数验证失败时
|
||||
"""
|
||||
if param_value is not None and not validator(param_value):
|
||||
raise ParameterValidationError(f"{param_name} - {error_message}")
|
||||
|
||||
@staticmethod
|
||||
def format_results(results: List[Dict[str, Any]], operation_type: str = "元数据查询") -> str:
|
||||
"""
|
||||
格式化查询结果为JSON字符串
|
||||
|
||||
Args:
|
||||
results: 查询结果列表
|
||||
operation_type: 操作类型描述
|
||||
|
||||
Returns:
|
||||
格式化后的JSON字符串
|
||||
"""
|
||||
try:
|
||||
# 添加元数据头部信息
|
||||
metadata_info = {
|
||||
"metadata_info": {
|
||||
"operation_type": operation_type,
|
||||
"result_count": len(results)
|
||||
},
|
||||
"results": results
|
||||
}
|
||||
return json.dumps(metadata_info, default=str)
|
||||
except Exception as e:
|
||||
logger.error(f"结果格式化失败: {str(e)}")
|
||||
# 如果格式化失败,尝试直接序列化结果
|
||||
return json.dumps({"error": f"结果格式化失败: {str(e)}"})
|
||||
|
||||
@staticmethod
|
||||
def handle_query_error(func):
|
||||
"""
|
||||
装饰器: 统一处理查询执行过程中的错误
|
||||
"""
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except ParameterValidationError as e:
|
||||
logger.error(f"参数验证错误: {str(e)}")
|
||||
return json.dumps({"error": f"参数错误: {str(e)}"})
|
||||
except QueryExecutionError as e:
|
||||
logger.error(f"查询执行错误: {str(e)}")
|
||||
return json.dumps({"error": f"查询执行失败: {str(e)}"})
|
||||
except Exception as e:
|
||||
logger.error(f"未预期的错误: {str(e)}")
|
||||
return json.dumps({"error": f"操作失败: {str(e)}"})
|
||||
return wrapper
|
||||
|
||||
@staticmethod
|
||||
async def execute_metadata_query(query: str, params: Optional[Dict[str, Any]] = None,
|
||||
operation_type: str = "元数据查询") -> str:
|
||||
"""
|
||||
执行元数据查询并返回格式化结果
|
||||
|
||||
Args:
|
||||
query: SQL查询语句
|
||||
params: 查询参数 (可选)
|
||||
operation_type: 操作类型描述
|
||||
|
||||
Returns:
|
||||
查询结果的JSON字符串
|
||||
"""
|
||||
try:
|
||||
with get_db_connection() as connection:
|
||||
results = await execute_query(connection, query, params)
|
||||
return MetadataToolBase.format_results(results, operation_type)
|
||||
except Exception as e:
|
||||
logger.error(f"元数据查询执行失败: {str(e)}")
|
||||
raise QueryExecutionError(str(e))
|
||||
332
src/tools/mysql_info_tool.py
Normal file
332
src/tools/mysql_info_tool.py
Normal file
@ -0,0 +1,332 @@
|
||||
"""
|
||||
MySQL数据库信息查询工具
|
||||
提供数据库、变量和状态等系统信息查询功能
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
|
||||
from .metadata_base_tool import MetadataToolBase, ParameterValidationError, QueryExecutionError
|
||||
from src.security.sql_analyzer import EnvironmentType
|
||||
from src.db.mysql_operations import get_db_connection, execute_query
|
||||
|
||||
logger = logging.getLogger("mysql_server")
|
||||
|
||||
# 自定义异常类
|
||||
class SecurityError(QueryExecutionError):
|
||||
"""安全限制错误"""
|
||||
pass
|
||||
|
||||
# 从环境变量读取敏感字段列表
|
||||
def get_sensitive_patterns():
|
||||
"""从环境变量获取敏感字段模式列表"""
|
||||
default_patterns = [
|
||||
r'password', r'auth', r'credential', r'key', r'secret', r'private',
|
||||
r'host', r'path', r'directory', r'ssl', r'iptables', r'filter'
|
||||
]
|
||||
|
||||
env_patterns = os.getenv('SENSITIVE_INFO_FIELDS', '')
|
||||
if env_patterns:
|
||||
# 合并自定义模式和默认模式
|
||||
patterns = [pattern.strip() for pattern in env_patterns.split(',') if pattern.strip()]
|
||||
return list(set(patterns + default_patterns))
|
||||
|
||||
return default_patterns
|
||||
|
||||
# 敏感变量和状态关键字列表
|
||||
SENSITIVE_VARIABLE_PATTERNS = get_sensitive_patterns()
|
||||
|
||||
# 敏感变量名前缀,生产环境中这些变量的值会被隐藏
|
||||
SENSITIVE_VARIABLE_PREFIXES = [
|
||||
"password", "auth", "secret", "key", "certificate", "ssl", "tls", "cipher",
|
||||
"authentication", "secure", "credential", "token"
|
||||
]
|
||||
|
||||
def check_environment_permission(env_type: EnvironmentType, query_type: str) -> bool:
|
||||
"""
|
||||
检查当前环境是否允许执行特定类型的查询
|
||||
|
||||
Args:
|
||||
env_type: 环境类型(开发/生产)
|
||||
query_type: 查询类型
|
||||
|
||||
Returns:
|
||||
bool: 是否允许执行
|
||||
"""
|
||||
# 开发环境不限制查询
|
||||
if env_type == EnvironmentType.DEVELOPMENT:
|
||||
return True
|
||||
|
||||
# 生产环境限制敏感信息查询
|
||||
sensitive_queries = ['variables', 'status', 'processlist']
|
||||
|
||||
if query_type in sensitive_queries:
|
||||
# 检查是否在环境变量中明确允许
|
||||
allow_sensitive = os.getenv('ALLOW_SENSITIVE_INFO', 'false').lower() == 'true'
|
||||
if not allow_sensitive:
|
||||
logger.warning(f"生产环境中禁止执行敏感查询: {query_type}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def filter_sensitive_info(results: List[Dict[str, Any]], filter_patterns: List[str] = None) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
过滤结果中的敏感信息
|
||||
|
||||
Args:
|
||||
results: 查询结果
|
||||
filter_patterns: 敏感信息的正则表达式模式列表
|
||||
|
||||
Returns:
|
||||
过滤后的结果列表
|
||||
"""
|
||||
if not filter_patterns:
|
||||
filter_patterns = SENSITIVE_VARIABLE_PATTERNS
|
||||
|
||||
filtered_results = []
|
||||
for item in results:
|
||||
# 复制一份,避免修改原始数据
|
||||
filtered_item = item.copy()
|
||||
|
||||
# 检查常见的变量名字段
|
||||
for field in ['Variable_name', 'variable_name', 'name']:
|
||||
if field in filtered_item:
|
||||
var_name = filtered_item[field].lower()
|
||||
# 检查是否匹配敏感模式
|
||||
is_sensitive = any(re.search(pattern, var_name, re.IGNORECASE) for pattern in filter_patterns)
|
||||
|
||||
if is_sensitive:
|
||||
# 敏感信息,隐藏具体的值
|
||||
for value_field in ['Value', 'value', 'variable_value']:
|
||||
if value_field in filtered_item:
|
||||
filtered_item[value_field] = '*** HIDDEN ***'
|
||||
|
||||
filtered_results.append(filtered_item)
|
||||
|
||||
return filtered_results
|
||||
|
||||
def register_info_tools(mcp: FastMCP):
|
||||
"""
|
||||
注册MySQL数据库信息查询工具到MCP服务器
|
||||
|
||||
Args:
|
||||
mcp: FastMCP服务器实例
|
||||
"""
|
||||
logger.debug("注册MySQL数据库信息查询工具...")
|
||||
|
||||
@mcp.tool()
|
||||
@MetadataToolBase.handle_query_error
|
||||
async def mysql_show_databases(pattern: Optional[str] = None, limit: int = 100, exclude_system: bool = True) -> str:
|
||||
"""
|
||||
获取所有数据库列表,支持筛选和限制结果数量
|
||||
|
||||
Args:
|
||||
pattern: 数据库名称匹配模式 (可选, 例如 '%test%')
|
||||
limit: 返回结果的最大数量 (默认100,设为0表示无限制)
|
||||
exclude_system: 是否排除系统数据库 (默认为True)
|
||||
|
||||
Returns:
|
||||
数据库列表的JSON字符串
|
||||
"""
|
||||
# 参数验证
|
||||
if pattern:
|
||||
MetadataToolBase.validate_parameter(
|
||||
"pattern", pattern,
|
||||
lambda x: re.match(r'^[a-zA-Z0-9_%]+$', x),
|
||||
"模式只能包含字母、数字、下划线和通配符(%_)"
|
||||
)
|
||||
|
||||
MetadataToolBase.validate_parameter(
|
||||
"limit", limit,
|
||||
lambda x: isinstance(x, int) and x >= 0,
|
||||
"返回结果的最大数量必须是非负整数"
|
||||
)
|
||||
|
||||
# 构建基础查询
|
||||
query = "SHOW DATABASES"
|
||||
|
||||
# 执行查询
|
||||
with get_db_connection() as connection:
|
||||
# 先获取所有数据库
|
||||
results = await execute_query(connection, query)
|
||||
|
||||
# 通常结果中每个数据库名会在"Database"字段
|
||||
db_field = next((k for k in results[0].keys() if k.lower() == 'database'), None) if results else None
|
||||
|
||||
if not db_field:
|
||||
logger.warning("查询结果未找到数据库名称字段")
|
||||
return MetadataToolBase.format_results(results, operation_type="数据库列表查询")
|
||||
|
||||
# 对结果进行过滤
|
||||
filtered_results = []
|
||||
system_dbs = ['information_schema', 'mysql', 'performance_schema', 'sys']
|
||||
|
||||
for item in results:
|
||||
db_name = item[db_field]
|
||||
|
||||
# 排除系统数据库
|
||||
if exclude_system and db_name.lower() in system_dbs:
|
||||
continue
|
||||
|
||||
# 根据模式过滤
|
||||
if pattern:
|
||||
pattern_regex = pattern.replace('%', '.*').replace('_', '.')
|
||||
if not re.search(pattern_regex, db_name, re.IGNORECASE):
|
||||
continue
|
||||
|
||||
filtered_results.append(item)
|
||||
|
||||
# 限制返回数量
|
||||
if limit > 0 and len(filtered_results) > limit:
|
||||
filtered_results = filtered_results[:limit]
|
||||
logger.debug(f"结果数量已限制为前{limit}个")
|
||||
|
||||
# 返回结果
|
||||
metadata_info = {
|
||||
"metadata_info": {
|
||||
"operation_type": "数据库列表查询",
|
||||
"result_count": len(filtered_results),
|
||||
"total_count": len(results),
|
||||
"filtered": {
|
||||
"pattern": pattern,
|
||||
"exclude_system": exclude_system,
|
||||
"limited": len(filtered_results) < len(results)
|
||||
}
|
||||
},
|
||||
"results": filtered_results
|
||||
}
|
||||
|
||||
return json.dumps(metadata_info, default=str)
|
||||
|
||||
@mcp.tool()
|
||||
@MetadataToolBase.handle_query_error
|
||||
async def mysql_show_variables(pattern: Optional[str] = None, global_scope: bool = False) -> str:
|
||||
"""
|
||||
获取MySQL系统变量
|
||||
|
||||
Args:
|
||||
pattern: 变量名称匹配模式 (可选, 例如 '%buffer%')
|
||||
global_scope: 是否查询全局变量 (默认为会话变量)
|
||||
|
||||
Returns:
|
||||
系统变量的JSON字符串
|
||||
"""
|
||||
# 获取当前环境类型
|
||||
from src.security.sql_analyzer import sql_analyzer
|
||||
env_type = sql_analyzer.env_type
|
||||
|
||||
# 检查环境权限
|
||||
if not check_environment_permission(env_type, 'variables'):
|
||||
raise SecurityError("当前环境不允许查询系统变量")
|
||||
|
||||
# 参数验证
|
||||
if pattern:
|
||||
MetadataToolBase.validate_parameter(
|
||||
"pattern", pattern,
|
||||
lambda x: re.match(r'^[a-zA-Z0-9_%]+$', x),
|
||||
"变量模式只能包含字母、数字、下划线和通配符(%_)"
|
||||
)
|
||||
|
||||
# 构建查询
|
||||
scope = "GLOBAL" if global_scope else "SESSION"
|
||||
query = f"SHOW {scope} VARIABLES"
|
||||
if pattern:
|
||||
query += f" LIKE '{pattern}'"
|
||||
|
||||
logger.debug(f"执行查询: {query}")
|
||||
|
||||
with get_db_connection() as connection:
|
||||
results = await execute_query(connection, query)
|
||||
|
||||
# 生产环境中过滤敏感信息
|
||||
if env_type == EnvironmentType.PRODUCTION:
|
||||
results = filter_sensitive_info(results)
|
||||
|
||||
return MetadataToolBase.format_results(results, operation_type="系统变量查询")
|
||||
|
||||
@mcp.tool()
|
||||
@MetadataToolBase.handle_query_error
|
||||
async def mysql_show_status(pattern: Optional[str] = None, global_scope: bool = False) -> str:
|
||||
"""
|
||||
获取MySQL服务器状态
|
||||
|
||||
Args:
|
||||
pattern: 状态名称匹配模式 (可选, 例如 '%conn%')
|
||||
global_scope: 是否查询全局状态 (默认为会话状态)
|
||||
|
||||
Returns:
|
||||
服务器状态的JSON字符串
|
||||
"""
|
||||
# 获取当前环境类型
|
||||
from src.security.sql_analyzer import sql_analyzer
|
||||
env_type = sql_analyzer.env_type
|
||||
|
||||
# 检查环境权限
|
||||
if not check_environment_permission(env_type, 'status'):
|
||||
raise SecurityError("当前环境不允许查询系统状态")
|
||||
|
||||
# 参数验证
|
||||
if pattern:
|
||||
MetadataToolBase.validate_parameter(
|
||||
"pattern", pattern,
|
||||
lambda x: re.match(r'^[a-zA-Z0-9_%]+$', x),
|
||||
"状态模式只能包含字母、数字、下划线和通配符(%_)"
|
||||
)
|
||||
|
||||
# 构建查询
|
||||
scope = "GLOBAL" if global_scope else "SESSION"
|
||||
query = f"SHOW {scope} STATUS"
|
||||
if pattern:
|
||||
query += f" LIKE '{pattern}'"
|
||||
|
||||
logger.debug(f"执行查询: {query}")
|
||||
|
||||
with get_db_connection() as connection:
|
||||
results = await execute_query(connection, query)
|
||||
|
||||
# 生产环境中过滤敏感信息
|
||||
if env_type == EnvironmentType.PRODUCTION:
|
||||
results = filter_sensitive_info(results)
|
||||
|
||||
return MetadataToolBase.format_results(results, operation_type="服务器状态查询")
|
||||
|
||||
# 工具函数: 用于参数验证
|
||||
def validate_pattern(pattern: str) -> bool:
|
||||
"""
|
||||
验证模式字符串是否安全 (防止SQL注入)
|
||||
|
||||
Args:
|
||||
pattern: 要验证的模式字符串
|
||||
|
||||
Returns:
|
||||
如果模式安全返回True,否则抛出ValueError
|
||||
|
||||
Raises:
|
||||
ValueError: 当模式包含不安全字符时
|
||||
"""
|
||||
# 仅允许字母、数字、下划线和通配符(% 和 _)
|
||||
if not re.match(r'^[a-zA-Z0-9_%]+$', pattern):
|
||||
raise ValueError("模式只能包含字母、数字、下划线和通配符(%_)")
|
||||
return True
|
||||
|
||||
def validate_engine_name(name: str) -> bool:
|
||||
"""
|
||||
验证存储引擎名称是否合法安全
|
||||
|
||||
Args:
|
||||
name: 要验证的引擎名称
|
||||
|
||||
Returns:
|
||||
如果引擎名称安全返回True,否则抛出ValueError
|
||||
|
||||
Raises:
|
||||
ValueError: 当引擎名称包含不安全字符时
|
||||
"""
|
||||
# 仅允许字母、数字和下划线
|
||||
if not re.match(r'^[a-zA-Z0-9_]+$', name):
|
||||
raise ValueError(f"无效的引擎名称: {name}, 引擎名称只能包含字母、数字和下划线")
|
||||
return True
|
||||
272
src/tools/mysql_metadata_tool.py
Normal file
272
src/tools/mysql_metadata_tool.py
Normal file
@ -0,0 +1,272 @@
|
||||
"""
|
||||
MySQL元数据查询工具
|
||||
提供表结构等元数据信息查询功能
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
|
||||
from .metadata_base_tool import MetadataToolBase, ParameterValidationError, QueryExecutionError
|
||||
from src.db.mysql_operations import get_db_connection, execute_query
|
||||
|
||||
logger = logging.getLogger("mysql_server")
|
||||
|
||||
# 工具函数: 用于参数验证
|
||||
def validate_pattern(pattern: str) -> bool:
|
||||
"""
|
||||
验证模式字符串是否安全 (防止SQL注入)
|
||||
|
||||
Args:
|
||||
pattern: 要验证的模式字符串
|
||||
|
||||
Returns:
|
||||
如果模式安全返回True,否则抛出ValueError
|
||||
|
||||
Raises:
|
||||
ValueError: 当模式包含不安全字符时
|
||||
"""
|
||||
# 仅允许字母、数字、下划线和通配符(% 和 _)
|
||||
if not re.match(r'^[a-zA-Z0-9_%]+$', pattern):
|
||||
raise ValueError("模式只能包含字母、数字、下划线和通配符(%_)")
|
||||
return True
|
||||
|
||||
def validate_table_name(name: str) -> bool:
|
||||
"""
|
||||
验证表名是否合法安全
|
||||
|
||||
Args:
|
||||
name: 要验证的表名
|
||||
|
||||
Returns:
|
||||
如果表名安全返回True,否则抛出ValueError
|
||||
|
||||
Raises:
|
||||
ValueError: 当表名包含不安全字符时
|
||||
"""
|
||||
# 仅允许字母、数字和下划线
|
||||
if not re.match(r'^[a-zA-Z0-9_]+$', name):
|
||||
raise ValueError(f"无效的表名: {name}, 表名只能包含字母、数字和下划线")
|
||||
return True
|
||||
|
||||
def validate_database_name(name: str) -> bool:
|
||||
"""
|
||||
验证数据库名是否合法安全
|
||||
|
||||
Args:
|
||||
name: 要验证的数据库名
|
||||
|
||||
Returns:
|
||||
如果数据库名安全返回True,否则抛出ValueError
|
||||
|
||||
Raises:
|
||||
ValueError: 当数据库名包含不安全字符时
|
||||
"""
|
||||
# 仅允许字母、数字和下划线
|
||||
if not re.match(r'^[a-zA-Z0-9_]+$', name):
|
||||
raise ValueError(f"无效的数据库名: {name}, 数据库名只能包含字母、数字和下划线")
|
||||
return True
|
||||
|
||||
def register_metadata_tools(mcp: FastMCP):
|
||||
"""
|
||||
注册MySQL元数据查询工具到MCP服务器
|
||||
|
||||
Args:
|
||||
mcp: FastMCP服务器实例
|
||||
"""
|
||||
logger.debug("注册MySQL元数据查询工具...")
|
||||
|
||||
@mcp.tool()
|
||||
@MetadataToolBase.handle_query_error
|
||||
async def mysql_show_tables(database: Optional[str] = None, pattern: Optional[str] = None,
|
||||
limit: int = 100, exclude_views: bool = False) -> str:
|
||||
"""
|
||||
获取数据库中的表列表,支持筛选和限制结果数量
|
||||
|
||||
Args:
|
||||
database: 数据库名称 (可选,默认使用当前连接的数据库)
|
||||
pattern: 表名匹配模式 (可选, 例如 '%user%')
|
||||
limit: 返回结果的最大数量 (默认100,设为0表示无限制)
|
||||
exclude_views: 是否排除视图 (默认为False)
|
||||
|
||||
Returns:
|
||||
表列表的JSON字符串
|
||||
"""
|
||||
# 参数验证
|
||||
if database:
|
||||
MetadataToolBase.validate_parameter(
|
||||
"database", database,
|
||||
lambda x: re.match(r'^[a-zA-Z0-9_]+$', x),
|
||||
"数据库名称只能包含字母、数字和下划线"
|
||||
)
|
||||
|
||||
if pattern:
|
||||
MetadataToolBase.validate_parameter(
|
||||
"pattern", pattern,
|
||||
lambda x: re.match(r'^[a-zA-Z0-9_%]+$', x),
|
||||
"模式只能包含字母、数字、下划线和通配符(%_)"
|
||||
)
|
||||
|
||||
MetadataToolBase.validate_parameter(
|
||||
"limit", limit,
|
||||
lambda x: isinstance(x, int) and x >= 0,
|
||||
"返回结果的最大数量必须是非负整数"
|
||||
)
|
||||
|
||||
# 基础查询
|
||||
base_query = "SHOW FULL TABLES" if exclude_views else "SHOW TABLES"
|
||||
if database:
|
||||
base_query += f" FROM `{database}`"
|
||||
if pattern:
|
||||
base_query += f" LIKE '{pattern}'"
|
||||
|
||||
logger.debug(f"执行查询: {base_query}")
|
||||
|
||||
# 执行查询
|
||||
with get_db_connection() as connection:
|
||||
results = await execute_query(connection, base_query)
|
||||
|
||||
# 如果需要排除视图,且使用的是SHOW FULL TABLES
|
||||
if exclude_views and "FULL" in base_query:
|
||||
filtered_results = []
|
||||
|
||||
# 查找表名和表类型字段
|
||||
fields = list(results[0].keys()) if results else []
|
||||
table_field = fields[0] if fields else None
|
||||
table_type_field = fields[1] if len(fields) > 1 else None
|
||||
|
||||
if table_field and table_type_field:
|
||||
# 基表类型通常是"BASE TABLE"
|
||||
for item in results:
|
||||
if item[table_type_field] == 'BASE TABLE':
|
||||
filtered_results.append(item)
|
||||
else:
|
||||
filtered_results = results
|
||||
else:
|
||||
filtered_results = results
|
||||
|
||||
# 限制返回数量
|
||||
if limit > 0 and len(filtered_results) > limit:
|
||||
limited_results = filtered_results[:limit]
|
||||
is_limited = True
|
||||
else:
|
||||
limited_results = filtered_results
|
||||
is_limited = False
|
||||
|
||||
# 构造元数据
|
||||
metadata_info = {
|
||||
"metadata_info": {
|
||||
"operation_type": "表列表查询",
|
||||
"result_count": len(limited_results),
|
||||
"total_count": len(results),
|
||||
"filtered": {
|
||||
"database": database,
|
||||
"pattern": pattern,
|
||||
"exclude_views": exclude_views and "FULL" in base_query,
|
||||
"limited": is_limited
|
||||
}
|
||||
},
|
||||
"results": limited_results
|
||||
}
|
||||
|
||||
return json.dumps(metadata_info, default=str)
|
||||
|
||||
@mcp.tool()
|
||||
@MetadataToolBase.handle_query_error
|
||||
async def mysql_show_columns(table: str, database: Optional[str] = None) -> str:
|
||||
"""
|
||||
获取表的列信息
|
||||
|
||||
Args:
|
||||
table: 表名
|
||||
database: 数据库名称 (可选,默认使用当前连接的数据库)
|
||||
|
||||
Returns:
|
||||
表列信息的JSON字符串
|
||||
"""
|
||||
# 参数验证
|
||||
MetadataToolBase.validate_parameter(
|
||||
"table", table,
|
||||
lambda x: re.match(r'^[a-zA-Z0-9_]+$', x),
|
||||
"表名只能包含字母、数字和下划线"
|
||||
)
|
||||
|
||||
if database:
|
||||
MetadataToolBase.validate_parameter(
|
||||
"database", database,
|
||||
lambda x: re.match(r'^[a-zA-Z0-9_]+$', x),
|
||||
"数据库名称只能包含字母、数字和下划线"
|
||||
)
|
||||
|
||||
query = f"SHOW COLUMNS FROM `{table}`" if not database else f"SHOW COLUMNS FROM `{database}`.`{table}`"
|
||||
logger.debug(f"执行查询: {query}")
|
||||
|
||||
return await MetadataToolBase.execute_metadata_query(query, operation_type="表列信息查询")
|
||||
|
||||
@mcp.tool()
|
||||
@MetadataToolBase.handle_query_error
|
||||
async def mysql_describe_table(table: str, database: Optional[str] = None) -> str:
|
||||
"""
|
||||
描述表结构
|
||||
|
||||
Args:
|
||||
table: 表名
|
||||
database: 数据库名称 (可选,默认使用当前连接的数据库)
|
||||
|
||||
Returns:
|
||||
表结构描述的JSON字符串
|
||||
"""
|
||||
# 参数验证
|
||||
MetadataToolBase.validate_parameter(
|
||||
"table", table,
|
||||
lambda x: re.match(r'^[a-zA-Z0-9_]+$', x),
|
||||
"表名只能包含字母、数字和下划线"
|
||||
)
|
||||
|
||||
if database:
|
||||
MetadataToolBase.validate_parameter(
|
||||
"database", database,
|
||||
lambda x: re.match(r'^[a-zA-Z0-9_]+$', x),
|
||||
"数据库名称只能包含字母、数字和下划线"
|
||||
)
|
||||
|
||||
# DESCRIBE 语句与 SHOW COLUMNS 功能类似,但结果格式可能略有不同
|
||||
query = f"DESCRIBE `{table}`" if not database else f"DESCRIBE `{database}`.`{table}`"
|
||||
logger.debug(f"执行查询: {query}")
|
||||
|
||||
return await MetadataToolBase.execute_metadata_query(query, operation_type="表结构描述")
|
||||
|
||||
@mcp.tool()
|
||||
@MetadataToolBase.handle_query_error
|
||||
async def mysql_show_create_table(table: str, database: Optional[str] = None) -> str:
|
||||
"""
|
||||
获取表的创建语句
|
||||
|
||||
Args:
|
||||
table: 表名
|
||||
database: 数据库名称 (可选,默认使用当前连接的数据库)
|
||||
|
||||
Returns:
|
||||
表创建语句的JSON字符串
|
||||
"""
|
||||
# 参数验证
|
||||
MetadataToolBase.validate_parameter(
|
||||
"table", table,
|
||||
lambda x: re.match(r'^[a-zA-Z0-9_]+$', x),
|
||||
"表名只能包含字母、数字和下划线"
|
||||
)
|
||||
|
||||
if database:
|
||||
MetadataToolBase.validate_parameter(
|
||||
"database", database,
|
||||
lambda x: re.match(r'^[a-zA-Z0-9_]+$', x),
|
||||
"数据库名称只能包含字母、数字和下划线"
|
||||
)
|
||||
|
||||
table_ref = f"`{table}`" if not database else f"`{database}`.`{table}`"
|
||||
query = f"SHOW CREATE TABLE {table_ref}"
|
||||
logger.debug(f"执行查询: {query}")
|
||||
|
||||
return await MetadataToolBase.execute_metadata_query(query, operation_type="表创建语句查询")
|
||||
294
src/tools/mysql_schema_tool.py
Normal file
294
src/tools/mysql_schema_tool.py
Normal file
@ -0,0 +1,294 @@
|
||||
"""
|
||||
MySQL表结构高级查询工具
|
||||
提供索引、约束、表状态等高级元数据查询功能
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
|
||||
from .metadata_base_tool import MetadataToolBase, ParameterValidationError, QueryExecutionError
|
||||
from src.db.mysql_operations import get_db_connection, execute_query
|
||||
|
||||
logger = logging.getLogger("mysql_server")
|
||||
|
||||
# 参数验证函数
|
||||
def validate_table_name(name: str) -> bool:
|
||||
"""
|
||||
验证表名是否合法安全
|
||||
|
||||
Args:
|
||||
name: 要验证的表名
|
||||
|
||||
Returns:
|
||||
如果表名安全返回True,否则抛出ValueError
|
||||
|
||||
Raises:
|
||||
ValueError: 当表名包含不安全字符时
|
||||
"""
|
||||
# 仅允许字母、数字和下划线
|
||||
if not re.match(r'^[a-zA-Z0-9_]+$', name):
|
||||
raise ValueError(f"无效的表名: {name}, 表名只能包含字母、数字和下划线")
|
||||
return True
|
||||
|
||||
def validate_database_name(name: str) -> bool:
|
||||
"""
|
||||
验证数据库名是否合法安全
|
||||
|
||||
Args:
|
||||
name: 要验证的数据库名
|
||||
|
||||
Returns:
|
||||
如果数据库名安全返回True,否则抛出ValueError
|
||||
|
||||
Raises:
|
||||
ValueError: 当数据库名包含不安全字符时
|
||||
"""
|
||||
# 仅允许字母、数字和下划线
|
||||
if not re.match(r'^[a-zA-Z0-9_]+$', name):
|
||||
raise ValueError(f"无效的数据库名: {name}, 数据库名只能包含字母、数字和下划线")
|
||||
return True
|
||||
|
||||
def validate_column_name(name: str) -> bool:
|
||||
"""
|
||||
验证列名是否合法安全
|
||||
|
||||
Args:
|
||||
name: 要验证的列名
|
||||
|
||||
Returns:
|
||||
如果列名安全返回True,否则抛出ValueError
|
||||
|
||||
Raises:
|
||||
ValueError: 当列名包含不安全字符时
|
||||
"""
|
||||
# 仅允许字母、数字和下划线
|
||||
if not re.match(r'^[a-zA-Z0-9_]+$', name):
|
||||
raise ValueError(f"无效的列名: {name}, 列名只能包含字母、数字和下划线")
|
||||
return True
|
||||
|
||||
async def execute_schema_query(
|
||||
query: str,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
operation_type: str = "元数据查询"
|
||||
) -> str:
|
||||
"""
|
||||
执行表结构查询
|
||||
|
||||
Args:
|
||||
query: SQL查询语句
|
||||
params: 查询参数 (可选)
|
||||
operation_type: 操作类型描述
|
||||
|
||||
Returns:
|
||||
查询结果的JSON字符串
|
||||
"""
|
||||
with get_db_connection() as connection:
|
||||
results = await execute_query(connection, query, params)
|
||||
return MetadataToolBase.format_results(results, operation_type)
|
||||
|
||||
def register_schema_tools(mcp: FastMCP):
|
||||
"""
|
||||
注册MySQL表结构高级查询工具到MCP服务器
|
||||
|
||||
Args:
|
||||
mcp: FastMCP服务器实例
|
||||
"""
|
||||
logger.debug("注册MySQL表结构高级查询工具...")
|
||||
|
||||
@mcp.tool()
|
||||
@MetadataToolBase.handle_query_error
|
||||
async def mysql_show_indexes(table: str, database: Optional[str] = None) -> str:
|
||||
"""
|
||||
获取表的索引信息
|
||||
|
||||
Args:
|
||||
table: 表名
|
||||
database: 数据库名称 (可选,默认使用当前连接的数据库)
|
||||
|
||||
Returns:
|
||||
表索引信息的JSON字符串
|
||||
"""
|
||||
# 参数验证
|
||||
validate_table_name(table)
|
||||
|
||||
if database:
|
||||
validate_database_name(database)
|
||||
|
||||
# 构建查询
|
||||
table_ref = f"`{table}`" if not database else f"`{database}`.`{table}`"
|
||||
query = f"SHOW INDEX FROM {table_ref}"
|
||||
logger.debug(f"执行查询: {query}")
|
||||
|
||||
# 执行查询
|
||||
return await execute_schema_query(query, operation_type="表索引查询")
|
||||
|
||||
@mcp.tool()
|
||||
@MetadataToolBase.handle_query_error
|
||||
async def mysql_show_table_status(database: Optional[str] = None, like_pattern: Optional[str] = None) -> str:
|
||||
"""
|
||||
获取表状态信息
|
||||
|
||||
Args:
|
||||
database: 数据库名称 (可选,默认使用当前连接的数据库)
|
||||
like_pattern: 表名匹配模式 (可选,例如 '%user%')
|
||||
|
||||
Returns:
|
||||
表状态信息的JSON字符串
|
||||
"""
|
||||
# 参数验证
|
||||
if database:
|
||||
validate_database_name(database)
|
||||
|
||||
if like_pattern:
|
||||
validate_column_name(like_pattern)
|
||||
|
||||
# 构建查询
|
||||
if database:
|
||||
query = f"SHOW TABLE STATUS FROM `{database}`"
|
||||
else:
|
||||
query = "SHOW TABLE STATUS"
|
||||
|
||||
if like_pattern:
|
||||
query += f" LIKE '{like_pattern}'"
|
||||
|
||||
logger.debug(f"执行查询: {query}")
|
||||
|
||||
# 执行查询
|
||||
return await execute_schema_query(query, operation_type="表状态查询")
|
||||
|
||||
@mcp.tool()
|
||||
@MetadataToolBase.handle_query_error
|
||||
async def mysql_show_foreign_keys(table: str, database: Optional[str] = None) -> str:
|
||||
"""
|
||||
获取表的外键约束信息
|
||||
|
||||
Args:
|
||||
table: 表名
|
||||
database: 数据库名称 (可选,默认使用当前连接的数据库)
|
||||
|
||||
Returns:
|
||||
表外键约束信息的JSON字符串
|
||||
"""
|
||||
# 参数验证
|
||||
validate_table_name(table)
|
||||
|
||||
if database:
|
||||
validate_database_name(database)
|
||||
|
||||
# 确定数据库名
|
||||
db_name = database
|
||||
if not db_name:
|
||||
# 获取当前数据库
|
||||
with get_db_connection() as connection:
|
||||
current_db_results = await execute_query(connection, "SELECT DATABASE() as db")
|
||||
if current_db_results and 'db' in current_db_results[0]:
|
||||
db_name = current_db_results[0]['db']
|
||||
|
||||
if not db_name:
|
||||
raise ValueError("无法确定数据库名称,请明确指定database参数")
|
||||
|
||||
# 使用INFORMATION_SCHEMA查询外键
|
||||
query = """
|
||||
SELECT
|
||||
CONSTRAINT_NAME,
|
||||
TABLE_NAME,
|
||||
COLUMN_NAME,
|
||||
REFERENCED_TABLE_NAME,
|
||||
REFERENCED_COLUMN_NAME,
|
||||
UPDATE_RULE,
|
||||
DELETE_RULE
|
||||
FROM
|
||||
INFORMATION_SCHEMA.KEY_COLUMN_USAGE kcu
|
||||
JOIN
|
||||
INFORMATION_SCHEMA.REFERENTIAL_CONSTRAINTS rc
|
||||
ON
|
||||
kcu.CONSTRAINT_NAME = rc.CONSTRAINT_NAME
|
||||
WHERE
|
||||
kcu.TABLE_SCHEMA = %s
|
||||
AND kcu.TABLE_NAME = %s
|
||||
AND kcu.REFERENCED_TABLE_NAME IS NOT NULL
|
||||
"""
|
||||
params = {'TABLE_SCHEMA': db_name, 'TABLE_NAME': table}
|
||||
|
||||
logger.debug(f"执行查询: 获取表 {db_name}.{table} 的外键约束")
|
||||
|
||||
# 执行查询
|
||||
return await execute_schema_query(query, params, operation_type="外键约束查询")
|
||||
|
||||
@mcp.tool()
|
||||
@MetadataToolBase.handle_query_error
|
||||
async def mysql_paginate_results(query: str, page: int = 1, page_size: int = 50) -> str:
|
||||
"""
|
||||
分页执行查询以处理大型结果集
|
||||
|
||||
Args:
|
||||
query: SQL查询语句
|
||||
page: 页码 (从1开始)
|
||||
page_size: 每页记录数 (默认50)
|
||||
|
||||
Returns:
|
||||
分页结果的JSON字符串
|
||||
"""
|
||||
# 参数验证
|
||||
MetadataToolBase.validate_parameter(
|
||||
"page", page,
|
||||
lambda x: isinstance(x, int) and x > 0,
|
||||
"页码必须是正整数"
|
||||
)
|
||||
|
||||
MetadataToolBase.validate_parameter(
|
||||
"page_size", page_size,
|
||||
lambda x: isinstance(x, int) and 1 <= x <= 1000,
|
||||
"每页记录数必须在1-1000之间"
|
||||
)
|
||||
|
||||
# 检查查询语法
|
||||
if not query.strip().upper().startswith('SELECT'):
|
||||
raise ValueError("只支持SELECT查询的分页")
|
||||
|
||||
# 计算LIMIT和OFFSET
|
||||
offset = (page - 1) * page_size
|
||||
|
||||
# 在查询末尾添加LIMIT子句
|
||||
paginated_query = query.strip()
|
||||
if 'LIMIT' in paginated_query.upper():
|
||||
raise ValueError("查询已包含LIMIT子句,请移除后重试")
|
||||
|
||||
paginated_query += f" LIMIT {page_size} OFFSET {offset}"
|
||||
|
||||
logger.debug(f"执行分页查询: 页码={page}, 每页记录数={page_size}")
|
||||
|
||||
# 获取总记录数(用于计算总页数)
|
||||
count_query = f"SELECT COUNT(*) as total FROM ({query}) as temp_count_table"
|
||||
|
||||
with get_db_connection() as connection:
|
||||
# 执行分页查询
|
||||
results = await execute_query(connection, paginated_query)
|
||||
|
||||
# 获取总记录数
|
||||
count_results = await execute_query(connection, count_query)
|
||||
total_records = count_results[0]['total'] if count_results else 0
|
||||
|
||||
# 计算总页数
|
||||
total_pages = (total_records + page_size - 1) // page_size
|
||||
|
||||
# 构建分页元数据
|
||||
pagination_info = {
|
||||
"metadata_info": {
|
||||
"operation_type": "分页查询",
|
||||
"result_count": len(results),
|
||||
"pagination": {
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"total_records": total_records,
|
||||
"total_pages": total_pages
|
||||
}
|
||||
},
|
||||
"results": results
|
||||
}
|
||||
|
||||
return json.dumps(pagination_info, default=str)
|
||||
Reference in New Issue
Block a user