<feat> 增加元数据操作功能

This commit is contained in:
tangyi
2025-04-08 11:40:25 +08:00
parent b0463903f5
commit e0550a531d
9 changed files with 1167 additions and 33 deletions

101
README.md
View File

@ -16,11 +16,45 @@
- 危险操作拦截
- WHERE子句强制检查
- 自动返回修改操作影响的行数
- 敏感信息保护机制
- 自动对元数据查询结果进行格式化和增强
## API接口功能
系统提供以下四大类工具:
### 基础查询工具
- `mysql_query`: 执行任意SQL查询支持参数化查询
### 元数据查询工具
- `mysql_show_tables`: 获取数据库中的表列表,支持模式匹配和限制结果数量
- `mysql_show_columns`: 获取表的列信息
- `mysql_describe_table`: 描述表结构
- `mysql_show_create_table`: 获取表的创建语句
### 数据库信息查询工具
- `mysql_show_databases`: 获取所有数据库列表,支持过滤系统数据库
- `mysql_show_variables`: 获取MySQL服务器变量
- `mysql_show_status`: 获取MySQL服务器状态信息
### 表结构高级查询工具
- `mysql_show_indexes`: 获取表的索引信息
- `mysql_show_table_status`: 获取表状态信息
- `mysql_show_foreign_keys`: 获取表的外键约束信息
- `mysql_paginate_results`: 提供结果分页功能
## 系统要求
- Python 3.6+
- MySQL服务器
- 依赖包:
- mysql-connector-python
- python-dotenv
- mcp (FastMCP框架)
## 安装步骤
@ -57,34 +91,33 @@ pip install -r requirements.txt
- `ALLOWED_RISK_LEVELS`: 允许的风险等级LOW/MEDIUM/HIGH/CRITICAL
- `BLOCKED_PATTERNS`: 禁止的SQL模式正则表达式用逗号分隔
- `ENABLE_QUERY_CHECK`: 是否启用SQL安全检查true/false
- `ALLOW_SENSITIVE_INFO`: 是否允许查询敏感信息true/false
- `SENSITIVE_INFO_FIELDS`: 自定义敏感字段模式列表(逗号分隔)
## SQL安全机制
## 安全机制详解
### 风险等级控制
- LOW: 查询操作SELECT
- LOW: 查询操作SELECT和元数据操作SHOW, DESCRIBE等
- MEDIUM: 基本数据修改INSERT有WHERE的UPDATE/DELETE
- HIGH: 结构变更CREATE/ALTER和无WHERE的UPDATE
- CRITICAL: 危险操作DROP/TRUNCATE和无WHERE的DELETE操作
### 环境变量加载顺序
项目使用python-dotenv加载环境变量需要确保在导入其他模块前先加载环境变量否则可能会导致配置未被正确应用。
### 环境特性差异
- 开发环境:
- 允许较高风险的操作
- 不隐藏敏感信息
- 提供详细的错误信息
- 生产环境:
- 默认只允许LOW风险操作
- 严格限制数据修改
- 自动隐藏敏感信息
- 错误信息不暴露实现细节
### 安全检查机制
- 强制要求UPDATE/DELETE操作包含WHERE子句
- SQL语句语法检查
- 危险操作模式检测
- 自动识别受影响的表
- 风险等级评估
### 环境特性
- 开发环境:允许较高风险的操作,但仍需遵守基本安全规则
- 生产环境默认只允许LOW风险操作严格限制数据修改
### 安全拦截
- SQL语句长度限制
- 危险操作模式检测
- 自动识别受影响的表
- SQL注入防护
### 敏感信息保护
系统会自动检测并隐藏包含以下关键词的变量/状态值:
- password、auth、credential、key、secret、private
- ssl、tls、cipher、certificate
- host、path、directory等系统路径信息
### 事务管理
- 对于修改操作INSERT/UPDATE/DELETE会自动提交事务
@ -105,17 +138,23 @@ python src/server.py
```
.
├── src/ # 源代码目录
│ ├── server.py # 主服务器文件
│ ├── db/ # 数据库相关代码
├── security/ # SQL安全相关代码
│ ├── interceptor.py # SQL拦截器
├── src/ # 源代码目录
│ ├── server.py # 主服务器文件
│ ├── db/ # 数据库相关代码
│ └── mysql_operations.py # MySQL操作实现
│ ├── security/ # SQL安全相关代码
│ │ ├── interceptor.py # SQL拦截器
│ │ ├── query_limiter.py # SQL安全检查器
│ │ └── sql_analyzer.py # SQL分析器
│ └── tools/ # 工具类代码
├── tests/ # 测试代码目录
├── .env.example # 环境变量示例文件
└── requirements.txt # 项目依赖文件
│ └── tools/ # 工具类代码
│ ├── mysql_tool.py # 基础查询工具
│ ├── mysql_metadata_tool.py # 元数据查询工具
│ ├── mysql_info_tool.py # 数据库信息查询工具
│ ├── mysql_schema_tool.py # 表结构高级查询工具
│ └── metadata_base_tool.py # 元数据工具基类
├── tests/ # 测试代码目录
├── .env.example # 环境变量示例文件
└── requirements.txt # 项目依赖文件
```
## 常见问题解决
@ -136,6 +175,10 @@ python src/server.py
- 如果需要执行高风险操作相应地调整ALLOWED_RISK_LEVELS
- 对于不带WHERE条件的UPDATE或DELETE可以添加条件即使是WHERE 1=1降低风险级别
### 无法查看敏感信息
- 在开发环境中设置ALLOW_SENSITIVE_INFO=true
- 在生产环境中,敏感信息默认会被隐藏,这是安全特性
## 日志系统
服务器包含完整的日志记录系统,可以在控制台和日志文件中查看运行状态和错误信息。日志级别可以在`server.py`中配置。

