From e0550a531d498e42b2b9acb208cbeb1896462a1d Mon Sep 17 00:00:00 2001 From: tangyi Date: Tue, 8 Apr 2025 11:40:25 +0800 Subject: [PATCH] =?UTF-8?q?=20=E5=A2=9E=E5=8A=A0=E5=85=83=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E6=93=8D=E4=BD=9C=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 101 +++++++--- src/db/mysql_operations.py | 30 ++- src/security/interceptor.py | 17 +- src/security/sql_analyzer.py | 12 ++ src/server.py | 19 +- src/tools/metadata_base_tool.py | 123 ++++++++++++ src/tools/mysql_info_tool.py | 332 +++++++++++++++++++++++++++++++ src/tools/mysql_metadata_tool.py | 272 +++++++++++++++++++++++++ src/tools/mysql_schema_tool.py | 294 +++++++++++++++++++++++++++ 9 files changed, 1167 insertions(+), 33 deletions(-) create mode 100644 src/tools/metadata_base_tool.py create mode 100644 src/tools/mysql_info_tool.py create mode 100644 src/tools/mysql_metadata_tool.py create mode 100644 src/tools/mysql_schema_tool.py diff --git a/README.md b/README.md index 4a1e3cf..41eb010 100644 --- a/README.md +++ b/README.md @@ -16,11 +16,45 @@ - 危险操作拦截 - WHERE子句强制检查 - 自动返回修改操作影响的行数 +- 敏感信息保护机制 +- 自动对元数据查询结果进行格式化和增强 + +## API接口功能 + +系统提供以下四大类工具: + +### 基础查询工具 + +- `mysql_query`: 执行任意SQL查询,支持参数化查询 + +### 元数据查询工具 + +- `mysql_show_tables`: 获取数据库中的表列表,支持模式匹配和限制结果数量 +- `mysql_show_columns`: 获取表的列信息 +- `mysql_describe_table`: 描述表结构 +- `mysql_show_create_table`: 获取表的创建语句 + +### 数据库信息查询工具 + +- `mysql_show_databases`: 获取所有数据库列表,支持过滤系统数据库 +- `mysql_show_variables`: 获取MySQL服务器变量 +- `mysql_show_status`: 获取MySQL服务器状态信息 + +### 表结构高级查询工具 + +- `mysql_show_indexes`: 获取表的索引信息 +- `mysql_show_table_status`: 获取表状态信息 +- `mysql_show_foreign_keys`: 获取表的外键约束信息 +- `mysql_paginate_results`: 提供结果分页功能 ## 系统要求 - Python 3.6+ - MySQL服务器 +- 依赖包: + - mysql-connector-python + - python-dotenv + - mcp (FastMCP框架) ## 安装步骤 @@ -57,34 +91,33 @@ pip install -r requirements.txt - `ALLOWED_RISK_LEVELS`: 允许的风险等级(LOW/MEDIUM/HIGH/CRITICAL) - `BLOCKED_PATTERNS`: 禁止的SQL模式(正则表达式,用逗号分隔) - `ENABLE_QUERY_CHECK`: 是否启用SQL安全检查(true/false) +- `ALLOW_SENSITIVE_INFO`: 是否允许查询敏感信息(true/false) +- `SENSITIVE_INFO_FIELDS`: 自定义敏感字段模式列表(逗号分隔) -## SQL安全机制 +## 安全机制详解 ### 风险等级控制 -- LOW: 查询操作(SELECT) +- LOW: 查询操作(SELECT)和元数据操作(SHOW, DESCRIBE等) - MEDIUM: 基本数据修改(INSERT,有WHERE的UPDATE/DELETE) - HIGH: 结构变更(CREATE/ALTER)和无WHERE的UPDATE - CRITICAL: 危险操作(DROP/TRUNCATE)和无WHERE的DELETE操作 -### 环境变量加载顺序 -项目使用python-dotenv加载环境变量,需要确保在导入其他模块前先加载环境变量,否则可能会导致配置未被正确应用。 +### 环境特性差异 +- 开发环境: + - 允许较高风险的操作 + - 不隐藏敏感信息 + - 提供详细的错误信息 +- 生产环境: + - 默认只允许LOW风险操作 + - 严格限制数据修改 + - 自动隐藏敏感信息 + - 错误信息不暴露实现细节 -### 安全检查机制 -- 强制要求UPDATE/DELETE操作包含WHERE子句 -- SQL语句语法检查 -- 危险操作模式检测 -- 自动识别受影响的表 -- 风险等级评估 - -### 环境特性 -- 开发环境:允许较高风险的操作,但仍需遵守基本安全规则 -- 生产环境:默认只允许LOW风险操作,严格限制数据修改 - -### 安全拦截 -- SQL语句长度限制 -- 危险操作模式检测 -- 自动识别受影响的表 -- SQL注入防护 +### 敏感信息保护 +系统会自动检测并隐藏包含以下关键词的变量/状态值: +- password、auth、credential、key、secret、private +- ssl、tls、cipher、certificate +- host、path、directory等系统路径信息 ### 事务管理 - 对于修改操作(INSERT/UPDATE/DELETE)会自动提交事务 @@ -105,17 +138,23 @@ python src/server.py ``` . -├── src/ # 源代码目录 -│ ├── server.py # 主服务器文件 -│ ├── db/ # 数据库相关代码 -│ ├── security/ # SQL安全相关代码 -│ │ ├── interceptor.py # SQL拦截器 +├── src/ # 源代码目录 +│ ├── server.py # 主服务器文件 +│ ├── db/ # 数据库相关代码 +│ │ └── mysql_operations.py # MySQL操作实现 +│ ├── security/ # SQL安全相关代码 +│ │ ├── interceptor.py # SQL拦截器 │ │ ├── query_limiter.py # SQL安全检查器 │ │ └── sql_analyzer.py # SQL分析器 -│ └── tools/ # 工具类代码 -├── tests/ # 测试代码目录 -├── .env.example # 环境变量示例文件 -└── requirements.txt # 项目依赖文件 +│ └── tools/ # 工具类代码 +│ ├── mysql_tool.py # 基础查询工具 +│ ├── mysql_metadata_tool.py # 元数据查询工具 +│ ├── mysql_info_tool.py # 数据库信息查询工具 +│ ├── mysql_schema_tool.py # 表结构高级查询工具 +│ └── metadata_base_tool.py # 元数据工具基类 +├── tests/ # 测试代码目录 +├── .env.example # 环境变量示例文件 +└── requirements.txt # 项目依赖文件 ``` ## 常见问题解决 @@ -136,6 +175,10 @@ python src/server.py - 如果需要执行高风险操作,相应地调整ALLOWED_RISK_LEVELS - 对于不带WHERE条件的UPDATE或DELETE,可以添加条件(即使是WHERE 1=1)降低风险级别 +### 无法查看敏感信息 +- 在开发环境中,设置ALLOW_SENSITIVE_INFO=true +- 在生产环境中,敏感信息默认会被隐藏,这是安全特性 + ## 日志系统 服务器包含完整的日志记录系统,可以在控制台和日志文件中查看运行状态和错误信息。日志级别可以在`server.py`中配置。 diff --git a/src/db/mysql_operations.py b/src/db/mysql_operations.py index 93e6c05..9e849da 100644 --- a/src/db/mysql_operations.py +++ b/src/db/mysql_operations.py @@ -81,6 +81,7 @@ async def execute_query(connection, query: str, params: Optional[Dict[str, Any]] ValueError: 当查询执行失败时 """ cursor = None + operation = None # 初始化操作类型变量 try: # 安全检查 if not await sql_interceptor.check_operation(query): @@ -105,6 +106,33 @@ async def execute_query(connection, query: str, params: Optional[Dict[str, Any]] logger.debug(f"修改操作 {operation} 影响了 {affected_rows} 行数据") return [{'affected_rows': affected_rows}] + # 处理元数据查询操作 + if operation in sql_analyzer.metadata_operations: + # 获取结果集 + results = cursor.fetchall() + + # 没有结果时返回空列表但添加元信息 + if not results: + logger.debug(f"元数据查询 {operation} 没有返回结果") + return [{'metadata_operation': operation, 'result_count': 0}] + + # 优化结果格式 - 为元数据结果添加额外信息 + metadata_results = [] + for row in results: + # 对某些特定元数据查询进行特殊处理 + if operation == 'SHOW' and 'Table' in row: + # SHOW TABLES 结果增强 + row['table_name'] = row['Table'] + elif operation in {'DESC', 'DESCRIBE'} and 'Field' in row: + # DESC/DESCRIBE 表结构结果增强 + row['column_name'] = row['Field'] + row['data_type'] = row['Type'] + + metadata_results.append(row) + + logger.debug(f"元数据查询 {operation} 返回 {len(metadata_results)} 条结果") + return metadata_results + # 对于查询操作,返回结果集 results = cursor.fetchall() logger.debug(f"查询返回 {len(results)} 条结果") @@ -115,7 +143,7 @@ async def execute_query(connection, query: str, params: Optional[Dict[str, Any]] raise except mysql.connector.Error as query_err: # 如果发生错误,进行回滚 - if operation in {'UPDATE', 'DELETE', 'INSERT'}: + if operation and operation in {'UPDATE', 'DELETE', 'INSERT'}: # 确保operation已定义 try: connection.rollback() logger.debug("事务已回滚") diff --git a/src/security/interceptor.py b/src/security/interceptor.py index 6ae70b0..e130027 100644 --- a/src/security/interceptor.py +++ b/src/security/interceptor.py @@ -46,7 +46,15 @@ class SQLInterceptor: raise SecurityException("SQL语句格式无效") operation = sql_parts[0].upper() - if operation not in {'SELECT', 'INSERT', 'UPDATE', 'DELETE', 'CREATE', 'ALTER', 'DROP', 'TRUNCATE', 'MERGE'}: + # 更新支持的操作类型列表,包括元数据操作 + supported_operations = { + 'SELECT', 'INSERT', 'UPDATE', 'DELETE', + 'CREATE', 'ALTER', 'DROP', 'TRUNCATE', 'MERGE', + 'SHOW', 'DESC', 'DESCRIBE', 'EXPLAIN', 'HELP', + 'ANALYZE', 'CHECK', 'CHECKSUM', 'OPTIMIZE' + } + + if operation not in supported_operations: raise SecurityException(f"不支持的SQL操作: {operation}") # 分析SQL风险 @@ -65,9 +73,14 @@ class SQLInterceptor: f"允许的风险等级: {[level.name for level in self.analyzer.allowed_risk_levels]}" ) + # 确定操作类型(DDL, DML 或 元数据) + operation_category = "元数据操作" if operation in self.analyzer.metadata_operations else ( + "DDL操作" if operation in self.analyzer.ddl_operations else "DML操作" + ) + # 记录详细日志 logger.info( - f"SQL操作检查通过 - " + f"SQL{operation_category}检查通过 - " f"操作: {risk_analysis['operation']}, " f"风险等级: {risk_analysis['risk_level'].name}, " f"影响表: {', '.join(risk_analysis['affected_tables'])}" diff --git a/src/security/sql_analyzer.py b/src/security/sql_analyzer.py index 4f99aa8..879c785 100644 --- a/src/security/sql_analyzer.py +++ b/src/security/sql_analyzer.py @@ -38,6 +38,12 @@ class SQLOperationType: 'SELECT', 'INSERT', 'UPDATE', 'DELETE', 'MERGE' } + # 添加元数据操作集合 + self.metadata_operations = { + 'SHOW', 'DESC', 'DESCRIBE', 'EXPLAIN', 'HELP', + 'ANALYZE', 'CHECK', 'CHECKSUM', 'OPTIMIZE' + } + # 风险等级配置 self.allowed_risk_levels = self._parse_risk_levels() self.blocked_patterns = self._parse_blocked_patterns('BLOCKED_PATTERNS') @@ -132,11 +138,17 @@ class SQLOperationType: - UPDATE/DELETE(有WHERE)=> MEDIUM - UPDATE(无WHERE)=> HIGH - DELETE(无WHERE)=> CRITICAL + 4. 元数据操作: + - SHOW/DESC/DESCRIBE等 => LOW """ # 危险操作 if is_dangerous: return SQLRiskLevel.CRITICAL + # 元数据操作 + if operation in self.metadata_operations: + return SQLRiskLevel.LOW # 元数据查询视为低风险操作 + # 生产环境中非SELECT操作 if self.env_type == EnvironmentType.PRODUCTION and operation != 'SELECT': return SQLRiskLevel.CRITICAL diff --git a/src/server.py b/src/server.py index d521b27..85c29c6 100644 --- a/src/server.py +++ b/src/server.py @@ -8,6 +8,9 @@ load_dotenv() # 导入自定义模块 - 确保在load_dotenv之后导入 from src.tools.mysql_tool import register_mysql_tool +from src.tools.mysql_metadata_tool import register_metadata_tools +from src.tools.mysql_info_tool import register_info_tools +from src.tools.mysql_schema_tool import register_schema_tools # 配置日志 logging.basicConfig( @@ -19,6 +22,8 @@ logger = logging.getLogger("mysql_server") # 记录环境变量加载情况 logger.debug("已加载环境变量") logger.debug(f"当前允许的风险等级: {os.getenv('ALLOWED_RISK_LEVELS', '未设置')}") +logger.debug(f"当前环境类型: {os.getenv('ENV_TYPE', 'development')}") +logger.debug(f"是否允许敏感信息查询: {os.getenv('ALLOW_SENSITIVE_INFO', 'false')}") # 尝试导入MySQL连接器 try: @@ -40,9 +45,21 @@ logger.debug("正在创建MCP服务器实例...") mcp = FastMCP("MySQL Query Server", "cccccccccc", host=host, port=port, debug=True, endpoint='/sse') logger.debug("MCP服务器实例创建完成") -# 注册MySQL工具 +# 注册MySQL基础查询工具 register_mysql_tool(mcp) +# 注册MySQL元数据查询工具 +register_metadata_tools(mcp) +logger.debug("已注册元数据查询工具") + +# 注册MySQL数据库信息查询工具 +register_info_tools(mcp) +logger.debug("已注册数据库信息查询工具") + +# 注册MySQL表结构高级查询工具 +register_schema_tools(mcp) +logger.debug("已注册表结构高级查询工具") + def start_server(): """启动SSE服务器的同步包装器""" logger.debug("开始启动MySQL查询服务器...") diff --git a/src/tools/metadata_base_tool.py b/src/tools/metadata_base_tool.py new file mode 100644 index 0000000..78ac8d4 --- /dev/null +++ b/src/tools/metadata_base_tool.py @@ -0,0 +1,123 @@ +""" +MySQL元数据工具基类 +提供元数据查询工具的共享功能 +""" + +import json +import logging +from typing import Any, Dict, List, Optional, Union, TypeVar, Generic, Callable +import functools + +from src.db.mysql_operations import get_db_connection, execute_query + +T = TypeVar('T') +logger = logging.getLogger("mysql_server") + +class MySQLToolError(Exception): + """MySQL工具异常基类""" + pass + +class ParameterValidationError(MySQLToolError): + """参数验证错误""" + pass + +class QueryExecutionError(MySQLToolError): + """查询执行错误""" + pass + +class MetadataToolBase: + """ + 元数据查询工具基类 + 提供共享功能: + - 错误处理 + - 结果格式化 + - 参数验证 + - 异常处理 + """ + + @staticmethod + def validate_parameter(param_name: str, param_value: Any, validator: Callable[[Any], bool], + error_message: str) -> None: + """ + 验证参数是否有效 + + Args: + param_name: 参数名称 + param_value: 参数值 + validator: 验证函数 + error_message: 错误消息 + + Raises: + ParameterValidationError: 当参数验证失败时 + """ + if param_value is not None and not validator(param_value): + raise ParameterValidationError(f"{param_name} - {error_message}") + + @staticmethod + def format_results(results: List[Dict[str, Any]], operation_type: str = "元数据查询") -> str: + """ + 格式化查询结果为JSON字符串 + + Args: + results: 查询结果列表 + operation_type: 操作类型描述 + + Returns: + 格式化后的JSON字符串 + """ + try: + # 添加元数据头部信息 + metadata_info = { + "metadata_info": { + "operation_type": operation_type, + "result_count": len(results) + }, + "results": results + } + return json.dumps(metadata_info, default=str) + except Exception as e: + logger.error(f"结果格式化失败: {str(e)}") + # 如果格式化失败,尝试直接序列化结果 + return json.dumps({"error": f"结果格式化失败: {str(e)}"}) + + @staticmethod + def handle_query_error(func): + """ + 装饰器: 统一处理查询执行过程中的错误 + """ + @functools.wraps(func) + async def wrapper(*args, **kwargs): + try: + return await func(*args, **kwargs) + except ParameterValidationError as e: + logger.error(f"参数验证错误: {str(e)}") + return json.dumps({"error": f"参数错误: {str(e)}"}) + except QueryExecutionError as e: + logger.error(f"查询执行错误: {str(e)}") + return json.dumps({"error": f"查询执行失败: {str(e)}"}) + except Exception as e: + logger.error(f"未预期的错误: {str(e)}") + return json.dumps({"error": f"操作失败: {str(e)}"}) + return wrapper + + @staticmethod + async def execute_metadata_query(query: str, params: Optional[Dict[str, Any]] = None, + operation_type: str = "元数据查询") -> str: + """ + 执行元数据查询并返回格式化结果 + + Args: + query: SQL查询语句 + params: 查询参数 (可选) + operation_type: 操作类型描述 + + Returns: + 查询结果的JSON字符串 + """ + try: + with get_db_connection() as connection: + results = await execute_query(connection, query, params) + return MetadataToolBase.format_results(results, operation_type) + except Exception as e: + logger.error(f"元数据查询执行失败: {str(e)}") + raise QueryExecutionError(str(e)) \ No newline at end of file diff --git a/src/tools/mysql_info_tool.py b/src/tools/mysql_info_tool.py new file mode 100644 index 0000000..8a0849b --- /dev/null +++ b/src/tools/mysql_info_tool.py @@ -0,0 +1,332 @@ +""" +MySQL数据库信息查询工具 +提供数据库、变量和状态等系统信息查询功能 +""" + +import json +import logging +import re +import os +from typing import Any, Dict, List, Optional +from mcp.server.fastmcp import FastMCP + +from .metadata_base_tool import MetadataToolBase, ParameterValidationError, QueryExecutionError +from src.security.sql_analyzer import EnvironmentType +from src.db.mysql_operations import get_db_connection, execute_query + +logger = logging.getLogger("mysql_server") + +# 自定义异常类 +class SecurityError(QueryExecutionError): + """安全限制错误""" + pass + +# 从环境变量读取敏感字段列表 +def get_sensitive_patterns(): + """从环境变量获取敏感字段模式列表""" + default_patterns = [ + r'password', r'auth', r'credential', r'key', r'secret', r'private', + r'host', r'path', r'directory', r'ssl', r'iptables', r'filter' + ] + + env_patterns = os.getenv('SENSITIVE_INFO_FIELDS', '') + if env_patterns: + # 合并自定义模式和默认模式 + patterns = [pattern.strip() for pattern in env_patterns.split(',') if pattern.strip()] + return list(set(patterns + default_patterns)) + + return default_patterns + +# 敏感变量和状态关键字列表 +SENSITIVE_VARIABLE_PATTERNS = get_sensitive_patterns() + +# 敏感变量名前缀,生产环境中这些变量的值会被隐藏 +SENSITIVE_VARIABLE_PREFIXES = [ + "password", "auth", "secret", "key", "certificate", "ssl", "tls", "cipher", + "authentication", "secure", "credential", "token" +] + +def check_environment_permission(env_type: EnvironmentType, query_type: str) -> bool: + """ + 检查当前环境是否允许执行特定类型的查询 + + Args: + env_type: 环境类型(开发/生产) + query_type: 查询类型 + + Returns: + bool: 是否允许执行 + """ + # 开发环境不限制查询 + if env_type == EnvironmentType.DEVELOPMENT: + return True + + # 生产环境限制敏感信息查询 + sensitive_queries = ['variables', 'status', 'processlist'] + + if query_type in sensitive_queries: + # 检查是否在环境变量中明确允许 + allow_sensitive = os.getenv('ALLOW_SENSITIVE_INFO', 'false').lower() == 'true' + if not allow_sensitive: + logger.warning(f"生产环境中禁止执行敏感查询: {query_type}") + return False + + return True + +def filter_sensitive_info(results: List[Dict[str, Any]], filter_patterns: List[str] = None) -> List[Dict[str, Any]]: + """ + 过滤结果中的敏感信息 + + Args: + results: 查询结果 + filter_patterns: 敏感信息的正则表达式模式列表 + + Returns: + 过滤后的结果列表 + """ + if not filter_patterns: + filter_patterns = SENSITIVE_VARIABLE_PATTERNS + + filtered_results = [] + for item in results: + # 复制一份,避免修改原始数据 + filtered_item = item.copy() + + # 检查常见的变量名字段 + for field in ['Variable_name', 'variable_name', 'name']: + if field in filtered_item: + var_name = filtered_item[field].lower() + # 检查是否匹配敏感模式 + is_sensitive = any(re.search(pattern, var_name, re.IGNORECASE) for pattern in filter_patterns) + + if is_sensitive: + # 敏感信息,隐藏具体的值 + for value_field in ['Value', 'value', 'variable_value']: + if value_field in filtered_item: + filtered_item[value_field] = '*** HIDDEN ***' + + filtered_results.append(filtered_item) + + return filtered_results + +def register_info_tools(mcp: FastMCP): + """ + 注册MySQL数据库信息查询工具到MCP服务器 + + Args: + mcp: FastMCP服务器实例 + """ + logger.debug("注册MySQL数据库信息查询工具...") + + @mcp.tool() + @MetadataToolBase.handle_query_error + async def mysql_show_databases(pattern: Optional[str] = None, limit: int = 100, exclude_system: bool = True) -> str: + """ + 获取所有数据库列表,支持筛选和限制结果数量 + + Args: + pattern: 数据库名称匹配模式 (可选, 例如 '%test%') + limit: 返回结果的最大数量 (默认100,设为0表示无限制) + exclude_system: 是否排除系统数据库 (默认为True) + + Returns: + 数据库列表的JSON字符串 + """ + # 参数验证 + if pattern: + MetadataToolBase.validate_parameter( + "pattern", pattern, + lambda x: re.match(r'^[a-zA-Z0-9_%]+$', x), + "模式只能包含字母、数字、下划线和通配符(%_)" + ) + + MetadataToolBase.validate_parameter( + "limit", limit, + lambda x: isinstance(x, int) and x >= 0, + "返回结果的最大数量必须是非负整数" + ) + + # 构建基础查询 + query = "SHOW DATABASES" + + # 执行查询 + with get_db_connection() as connection: + # 先获取所有数据库 + results = await execute_query(connection, query) + + # 通常结果中每个数据库名会在"Database"字段 + db_field = next((k for k in results[0].keys() if k.lower() == 'database'), None) if results else None + + if not db_field: + logger.warning("查询结果未找到数据库名称字段") + return MetadataToolBase.format_results(results, operation_type="数据库列表查询") + + # 对结果进行过滤 + filtered_results = [] + system_dbs = ['information_schema', 'mysql', 'performance_schema', 'sys'] + + for item in results: + db_name = item[db_field] + + # 排除系统数据库 + if exclude_system and db_name.lower() in system_dbs: + continue + + # 根据模式过滤 + if pattern: + pattern_regex = pattern.replace('%', '.*').replace('_', '.') + if not re.search(pattern_regex, db_name, re.IGNORECASE): + continue + + filtered_results.append(item) + + # 限制返回数量 + if limit > 0 and len(filtered_results) > limit: + filtered_results = filtered_results[:limit] + logger.debug(f"结果数量已限制为前{limit}个") + + # 返回结果 + metadata_info = { + "metadata_info": { + "operation_type": "数据库列表查询", + "result_count": len(filtered_results), + "total_count": len(results), + "filtered": { + "pattern": pattern, + "exclude_system": exclude_system, + "limited": len(filtered_results) < len(results) + } + }, + "results": filtered_results + } + + return json.dumps(metadata_info, default=str) + + @mcp.tool() + @MetadataToolBase.handle_query_error + async def mysql_show_variables(pattern: Optional[str] = None, global_scope: bool = False) -> str: + """ + 获取MySQL系统变量 + + Args: + pattern: 变量名称匹配模式 (可选, 例如 '%buffer%') + global_scope: 是否查询全局变量 (默认为会话变量) + + Returns: + 系统变量的JSON字符串 + """ + # 获取当前环境类型 + from src.security.sql_analyzer import sql_analyzer + env_type = sql_analyzer.env_type + + # 检查环境权限 + if not check_environment_permission(env_type, 'variables'): + raise SecurityError("当前环境不允许查询系统变量") + + # 参数验证 + if pattern: + MetadataToolBase.validate_parameter( + "pattern", pattern, + lambda x: re.match(r'^[a-zA-Z0-9_%]+$', x), + "变量模式只能包含字母、数字、下划线和通配符(%_)" + ) + + # 构建查询 + scope = "GLOBAL" if global_scope else "SESSION" + query = f"SHOW {scope} VARIABLES" + if pattern: + query += f" LIKE '{pattern}'" + + logger.debug(f"执行查询: {query}") + + with get_db_connection() as connection: + results = await execute_query(connection, query) + + # 生产环境中过滤敏感信息 + if env_type == EnvironmentType.PRODUCTION: + results = filter_sensitive_info(results) + + return MetadataToolBase.format_results(results, operation_type="系统变量查询") + + @mcp.tool() + @MetadataToolBase.handle_query_error + async def mysql_show_status(pattern: Optional[str] = None, global_scope: bool = False) -> str: + """ + 获取MySQL服务器状态 + + Args: + pattern: 状态名称匹配模式 (可选, 例如 '%conn%') + global_scope: 是否查询全局状态 (默认为会话状态) + + Returns: + 服务器状态的JSON字符串 + """ + # 获取当前环境类型 + from src.security.sql_analyzer import sql_analyzer + env_type = sql_analyzer.env_type + + # 检查环境权限 + if not check_environment_permission(env_type, 'status'): + raise SecurityError("当前环境不允许查询系统状态") + + # 参数验证 + if pattern: + MetadataToolBase.validate_parameter( + "pattern", pattern, + lambda x: re.match(r'^[a-zA-Z0-9_%]+$', x), + "状态模式只能包含字母、数字、下划线和通配符(%_)" + ) + + # 构建查询 + scope = "GLOBAL" if global_scope else "SESSION" + query = f"SHOW {scope} STATUS" + if pattern: + query += f" LIKE '{pattern}'" + + logger.debug(f"执行查询: {query}") + + with get_db_connection() as connection: + results = await execute_query(connection, query) + + # 生产环境中过滤敏感信息 + if env_type == EnvironmentType.PRODUCTION: + results = filter_sensitive_info(results) + + return MetadataToolBase.format_results(results, operation_type="服务器状态查询") + +# 工具函数: 用于参数验证 +def validate_pattern(pattern: str) -> bool: + """ + 验证模式字符串是否安全 (防止SQL注入) + + Args: + pattern: 要验证的模式字符串 + + Returns: + 如果模式安全返回True,否则抛出ValueError + + Raises: + ValueError: 当模式包含不安全字符时 + """ + # 仅允许字母、数字、下划线和通配符(% 和 _) + if not re.match(r'^[a-zA-Z0-9_%]+$', pattern): + raise ValueError("模式只能包含字母、数字、下划线和通配符(%_)") + return True + +def validate_engine_name(name: str) -> bool: + """ + 验证存储引擎名称是否合法安全 + + Args: + name: 要验证的引擎名称 + + Returns: + 如果引擎名称安全返回True,否则抛出ValueError + + Raises: + ValueError: 当引擎名称包含不安全字符时 + """ + # 仅允许字母、数字和下划线 + if not re.match(r'^[a-zA-Z0-9_]+$', name): + raise ValueError(f"无效的引擎名称: {name}, 引擎名称只能包含字母、数字和下划线") + return True \ No newline at end of file diff --git a/src/tools/mysql_metadata_tool.py b/src/tools/mysql_metadata_tool.py new file mode 100644 index 0000000..6fbbb7f --- /dev/null +++ b/src/tools/mysql_metadata_tool.py @@ -0,0 +1,272 @@ +""" +MySQL元数据查询工具 +提供表结构等元数据信息查询功能 +""" + +import json +import logging +import re +from typing import Any, Dict, List, Optional +from mcp.server.fastmcp import FastMCP + +from .metadata_base_tool import MetadataToolBase, ParameterValidationError, QueryExecutionError +from src.db.mysql_operations import get_db_connection, execute_query + +logger = logging.getLogger("mysql_server") + +# 工具函数: 用于参数验证 +def validate_pattern(pattern: str) -> bool: + """ + 验证模式字符串是否安全 (防止SQL注入) + + Args: + pattern: 要验证的模式字符串 + + Returns: + 如果模式安全返回True,否则抛出ValueError + + Raises: + ValueError: 当模式包含不安全字符时 + """ + # 仅允许字母、数字、下划线和通配符(% 和 _) + if not re.match(r'^[a-zA-Z0-9_%]+$', pattern): + raise ValueError("模式只能包含字母、数字、下划线和通配符(%_)") + return True + +def validate_table_name(name: str) -> bool: + """ + 验证表名是否合法安全 + + Args: + name: 要验证的表名 + + Returns: + 如果表名安全返回True,否则抛出ValueError + + Raises: + ValueError: 当表名包含不安全字符时 + """ + # 仅允许字母、数字和下划线 + if not re.match(r'^[a-zA-Z0-9_]+$', name): + raise ValueError(f"无效的表名: {name}, 表名只能包含字母、数字和下划线") + return True + +def validate_database_name(name: str) -> bool: + """ + 验证数据库名是否合法安全 + + Args: + name: 要验证的数据库名 + + Returns: + 如果数据库名安全返回True,否则抛出ValueError + + Raises: + ValueError: 当数据库名包含不安全字符时 + """ + # 仅允许字母、数字和下划线 + if not re.match(r'^[a-zA-Z0-9_]+$', name): + raise ValueError(f"无效的数据库名: {name}, 数据库名只能包含字母、数字和下划线") + return True + +def register_metadata_tools(mcp: FastMCP): + """ + 注册MySQL元数据查询工具到MCP服务器 + + Args: + mcp: FastMCP服务器实例 + """ + logger.debug("注册MySQL元数据查询工具...") + + @mcp.tool() + @MetadataToolBase.handle_query_error + async def mysql_show_tables(database: Optional[str] = None, pattern: Optional[str] = None, + limit: int = 100, exclude_views: bool = False) -> str: + """ + 获取数据库中的表列表,支持筛选和限制结果数量 + + Args: + database: 数据库名称 (可选,默认使用当前连接的数据库) + pattern: 表名匹配模式 (可选, 例如 '%user%') + limit: 返回结果的最大数量 (默认100,设为0表示无限制) + exclude_views: 是否排除视图 (默认为False) + + Returns: + 表列表的JSON字符串 + """ + # 参数验证 + if database: + MetadataToolBase.validate_parameter( + "database", database, + lambda x: re.match(r'^[a-zA-Z0-9_]+$', x), + "数据库名称只能包含字母、数字和下划线" + ) + + if pattern: + MetadataToolBase.validate_parameter( + "pattern", pattern, + lambda x: re.match(r'^[a-zA-Z0-9_%]+$', x), + "模式只能包含字母、数字、下划线和通配符(%_)" + ) + + MetadataToolBase.validate_parameter( + "limit", limit, + lambda x: isinstance(x, int) and x >= 0, + "返回结果的最大数量必须是非负整数" + ) + + # 基础查询 + base_query = "SHOW FULL TABLES" if exclude_views else "SHOW TABLES" + if database: + base_query += f" FROM `{database}`" + if pattern: + base_query += f" LIKE '{pattern}'" + + logger.debug(f"执行查询: {base_query}") + + # 执行查询 + with get_db_connection() as connection: + results = await execute_query(connection, base_query) + + # 如果需要排除视图,且使用的是SHOW FULL TABLES + if exclude_views and "FULL" in base_query: + filtered_results = [] + + # 查找表名和表类型字段 + fields = list(results[0].keys()) if results else [] + table_field = fields[0] if fields else None + table_type_field = fields[1] if len(fields) > 1 else None + + if table_field and table_type_field: + # 基表类型通常是"BASE TABLE" + for item in results: + if item[table_type_field] == 'BASE TABLE': + filtered_results.append(item) + else: + filtered_results = results + else: + filtered_results = results + + # 限制返回数量 + if limit > 0 and len(filtered_results) > limit: + limited_results = filtered_results[:limit] + is_limited = True + else: + limited_results = filtered_results + is_limited = False + + # 构造元数据 + metadata_info = { + "metadata_info": { + "operation_type": "表列表查询", + "result_count": len(limited_results), + "total_count": len(results), + "filtered": { + "database": database, + "pattern": pattern, + "exclude_views": exclude_views and "FULL" in base_query, + "limited": is_limited + } + }, + "results": limited_results + } + + return json.dumps(metadata_info, default=str) + + @mcp.tool() + @MetadataToolBase.handle_query_error + async def mysql_show_columns(table: str, database: Optional[str] = None) -> str: + """ + 获取表的列信息 + + Args: + table: 表名 + database: 数据库名称 (可选,默认使用当前连接的数据库) + + Returns: + 表列信息的JSON字符串 + """ + # 参数验证 + MetadataToolBase.validate_parameter( + "table", table, + lambda x: re.match(r'^[a-zA-Z0-9_]+$', x), + "表名只能包含字母、数字和下划线" + ) + + if database: + MetadataToolBase.validate_parameter( + "database", database, + lambda x: re.match(r'^[a-zA-Z0-9_]+$', x), + "数据库名称只能包含字母、数字和下划线" + ) + + query = f"SHOW COLUMNS FROM `{table}`" if not database else f"SHOW COLUMNS FROM `{database}`.`{table}`" + logger.debug(f"执行查询: {query}") + + return await MetadataToolBase.execute_metadata_query(query, operation_type="表列信息查询") + + @mcp.tool() + @MetadataToolBase.handle_query_error + async def mysql_describe_table(table: str, database: Optional[str] = None) -> str: + """ + 描述表结构 + + Args: + table: 表名 + database: 数据库名称 (可选,默认使用当前连接的数据库) + + Returns: + 表结构描述的JSON字符串 + """ + # 参数验证 + MetadataToolBase.validate_parameter( + "table", table, + lambda x: re.match(r'^[a-zA-Z0-9_]+$', x), + "表名只能包含字母、数字和下划线" + ) + + if database: + MetadataToolBase.validate_parameter( + "database", database, + lambda x: re.match(r'^[a-zA-Z0-9_]+$', x), + "数据库名称只能包含字母、数字和下划线" + ) + + # DESCRIBE 语句与 SHOW COLUMNS 功能类似,但结果格式可能略有不同 + query = f"DESCRIBE `{table}`" if not database else f"DESCRIBE `{database}`.`{table}`" + logger.debug(f"执行查询: {query}") + + return await MetadataToolBase.execute_metadata_query(query, operation_type="表结构描述") + + @mcp.tool() + @MetadataToolBase.handle_query_error + async def mysql_show_create_table(table: str, database: Optional[str] = None) -> str: + """ + 获取表的创建语句 + + Args: + table: 表名 + database: 数据库名称 (可选,默认使用当前连接的数据库) + + Returns: + 表创建语句的JSON字符串 + """ + # 参数验证 + MetadataToolBase.validate_parameter( + "table", table, + lambda x: re.match(r'^[a-zA-Z0-9_]+$', x), + "表名只能包含字母、数字和下划线" + ) + + if database: + MetadataToolBase.validate_parameter( + "database", database, + lambda x: re.match(r'^[a-zA-Z0-9_]+$', x), + "数据库名称只能包含字母、数字和下划线" + ) + + table_ref = f"`{table}`" if not database else f"`{database}`.`{table}`" + query = f"SHOW CREATE TABLE {table_ref}" + logger.debug(f"执行查询: {query}") + + return await MetadataToolBase.execute_metadata_query(query, operation_type="表创建语句查询") \ No newline at end of file diff --git a/src/tools/mysql_schema_tool.py b/src/tools/mysql_schema_tool.py new file mode 100644 index 0000000..ae60a63 --- /dev/null +++ b/src/tools/mysql_schema_tool.py @@ -0,0 +1,294 @@ +""" +MySQL表结构高级查询工具 +提供索引、约束、表状态等高级元数据查询功能 +""" + +import json +import logging +import re +import os +from typing import Any, Dict, List, Optional, Union +from mcp.server.fastmcp import FastMCP + +from .metadata_base_tool import MetadataToolBase, ParameterValidationError, QueryExecutionError +from src.db.mysql_operations import get_db_connection, execute_query + +logger = logging.getLogger("mysql_server") + +# 参数验证函数 +def validate_table_name(name: str) -> bool: + """ + 验证表名是否合法安全 + + Args: + name: 要验证的表名 + + Returns: + 如果表名安全返回True,否则抛出ValueError + + Raises: + ValueError: 当表名包含不安全字符时 + """ + # 仅允许字母、数字和下划线 + if not re.match(r'^[a-zA-Z0-9_]+$', name): + raise ValueError(f"无效的表名: {name}, 表名只能包含字母、数字和下划线") + return True + +def validate_database_name(name: str) -> bool: + """ + 验证数据库名是否合法安全 + + Args: + name: 要验证的数据库名 + + Returns: + 如果数据库名安全返回True,否则抛出ValueError + + Raises: + ValueError: 当数据库名包含不安全字符时 + """ + # 仅允许字母、数字和下划线 + if not re.match(r'^[a-zA-Z0-9_]+$', name): + raise ValueError(f"无效的数据库名: {name}, 数据库名只能包含字母、数字和下划线") + return True + +def validate_column_name(name: str) -> bool: + """ + 验证列名是否合法安全 + + Args: + name: 要验证的列名 + + Returns: + 如果列名安全返回True,否则抛出ValueError + + Raises: + ValueError: 当列名包含不安全字符时 + """ + # 仅允许字母、数字和下划线 + if not re.match(r'^[a-zA-Z0-9_]+$', name): + raise ValueError(f"无效的列名: {name}, 列名只能包含字母、数字和下划线") + return True + +async def execute_schema_query( + query: str, + params: Optional[Dict[str, Any]] = None, + operation_type: str = "元数据查询" +) -> str: + """ + 执行表结构查询 + + Args: + query: SQL查询语句 + params: 查询参数 (可选) + operation_type: 操作类型描述 + + Returns: + 查询结果的JSON字符串 + """ + with get_db_connection() as connection: + results = await execute_query(connection, query, params) + return MetadataToolBase.format_results(results, operation_type) + +def register_schema_tools(mcp: FastMCP): + """ + 注册MySQL表结构高级查询工具到MCP服务器 + + Args: + mcp: FastMCP服务器实例 + """ + logger.debug("注册MySQL表结构高级查询工具...") + + @mcp.tool() + @MetadataToolBase.handle_query_error + async def mysql_show_indexes(table: str, database: Optional[str] = None) -> str: + """ + 获取表的索引信息 + + Args: + table: 表名 + database: 数据库名称 (可选,默认使用当前连接的数据库) + + Returns: + 表索引信息的JSON字符串 + """ + # 参数验证 + validate_table_name(table) + + if database: + validate_database_name(database) + + # 构建查询 + table_ref = f"`{table}`" if not database else f"`{database}`.`{table}`" + query = f"SHOW INDEX FROM {table_ref}" + logger.debug(f"执行查询: {query}") + + # 执行查询 + return await execute_schema_query(query, operation_type="表索引查询") + + @mcp.tool() + @MetadataToolBase.handle_query_error + async def mysql_show_table_status(database: Optional[str] = None, like_pattern: Optional[str] = None) -> str: + """ + 获取表状态信息 + + Args: + database: 数据库名称 (可选,默认使用当前连接的数据库) + like_pattern: 表名匹配模式 (可选,例如 '%user%') + + Returns: + 表状态信息的JSON字符串 + """ + # 参数验证 + if database: + validate_database_name(database) + + if like_pattern: + validate_column_name(like_pattern) + + # 构建查询 + if database: + query = f"SHOW TABLE STATUS FROM `{database}`" + else: + query = "SHOW TABLE STATUS" + + if like_pattern: + query += f" LIKE '{like_pattern}'" + + logger.debug(f"执行查询: {query}") + + # 执行查询 + return await execute_schema_query(query, operation_type="表状态查询") + + @mcp.tool() + @MetadataToolBase.handle_query_error + async def mysql_show_foreign_keys(table: str, database: Optional[str] = None) -> str: + """ + 获取表的外键约束信息 + + Args: + table: 表名 + database: 数据库名称 (可选,默认使用当前连接的数据库) + + Returns: + 表外键约束信息的JSON字符串 + """ + # 参数验证 + validate_table_name(table) + + if database: + validate_database_name(database) + + # 确定数据库名 + db_name = database + if not db_name: + # 获取当前数据库 + with get_db_connection() as connection: + current_db_results = await execute_query(connection, "SELECT DATABASE() as db") + if current_db_results and 'db' in current_db_results[0]: + db_name = current_db_results[0]['db'] + + if not db_name: + raise ValueError("无法确定数据库名称,请明确指定database参数") + + # 使用INFORMATION_SCHEMA查询外键 + query = """ + SELECT + CONSTRAINT_NAME, + TABLE_NAME, + COLUMN_NAME, + REFERENCED_TABLE_NAME, + REFERENCED_COLUMN_NAME, + UPDATE_RULE, + DELETE_RULE + FROM + INFORMATION_SCHEMA.KEY_COLUMN_USAGE kcu + JOIN + INFORMATION_SCHEMA.REFERENTIAL_CONSTRAINTS rc + ON + kcu.CONSTRAINT_NAME = rc.CONSTRAINT_NAME + WHERE + kcu.TABLE_SCHEMA = %s + AND kcu.TABLE_NAME = %s + AND kcu.REFERENCED_TABLE_NAME IS NOT NULL + """ + params = {'TABLE_SCHEMA': db_name, 'TABLE_NAME': table} + + logger.debug(f"执行查询: 获取表 {db_name}.{table} 的外键约束") + + # 执行查询 + return await execute_schema_query(query, params, operation_type="外键约束查询") + + @mcp.tool() + @MetadataToolBase.handle_query_error + async def mysql_paginate_results(query: str, page: int = 1, page_size: int = 50) -> str: + """ + 分页执行查询以处理大型结果集 + + Args: + query: SQL查询语句 + page: 页码 (从1开始) + page_size: 每页记录数 (默认50) + + Returns: + 分页结果的JSON字符串 + """ + # 参数验证 + MetadataToolBase.validate_parameter( + "page", page, + lambda x: isinstance(x, int) and x > 0, + "页码必须是正整数" + ) + + MetadataToolBase.validate_parameter( + "page_size", page_size, + lambda x: isinstance(x, int) and 1 <= x <= 1000, + "每页记录数必须在1-1000之间" + ) + + # 检查查询语法 + if not query.strip().upper().startswith('SELECT'): + raise ValueError("只支持SELECT查询的分页") + + # 计算LIMIT和OFFSET + offset = (page - 1) * page_size + + # 在查询末尾添加LIMIT子句 + paginated_query = query.strip() + if 'LIMIT' in paginated_query.upper(): + raise ValueError("查询已包含LIMIT子句,请移除后重试") + + paginated_query += f" LIMIT {page_size} OFFSET {offset}" + + logger.debug(f"执行分页查询: 页码={page}, 每页记录数={page_size}") + + # 获取总记录数(用于计算总页数) + count_query = f"SELECT COUNT(*) as total FROM ({query}) as temp_count_table" + + with get_db_connection() as connection: + # 执行分页查询 + results = await execute_query(connection, paginated_query) + + # 获取总记录数 + count_results = await execute_query(connection, count_query) + total_records = count_results[0]['total'] if count_results else 0 + + # 计算总页数 + total_pages = (total_records + page_size - 1) // page_size + + # 构建分页元数据 + pagination_info = { + "metadata_info": { + "operation_type": "分页查询", + "result_count": len(results), + "pagination": { + "page": page, + "page_size": page_size, + "total_records": total_records, + "total_pages": total_pages + } + }, + "results": results + } + + return json.dumps(pagination_info, default=str) \ No newline at end of file