<feat> 重构配置管理与安全检查机制,新增SQL解析器,优化数据库连接池管理

This commit is contained in:
tangyi
2025-05-09 11:36:30 +08:00
parent e0550a531d
commit 9cad3837a6
17 changed files with 1731 additions and 728 deletions

View File

@ -11,64 +11,10 @@ from mcp.server.fastmcp import FastMCP
from .metadata_base_tool import MetadataToolBase, ParameterValidationError, QueryExecutionError
from src.db.mysql_operations import get_db_connection, execute_query
from src.validators import SQLValidators
logger = logging.getLogger("mysql_server")
# 工具函数: 用于参数验证
def validate_pattern(pattern: str) -> bool:
"""
验证模式字符串是否安全 (防止SQL注入)
Args:
pattern: 要验证的模式字符串
Returns:
如果模式安全返回True否则抛出ValueError
Raises:
ValueError: 当模式包含不安全字符时
"""
# 仅允许字母、数字、下划线和通配符(% 和 _)
if not re.match(r'^[a-zA-Z0-9_%]+$', pattern):
raise ValueError("模式只能包含字母、数字、下划线和通配符(%_)")
return True
def validate_table_name(name: str) -> bool:
"""
验证表名是否合法安全
Args:
name: 要验证的表名
Returns:
如果表名安全返回True否则抛出ValueError
Raises:
ValueError: 当表名包含不安全字符时
"""
# 仅允许字母、数字和下划线
if not re.match(r'^[a-zA-Z0-9_]+$', name):
raise ValueError(f"无效的表名: {name}, 表名只能包含字母、数字和下划线")
return True
def validate_database_name(name: str) -> bool:
"""
验证数据库名是否合法安全
Args:
name: 要验证的数据库名
Returns:
如果数据库名安全返回True否则抛出ValueError
Raises:
ValueError: 当数据库名包含不安全字符时
"""
# 仅允许字母、数字和下划线
if not re.match(r'^[a-zA-Z0-9_]+$', name):
raise ValueError(f"无效的数据库名: {name}, 数据库名只能包含字母、数字和下划线")
return True
def register_metadata_tools(mcp: FastMCP):
"""
注册MySQL元数据查询工具到MCP服务器
@ -98,20 +44,20 @@ def register_metadata_tools(mcp: FastMCP):
if database:
MetadataToolBase.validate_parameter(
"database", database,
lambda x: re.match(r'^[a-zA-Z0-9_]+$', x),
SQLValidators.validate_database_name,
"数据库名称只能包含字母、数字和下划线"
)
if pattern:
MetadataToolBase.validate_parameter(
"pattern", pattern,
lambda x: re.match(r'^[a-zA-Z0-9_%]+$', x),
SQLValidators.validate_like_pattern,
"模式只能包含字母、数字、下划线和通配符(%_)"
)
MetadataToolBase.validate_parameter(
"limit", limit,
lambda x: isinstance(x, int) and x >= 0,
lambda x: SQLValidators.validate_integer(x, min_value=0),
"返回结果的最大数量必须是非负整数"
)
@ -124,8 +70,8 @@ def register_metadata_tools(mcp: FastMCP):
logger.debug(f"执行查询: {base_query}")
# 执行查询
with get_db_connection() as connection:
# 执行查询 - 使用异步上下文管理器
async with get_db_connection() as connection:
results = await execute_query(connection, base_query)
# 如果需要排除视图且使用的是SHOW FULL TABLES
@ -133,17 +79,46 @@ def register_metadata_tools(mcp: FastMCP):
filtered_results = []
# 查找表名和表类型字段
fields = list(results[0].keys()) if results else []
table_field = fields[0] if fields else None
table_type_field = fields[1] if len(fields) > 1 else None
if results:
# 确定表名和表类型字段名
field_names = list(results[0].keys())
table_type_field = None
table_field = None
# 查找表类型字段 - 这通常是'Table_type',但也检查其他可能的名称
possible_type_fields = ['Table_type', 'table_type', 'type']
for field in possible_type_fields:
if field in field_names:
table_type_field = field
break
# 查找表名字段 - 这可能是结果中的第一个字段
for field in field_names:
if field != table_type_field: # 表名不会是类型字段
if field.lower() in ['table', 'name', 'table_name', 'tables_in_']:
table_field = field
break
# 如果没找到明确的表名字段,使用第一个非类型字段
if not table_field and len(field_names) > 0:
for field in field_names:
if field != table_type_field:
table_field = field
break
if table_field and table_type_field:
# 基表类型通常是"BASE TABLE"
for item in results:
if item[table_type_field] == 'BASE TABLE':
filtered_results.append(item)
# 只有当我们能确定表名和类型字段时才进行过滤
if table_field and table_type_field:
logger.debug(f"表名字段: {table_field}, 表类型字段: {table_type_field}")
# 只保留基表 (BASE TABLE),排除视图和其他对象
for item in results:
if item[table_type_field] == 'BASE TABLE':
filtered_results.append(item)
else:
# 如果无法确定字段,保留所有结果并记录警告
logger.warning("无法确定表类型字段,无法排除视图")
filtered_results = results
else:
filtered_results = results
filtered_results = []
else:
filtered_results = results
@ -189,14 +164,14 @@ def register_metadata_tools(mcp: FastMCP):
# 参数验证
MetadataToolBase.validate_parameter(
"table", table,
lambda x: re.match(r'^[a-zA-Z0-9_]+$', x),
SQLValidators.validate_table_name,
"表名只能包含字母、数字和下划线"
)
if database:
MetadataToolBase.validate_parameter(
"database", database,
lambda x: re.match(r'^[a-zA-Z0-9_]+$', x),
SQLValidators.validate_database_name,
"数据库名称只能包含字母、数字和下划线"
)
@ -221,22 +196,21 @@ def register_metadata_tools(mcp: FastMCP):
# 参数验证
MetadataToolBase.validate_parameter(
"table", table,
lambda x: re.match(r'^[a-zA-Z0-9_]+$', x),
SQLValidators.validate_table_name,
"表名只能包含字母、数字和下划线"
)
if database:
MetadataToolBase.validate_parameter(
"database", database,
lambda x: re.match(r'^[a-zA-Z0-9_]+$', x),
SQLValidators.validate_database_name,
"数据库名称只能包含字母、数字和下划线"
)
# DESCRIBE 语句与 SHOW COLUMNS 功能类似,但结果格式可能略有不同
query = f"DESCRIBE `{table}`" if not database else f"DESCRIBE `{database}`.`{table}`"
logger.debug(f"执行查询: {query}")
return await MetadataToolBase.execute_metadata_query(query, operation_type="表结构描述")
return await MetadataToolBase.execute_metadata_query(query, operation_type="表结构描述查询")
@mcp.tool()
@MetadataToolBase.handle_query_error
@ -254,14 +228,14 @@ def register_metadata_tools(mcp: FastMCP):
# 参数验证
MetadataToolBase.validate_parameter(
"table", table,
lambda x: re.match(r'^[a-zA-Z0-9_]+$', x),
SQLValidators.validate_table_name,
"表名只能包含字母、数字和下划线"
)
if database:
MetadataToolBase.validate_parameter(
"database", database,
lambda x: re.match(r'^[a-zA-Z0-9_]+$', x),
SQLValidators.validate_database_name,
"数据库名称只能包含字母、数字和下划线"
)