View File

@ -81,6 +81,7 @@ async def execute_query(connection, query: str, params: Optional[Dict[str, Any]]
ValueError: 当查询执行失败时
"""
cursor = None
operation = None # 初始化操作类型变量
try:
# 安全检查
if not await sql_interceptor.check_operation(query):
@ -105,6 +106,33 @@ async def execute_query(connection, query: str, params: Optional[Dict[str, Any]]
logger.debug(f"修改操作 {operation} 影响了 {affected_rows} 行数据")
return [{'affected_rows': affected_rows}]
# 处理元数据查询操作
if operation in sql_analyzer.metadata_operations:
# 获取结果集
results = cursor.fetchall()
# 没有结果时返回空列表但添加元信息
if not results:
logger.debug(f"元数据查询 {operation} 没有返回结果")
return [{'metadata_operation': operation, 'result_count': 0}]
# 优化结果格式 - 为元数据结果添加额外信息
metadata_results = []
for row in results:
# 对某些特定元数据查询进行特殊处理
if operation == 'SHOW' and 'Table' in row:
# SHOW TABLES 结果增强
row['table_name'] = row['Table']
elif operation in {'DESC', 'DESCRIBE'} and 'Field' in row:
# DESC/DESCRIBE 表结构结果增强
row['column_name'] = row['Field']
row['data_type'] = row['Type']
metadata_results.append(row)
logger.debug(f"元数据查询 {operation} 返回 {len(metadata_results)} 条结果")
return metadata_results
# 对于查询操作,返回结果集
results = cursor.fetchall()
logger.debug(f"查询返回 {len(results)} 条结果")
@ -115,7 +143,7 @@ async def execute_query(connection, query: str, params: Optional[Dict[str, Any]]
raise
except mysql.connector.Error as query_err:
# 如果发生错误,进行回滚
if operation in {'UPDATE', 'DELETE', 'INSERT'}:
if operation and operation in {'UPDATE', 'DELETE', 'INSERT'}: # 确保operation已定义
try:
connection.rollback()
logger.debug("事务已回滚")

View File

@ -46,7 +46,15 @@ class SQLInterceptor:
raise SecurityException("SQL语句格式无效")
operation = sql_parts[0].upper()
if operation not in {'SELECT', 'INSERT', 'UPDATE', 'DELETE', 'CREATE', 'ALTER', 'DROP', 'TRUNCATE', 'MERGE'}:
# 更新支持的操作类型列表,包括元数据操作
supported_operations = {
'SELECT', 'INSERT', 'UPDATE', 'DELETE',
'CREATE', 'ALTER', 'DROP', 'TRUNCATE', 'MERGE',
'SHOW', 'DESC', 'DESCRIBE', 'EXPLAIN', 'HELP',
'ANALYZE', 'CHECK', 'CHECKSUM', 'OPTIMIZE'
}
if operation not in supported_operations:
raise SecurityException(f"不支持的SQL操作: {operation}")
# 分析SQL风险
@ -65,9 +73,14 @@ class SQLInterceptor:
f"允许的风险等级: {[level.name for level in self.analyzer.allowed_risk_levels]}"
)
# 确定操作类型DDL, DML 或 元数据)
operation_category = "元数据操作" if operation in self.analyzer.metadata_operations else (
"DDL操作" if operation in self.analyzer.ddl_operations else "DML操作"
)
# 记录详细日志
logger.info(
f"SQL操作检查通过 - "
f"SQL{operation_category}检查通过 - "
f"操作: {risk_analysis['operation']}, "
f"风险等级: {risk_analysis['risk_level'].name}, "
f"影响表: {', '.join(risk_analysis['affected_tables'])}"

View File

@ -38,6 +38,12 @@ class SQLOperationType:
'SELECT', 'INSERT', 'UPDATE', 'DELETE', 'MERGE'
}
# 添加元数据操作集合
self.metadata_operations = {
'SHOW', 'DESC', 'DESCRIBE', 'EXPLAIN', 'HELP',
'ANALYZE', 'CHECK', 'CHECKSUM', 'OPTIMIZE'
}
# 风险等级配置
self.allowed_risk_levels = self._parse_risk_levels()
self.blocked_patterns = self._parse_blocked_patterns('BLOCKED_PATTERNS')
@ -132,11 +138,17 @@ class SQLOperationType:
- UPDATE/DELETE有WHERE=> MEDIUM
- UPDATE无WHERE=> HIGH
- DELETE无WHERE=> CRITICAL
4. 元数据操作:
- SHOW/DESC/DESCRIBE等 => LOW
"""
# 危险操作
if is_dangerous:
return SQLRiskLevel.CRITICAL
# 元数据操作
if operation in self.metadata_operations:
return SQLRiskLevel.LOW # 元数据查询视为低风险操作
# 生产环境中非SELECT操作
if self.env_type == EnvironmentType.PRODUCTION and operation != 'SELECT':
return SQLRiskLevel.CRITICAL

View File

@ -8,6 +8,9 @@ load_dotenv()
# 导入自定义模块 - 确保在load_dotenv之后导入
from src.tools.mysql_tool import register_mysql_tool
from src.tools.mysql_metadata_tool import register_metadata_tools
from src.tools.mysql_info_tool import register_info_tools
from src.tools.mysql_schema_tool import register_schema_tools
# 配置日志
logging.basicConfig(
@ -19,6 +22,8 @@ logger = logging.getLogger("mysql_server")
# 记录环境变量加载情况
logger.debug("已加载环境变量")
logger.debug(f"当前允许的风险等级: {os.getenv('ALLOWED_RISK_LEVELS', '未设置')}")
logger.debug(f"当前环境类型: {os.getenv('ENV_TYPE', 'development')}")
logger.debug(f"是否允许敏感信息查询: {os.getenv('ALLOW_SENSITIVE_INFO', 'false')}")
# 尝试导入MySQL连接器
try:
@ -40,9 +45,21 @@ logger.debug("正在创建MCP服务器实例...")
mcp = FastMCP("MySQL Query Server", "cccccccccc", host=host, port=port, debug=True, endpoint='/sse')
logger.debug("MCP服务器实例创建完成")
# 注册MySQL工具
# 注册MySQL基础查询工具
register_mysql_tool(mcp)
# 注册MySQL元数据查询工具
register_metadata_tools(mcp)
logger.debug("已注册元数据查询工具")
# 注册MySQL数据库信息查询工具
register_info_tools(mcp)
logger.debug("已注册数据库信息查询工具")
# 注册MySQL表结构高级查询工具
register_schema_tools(mcp)
logger.debug("已注册表结构高级查询工具")
def start_server():
"""启动SSE服务器的同步包装器"""
logger.debug("开始启动MySQL查询服务器...")

View File

@ -0,0 +1,123 @@
"""
MySQL元数据工具基类
提供元数据查询工具的共享功能
"""
import json
import logging
from typing import Any, Dict, List, Optional, Union, TypeVar, Generic, Callable
import functools
from src.db.mysql_operations import get_db_connection, execute_query
T = TypeVar('T')
logger = logging.getLogger("mysql_server")
class MySQLToolError(Exception):
"""MySQL工具异常基类"""
pass
class ParameterValidationError(MySQLToolError):
"""参数验证错误"""
pass
class QueryExecutionError(MySQLToolError):
"""查询执行错误"""
pass
class MetadataToolBase:
"""
元数据查询工具基类
提供共享功能:
- 错误处理
- 结果格式化
- 参数验证
- 异常处理
"""
@staticmethod
def validate_parameter(param_name: str, param_value: Any, validator: Callable[[Any], bool],
error_message: str) -> None:
"""
验证参数是否有效
Args:
param_name: 参数名称
param_value: 参数值
validator: 验证函数
error_message: 错误消息
Raises:
ParameterValidationError: 当参数验证失败时
"""
if param_value is not None and not validator(param_value):
raise ParameterValidationError(f"{param_name} - {error_message}")
@staticmethod
def format_results(results: List[Dict[str, Any]], operation_type: str = "元数据查询") -> str:
"""
格式化查询结果为JSON字符串
Args:
results: 查询结果列表
operation_type: 操作类型描述
Returns:
格式化后的JSON字符串
"""
try:
# 添加元数据头部信息
metadata_info = {
"metadata_info": {
"operation_type": operation_type,
"result_count": len(results)
},
"results": results
}
return json.dumps(metadata_info, default=str)
except Exception as e:
logger.error(f"结果格式化失败: {str(e)}")
# 如果格式化失败,尝试直接序列化结果
return json.dumps({"error": f"结果格式化失败: {str(e)}"})
@staticmethod
def handle_query_error(func):
"""
装饰器: 统一处理查询执行过程中的错误
"""
@functools.wraps(func)
async def wrapper(*args, **kwargs):
try:
return await func(*args, **kwargs)
except ParameterValidationError as e:
logger.error(f"参数验证错误: {str(e)}")
return json.dumps({"error": f"参数错误: {str(e)}"})
except QueryExecutionError as e:
logger.error(f"查询执行错误: {str(e)}")
return json.dumps({"error": f"查询执行失败: {str(e)}"})
except Exception as e:
logger.error(f"未预期的错误: {str(e)}")
return json.dumps({"error": f"操作失败: {str(e)}"})
return wrapper
@staticmethod
async def execute_metadata_query(query: str, params: Optional[Dict[str, Any]] = None,
operation_type: str = "元数据查询") -> str:
"""
执行元数据查询并返回格式化结果
Args:
query: SQL查询语句
params: 查询参数 (可选)
operation_type: 操作类型描述
Returns:
查询结果的JSON字符串
"""
try:
with get_db_connection() as connection:
results = await execute_query(connection, query, params)
return MetadataToolBase.format_results(results, operation_type)
except Exception as e:
logger.error(f"元数据查询执行失败: {str(e)}")
raise QueryExecutionError(str(e))

View File

@ -0,0 +1,332 @@
"""
MySQL数据库信息查询工具
提供数据库、变量和状态等系统信息查询功能
"""
import json
import logging
import re
import os
from typing import Any, Dict, List, Optional
from mcp.server.fastmcp import FastMCP
from .metadata_base_tool import MetadataToolBase, ParameterValidationError, QueryExecutionError
from src.security.sql_analyzer import EnvironmentType
from src.db.mysql_operations import get_db_connection, execute_query
logger = logging.getLogger("mysql_server")
# 自定义异常类
class SecurityError(QueryExecutionError):
"""安全限制错误"""
pass
# 从环境变量读取敏感字段列表
def get_sensitive_patterns():
"""从环境变量获取敏感字段模式列表"""
default_patterns = [
r'password', r'auth', r'credential', r'key', r'secret', r'private',
r'host', r'path', r'directory', r'ssl', r'iptables', r'filter'
]
env_patterns = os.getenv('SENSITIVE_INFO_FIELDS', '')
if env_patterns:
# 合并自定义模式和默认模式
patterns = [pattern.strip() for pattern in env_patterns.split(',') if pattern.strip()]
return list(set(patterns + default_patterns))
return default_patterns
# 敏感变量和状态关键字列表
SENSITIVE_VARIABLE_PATTERNS = get_sensitive_patterns()
# 敏感变量名前缀,生产环境中这些变量的值会被隐藏
SENSITIVE_VARIABLE_PREFIXES = [
"password", "auth", "secret", "key", "certificate", "ssl", "tls", "cipher",
"authentication", "secure", "credential", "token"
]
def check_environment_permission(env_type: EnvironmentType, query_type: str) -> bool:
"""
检查当前环境是否允许执行特定类型的查询
Args:
env_type: 环境类型(开发/生产)
query_type: 查询类型
Returns:
bool: 是否允许执行
"""
# 开发环境不限制查询
if env_type == EnvironmentType.DEVELOPMENT:
return True
# 生产环境限制敏感信息查询
sensitive_queries = ['variables', 'status', 'processlist']
if query_type in sensitive_queries:
# 检查是否在环境变量中明确允许
allow_sensitive = os.getenv('ALLOW_SENSITIVE_INFO', 'false').lower() == 'true'
if not allow_sensitive:
logger.warning(f"生产环境中禁止执行敏感查询: {query_type}")
return False
return True
def filter_sensitive_info(results: List[Dict[str, Any]], filter_patterns: List[str] = None) -> List[Dict[str, Any]]:
"""
过滤结果中的敏感信息
Args:
results: 查询结果
filter_patterns: 敏感信息的正则表达式模式列表
Returns:
过滤后的结果列表
"""
if not filter_patterns:
filter_patterns = SENSITIVE_VARIABLE_PATTERNS
filtered_results = []
for item in results:
# 复制一份,避免修改原始数据
filtered_item = item.copy()
# 检查常见的变量名字段
for field in ['Variable_name', 'variable_name', 'name']:
if field in filtered_item:
var_name = filtered_item[field].lower()
# 检查是否匹配敏感模式
is_sensitive = any(re.search(pattern, var_name, re.IGNORECASE) for pattern in filter_patterns)
if is_sensitive:
# 敏感信息,隐藏具体的值
for value_field in ['Value', 'value', 'variable_value']:
if value_field in filtered_item:
filtered_item[value_field] = '*** HIDDEN ***'
filtered_results.append(filtered_item)
return filtered_results
def register_info_tools(mcp: FastMCP):
"""
注册MySQL数据库信息查询工具到MCP服务器
Args:
mcp: FastMCP服务器实例
"""
logger.debug("注册MySQL数据库信息查询工具...")
@mcp.tool()
@MetadataToolBase.handle_query_error
async def mysql_show_databases(pattern: Optional[str] = None, limit: int = 100, exclude_system: bool = True) -> str:
"""
获取所有数据库列表,支持筛选和限制结果数量
Args:
pattern: 数据库名称匹配模式 (可选, 例如 '%test%')
limit: 返回结果的最大数量 (默认100设为0表示无限制)
exclude_system: 是否排除系统数据库 (默认为True)
Returns:
数据库列表的JSON字符串
"""
# 参数验证
if pattern:
MetadataToolBase.validate_parameter(
"pattern", pattern,
lambda x: re.match(r'^[a-zA-Z0-9_%]+$', x),
"模式只能包含字母、数字、下划线和通配符(%_)"
)
MetadataToolBase.validate_parameter(
"limit", limit,
lambda x: isinstance(x, int) and x >= 0,
"返回结果的最大数量必须是非负整数"
)
# 构建基础查询
query = "SHOW DATABASES"
# 执行查询
with get_db_connection() as connection:
# 先获取所有数据库
results = await execute_query(connection, query)
# 通常结果中每个数据库名会在"Database"字段
db_field = next((k for k in results[0].keys() if k.lower() == 'database'), None) if results else None
if not db_field:
logger.warning("查询结果未找到数据库名称字段")
return MetadataToolBase.format_results(results, operation_type="数据库列表查询")
# 对结果进行过滤
filtered_results = []
system_dbs = ['information_schema', 'mysql', 'performance_schema', 'sys']
for item in results:
db_name = item[db_field]
# 排除系统数据库
if exclude_system and db_name.lower() in system_dbs:
continue
# 根据模式过滤
if pattern:
pattern_regex = pattern.replace('%', '.*').replace('_', '.')
if not re.search(pattern_regex, db_name, re.IGNORECASE):
continue
filtered_results.append(item)
# 限制返回数量
if limit > 0 and len(filtered_results) > limit:
filtered_results = filtered_results[:limit]
logger.debug(f"结果数量已限制为前{limit}")
# 返回结果
metadata_info = {
"metadata_info": {
"operation_type": "数据库列表查询",
"result_count": len(filtered_results),
"total_count": len(results),
"filtered": {
"pattern": pattern,
"exclude_system": exclude_system,
"limited": len(filtered_results) < len(results)
}
},
"results": filtered_results
}
return json.dumps(metadata_info, default=str)
@mcp.tool()
@MetadataToolBase.handle_query_error
async def mysql_show_variables(pattern: Optional[str] = None, global_scope: bool = False) -> str:
"""
获取MySQL系统变量
Args:
pattern: 变量名称匹配模式 (可选, 例如 '%buffer%')
global_scope: 是否查询全局变量 (默认为会话变量)
Returns:
系统变量的JSON字符串
"""
# 获取当前环境类型
from src.security.sql_analyzer import sql_analyzer
env_type = sql_analyzer.env_type
# 检查环境权限
if not check_environment_permission(env_type, 'variables'):
raise SecurityError("当前环境不允许查询系统变量")
# 参数验证
if pattern:
MetadataToolBase.validate_parameter(
"pattern", pattern,
lambda x: re.match(r'^[a-zA-Z0-9_%]+$', x),
"变量模式只能包含字母、数字、下划线和通配符(%_)"
)
# 构建查询
scope = "GLOBAL" if global_scope else "SESSION"
query = f"SHOW {scope} VARIABLES"
if pattern:
query += f" LIKE '{pattern}'"
logger.debug(f"执行查询: {query}")
with get_db_connection() as connection:
results = await execute_query(connection, query)
# 生产环境中过滤敏感信息
if env_type == EnvironmentType.PRODUCTION:
results = filter_sensitive_info(results)
return MetadataToolBase.format_results(results, operation_type="系统变量查询")
@mcp.tool()
@MetadataToolBase.handle_query_error
async def mysql_show_status(pattern: Optional[str] = None, global_scope: bool = False) -> str:
"""
获取MySQL服务器状态
Args:
pattern: 状态名称匹配模式 (可选, 例如 '%conn%')
global_scope: 是否查询全局状态 (默认为会话状态)
Returns:
服务器状态的JSON字符串
"""
# 获取当前环境类型
from src.security.sql_analyzer import sql_analyzer
env_type = sql_analyzer.env_type
# 检查环境权限
if not check_environment_permission(env_type, 'status'):
raise SecurityError("当前环境不允许查询系统状态")
# 参数验证
if pattern:
MetadataToolBase.validate_parameter(
"pattern", pattern,
lambda x: re.match(r'^[a-zA-Z0-9_%]+$', x),
"状态模式只能包含字母、数字、下划线和通配符(%_)"
)
# 构建查询
scope = "GLOBAL" if global_scope else "SESSION"
query = f"SHOW {scope} STATUS"
if pattern:
query += f" LIKE '{pattern}'"
logger.debug(f"执行查询: {query}")
with get_db_connection() as connection:
results = await execute_query(connection, query)
# 生产环境中过滤敏感信息
if env_type == EnvironmentType.PRODUCTION:
results = filter_sensitive_info(results)
return MetadataToolBase.format_results(results, operation_type="服务器状态查询")
# 工具函数: 用于参数验证
def validate_pattern(pattern: str) -> bool:
"""
验证模式字符串是否安全 (防止SQL注入)
Args:
pattern: 要验证的模式字符串
Returns:
如果模式安全返回True否则抛出ValueError
Raises:
ValueError: 当模式包含不安全字符时
"""
# 仅允许字母、数字、下划线和通配符(% 和 _)
if not re.match(r'^[a-zA-Z0-9_%]+$', pattern):
raise ValueError("模式只能包含字母、数字、下划线和通配符(%_)")
return True
def validate_engine_name(name: str) -> bool:
"""
验证存储引擎名称是否合法安全
Args:
name: 要验证的引擎名称
Returns:
如果引擎名称安全返回True否则抛出ValueError
Raises:
ValueError: 当引擎名称包含不安全字符时
"""
# 仅允许字母、数字和下划线
if not re.match(r'^[a-zA-Z0-9_]+$', name):
raise ValueError(f"无效的引擎名称: {name}, 引擎名称只能包含字母、数字和下划线")
return True

View File

@ -0,0 +1,272 @@
"""
MySQL元数据查询工具
提供表结构等元数据信息查询功能
"""
import json
import logging
import re
from typing import Any, Dict, List, Optional
from mcp.server.fastmcp import FastMCP
from .metadata_base_tool import MetadataToolBase, ParameterValidationError, QueryExecutionError
from src.db.mysql_operations import get_db_connection, execute_query
logger = logging.getLogger("mysql_server")
# 工具函数: 用于参数验证
def validate_pattern(pattern: str) -> bool:
"""
验证模式字符串是否安全 (防止SQL注入)
Args:
pattern: 要验证的模式字符串
Returns:
如果模式安全返回True否则抛出ValueError
Raises:
ValueError: 当模式包含不安全字符时
"""
# 仅允许字母、数字、下划线和通配符(% 和 _)
if not re.match(r'^[a-zA-Z0-9_%]+$', pattern):
raise ValueError("模式只能包含字母、数字、下划线和通配符(%_)")
return True
def validate_table_name(name: str) -> bool:
"""
验证表名是否合法安全
Args:
name: 要验证的表名
Returns:
如果表名安全返回True否则抛出ValueError
Raises:
ValueError: 当表名包含不安全字符时
"""
# 仅允许字母、数字和下划线
if not re.match(r'^[a-zA-Z0-9_]+$', name):
raise ValueError(f"无效的表名: {name}, 表名只能包含字母、数字和下划线")
return True
def validate_database_name(name: str) -> bool:
"""
验证数据库名是否合法安全
Args:
name: 要验证的数据库名
Returns:
如果数据库名安全返回True否则抛出ValueError
Raises:
ValueError: 当数据库名包含不安全字符时
"""
# 仅允许字母、数字和下划线
if not re.match(r'^[a-zA-Z0-9_]+$', name):
raise ValueError(f"无效的数据库名: {name}, 数据库名只能包含字母、数字和下划线")
return True
def register_metadata_tools(mcp: FastMCP):
"""
注册MySQL元数据查询工具到MCP服务器
Args:
mcp: FastMCP服务器实例
"""
logger.debug("注册MySQL元数据查询工具...")
@mcp.tool()
@MetadataToolBase.handle_query_error
async def mysql_show_tables(database: Optional[str] = None, pattern: Optional[str] = None,
limit: int = 100, exclude_views: bool = False) -> str:
"""
获取数据库中的表列表,支持筛选和限制结果数量
Args:
database: 数据库名称 (可选,默认使用当前连接的数据库)
pattern: 表名匹配模式 (可选, 例如 '%user%')
limit: 返回结果的最大数量 (默认100设为0表示无限制)
exclude_views: 是否排除视图 (默认为False)
Returns:
表列表的JSON字符串
"""
# 参数验证
if database:
MetadataToolBase.validate_parameter(
"database", database,
lambda x: re.match(r'^[a-zA-Z0-9_]+$', x),
"数据库名称只能包含字母、数字和下划线"
)
if pattern:
MetadataToolBase.validate_parameter(
"pattern", pattern,
lambda x: re.match(r'^[a-zA-Z0-9_%]+$', x),
"模式只能包含字母、数字、下划线和通配符(%_)"
)
MetadataToolBase.validate_parameter(
"limit", limit,
lambda x: isinstance(x, int) and x >= 0,
"返回结果的最大数量必须是非负整数"
)
# 基础查询
base_query = "SHOW FULL TABLES" if exclude_views else "SHOW TABLES"
if database:
base_query += f" FROM `{database}`"
if pattern:
base_query += f" LIKE '{pattern}'"
logger.debug(f"执行查询: {base_query}")
# 执行查询
with get_db_connection() as connection:
results = await execute_query(connection, base_query)
# 如果需要排除视图且使用的是SHOW FULL TABLES
if exclude_views and "FULL" in base_query:
filtered_results = []
# 查找表名和表类型字段
fields = list(results[0].keys()) if results else []
table_field = fields[0] if fields else None
table_type_field = fields[1] if len(fields) > 1 else None
if table_field and table_type_field:
# 基表类型通常是"BASE TABLE"
for item in results:
if item[table_type_field] == 'BASE TABLE':
filtered_results.append(item)
else:
filtered_results = results
else:
filtered_results = results
# 限制返回数量
if limit > 0 and len(filtered_results) > limit:
limited_results = filtered_results[:limit]
is_limited = True
else:
limited_results = filtered_results
is_limited = False
# 构造元数据
metadata_info = {
"metadata_info": {
"operation_type": "表列表查询",
"result_count": len(limited_results),
"total_count": len(results),
"filtered": {
"database": database,
"pattern": pattern,
"exclude_views": exclude_views and "FULL" in base_query,
"limited": is_limited
}
},
"results": limited_results
}
return json.dumps(metadata_info, default=str)
@mcp.tool()
@MetadataToolBase.handle_query_error
async def mysql_show_columns(table: str, database: Optional[str] = None) -> str:
"""
获取表的列信息
Args:
table: 表名
database: 数据库名称 (可选,默认使用当前连接的数据库)
Returns:
表列信息的JSON字符串
"""
# 参数验证
MetadataToolBase.validate_parameter(
"table", table,
lambda x: re.match(r'^[a-zA-Z0-9_]+$', x),
"表名只能包含字母、数字和下划线"
)
if database:
MetadataToolBase.validate_parameter(
"database", database,
lambda x: re.match(r'^[a-zA-Z0-9_]+$', x),
"数据库名称只能包含字母、数字和下划线"
)
query = f"SHOW COLUMNS FROM `{table}`" if not database else f"SHOW COLUMNS FROM `{database}`.`{table}`"
logger.debug(f"执行查询: {query}")
return await MetadataToolBase.execute_metadata_query(query, operation_type="表列信息查询")
@mcp.tool()
@MetadataToolBase.handle_query_error
async def mysql_describe_table(table: str, database: Optional[str] = None) -> str:
"""
描述表结构
Args:
table: 表名
database: 数据库名称 (可选,默认使用当前连接的数据库)
Returns:
表结构描述的JSON字符串
"""
# 参数验证
MetadataToolBase.validate_parameter(
"table", table,
lambda x: re.match(r'^[a-zA-Z0-9_]+$', x),
"表名只能包含字母、数字和下划线"
)
if database:
MetadataToolBase.validate_parameter(
"database", database,
lambda x: re.match(r'^[a-zA-Z0-9_]+$', x),
"数据库名称只能包含字母、数字和下划线"
)
# DESCRIBE 语句与 SHOW COLUMNS 功能类似,但结果格式可能略有不同
query = f"DESCRIBE `{table}`" if not database else f"DESCRIBE `{database}`.`{table}`"
logger.debug(f"执行查询: {query}")
return await MetadataToolBase.execute_metadata_query(query, operation_type="表结构描述")
@mcp.tool()
@MetadataToolBase.handle_query_error
async def mysql_show_create_table(table: str, database: Optional[str] = None) -> str:
"""
获取表的创建语句
Args:
table: 表名
database: 数据库名称 (可选,默认使用当前连接的数据库)
Returns:
表创建语句的JSON字符串
"""
# 参数验证
MetadataToolBase.validate_parameter(
"table", table,
lambda x: re.match(r'^[a-zA-Z0-9_]+$', x),
"表名只能包含字母、数字和下划线"
)
if database:
MetadataToolBase.validate_parameter(
"database", database,
lambda x: re.match(r'^[a-zA-Z0-9_]+$', x),
"数据库名称只能包含字母、数字和下划线"
)
table_ref = f"`{table}`" if not database else f"`{database}`.`{table}`"
query = f"SHOW CREATE TABLE {table_ref}"
logger.debug(f"执行查询: {query}")
return await MetadataToolBase.execute_metadata_query(query, operation_type="表创建语句查询")

View File

@ -0,0 +1,294 @@
"""
MySQL表结构高级查询工具
提供索引、约束、表状态等高级元数据查询功能
"""
import json
import logging
import re
import os
from typing import Any, Dict, List, Optional, Union
from mcp.server.fastmcp import FastMCP
from .metadata_base_tool import MetadataToolBase, ParameterValidationError, QueryExecutionError
from src.db.mysql_operations import get_db_connection, execute_query
logger = logging.getLogger("mysql_server")
# 参数验证函数
def validate_table_name(name: str) -> bool:
"""
验证表名是否合法安全
Args:
name: 要验证的表名
Returns:
如果表名安全返回True否则抛出ValueError
Raises:
ValueError: 当表名包含不安全字符时
"""
# 仅允许字母、数字和下划线
if not re.match(r'^[a-zA-Z0-9_]+$', name):
raise ValueError(f"无效的表名: {name}, 表名只能包含字母、数字和下划线")
return True
def validate_database_name(name: str) -> bool:
"""
验证数据库名是否合法安全
Args:
name: 要验证的数据库名
Returns:
如果数据库名安全返回True否则抛出ValueError
Raises:
ValueError: 当数据库名包含不安全字符时
"""
# 仅允许字母、数字和下划线
if not re.match(r'^[a-zA-Z0-9_]+$', name):
raise ValueError(f"无效的数据库名: {name}, 数据库名只能包含字母、数字和下划线")
return True
def validate_column_name(name: str) -> bool:
"""
验证列名是否合法安全
Args:
name: 要验证的列名
Returns:
如果列名安全返回True否则抛出ValueError
Raises:
ValueError: 当列名包含不安全字符时
"""
# 仅允许字母、数字和下划线
if not re.match(r'^[a-zA-Z0-9_]+$', name):
raise ValueError(f"无效的列名: {name}, 列名只能包含字母、数字和下划线")
return True
async def execute_schema_query(
query: str,
params: Optional[Dict[str, Any]] = None,
operation_type: str = "元数据查询"
) -> str:
"""
执行表结构查询
Args:
query: SQL查询语句
params: 查询参数 (可选)
operation_type: 操作类型描述
Returns:
查询结果的JSON字符串
"""
with get_db_connection() as connection:
results = await execute_query(connection, query, params)
return MetadataToolBase.format_results(results, operation_type)
def register_schema_tools(mcp: FastMCP):
"""
注册MySQL表结构高级查询工具到MCP服务器
Args:
mcp: FastMCP服务器实例
"""
logger.debug("注册MySQL表结构高级查询工具...")
@mcp.tool()
@MetadataToolBase.handle_query_error
async def mysql_show_indexes(table: str, database: Optional[str] = None) -> str:
"""
获取表的索引信息
Args:
table: 表名
database: 数据库名称 (可选,默认使用当前连接的数据库)
Returns:
表索引信息的JSON字符串
"""
# 参数验证
validate_table_name(table)
if database:
validate_database_name(database)
# 构建查询
table_ref = f"`{table}`" if not database else f"`{database}`.`{table}`"
query = f"SHOW INDEX FROM {table_ref}"
logger.debug(f"执行查询: {query}")
# 执行查询
return await execute_schema_query(query, operation_type="表索引查询")
@mcp.tool()
@MetadataToolBase.handle_query_error
async def mysql_show_table_status(database: Optional[str] = None, like_pattern: Optional[str] = None) -> str:
"""
获取表状态信息
Args:
database: 数据库名称 (可选,默认使用当前连接的数据库)
like_pattern: 表名匹配模式 (可选,例如 '%user%')
Returns:
表状态信息的JSON字符串
"""
# 参数验证
if database:
validate_database_name(database)
if like_pattern:
validate_column_name(like_pattern)
# 构建查询
if database:
query = f"SHOW TABLE STATUS FROM `{database}`"
else:
query = "SHOW TABLE STATUS"
if like_pattern:
query += f" LIKE '{like_pattern}'"
logger.debug(f"执行查询: {query}")
# 执行查询
return await execute_schema_query(query, operation_type="表状态查询")
@mcp.tool()
@MetadataToolBase.handle_query_error
async def mysql_show_foreign_keys(table: str, database: Optional[str] = None) -> str:
"""
获取表的外键约束信息
Args:
table: 表名
database: 数据库名称 (可选,默认使用当前连接的数据库)
Returns:
表外键约束信息的JSON字符串
"""
# 参数验证
validate_table_name(table)
if database:
validate_database_name(database)
# 确定数据库名
db_name = database
if not db_name:
# 获取当前数据库
with get_db_connection() as connection:
current_db_results = await execute_query(connection, "SELECT DATABASE() as db")
if current_db_results and 'db' in current_db_results[0]:
db_name = current_db_results[0]['db']
if not db_name:
raise ValueError("无法确定数据库名称请明确指定database参数")
# 使用INFORMATION_SCHEMA查询外键
query = """
SELECT
CONSTRAINT_NAME,
TABLE_NAME,
COLUMN_NAME,
REFERENCED_TABLE_NAME,
REFERENCED_COLUMN_NAME,
UPDATE_RULE,
DELETE_RULE
FROM
INFORMATION_SCHEMA.KEY_COLUMN_USAGE kcu
JOIN
INFORMATION_SCHEMA.REFERENTIAL_CONSTRAINTS rc
ON
kcu.CONSTRAINT_NAME = rc.CONSTRAINT_NAME
WHERE
kcu.TABLE_SCHEMA = %s
AND kcu.TABLE_NAME = %s
AND kcu.REFERENCED_TABLE_NAME IS NOT NULL
"""
params = {'TABLE_SCHEMA': db_name, 'TABLE_NAME': table}
logger.debug(f"执行查询: 获取表 {db_name}.{table} 的外键约束")
# 执行查询
return await execute_schema_query(query, params, operation_type="外键约束查询")
@mcp.tool()
@MetadataToolBase.handle_query_error
async def mysql_paginate_results(query: str, page: int = 1, page_size: int = 50) -> str:
"""
分页执行查询以处理大型结果集
Args:
query: SQL查询语句
page: 页码 (从1开始)
page_size: 每页记录数 (默认50)
Returns:
分页结果的JSON字符串
"""
# 参数验证
MetadataToolBase.validate_parameter(
"page", page,
lambda x: isinstance(x, int) and x > 0,
"页码必须是正整数"
)
MetadataToolBase.validate_parameter(
"page_size", page_size,
lambda x: isinstance(x, int) and 1 <= x <= 1000,
"每页记录数必须在1-1000之间"
)
# 检查查询语法
if not query.strip().upper().startswith('SELECT'):
raise ValueError("只支持SELECT查询的分页")
# 计算LIMIT和OFFSET
offset = (page - 1) * page_size
# 在查询末尾添加LIMIT子句
paginated_query = query.strip()
if 'LIMIT' in paginated_query.upper():
raise ValueError("查询已包含LIMIT子句请移除后重试")
paginated_query += f" LIMIT {page_size} OFFSET {offset}"
logger.debug(f"执行分页查询: 页码={page}, 每页记录数={page_size}")
# 获取总记录数(用于计算总页数)
count_query = f"SELECT COUNT(*) as total FROM ({query}) as temp_count_table"
with get_db_connection() as connection:
# 执行分页查询
results = await execute_query(connection, paginated_query)
# 获取总记录数
count_results = await execute_query(connection, count_query)
total_records = count_results[0]['total'] if count_results else 0
# 计算总页数
total_pages = (total_records + page_size - 1) // page_size
# 构建分页元数据
pagination_info = {
"metadata_info": {
"operation_type": "分页查询",
"result_count": len(results),
"pagination": {
"page": page,
"page_size": page_size,
"total_records": total_records,
"total_pages": total_pages
}
},
"results": results
}
return json.dumps(pagination_info, default=str)