mirror of
https://github.com/mangooer/mysql-mcp-server-sse.git
synced 2025-12-18 03:26:39 +08:00
<feat> 重构配置管理与安全检查机制,新增SQL解析器,优化数据库连接池管理
This commit is contained in:
129
src/config.py
Normal file
129
src/config.py
Normal file
@ -0,0 +1,129 @@
|
||||
import os
|
||||
from typing import Set, List
|
||||
from enum import IntEnum, Enum
|
||||
|
||||
# 环境变量
|
||||
class EnvironmentType(Enum):
|
||||
"""环境类型"""
|
||||
DEVELOPMENT = 'development'
|
||||
PRODUCTION = 'production'
|
||||
|
||||
class SQLRiskLevel(IntEnum):
|
||||
"""SQL操作风险等级"""
|
||||
LOW = 1 # 查询操作(SELECT)
|
||||
MEDIUM = 2 # 基本数据修改(INSERT,有WHERE的UPDATE/DELETE)
|
||||
HIGH = 3 # 结构变更(CREATE/ALTER)和无WHERE的数据修改
|
||||
CRITICAL = 4 # 危险操作(DROP/TRUNCATE等)
|
||||
|
||||
# 服务器配置
|
||||
class ServerConfig:
|
||||
"""服务器配置"""
|
||||
HOST = os.getenv('HOST', '127.0.0.1')
|
||||
PORT = int(os.getenv('PORT', '3000'))
|
||||
|
||||
# 数据库配置
|
||||
class DatabaseConfig:
|
||||
"""数据库连接配置"""
|
||||
HOST = os.getenv('MYSQL_HOST', 'localhost')
|
||||
USER = os.getenv('MYSQL_USER', 'root')
|
||||
PASSWORD = os.getenv('MYSQL_PASSWORD', '')
|
||||
DATABASE = os.getenv('MYSQL_DATABASE', '')
|
||||
PORT = int(os.getenv('MYSQL_PORT', '3306'))
|
||||
CONNECTION_TIMEOUT = int(os.getenv('DB_CONNECTION_TIMEOUT', '5'))
|
||||
AUTH_PLUGIN = os.getenv('DB_AUTH_PLUGIN', 'mysql_native_password')
|
||||
|
||||
@staticmethod
|
||||
def get_config():
|
||||
"""获取数据库配置字典"""
|
||||
return {
|
||||
'host': DatabaseConfig.HOST,
|
||||
'user': DatabaseConfig.USER,
|
||||
'password': DatabaseConfig.PASSWORD,
|
||||
'database': DatabaseConfig.DATABASE,
|
||||
'port': DatabaseConfig.PORT,
|
||||
'connection_timeout': DatabaseConfig.CONNECTION_TIMEOUT,
|
||||
'auth_plugin': DatabaseConfig.AUTH_PLUGIN
|
||||
}
|
||||
|
||||
# 数据库连接池配置
|
||||
class ConnectionPoolConfig:
|
||||
"""数据库连接池配置"""
|
||||
# 连接池最小连接数
|
||||
MIN_SIZE = int(os.getenv('DB_POOL_MIN_SIZE', '5'))
|
||||
# 连接池最大连接数
|
||||
MAX_SIZE = int(os.getenv('DB_POOL_MAX_SIZE', '20'))
|
||||
# 连接池回收时间(秒)
|
||||
POOL_RECYCLE = int(os.getenv('DB_POOL_RECYCLE', '300'))
|
||||
# 连接最大存活时间(秒,0表示不限制)
|
||||
MAX_LIFETIME = int(os.getenv('DB_POOL_MAX_LIFETIME', '0'))
|
||||
# 连接获取超时时间(秒)
|
||||
ACQUIRE_TIMEOUT = float(os.getenv('DB_POOL_ACQUIRE_TIMEOUT', '10.0'))
|
||||
# 是否启用连接池
|
||||
ENABLED = os.getenv('DB_POOL_ENABLED', 'true').lower() in ('true', 'yes', '1')
|
||||
|
||||
@staticmethod
|
||||
def get_config():
|
||||
"""获取连接池配置字典"""
|
||||
return {
|
||||
'minsize': ConnectionPoolConfig.MIN_SIZE,
|
||||
'maxsize': ConnectionPoolConfig.MAX_SIZE,
|
||||
'pool_recycle': ConnectionPoolConfig.POOL_RECYCLE,
|
||||
'max_lifetime': ConnectionPoolConfig.MAX_LIFETIME,
|
||||
'acquire_timeout': ConnectionPoolConfig.ACQUIRE_TIMEOUT,
|
||||
'enabled': ConnectionPoolConfig.ENABLED
|
||||
}
|
||||
|
||||
# 安全配置
|
||||
class SecurityConfig:
|
||||
"""安全相关配置"""
|
||||
# 环境类型
|
||||
ENV_TYPE_STR = os.getenv('ENV_TYPE', 'development').lower()
|
||||
try:
|
||||
ENV_TYPE = EnvironmentType(ENV_TYPE_STR)
|
||||
except ValueError:
|
||||
ENV_TYPE = EnvironmentType.DEVELOPMENT
|
||||
|
||||
# 允许的风险等级
|
||||
ALLOWED_RISK_LEVELS_STR = os.getenv('ALLOWED_RISK_LEVELS', 'LOW,MEDIUM')
|
||||
ALLOWED_RISK_LEVELS = set()
|
||||
for level_str in ALLOWED_RISK_LEVELS_STR.upper().split(','):
|
||||
level_str = level_str.strip()
|
||||
try:
|
||||
ALLOWED_RISK_LEVELS.add(SQLRiskLevel[level_str])
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
# 如果是生产环境且没有明确配置风险等级,则只允许LOW风险操作
|
||||
if ENV_TYPE == EnvironmentType.PRODUCTION and not os.getenv('ALLOWED_RISK_LEVELS'):
|
||||
ALLOWED_RISK_LEVELS = {SQLRiskLevel.LOW}
|
||||
|
||||
# 最大SQL长度
|
||||
MAX_SQL_LENGTH = int(os.getenv('MAX_SQL_LENGTH', '1000'))
|
||||
|
||||
# 敏感信息查询
|
||||
ALLOW_SENSITIVE_INFO = os.getenv('ALLOW_SENSITIVE_INFO', 'false').lower() in ('true', 'yes', '1')
|
||||
|
||||
# 阻止的模式
|
||||
BLOCKED_PATTERNS_STR = os.getenv('BLOCKED_PATTERNS', '')
|
||||
BLOCKED_PATTERNS = [p.strip() for p in BLOCKED_PATTERNS_STR.split(',') if p.strip()]
|
||||
|
||||
# 查询检查
|
||||
ENABLE_QUERY_CHECK = os.getenv('ENABLE_QUERY_CHECK', 'true').lower() in ('true', 'yes', '1')
|
||||
|
||||
# SQL操作配置
|
||||
class SQLConfig:
|
||||
"""SQL操作相关配置"""
|
||||
# 基础操作集合
|
||||
DDL_OPERATIONS = {
|
||||
'CREATE', 'ALTER', 'DROP', 'TRUNCATE', 'RENAME'
|
||||
}
|
||||
|
||||
DML_OPERATIONS = {
|
||||
'SELECT', 'INSERT', 'UPDATE', 'DELETE', 'MERGE'
|
||||
}
|
||||
|
||||
# 元数据操作集合
|
||||
METADATA_OPERATIONS = {
|
||||
'SHOW', 'DESC', 'DESCRIBE', 'EXPLAIN', 'HELP',
|
||||
'ANALYZE', 'CHECK', 'CHECKSUM', 'OPTIMIZE'
|
||||
}
|
||||
@ -1,70 +1,309 @@
|
||||
import os
|
||||
import logging
|
||||
import mysql.connector
|
||||
from mysql.connector import Error
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, List, Optional
|
||||
import aiomysql
|
||||
import asyncio
|
||||
import time
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
import threading
|
||||
import weakref
|
||||
|
||||
from ..config import DatabaseConfig, SecurityConfig, SQLConfig, ConnectionPoolConfig
|
||||
from ..security.sql_analyzer import SQLOperationType
|
||||
from ..security.query_limiter import QueryLimiter
|
||||
from ..security.interceptor import SQLInterceptor, SecurityException
|
||||
from ..security.sql_parser import SQLParser
|
||||
|
||||
logger = logging.getLogger("mysql_server")
|
||||
|
||||
# 初始化安全组件
|
||||
sql_analyzer = SQLOperationType()
|
||||
query_limiter = QueryLimiter()
|
||||
sql_interceptor = SQLInterceptor(sql_analyzer)
|
||||
|
||||
# 全局连接池 - 使用线程本地存储
|
||||
_pools = threading.local()
|
||||
|
||||
# 定期回收无效连接池
|
||||
_cleanup_interval = 300 # 秒,可根据需要调整
|
||||
_last_cleanup = 0
|
||||
|
||||
def _cleanup_unused_pools():
|
||||
"""回收无效或已关闭的连接池,释放资源"""
|
||||
global _last_cleanup
|
||||
now = time.time()
|
||||
if now - _last_cleanup < _cleanup_interval:
|
||||
return
|
||||
_last_cleanup = now
|
||||
if hasattr(_pools, 'pools'):
|
||||
to_remove = []
|
||||
for loop_id, pool in list(_pools.pools.items()):
|
||||
# 检查事件循环是否还活着
|
||||
if pool.closed:
|
||||
to_remove.append(loop_id)
|
||||
continue
|
||||
# 尝试获取事件循环对象
|
||||
for loop in asyncio.all_tasks():
|
||||
if id(loop.get_loop()) == loop_id:
|
||||
break
|
||||
else:
|
||||
# 没找到对应事件循环,关闭池
|
||||
pool.close()
|
||||
to_remove.append(loop_id)
|
||||
logger.info(f"检测到无主事件循环,已关闭连接池 (事件循环ID: {loop_id})")
|
||||
for loop_id in to_remove:
|
||||
del _pools.pools[loop_id]
|
||||
|
||||
def get_db_config():
|
||||
"""动态获取数据库配置"""
|
||||
return {
|
||||
'host': os.getenv('MYSQL_HOST', 'localhost'),
|
||||
'user': os.getenv('MYSQL_USER', 'root'),
|
||||
'password': os.getenv('MYSQL_PASSWORD', ''),
|
||||
'database': os.getenv('MYSQL_DATABASE', ''),
|
||||
'port': int(os.getenv('MYSQL_PORT', '3306')),
|
||||
'connection_timeout': 5,
|
||||
'auth_plugin': 'mysql_native_password' # 指定认证插件
|
||||
# 获取基础配置
|
||||
config = DatabaseConfig.get_config()
|
||||
|
||||
# aiomysql使用不同的配置键名,进行映射
|
||||
aiomysql_config = {
|
||||
'host': config['host'],
|
||||
'user': config['user'],
|
||||
'password': config['password'],
|
||||
'db': config['database'], # 'database' -> 'db'
|
||||
'port': config['port'],
|
||||
'connect_timeout': config.get('connection_timeout', 5), # 'connection_timeout' -> 'connect_timeout'
|
||||
# auth_plugin在aiomysql中不直接支持,忽略此参数
|
||||
}
|
||||
|
||||
return aiomysql_config
|
||||
|
||||
@contextmanager
|
||||
def get_db_connection():
|
||||
# 自定义异常类,细化错误处理
|
||||
class MySQLConnectionError(Exception):
|
||||
"""数据库连接错误基类"""
|
||||
pass
|
||||
|
||||
class MySQLAuthError(MySQLConnectionError):
|
||||
"""认证错误"""
|
||||
pass
|
||||
|
||||
class MySQLDatabaseNotFoundError(MySQLConnectionError):
|
||||
"""数据库不存在错误"""
|
||||
pass
|
||||
|
||||
class MySQLServerError(MySQLConnectionError):
|
||||
"""服务器连接错误"""
|
||||
pass
|
||||
|
||||
class MySQLAuthPluginError(MySQLConnectionError):
|
||||
"""认证插件错误"""
|
||||
pass
|
||||
|
||||
async def init_db_pool(min_size: Optional[int] = None, max_size: Optional[int] = None, require_database: bool = True):
|
||||
"""
|
||||
创建数据库连接的上下文管理器
|
||||
初始化数据库连接池
|
||||
|
||||
Args:
|
||||
min_size: 连接池最小连接数 (可选,默认从配置读取)
|
||||
max_size: 连接池最大连接数 (可选,默认从配置读取)
|
||||
require_database: 是否要求指定数据库
|
||||
|
||||
Returns:
|
||||
连接池对象
|
||||
|
||||
Raises:
|
||||
MySQLConnectionError: 连接池初始化失败时
|
||||
"""
|
||||
try:
|
||||
# 获取数据库配置
|
||||
db_config = get_db_config()
|
||||
|
||||
# 检查是否需要数据库名
|
||||
if require_database and not db_config.get('db'):
|
||||
raise MySQLDatabaseNotFoundError("数据库名称未设置,请检查环境变量MYSQL_DATABASE")
|
||||
|
||||
# 如果不需要指定数据库,且db为空,则移除db参数
|
||||
if not require_database and not db_config.get('db'):
|
||||
db_config.pop('db', None)
|
||||
|
||||
# 获取当前事件循环
|
||||
current_loop = asyncio.get_event_loop()
|
||||
loop_id = id(current_loop)
|
||||
|
||||
# 获取连接池配置
|
||||
pool_config = ConnectionPoolConfig.get_config()
|
||||
|
||||
# 使用传入的参数或者配置值
|
||||
min_size = min_size if min_size is not None else pool_config['minsize']
|
||||
max_size = max_size if max_size is not None else pool_config['maxsize']
|
||||
pool_recycle = pool_config['pool_recycle']
|
||||
|
||||
# 检查是否启用连接池
|
||||
if not pool_config['enabled']:
|
||||
logger.warning("连接池功能已被禁用,使用直接连接")
|
||||
# 创建单连接的池
|
||||
min_size = 1
|
||||
max_size = 1
|
||||
|
||||
# 创建连接池
|
||||
logger.info(f"初始化连接池: 最小连接数={min_size}, 最大连接数={max_size}, 回收时间={pool_recycle}秒")
|
||||
pool = await aiomysql.create_pool(
|
||||
minsize=min_size,
|
||||
maxsize=max_size,
|
||||
pool_recycle=pool_recycle,
|
||||
echo=False, # 不记录SQL执行日志,由我们自己的日志系统处理
|
||||
loop=current_loop, # 显式指定事件循环
|
||||
**db_config
|
||||
)
|
||||
|
||||
# 将池存储在线程本地存储中,键是事件循环ID
|
||||
if not hasattr(_pools, 'pools'):
|
||||
_pools.pools = {}
|
||||
_pools.pools[loop_id] = pool
|
||||
|
||||
# 注册事件循环关闭时自动清理
|
||||
def _finalizer(p=pool, lid=loop_id):
|
||||
if not p.closed:
|
||||
p.close()
|
||||
logger.info(f"事件循环关闭时自动关闭连接池 (事件循环ID: {lid})")
|
||||
try:
|
||||
weakref.finalize(current_loop, _finalizer)
|
||||
except Exception as e:
|
||||
logger.warning(f"注册事件循环关闭回调失败: {e}")
|
||||
|
||||
logger.info(f"MySQL连接池初始化成功,最小连接数: {min_size},最大连接数: {max_size},事件循环ID: {loop_id}")
|
||||
return pool
|
||||
except aiomysql.Error as err:
|
||||
error_msg = str(err)
|
||||
logger.error(f"数据库连接池初始化失败: {error_msg}")
|
||||
|
||||
# 细化错误类型
|
||||
if "Access denied" in error_msg:
|
||||
raise MySQLAuthError("访问被拒绝,请检查用户名和密码")
|
||||
elif "Unknown database" in error_msg:
|
||||
raise MySQLDatabaseNotFoundError(f"数据库'{db_config.get('db', '')}'不存在")
|
||||
elif "Can't connect" in error_msg or "Connection refused" in error_msg:
|
||||
raise MySQLServerError("无法连接到MySQL服务器,请检查服务是否启动")
|
||||
elif "Authentication plugin" in error_msg:
|
||||
raise MySQLAuthPluginError(f"认证插件问题: {error_msg},请尝试修改用户认证方式为mysql_native_password")
|
||||
else:
|
||||
raise MySQLConnectionError(f"数据库连接失败: {error_msg}")
|
||||
except Exception as e:
|
||||
logger.error(f"连接池初始化发生未预期错误: {str(e)}")
|
||||
raise MySQLConnectionError(f"连接池初始化失败: {str(e)}")
|
||||
|
||||
def get_pool_for_current_loop():
|
||||
"""获取当前事件循环对应的连接池"""
|
||||
_cleanup_unused_pools() # 每次获取时尝试回收
|
||||
try:
|
||||
# 获取当前事件循环ID
|
||||
current_loop = asyncio.get_event_loop()
|
||||
loop_id = id(current_loop)
|
||||
|
||||
# 检查是否有此循环的连接池
|
||||
if hasattr(_pools, 'pools') and loop_id in _pools.pools:
|
||||
pool = _pools.pools[loop_id]
|
||||
# 检查连接池是否已关闭
|
||||
if pool.closed:
|
||||
logger.debug(f"连接池已关闭,将重新创建 (事件循环ID: {loop_id})")
|
||||
return None
|
||||
return pool
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"获取当前事件循环的连接池失败: {str(e)}")
|
||||
return None
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_db_connection(require_database: bool = True):
|
||||
"""
|
||||
从连接池获取数据库连接的异步上下文管理器
|
||||
|
||||
Args:
|
||||
require_database: 是否要求必须指定数据库。设置为False时可以执行如SHOW DATABASES等不需要
|
||||
指定具体数据库的操作。
|
||||
|
||||
Yields:
|
||||
mysql.connector.connection.MySQLConnection: 数据库连接对象
|
||||
aiomysql.Connection: 数据库连接对象
|
||||
"""
|
||||
connection = None
|
||||
# 获取当前事件循环的连接池
|
||||
pool = get_pool_for_current_loop()
|
||||
|
||||
# 如果没有连接池,则初始化一个
|
||||
if pool is None:
|
||||
pool = await init_db_pool(require_database=require_database)
|
||||
|
||||
try:
|
||||
db_config = get_db_config()
|
||||
if not db_config['database']:
|
||||
raise ValueError("数据库名称未设置,请检查环境变量MYSQL_DATABASE")
|
||||
|
||||
connection = mysql.connector.connect(**db_config)
|
||||
yield connection
|
||||
except mysql.connector.Error as err:
|
||||
# 从连接池获取连接
|
||||
async with pool.acquire() as connection:
|
||||
yield connection
|
||||
except aiomysql.Error as err:
|
||||
error_msg = str(err)
|
||||
logger.error(f"数据库连接失败: {error_msg}")
|
||||
logger.error(f"获取数据库连接失败: {error_msg}")
|
||||
|
||||
if "Access denied" in error_msg:
|
||||
raise ValueError("访问被拒绝,请检查用户名和密码")
|
||||
raise MySQLAuthError("访问被拒绝,请检查用户名和密码")
|
||||
elif "Unknown database" in error_msg:
|
||||
db_config = get_db_config()
|
||||
raise ValueError(f"数据库'{db_config['database']}'不存在")
|
||||
raise MySQLDatabaseNotFoundError(f"数据库'{db_config.get('db', '')}'不存在")
|
||||
elif "Can't connect" in error_msg or "Connection refused" in error_msg:
|
||||
raise ConnectionError("无法连接到MySQL服务器,请检查服务是否启动")
|
||||
raise MySQLServerError("无法连接到MySQL服务器,请检查服务是否启动")
|
||||
elif "Authentication plugin" in error_msg:
|
||||
raise ValueError(f"认证插件问题: {error_msg},请尝试修改用户认证方式为mysql_native_password")
|
||||
raise MySQLAuthPluginError(f"认证插件问题: {error_msg},请尝试修改用户认证方式为mysql_native_password")
|
||||
else:
|
||||
raise ConnectionError(f"数据库连接失败: {error_msg}")
|
||||
finally:
|
||||
if connection and connection.is_connected():
|
||||
connection.close()
|
||||
logger.debug("数据库连接已关闭")
|
||||
raise MySQLConnectionError(f"数据库连接失败: {error_msg}")
|
||||
except Exception as e:
|
||||
logger.error(f"获取数据库连接时发生未预期错误: {str(e)}")
|
||||
raise MySQLConnectionError(f"获取数据库连接失败: {str(e)}")
|
||||
|
||||
async def execute_query(connection, query: str, params: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]:
|
||||
async def close_all_pools():
|
||||
"""关闭所有连接池"""
|
||||
if hasattr(_pools, 'pools'):
|
||||
for loop_id, pool in list(_pools.pools.items()):
|
||||
if not pool.closed:
|
||||
pool.close()
|
||||
await pool.wait_closed()
|
||||
logger.info(f"连接池已关闭 (事件循环ID: {loop_id})")
|
||||
_pools.pools = {}
|
||||
|
||||
@asynccontextmanager
|
||||
async def transaction(connection):
|
||||
"""
|
||||
事务上下文管理器
|
||||
|
||||
用法示例:
|
||||
async with get_db_connection() as conn:
|
||||
async with transaction(conn):
|
||||
await execute_query(conn, "INSERT INTO...")
|
||||
await execute_query(conn, "UPDATE...")
|
||||
|
||||
Args:
|
||||
connection: 数据库连接
|
||||
|
||||
Yields:
|
||||
connection: 事务中的数据库连接
|
||||
"""
|
||||
try:
|
||||
# 开始事务
|
||||
await connection.begin()
|
||||
logger.debug("事务已开始")
|
||||
yield connection
|
||||
# 提交事务
|
||||
await connection.commit()
|
||||
logger.debug("事务已提交")
|
||||
except Exception as e:
|
||||
# 回滚事务
|
||||
await connection.rollback()
|
||||
logger.error(f"事务执行失败,已回滚: {str(e)}")
|
||||
raise
|
||||
|
||||
def normalize_result(result_rows):
|
||||
"""
|
||||
将 DictRow 对象转换为普通字典
|
||||
|
||||
Args:
|
||||
result_rows: 查询结果行列表
|
||||
|
||||
Returns:
|
||||
包含普通字典的列表
|
||||
"""
|
||||
if not result_rows:
|
||||
return []
|
||||
|
||||
return [dict(row) for row in result_rows]
|
||||
|
||||
async def execute_query(connection, query: str, params: Optional[Dict[str, Any]] = None,
|
||||
batch_size: int = 1000, stream_results: bool = False) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
在给定的数据库连接上执行查询
|
||||
|
||||
@ -72,6 +311,8 @@ async def execute_query(connection, query: str, params: Optional[Dict[str, Any]]
|
||||
connection: 数据库连接
|
||||
query: SQL查询语句
|
||||
params: 查询参数 (可选)
|
||||
batch_size: 批处理大小,控制每次从游标获取的记录数量 (仅当stream_results=True时有效)
|
||||
stream_results: 是否使用流式处理获取大型结果集
|
||||
|
||||
Returns:
|
||||
查询结果列表,如果是修改操作则返回影响的行数
|
||||
@ -81,78 +322,211 @@ async def execute_query(connection, query: str, params: Optional[Dict[str, Any]]
|
||||
ValueError: 当查询执行失败时
|
||||
"""
|
||||
cursor = None
|
||||
operation = None # 初始化操作类型变量
|
||||
parsed_sql = None # 初始化SQL解析结果
|
||||
start_time = time.time() # 记录查询开始时间
|
||||
|
||||
try:
|
||||
# 安全检查
|
||||
if not await sql_interceptor.check_operation(query):
|
||||
raise SecurityException("操作被安全机制拒绝")
|
||||
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
# 创建异步游标,支持字典结果
|
||||
cursor = await connection.cursor(aiomysql.DictCursor)
|
||||
|
||||
# 执行查询
|
||||
# 执行查询 - 异步执行
|
||||
if params:
|
||||
cursor.execute(query, params)
|
||||
# 检查参数类型并转换为适合aiomysql的格式
|
||||
if isinstance(params, dict):
|
||||
# 构建使用%(key)s格式的查询
|
||||
await cursor.execute(query, params)
|
||||
else:
|
||||
await cursor.execute(query, params)
|
||||
else:
|
||||
cursor.execute(query)
|
||||
await cursor.execute(query)
|
||||
|
||||
# 获取操作类型
|
||||
operation = query.strip().split()[0].upper()
|
||||
# 解析SQL语句获取操作类型
|
||||
parsed_sql = SQLParser.parse_query(query)
|
||||
operation = parsed_sql['operation_type']
|
||||
|
||||
# 对于修改操作,提交事务并返回影响的行数
|
||||
if operation in {'UPDATE', 'DELETE', 'INSERT'}:
|
||||
if parsed_sql['category'] == 'DML' and operation in {'UPDATE', 'DELETE', 'INSERT'}:
|
||||
affected_rows = cursor.rowcount
|
||||
# 提交事务,确保更改被保存
|
||||
connection.commit()
|
||||
await connection.commit()
|
||||
logger.debug(f"修改操作 {operation} 影响了 {affected_rows} 行数据")
|
||||
|
||||
# 记录查询执行时间
|
||||
execution_time = time.time() - start_time
|
||||
_log_query_performance(query, execution_time, operation)
|
||||
|
||||
return [{'affected_rows': affected_rows}]
|
||||
|
||||
# 处理元数据查询操作
|
||||
if operation in sql_analyzer.metadata_operations:
|
||||
# 获取结果集
|
||||
results = cursor.fetchall()
|
||||
if parsed_sql['category'] == 'METADATA':
|
||||
# 元数据查询通常结果较小,直接获取所有结果
|
||||
results = await cursor.fetchall()
|
||||
|
||||
# 没有结果时返回空列表但添加元信息
|
||||
if not results:
|
||||
logger.debug(f"元数据查询 {operation} 没有返回结果")
|
||||
# 记录查询执行时间
|
||||
execution_time = time.time() - start_time
|
||||
_log_query_performance(query, execution_time, operation)
|
||||
return [{'metadata_operation': operation, 'result_count': 0}]
|
||||
|
||||
# 优化结果格式 - 为元数据结果添加额外信息
|
||||
metadata_results = []
|
||||
for row in results:
|
||||
# 对某些特定元数据查询进行特殊处理
|
||||
if operation == 'SHOW' and 'Table' in row:
|
||||
# SHOW TABLES 结果增强
|
||||
row['table_name'] = row['Table']
|
||||
elif operation in {'DESC', 'DESCRIBE'} and 'Field' in row:
|
||||
# DESC/DESCRIBE 表结构结果增强
|
||||
row['column_name'] = row['Field']
|
||||
row['data_type'] = row['Type']
|
||||
# 将行结果转为普通字典,而不是DictCursor的特殊对象
|
||||
row_dict = dict(row)
|
||||
|
||||
metadata_results.append(row)
|
||||
# 对某些特定元数据查询进行特殊处理
|
||||
if operation == 'SHOW' and 'Table' in row_dict:
|
||||
# SHOW TABLES 结果增强
|
||||
row_dict['table_name'] = row_dict['Table']
|
||||
elif operation in {'DESC', 'DESCRIBE'} and 'Field' in row_dict:
|
||||
# DESC/DESCRIBE 表结构结果增强
|
||||
row_dict['column_name'] = row_dict['Field']
|
||||
row_dict['data_type'] = row_dict['Type']
|
||||
|
||||
metadata_results.append(row_dict)
|
||||
|
||||
logger.debug(f"元数据查询 {operation} 返回 {len(metadata_results)} 条结果")
|
||||
|
||||
# 记录查询执行时间
|
||||
execution_time = time.time() - start_time
|
||||
_log_query_performance(query, execution_time, operation)
|
||||
|
||||
return metadata_results
|
||||
|
||||
# 对于查询操作,返回结果集
|
||||
results = cursor.fetchall()
|
||||
logger.debug(f"查询返回 {len(results)} 条结果")
|
||||
return results
|
||||
# 对于普通查询操作,根据stream_results参数决定结果获取方式
|
||||
if stream_results:
|
||||
# 流式处理大型结果集 - 分批获取
|
||||
all_results = []
|
||||
total_fetched = 0
|
||||
|
||||
# 分批次获取结果
|
||||
while True:
|
||||
batch = await cursor.fetchmany(batch_size)
|
||||
if not batch:
|
||||
break
|
||||
|
||||
# 使用工具函数将DictRow对象转换为普通字典
|
||||
dict_batch = normalize_result(batch)
|
||||
all_results.extend(dict_batch)
|
||||
|
||||
total_fetched += len(batch)
|
||||
logger.debug(f"已获取 {total_fetched} 条记录")
|
||||
|
||||
# 检查是否还有剩余结果
|
||||
if len(batch) < batch_size:
|
||||
break
|
||||
|
||||
logger.debug(f"流式查询总共返回 {len(all_results)} 条结果")
|
||||
|
||||
# 记录查询执行时间
|
||||
execution_time = time.time() - start_time
|
||||
_log_query_performance(query, execution_time, operation)
|
||||
|
||||
return all_results
|
||||
else:
|
||||
# 传统方式 - 一次性获取所有结果
|
||||
results = await cursor.fetchall()
|
||||
|
||||
# 使用工具函数将DictRow对象转换为普通字典
|
||||
dict_results = normalize_result(results)
|
||||
|
||||
logger.debug(f"查询返回 {len(dict_results)} 条结果")
|
||||
|
||||
# 记录查询执行时间
|
||||
execution_time = time.time() - start_time
|
||||
_log_query_performance(query, execution_time, operation)
|
||||
|
||||
return dict_results
|
||||
|
||||
except SecurityException as security_err:
|
||||
logger.error(f"安全检查失败: {str(security_err)}")
|
||||
raise
|
||||
except mysql.connector.Error as query_err:
|
||||
except aiomysql.Error as query_err:
|
||||
# 如果发生错误,进行回滚
|
||||
if operation and operation in {'UPDATE', 'DELETE', 'INSERT'}: # 确保operation已定义
|
||||
if parsed_sql and parsed_sql['operation_type'] in {'UPDATE', 'DELETE', 'INSERT'}:
|
||||
try:
|
||||
connection.rollback()
|
||||
await connection.rollback()
|
||||
logger.debug("事务已回滚")
|
||||
except:
|
||||
pass
|
||||
except Exception as rollback_err:
|
||||
logger.error(f"回滚事务失败: {str(rollback_err)}")
|
||||
logger.error(f"查询执行失败: {str(query_err)}")
|
||||
raise ValueError(f"查询执行失败: {str(query_err)}")
|
||||
finally:
|
||||
# 确保游标正确关闭
|
||||
if cursor:
|
||||
cursor.close()
|
||||
logger.debug("数据库游标已关闭")
|
||||
await cursor.close()
|
||||
logger.debug("数据库游标已关闭")
|
||||
|
||||
def _log_query_performance(query: str, execution_time: float, operation_type: str = ""):
|
||||
"""
|
||||
记录查询性能日志
|
||||
|
||||
Args:
|
||||
query: SQL查询语句
|
||||
execution_time: 执行时间(秒)
|
||||
operation_type: 操作类型
|
||||
"""
|
||||
# 截断长查询以避免日志过大
|
||||
truncated_query = query[:150] + '...' if len(query) > 150 else query
|
||||
|
||||
# 根据执行时间确定日志级别
|
||||
if execution_time >= 1.0: # 超过1秒的查询记录为警告
|
||||
logger.warning(f"慢查询 [{operation_type}]: {truncated_query} 执行时间: {execution_time:.4f}秒")
|
||||
elif execution_time >= 0.5: # 超过0.5秒的查询记录为提醒
|
||||
logger.info(f"较慢查询 [{operation_type}]: {truncated_query} 执行时间: {execution_time:.4f}秒")
|
||||
else:
|
||||
logger.debug(f"查询 [{operation_type}] 执行时间: {execution_time:.4f}秒")
|
||||
|
||||
async def execute_transaction_queries(connection, queries: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
在单个事务中执行多个查询
|
||||
|
||||
Args:
|
||||
connection: 数据库连接
|
||||
queries: 查询列表,每个查询是一个包含 'query' 和可选 'params' 的字典
|
||||
|
||||
Returns:
|
||||
所有查询的结果列表
|
||||
|
||||
Raises:
|
||||
Exception: 当任何查询执行失败时,整个事务将回滚
|
||||
"""
|
||||
results = []
|
||||
|
||||
async with transaction(connection):
|
||||
for query_item in queries:
|
||||
query = query_item['query']
|
||||
params = query_item.get('params')
|
||||
|
||||
# 执行单个查询
|
||||
result = await execute_query(connection, query, params)
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
|
||||
async def get_current_database() -> str:
|
||||
"""
|
||||
获取当前连接的数据库名称
|
||||
|
||||
Returns:
|
||||
当前数据库名称,如果未设置则返回空字符串
|
||||
"""
|
||||
async with get_db_connection(require_database=False) as connection:
|
||||
try:
|
||||
cursor = await connection.cursor(aiomysql.DictCursor)
|
||||
await cursor.execute("SELECT DATABASE() as db")
|
||||
result = await cursor.fetchone()
|
||||
await cursor.close()
|
||||
|
||||
if result and 'db' in result:
|
||||
return result['db'] or ""
|
||||
return ""
|
||||
except Exception as e:
|
||||
logger.error(f"获取当前数据库名称失败: {str(e)}")
|
||||
return ""
|
||||
@ -1,8 +1,9 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Dict
|
||||
|
||||
from ..config import SecurityConfig, SQLConfig
|
||||
from .sql_analyzer import SQLOperationType, SQLRiskLevel
|
||||
from .sql_parser import SQLParser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -15,8 +16,8 @@ class SQLInterceptor:
|
||||
|
||||
def __init__(self, analyzer: SQLOperationType):
|
||||
self.analyzer = analyzer
|
||||
# 设置最大SQL长度限制(默认1000个字符)
|
||||
self.max_sql_length = 1000
|
||||
# 设置最大SQL长度限制
|
||||
self.max_sql_length = SecurityConfig.MAX_SQL_LENGTH
|
||||
|
||||
async def check_operation(self, sql_query: str) -> bool:
|
||||
"""
|
||||
@ -40,19 +41,16 @@ class SQLInterceptor:
|
||||
if len(sql_query) > self.max_sql_length:
|
||||
raise SecurityException(f"SQL语句长度({len(sql_query)})超出限制({self.max_sql_length})")
|
||||
|
||||
# 使用SQLParser解析SQL
|
||||
parsed_sql = SQLParser.parse_query(sql_query)
|
||||
|
||||
# 检查SQL是否有效
|
||||
sql_parts = sql_query.strip().split()
|
||||
if not sql_parts:
|
||||
if not parsed_sql['is_valid']:
|
||||
raise SecurityException("SQL语句格式无效")
|
||||
|
||||
operation = sql_parts[0].upper()
|
||||
operation = parsed_sql['operation_type']
|
||||
# 更新支持的操作类型列表,包括元数据操作
|
||||
supported_operations = {
|
||||
'SELECT', 'INSERT', 'UPDATE', 'DELETE',
|
||||
'CREATE', 'ALTER', 'DROP', 'TRUNCATE', 'MERGE',
|
||||
'SHOW', 'DESC', 'DESCRIBE', 'EXPLAIN', 'HELP',
|
||||
'ANALYZE', 'CHECK', 'CHECKSUM', 'OPTIMIZE'
|
||||
}
|
||||
supported_operations = SQLConfig.DDL_OPERATIONS | SQLConfig.DML_OPERATIONS | SQLConfig.METADATA_OPERATIONS
|
||||
|
||||
if operation not in supported_operations:
|
||||
raise SecurityException(f"不支持的SQL操作: {operation}")
|
||||
@ -74,13 +72,11 @@ class SQLInterceptor:
|
||||
)
|
||||
|
||||
# 确定操作类型(DDL, DML 或 元数据)
|
||||
operation_category = "元数据操作" if operation in self.analyzer.metadata_operations else (
|
||||
"DDL操作" if operation in self.analyzer.ddl_operations else "DML操作"
|
||||
)
|
||||
operation_category = parsed_sql['category']
|
||||
|
||||
# 记录详细日志
|
||||
logger.info(
|
||||
f"SQL{operation_category}检查通过 - "
|
||||
f"SQL{operation_category}操作检查通过 - "
|
||||
f"操作: {risk_analysis['operation']}, "
|
||||
f"风险等级: {risk_analysis['risk_level'].name}, "
|
||||
f"影响表: {', '.join(risk_analysis['affected_tables'])}"
|
||||
|
||||
@ -1,16 +1,17 @@
|
||||
import os
|
||||
import logging
|
||||
from typing import Tuple
|
||||
|
||||
from ..config import SecurityConfig, SQLConfig
|
||||
from .sql_parser import SQLParser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class QueryLimiter:
|
||||
"""查询安全检查器"""
|
||||
|
||||
def __init__(self):
|
||||
# 解析启用状态(默认启用)
|
||||
enable_check = os.getenv('ENABLE_QUERY_CHECK', 'true')
|
||||
self.enable_check = str(enable_check).lower() not in {'false', '0', 'no', 'off'}
|
||||
# 从配置中获取启用状态
|
||||
self.enable_check = SecurityConfig.ENABLE_QUERY_CHECK
|
||||
|
||||
def check_query(self, sql_query: str) -> Tuple[bool, str]:
|
||||
"""
|
||||
@ -25,44 +26,14 @@ class QueryLimiter:
|
||||
if not self.enable_check:
|
||||
return True, ""
|
||||
|
||||
sql_query = sql_query.strip().upper()
|
||||
operation_type = self._get_operation_type(sql_query)
|
||||
# 使用SQLParser解析SQL
|
||||
parsed_sql = SQLParser.parse_query(sql_query)
|
||||
operation_type = parsed_sql['operation_type']
|
||||
|
||||
# 检查是否为无 WHERE 子句的更新/删除操作
|
||||
if operation_type in {'UPDATE', 'DELETE'} and 'WHERE' not in sql_query:
|
||||
if operation_type in {'UPDATE', 'DELETE'} and not parsed_sql['has_where']:
|
||||
error_msg = f"{operation_type}操作必须包含WHERE子句"
|
||||
logger.warning(f"查询被限制: {error_msg}")
|
||||
return False, error_msg
|
||||
|
||||
return True, ""
|
||||
|
||||
def _get_operation_type(self, sql_query: str) -> str:
|
||||
"""获取SQL操作类型"""
|
||||
if not sql_query:
|
||||
return ""
|
||||
words = sql_query.split()
|
||||
if not words:
|
||||
return ""
|
||||
return words[0].upper()
|
||||
|
||||
def _parse_int_env(self, env_name: str, default: int) -> int:
|
||||
"""解析整数类型的环境变量"""
|
||||
try:
|
||||
return int(os.getenv(env_name, str(default)))
|
||||
except (ValueError, TypeError):
|
||||
return default
|
||||
|
||||
def update_limits(self, new_limits: dict):
|
||||
"""
|
||||
更新限制阈值
|
||||
|
||||
Args:
|
||||
new_limits: 新的限制值字典
|
||||
"""
|
||||
for operation, limit in new_limits.items():
|
||||
if operation in self.max_limits:
|
||||
try:
|
||||
self.max_limits[operation] = int(limit)
|
||||
logger.info(f"更新{operation}操作的限制为: {limit}")
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(f"无效的限制值: {operation}={limit}")
|
||||
return True, ""
|
||||
@ -1,81 +1,31 @@
|
||||
import re
|
||||
import os
|
||||
from enum import IntEnum, Enum
|
||||
import logging
|
||||
from typing import Set, List
|
||||
from typing import List, Set
|
||||
|
||||
from ..config import SQLRiskLevel, EnvironmentType, SecurityConfig, SQLConfig
|
||||
from .sql_parser import SQLParser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class SQLRiskLevel(IntEnum):
|
||||
"""SQL操作风险等级"""
|
||||
LOW = 1 # 查询操作(SELECT)
|
||||
MEDIUM = 2 # 基本数据修改(INSERT,有WHERE的UPDATE/DELETE)
|
||||
HIGH = 3 # 结构变更(CREATE/ALTER)和无WHERE的数据修改
|
||||
CRITICAL = 4 # 危险操作(DROP/TRUNCATE等)
|
||||
|
||||
class EnvironmentType(Enum):
|
||||
"""环境类型"""
|
||||
DEVELOPMENT = 'development'
|
||||
PRODUCTION = 'production'
|
||||
|
||||
class SQLOperationType:
|
||||
"""SQL操作类型分析器"""
|
||||
|
||||
def __init__(self):
|
||||
# 环境类型处理
|
||||
env_type_str = os.getenv('ENV_TYPE', 'development').lower()
|
||||
try:
|
||||
self.env_type = EnvironmentType(env_type_str)
|
||||
except ValueError:
|
||||
logger.warning(f"无效的环境类型: {env_type_str},使用默认值: development")
|
||||
self.env_type = EnvironmentType.DEVELOPMENT
|
||||
# 环境类型从配置读取
|
||||
self.env_type = SecurityConfig.ENV_TYPE
|
||||
|
||||
# 基础操作集合
|
||||
self.ddl_operations = {
|
||||
'CREATE', 'ALTER', 'DROP', 'TRUNCATE', 'RENAME'
|
||||
}
|
||||
self.dml_operations = {
|
||||
'SELECT', 'INSERT', 'UPDATE', 'DELETE', 'MERGE'
|
||||
}
|
||||
# 操作类型集合从配置读取
|
||||
self.ddl_operations = SQLConfig.DDL_OPERATIONS
|
||||
self.dml_operations = SQLConfig.DML_OPERATIONS
|
||||
self.metadata_operations = SQLConfig.METADATA_OPERATIONS
|
||||
|
||||
# 添加元数据操作集合
|
||||
self.metadata_operations = {
|
||||
'SHOW', 'DESC', 'DESCRIBE', 'EXPLAIN', 'HELP',
|
||||
'ANALYZE', 'CHECK', 'CHECKSUM', 'OPTIMIZE'
|
||||
}
|
||||
|
||||
# 风险等级配置
|
||||
self.allowed_risk_levels = self._parse_risk_levels()
|
||||
self.blocked_patterns = self._parse_blocked_patterns('BLOCKED_PATTERNS')
|
||||
|
||||
# 生产环境特殊处理:如果没有明确配置风险等级,则只允许LOW风险操作
|
||||
if self.env_type == EnvironmentType.PRODUCTION and not os.getenv('ALLOWED_RISK_LEVELS'):
|
||||
self.allowed_risk_levels = {SQLRiskLevel.LOW}
|
||||
# 风险等级配置从配置读取
|
||||
self.allowed_risk_levels = SecurityConfig.ALLOWED_RISK_LEVELS
|
||||
self.blocked_patterns = SecurityConfig.BLOCKED_PATTERNS
|
||||
|
||||
logger.info(f"SQL分析器初始化 - 环境: {self.env_type.value}")
|
||||
logger.info(f"允许的风险等级: {[level.name for level in self.allowed_risk_levels]}")
|
||||
|
||||
def _parse_risk_levels(self) -> Set[SQLRiskLevel]:
|
||||
"""解析允许的风险等级"""
|
||||
allowed_levels_str = os.getenv('ALLOWED_RISK_LEVELS', 'LOW,MEDIUM')
|
||||
allowed_levels = set()
|
||||
|
||||
logger.info(f"从环境变量读取到的风险等级设置: '{allowed_levels_str}'")
|
||||
|
||||
for level_str in allowed_levels_str.upper().split(','):
|
||||
level_str = level_str.strip()
|
||||
try:
|
||||
allowed_levels.add(SQLRiskLevel[level_str])
|
||||
except KeyError:
|
||||
logger.warning(f"未知的风险等级配置: {level_str}")
|
||||
|
||||
return allowed_levels
|
||||
|
||||
def _parse_blocked_patterns(self, env_var: str) -> List[str]:
|
||||
"""解析禁止的操作模式"""
|
||||
patterns = os.getenv(env_var, '').split(',')
|
||||
return [p.strip() for p in patterns if p.strip()]
|
||||
|
||||
def analyze_risk(self, sql_query: str) -> dict:
|
||||
"""
|
||||
分析SQL查询的风险级别和影响范围
|
||||
@ -105,54 +55,76 @@ class SQLOperationType:
|
||||
'is_allowed': False
|
||||
}
|
||||
|
||||
operation = sql_query.split()[0].upper()
|
||||
# 使用SQLParser解析SQL
|
||||
parsed_sql = SQLParser.parse_query(sql_query)
|
||||
operation = parsed_sql['operation_type']
|
||||
|
||||
# 基本风险分析
|
||||
risk_analysis = {
|
||||
'operation': operation,
|
||||
'operation_type': 'DDL' if operation in self.ddl_operations else 'DML',
|
||||
'operation_type': parsed_sql['category'],
|
||||
'is_dangerous': self._check_dangerous_patterns(sql_query),
|
||||
'affected_tables': self._get_affected_tables(sql_query),
|
||||
'estimated_impact': self._estimate_impact(sql_query)
|
||||
'affected_tables': parsed_sql['tables'],
|
||||
'estimated_impact': self._estimate_impact(sql_query, parsed_sql)
|
||||
}
|
||||
|
||||
# 计算风险等级
|
||||
risk_level = self._calculate_risk_level(sql_query, operation, risk_analysis['is_dangerous'])
|
||||
risk_level = self._calculate_risk_level(sql_query, operation, risk_analysis['is_dangerous'], parsed_sql['has_where'])
|
||||
risk_analysis['risk_level'] = risk_level
|
||||
risk_analysis['is_allowed'] = risk_level in self.allowed_risk_levels
|
||||
|
||||
return risk_analysis
|
||||
|
||||
def _calculate_risk_level(self, sql_query: str, operation: str, is_dangerous: bool) -> SQLRiskLevel:
|
||||
def _calculate_risk_level(self, sql_query: str, operation: str, is_dangerous: bool, has_where: bool) -> SQLRiskLevel:
|
||||
"""
|
||||
计算操作风险等级
|
||||
|
||||
规则:
|
||||
1. 危险操作(匹配危险模式)=> CRITICAL
|
||||
2. DDL操作:
|
||||
2. 生产环境非SELECT操作 => CRITICAL
|
||||
3. DDL操作:
|
||||
- CREATE/ALTER => HIGH
|
||||
- DROP/TRUNCATE => CRITICAL
|
||||
3. DML操作:
|
||||
4. DML操作:
|
||||
- SELECT => LOW
|
||||
- INSERT => MEDIUM
|
||||
- UPDATE/DELETE(有WHERE)=> MEDIUM
|
||||
- UPDATE(无WHERE)=> HIGH
|
||||
- DELETE(无WHERE)=> CRITICAL
|
||||
4. 元数据操作:
|
||||
5. 元数据操作:
|
||||
- SHOW/DESC/DESCRIBE等 => LOW
|
||||
6. 多语句SQL通常被认为是高风险的
|
||||
"""
|
||||
# 解析SQL获取额外信息
|
||||
parsed_sql = SQLParser.parse_query(sql_query)
|
||||
|
||||
# 危险操作
|
||||
if is_dangerous:
|
||||
return SQLRiskLevel.CRITICAL
|
||||
|
||||
# 生产环境特别规则
|
||||
if self.env_type == EnvironmentType.PRODUCTION:
|
||||
# 生产环境中只允许SELECT和元数据操作
|
||||
if operation != 'SELECT' and parsed_sql['category'] != 'METADATA':
|
||||
return SQLRiskLevel.CRITICAL
|
||||
|
||||
# 生产环境中的多语句SQL视为高风险
|
||||
if parsed_sql.get('multi_statement', False):
|
||||
return SQLRiskLevel.HIGH
|
||||
|
||||
# 多语句SQL在任何环境中都是更高风险的
|
||||
if parsed_sql.get('multi_statement', False):
|
||||
# 至少中等风险,如果包含DDL则为高风险或严重风险
|
||||
if parsed_sql['category'] == 'DDL':
|
||||
return SQLRiskLevel.HIGH
|
||||
elif parsed_sql['category'] == 'DML' and operation not in {'SELECT'}:
|
||||
return SQLRiskLevel.HIGH
|
||||
return SQLRiskLevel.MEDIUM
|
||||
|
||||
# 元数据操作
|
||||
if operation in self.metadata_operations:
|
||||
return SQLRiskLevel.LOW # 元数据查询视为低风险操作
|
||||
|
||||
# 生产环境中非SELECT操作
|
||||
if self.env_type == EnvironmentType.PRODUCTION and operation != 'SELECT':
|
||||
return SQLRiskLevel.CRITICAL
|
||||
|
||||
# DDL操作
|
||||
if operation in self.ddl_operations:
|
||||
if operation in {'DROP', 'TRUNCATE'}:
|
||||
@ -161,61 +133,58 @@ class SQLOperationType:
|
||||
|
||||
# DML操作
|
||||
if operation == 'SELECT':
|
||||
# 对于不带LIMIT的大型SELECT, 风险可能提高
|
||||
if not parsed_sql['has_limit'] and self.env_type == EnvironmentType.PRODUCTION:
|
||||
return SQLRiskLevel.MEDIUM
|
||||
return SQLRiskLevel.LOW
|
||||
elif operation == 'INSERT':
|
||||
return SQLRiskLevel.MEDIUM
|
||||
elif operation == 'UPDATE':
|
||||
return SQLRiskLevel.HIGH if 'WHERE' not in sql_query.upper() else SQLRiskLevel.MEDIUM
|
||||
return SQLRiskLevel.HIGH if not has_where else SQLRiskLevel.MEDIUM
|
||||
elif operation == 'DELETE':
|
||||
# 无WHERE条件的DELETE操作视为CRITICAL风险
|
||||
return SQLRiskLevel.CRITICAL if 'WHERE' not in sql_query.upper() else SQLRiskLevel.MEDIUM
|
||||
return SQLRiskLevel.CRITICAL if not has_where else SQLRiskLevel.MEDIUM
|
||||
|
||||
# 默认情况
|
||||
return SQLRiskLevel.HIGH
|
||||
|
||||
def _check_dangerous_patterns(self, sql_query: str) -> bool:
|
||||
"""检查是否匹配危险操作模式"""
|
||||
sql_upper = sql_query.upper()
|
||||
|
||||
# 生产环境额外的安全检查
|
||||
if self.env_type == EnvironmentType.PRODUCTION:
|
||||
# 生产环境中禁止所有非SELECT操作
|
||||
if sql_upper.split()[0] != 'SELECT':
|
||||
return True
|
||||
# 解析SQL以获取更多信息
|
||||
parsed_sql = SQLParser.parse_query(sql_query)
|
||||
|
||||
# 检查是否为多语句SQL - 大多数情况下使用多语句SQL可能是危险的
|
||||
if parsed_sql.get('multi_statement', False) and self.env_type == EnvironmentType.PRODUCTION:
|
||||
# 生产环境中的多语句SQL视为危险
|
||||
return True
|
||||
|
||||
# 对敏感关键字的检查
|
||||
for pattern in self.blocked_patterns:
|
||||
if re.search(pattern, sql_upper, re.IGNORECASE):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _get_affected_tables(self, sql_query: str) -> list:
|
||||
"""获取受影响的表名列表"""
|
||||
words = sql_query.upper().split()
|
||||
tables = []
|
||||
|
||||
for i, word in enumerate(words):
|
||||
if word in {'FROM', 'JOIN', 'UPDATE', 'INTO', 'TABLE'}:
|
||||
if i + 1 < len(words):
|
||||
table = words[i + 1].strip('`;')
|
||||
if table not in {'SELECT', 'WHERE', 'SET'}:
|
||||
tables.append(table)
|
||||
|
||||
return list(set(tables))
|
||||
|
||||
def _estimate_impact(self, sql_query: str) -> dict:
|
||||
def _estimate_impact(self, sql_query: str, parsed_sql: dict) -> dict:
|
||||
"""
|
||||
估算查询影响范围
|
||||
|
||||
Args:
|
||||
sql_query: 原始SQL查询
|
||||
parsed_sql: 解析后的SQL信息
|
||||
|
||||
Returns:
|
||||
dict: 包含预估影响的字典
|
||||
"""
|
||||
operation = sql_query.split()[0].upper()
|
||||
operation = parsed_sql['operation_type']
|
||||
|
||||
impact = {
|
||||
'operation': operation,
|
||||
'estimated_rows': 0,
|
||||
'needs_where': operation in {'UPDATE', 'DELETE'},
|
||||
'has_where': 'WHERE' in sql_query.upper()
|
||||
'has_where': parsed_sql['has_where']
|
||||
}
|
||||
|
||||
# 根据环境类型调整估算
|
||||
|
||||
358
src/security/sql_parser.py
Normal file
358
src/security/sql_parser.py
Normal file
@ -0,0 +1,358 @@
|
||||
import sqlparse
|
||||
import re
|
||||
import logging
|
||||
from typing import List, Set, Tuple, Optional, Dict
|
||||
|
||||
from ..config import SQLConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class SQLParser:
|
||||
"""
|
||||
SQL解析器 - 使用sqlparse库提供更精确的SQL解析功能
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def parse_query(sql_query: str) -> Dict:
|
||||
"""
|
||||
解析SQL查询,返回解析结果
|
||||
|
||||
Args:
|
||||
sql_query: SQL查询语句
|
||||
|
||||
Returns:
|
||||
Dict: 包含解析结果的字典
|
||||
"""
|
||||
if not sql_query or not sql_query.strip():
|
||||
return {
|
||||
'operation_type': '',
|
||||
'tables': [],
|
||||
'has_where': False,
|
||||
'has_limit': False,
|
||||
'is_valid': False,
|
||||
'normalized_query': '',
|
||||
'category': 'UNKNOWN',
|
||||
'multi_statement': False,
|
||||
'statement_count': 0
|
||||
}
|
||||
|
||||
try:
|
||||
# 标准化和格式化SQL
|
||||
formatted_sql = SQLParser._format_sql(sql_query)
|
||||
# 解析SQL语句 - 可能有多个语句
|
||||
parsed = sqlparse.parse(formatted_sql)
|
||||
|
||||
# 检查是否有多个语句
|
||||
is_multi_statement = len(parsed) > 1
|
||||
statement_count = len(parsed)
|
||||
|
||||
if not parsed:
|
||||
return {
|
||||
'operation_type': '',
|
||||
'tables': [],
|
||||
'has_where': False,
|
||||
'has_limit': False,
|
||||
'is_valid': False,
|
||||
'normalized_query': formatted_sql,
|
||||
'category': 'UNKNOWN',
|
||||
'multi_statement': False,
|
||||
'statement_count': 0
|
||||
}
|
||||
|
||||
# 默认分析第一个语句,但记录多语句信息
|
||||
stmt = parsed[0]
|
||||
|
||||
# 获取操作类型
|
||||
operation_type = SQLParser._get_operation_type(stmt)
|
||||
|
||||
# 确定操作类别
|
||||
category = SQLParser._get_operation_category(operation_type)
|
||||
|
||||
# 提取表名 - 汇总所有语句中的表名
|
||||
tables = set()
|
||||
has_where = False
|
||||
has_limit = False
|
||||
|
||||
for statement in parsed:
|
||||
# 将各语句涉及的表合并
|
||||
tables.update(SQLParser._extract_tables(statement))
|
||||
|
||||
# 检查任一语句是否有WHERE子句
|
||||
if SQLParser._has_where_clause(statement):
|
||||
has_where = True
|
||||
|
||||
# 检查任一语句是否有LIMIT子句
|
||||
if SQLParser._has_limit_clause(statement):
|
||||
has_limit = True
|
||||
|
||||
# 对于多语句,获取最高风险的操作类型
|
||||
if is_multi_statement and len(parsed) > 1:
|
||||
operations = []
|
||||
categories = []
|
||||
for statement in parsed:
|
||||
op = SQLParser._get_operation_type(statement)
|
||||
operations.append(op)
|
||||
categories.append(SQLParser._get_operation_category(op))
|
||||
|
||||
# 风险优先级: DDL > DML > METADATA
|
||||
if 'DDL' in categories:
|
||||
category = 'DDL'
|
||||
# 在DDL操作中找出优先级最高的
|
||||
# DROP/TRUNCATE > ALTER > CREATE
|
||||
if 'DROP' in operations or 'TRUNCATE' in operations:
|
||||
operation_type = 'DROP' if 'DROP' in operations else 'TRUNCATE'
|
||||
elif 'ALTER' in operations:
|
||||
operation_type = 'ALTER'
|
||||
elif 'CREATE' in operations:
|
||||
operation_type = 'CREATE'
|
||||
elif 'DML' in categories:
|
||||
category = 'DML'
|
||||
# 在DML操作中找出优先级最高的
|
||||
# DELETE > UPDATE > INSERT > SELECT
|
||||
if 'DELETE' in operations:
|
||||
operation_type = 'DELETE'
|
||||
elif 'UPDATE' in operations:
|
||||
operation_type = 'UPDATE'
|
||||
elif 'INSERT' in operations:
|
||||
operation_type = 'INSERT'
|
||||
elif 'SELECT' in operations:
|
||||
operation_type = 'SELECT'
|
||||
|
||||
return {
|
||||
'operation_type': operation_type,
|
||||
'tables': list(tables),
|
||||
'has_where': has_where,
|
||||
'has_limit': has_limit,
|
||||
'is_valid': True,
|
||||
'normalized_query': formatted_sql,
|
||||
'category': category,
|
||||
'multi_statement': is_multi_statement,
|
||||
'statement_count': statement_count
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"SQL解析错误: {str(e)}")
|
||||
# 回退到简单的字符串解析
|
||||
result = SQLParser._fallback_parse(sql_query)
|
||||
# 添加多语句检测,简单检测分号
|
||||
result['multi_statement'] = ';' in sql_query.strip()
|
||||
result['statement_count'] = sql_query.count(';') + 1 if sql_query.strip() else 0
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _format_sql(sql_query: str) -> str:
|
||||
"""标准化SQL查询格式"""
|
||||
# 去除多余空白和注释
|
||||
return sqlparse.format(
|
||||
sql_query,
|
||||
strip_comments=True,
|
||||
reindent=True,
|
||||
keyword_case='upper'
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_operation_type(stmt: sqlparse.sql.Statement) -> str:
|
||||
"""获取SQL操作类型"""
|
||||
# 获取第一个token
|
||||
if stmt.tokens and stmt.tokens[0].ttype is sqlparse.tokens.DML:
|
||||
return stmt.tokens[0].value.upper()
|
||||
elif stmt.tokens and stmt.tokens[0].ttype is sqlparse.tokens.DDL:
|
||||
return stmt.tokens[0].value.upper()
|
||||
elif stmt.tokens and stmt.tokens[0].ttype is sqlparse.tokens.Keyword:
|
||||
return stmt.tokens[0].value.upper()
|
||||
|
||||
# 如果无法确定,返回空字符串
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _get_operation_category(operation_type: str) -> str:
|
||||
"""确定操作类别(DDL、DML或元数据)"""
|
||||
if operation_type in SQLConfig.DDL_OPERATIONS:
|
||||
return 'DDL'
|
||||
elif operation_type in SQLConfig.DML_OPERATIONS:
|
||||
return 'DML'
|
||||
elif operation_type in SQLConfig.METADATA_OPERATIONS:
|
||||
return 'METADATA'
|
||||
else:
|
||||
return 'UNKNOWN'
|
||||
|
||||
@staticmethod
|
||||
def _extract_tables(stmt: sqlparse.sql.Statement) -> List[str]:
|
||||
"""从SQL语句中提取所有表名"""
|
||||
tables = []
|
||||
|
||||
# 根据操作类型处理表名提取
|
||||
operation_type = SQLParser._get_operation_type(stmt)
|
||||
|
||||
# 递归函数用于深入处理复杂的SQL结构
|
||||
def extract_from_token_list(token_list):
|
||||
local_tables = []
|
||||
in_from_clause = False
|
||||
in_join_clause = False
|
||||
|
||||
for token in token_list.tokens:
|
||||
# 检测FROM子句
|
||||
if token.ttype is sqlparse.tokens.Keyword and token.value.upper() == 'FROM':
|
||||
in_from_clause = True
|
||||
continue
|
||||
|
||||
# 检测JOIN子句
|
||||
if token.ttype is sqlparse.tokens.Keyword and 'JOIN' in token.value.upper():
|
||||
in_join_clause = True
|
||||
continue
|
||||
|
||||
# 在FROM或JOIN子句后提取表名
|
||||
if in_from_clause or in_join_clause:
|
||||
if isinstance(token, sqlparse.sql.Identifier):
|
||||
# 直接引用的表名
|
||||
if token.get_real_name():
|
||||
local_tables.append(token.get_real_name())
|
||||
elif isinstance(token, sqlparse.sql.IdentifierList):
|
||||
# 多个表,如FROM table1, table2
|
||||
for identifier in token.get_identifiers():
|
||||
if identifier.get_real_name():
|
||||
local_tables.append(identifier.get_real_name())
|
||||
elif isinstance(token, sqlparse.sql.Function):
|
||||
# 处理子查询中的函数,可能包含表
|
||||
local_tables.extend(extract_from_token_list(token))
|
||||
elif isinstance(token, sqlparse.sql.Parenthesis):
|
||||
# 可能是子查询
|
||||
if token.tokens and isinstance(token.tokens[1], sqlparse.sql.Statement):
|
||||
# 是子查询,递归解析
|
||||
local_tables.extend(SQLParser._extract_tables(token.tokens[1]))
|
||||
else:
|
||||
# 其他括号结构,递归处理
|
||||
local_tables.extend(extract_from_token_list(token))
|
||||
|
||||
# 重置标志以避免收集其他部分的标识符
|
||||
if token.ttype in (sqlparse.tokens.Keyword, sqlparse.tokens.Punctuation):
|
||||
in_from_clause = False
|
||||
in_join_clause = False
|
||||
|
||||
# 递归处理其他TokenList
|
||||
if isinstance(token, sqlparse.sql.TokenList) and not isinstance(token, sqlparse.sql.Identifier):
|
||||
local_tables.extend(extract_from_token_list(token))
|
||||
|
||||
return local_tables
|
||||
|
||||
# 特殊处理DML语句
|
||||
if operation_type == 'UPDATE':
|
||||
# UPDATE语句通常在第一个标识符中包含表名
|
||||
for i, token in enumerate(stmt.tokens):
|
||||
if token.ttype is sqlparse.tokens.DML and token.value.upper() == 'UPDATE':
|
||||
if i+1 < len(stmt.tokens):
|
||||
if isinstance(stmt.tokens[i+1], sqlparse.sql.Identifier):
|
||||
tables.append(stmt.tokens[i+1].get_real_name())
|
||||
elif isinstance(stmt.tokens[i+1], sqlparse.sql.IdentifierList):
|
||||
# 多表更新
|
||||
for identifier in stmt.tokens[i+1].get_identifiers():
|
||||
if identifier.get_real_name():
|
||||
tables.append(identifier.get_real_name())
|
||||
break
|
||||
elif operation_type == 'INSERT':
|
||||
# INSERT语句
|
||||
into_found = False
|
||||
for i, token in enumerate(stmt.tokens):
|
||||
if token.ttype is sqlparse.tokens.Keyword and token.value.upper() == 'INTO':
|
||||
into_found = True
|
||||
elif into_found and isinstance(token, sqlparse.sql.Identifier):
|
||||
tables.append(token.get_real_name())
|
||||
break
|
||||
elif into_found and isinstance(token, sqlparse.sql.Function):
|
||||
# 处理INSERT INTO table(...)
|
||||
if token.get_name():
|
||||
tables.append(token.get_name())
|
||||
break
|
||||
elif operation_type == 'DELETE':
|
||||
# DELETE FROM table
|
||||
from_found = False
|
||||
for i, token in enumerate(stmt.tokens):
|
||||
if token.ttype is sqlparse.tokens.Keyword and token.value.upper() == 'FROM':
|
||||
from_found = True
|
||||
elif from_found and isinstance(token, sqlparse.sql.Identifier):
|
||||
tables.append(token.get_real_name())
|
||||
break
|
||||
elif from_found and isinstance(token, sqlparse.sql.IdentifierList):
|
||||
for identifier in token.get_identifiers():
|
||||
if identifier.get_real_name():
|
||||
tables.append(identifier.get_real_name())
|
||||
break
|
||||
elif operation_type in {'CREATE', 'ALTER', 'DROP', 'TRUNCATE'}:
|
||||
# DDL语句
|
||||
table_found = False
|
||||
for i, token in enumerate(stmt.tokens):
|
||||
if token.ttype is sqlparse.tokens.Keyword and token.value.upper() == 'TABLE':
|
||||
table_found = True
|
||||
elif table_found and isinstance(token, sqlparse.sql.Identifier):
|
||||
tables.append(token.get_real_name())
|
||||
break
|
||||
else:
|
||||
# 对于其他语句,通过递归处理提取表名
|
||||
tables.extend(extract_from_token_list(stmt))
|
||||
|
||||
# 移除可能的重复项
|
||||
return list(set([table for table in tables if table]))
|
||||
|
||||
@staticmethod
|
||||
def _has_where_clause(stmt: sqlparse.sql.Statement) -> bool:
|
||||
"""检查SQL语句是否包含WHERE子句"""
|
||||
for token in stmt.tokens:
|
||||
if isinstance(token, sqlparse.sql.Where):
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _has_limit_clause(stmt: sqlparse.sql.Statement) -> bool:
|
||||
"""检查SQL语句是否包含LIMIT子句"""
|
||||
# LIMIT通常作为一个关键字出现
|
||||
for token in stmt.tokens:
|
||||
if token.ttype is sqlparse.tokens.Keyword and token.value.upper() == 'LIMIT':
|
||||
return True
|
||||
# 处理更复杂的语句结构
|
||||
elif isinstance(token, sqlparse.sql.TokenList):
|
||||
for subtoken in token.tokens:
|
||||
if subtoken.ttype is sqlparse.tokens.Keyword and subtoken.value.upper() == 'LIMIT':
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _fallback_parse(sql_query: str) -> Dict:
|
||||
"""当高级解析失败时,回退到基本字符串解析"""
|
||||
sql_upper = sql_query.strip().upper()
|
||||
parts = sql_upper.split()
|
||||
|
||||
operation_type = parts[0] if parts else ""
|
||||
|
||||
# 确定操作类别
|
||||
category = 'UNKNOWN'
|
||||
if operation_type in SQLConfig.DDL_OPERATIONS:
|
||||
category = 'DDL'
|
||||
elif operation_type in SQLConfig.DML_OPERATIONS:
|
||||
category = 'DML'
|
||||
elif operation_type in SQLConfig.METADATA_OPERATIONS:
|
||||
category = 'METADATA'
|
||||
|
||||
# 基本的表名提取
|
||||
tables = []
|
||||
for i, word in enumerate(parts):
|
||||
if word in {'FROM', 'JOIN', 'UPDATE', 'INTO', 'TABLE'}:
|
||||
if i + 1 < len(parts):
|
||||
table = parts[i + 1].strip('`;')
|
||||
if table not in {'SELECT', 'WHERE', 'SET'}:
|
||||
tables.append(table)
|
||||
|
||||
# 简单检查WHERE子句
|
||||
has_where = 'WHERE' in sql_upper
|
||||
|
||||
# 简单检查LIMIT子句
|
||||
has_limit = 'LIMIT' in sql_upper
|
||||
|
||||
return {
|
||||
'operation_type': operation_type,
|
||||
'tables': list(set(tables)),
|
||||
'has_where': has_where,
|
||||
'has_limit': has_limit,
|
||||
'is_valid': bool(operation_type),
|
||||
'normalized_query': sql_query,
|
||||
'category': category
|
||||
}
|
||||
163
src/server.py
163
src/server.py
@ -1,12 +1,19 @@
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
import os
|
||||
import logging
|
||||
import asyncio
|
||||
from dotenv import load_dotenv
|
||||
import atexit
|
||||
import signal
|
||||
import importlib
|
||||
import pkgutil
|
||||
import inspect
|
||||
import threading
|
||||
|
||||
# 加载环境变量 - 移到最前面确保所有模块导入前环境变量已加载
|
||||
load_dotenv()
|
||||
|
||||
# 导入自定义模块 - 确保在load_dotenv之后导入
|
||||
from src.config import ServerConfig, SecurityConfig, DatabaseConfig, ConnectionPoolConfig
|
||||
from src.tools.mysql_tool import register_mysql_tool
|
||||
from src.tools.mysql_metadata_tool import register_metadata_tools
|
||||
from src.tools.mysql_info_tool import register_info_tools
|
||||
@ -21,23 +28,23 @@ logger = logging.getLogger("mysql_server")
|
||||
|
||||
# 记录环境变量加载情况
|
||||
logger.debug("已加载环境变量")
|
||||
logger.debug(f"当前允许的风险等级: {os.getenv('ALLOWED_RISK_LEVELS', '未设置')}")
|
||||
logger.debug(f"当前环境类型: {os.getenv('ENV_TYPE', 'development')}")
|
||||
logger.debug(f"是否允许敏感信息查询: {os.getenv('ALLOW_SENSITIVE_INFO', 'false')}")
|
||||
logger.debug(f"当前允许的风险等级: {SecurityConfig.ALLOWED_RISK_LEVELS_STR}")
|
||||
logger.debug(f"当前环境类型: {SecurityConfig.ENV_TYPE.value}")
|
||||
logger.debug(f"是否允许敏感信息查询: {SecurityConfig.ALLOW_SENSITIVE_INFO}")
|
||||
|
||||
# 尝试导入MySQL连接器
|
||||
try:
|
||||
import mysql.connector
|
||||
logger.debug("MySQL连接器导入成功")
|
||||
import aiomysql
|
||||
logger.debug("aiomysql连接器导入成功")
|
||||
mysql_available = True
|
||||
except ImportError as e:
|
||||
logger.critical(f"无法导入MySQL连接器: {str(e)}")
|
||||
logger.critical("请确保已安装mysql-connector-python包: pip install mysql-connector-python")
|
||||
logger.critical(f"无法导入aiomysql连接器: {str(e)}")
|
||||
logger.critical("请确保已安装aiomysql包: pip install aiomysql")
|
||||
mysql_available = False
|
||||
|
||||
# 从环境变量获取服务器配置
|
||||
host = os.getenv('HOST', '127.0.0.1')
|
||||
port = int(os.getenv('PORT', '3000'))
|
||||
# 从配置获取服务器配置
|
||||
host = ServerConfig.HOST
|
||||
port = ServerConfig.PORT
|
||||
logger.debug(f"服务器配置: host={host}, port={port}")
|
||||
|
||||
# 创建MCP服务器实例
|
||||
@ -45,20 +52,120 @@ logger.debug("正在创建MCP服务器实例...")
|
||||
mcp = FastMCP("MySQL Query Server", "cccccccccc", host=host, port=port, debug=True, endpoint='/sse')
|
||||
logger.debug("MCP服务器实例创建完成")
|
||||
|
||||
def auto_register_tools(mcp):
|
||||
"""
|
||||
自动扫描src.tools目录下所有register_开头的函数并注册到mcp
|
||||
"""
|
||||
import src.tools
|
||||
package = src.tools
|
||||
for finder, name, ispkg in pkgutil.iter_modules(package.__path__, package.__name__ + "."):
|
||||
if ispkg:
|
||||
continue
|
||||
module = importlib.import_module(name)
|
||||
for func_name, func in inspect.getmembers(module, inspect.isfunction):
|
||||
if func_name.startswith("register_") and func_name.endswith("tool") or func_name.endswith("tools"):
|
||||
try:
|
||||
func(mcp)
|
||||
logger.info(f"自动注册工具: {name}.{func_name}")
|
||||
except Exception as e:
|
||||
logger.error(f"自动注册工具失败: {name}.{func_name} - {e}")
|
||||
|
||||
# 注册MySQL基础查询工具
|
||||
register_mysql_tool(mcp)
|
||||
auto_register_tools(mcp)
|
||||
logger.debug("已自动注册所有MySQL工具")
|
||||
|
||||
# 注册MySQL元数据查询工具
|
||||
register_metadata_tools(mcp)
|
||||
logger.debug("已注册元数据查询工具")
|
||||
# 启动连接池定时回收任务
|
||||
def _start_pool_cleanup_task():
|
||||
"""启动后台线程定期回收连接池资源"""
|
||||
import time
|
||||
from src.db.mysql_operations import _cleanup_unused_pools
|
||||
def _loop():
|
||||
while True:
|
||||
try:
|
||||
_cleanup_unused_pools()
|
||||
except Exception as e:
|
||||
logger.warning(f"定时回收连接池异常: {e}")
|
||||
time.sleep(300) # 每5分钟回收一次
|
||||
t = threading.Thread(target=_loop, daemon=True)
|
||||
t.start()
|
||||
|
||||
# 注册MySQL数据库信息查询工具
|
||||
register_info_tools(mcp)
|
||||
logger.debug("已注册数据库信息查询工具")
|
||||
# 用于保存事件循环和初始化状态
|
||||
_server_data = {
|
||||
'loop': None,
|
||||
'db_initialized': False
|
||||
}
|
||||
|
||||
# 注册MySQL表结构高级查询工具
|
||||
register_schema_tools(mcp)
|
||||
logger.debug("已注册表结构高级查询工具")
|
||||
def cleanup_resources():
|
||||
"""清理资源,关闭连接池"""
|
||||
if _server_data['loop'] and _server_data['db_initialized']:
|
||||
try:
|
||||
# 导入连接池关闭函数
|
||||
from src.db.mysql_operations import close_all_pools
|
||||
|
||||
# 创建关闭任务并运行
|
||||
logger.info("正在关闭所有数据库连接池...")
|
||||
close_task = close_all_pools()
|
||||
|
||||
# 在当前事件循环中运行
|
||||
if _server_data['loop'].is_running():
|
||||
future = asyncio.run_coroutine_threadsafe(close_task, _server_data['loop'])
|
||||
future.result(timeout=5) # 等待最多5秒
|
||||
else:
|
||||
# 如果循环已经停止,创建新的循环运行清理任务
|
||||
temp_loop = asyncio.new_event_loop()
|
||||
temp_loop.run_until_complete(close_task)
|
||||
temp_loop.close()
|
||||
|
||||
logger.info("数据库连接池已关闭")
|
||||
except Exception as e:
|
||||
logger.error(f"关闭数据库连接池时出错: {str(e)}")
|
||||
|
||||
# 注册退出处理函数
|
||||
atexit.register(cleanup_resources)
|
||||
|
||||
# 注册信号处理
|
||||
def signal_handler(sig, frame):
|
||||
"""处理终止信号"""
|
||||
logger.info(f"收到信号 {sig},正在清理资源...")
|
||||
cleanup_resources()
|
||||
# 正常退出
|
||||
exit(0)
|
||||
|
||||
# 注册常见的终止信号
|
||||
signal.signal(signal.SIGINT, signal_handler) # Ctrl+C
|
||||
signal.signal(signal.SIGTERM, signal_handler) # kill
|
||||
|
||||
async def init_database():
|
||||
"""初始化数据库连接池"""
|
||||
try:
|
||||
from src.db.mysql_operations import init_db_pool, get_db_config
|
||||
|
||||
# 获取数据库配置
|
||||
db_config = get_db_config()
|
||||
|
||||
# 获取连接池配置
|
||||
pool_config = ConnectionPoolConfig.get_config()
|
||||
min_size = pool_config['minsize']
|
||||
max_size = pool_config['maxsize']
|
||||
|
||||
# 记录连接池配置
|
||||
logger.info(f"连接池配置: 最小连接数={min_size}, 最大连接数={max_size}, 回收时间={pool_config['pool_recycle']}秒")
|
||||
logger.info(f"连接池功能状态: {'启用' if pool_config['enabled'] else '禁用'}")
|
||||
|
||||
if not db_config.get('db'):
|
||||
logger.warning("未设置数据库名称,请检查环境变量MYSQL_DATABASE")
|
||||
print("警告: 未设置数据库名称,请检查环境变量MYSQL_DATABASE")
|
||||
# 初始化连接池但不要求指定数据库
|
||||
await init_db_pool(require_database=False)
|
||||
else:
|
||||
# 正常初始化连接池
|
||||
await init_db_pool()
|
||||
|
||||
logger.info("数据库连接池初始化完成")
|
||||
_server_data['db_initialized'] = True
|
||||
except Exception as e:
|
||||
logger.error(f"数据库连接池初始化失败: {str(e)}")
|
||||
print(f"警告: 数据库连接池初始化失败: {str(e)}")
|
||||
|
||||
def start_server():
|
||||
"""启动SSE服务器的同步包装器"""
|
||||
@ -68,12 +175,11 @@ def start_server():
|
||||
print(f"服务器监听在 {host}:{port}/sse")
|
||||
|
||||
try:
|
||||
# 检查MySQL配置是否有效
|
||||
from src.db.mysql_operations import get_db_config
|
||||
db_config = get_db_config()
|
||||
if mysql_available and not db_config['database']:
|
||||
logger.warning("未设置数据库名称,请检查环境变量MYSQL_DATABASE")
|
||||
print("警告: 未设置数据库名称,请检查环境变量MYSQL_DATABASE")
|
||||
# 检查MySQL配置是否有效并初始化连接池
|
||||
if mysql_available:
|
||||
# 使用事件循环执行异步初始化函数
|
||||
_server_data['loop'] = asyncio.get_event_loop()
|
||||
_server_data['loop'].run_until_complete(init_database())
|
||||
|
||||
# 使用run_app函数启动服务器
|
||||
logger.debug("调用mcp.run('sse')启动服务器...")
|
||||
@ -81,6 +187,9 @@ def start_server():
|
||||
except Exception as e:
|
||||
logger.exception(f"服务器运行时发生错误: {str(e)}")
|
||||
print(f"服务器运行时发生错误: {str(e)}")
|
||||
finally:
|
||||
# 确保资源被清理
|
||||
cleanup_resources()
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 确保初始化后工具才被注册
|
||||
|
||||
@ -5,12 +5,12 @@ MySQL元数据工具基类
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Union, TypeVar, Generic, Callable
|
||||
from typing import Any, Dict, List, Optional, Union, Callable
|
||||
import functools
|
||||
|
||||
from src.db.mysql_operations import get_db_connection, execute_query
|
||||
from src.validators import SQLValidators, ValidationError
|
||||
|
||||
T = TypeVar('T')
|
||||
logger = logging.getLogger("mysql_server")
|
||||
|
||||
class MySQLToolError(Exception):
|
||||
@ -50,8 +50,10 @@ class MetadataToolBase:
|
||||
Raises:
|
||||
ParameterValidationError: 当参数验证失败时
|
||||
"""
|
||||
if param_value is not None and not validator(param_value):
|
||||
raise ParameterValidationError(f"{param_name} - {error_message}")
|
||||
try:
|
||||
SQLValidators.validate_parameter(param_name, param_value, validator, "参数验证")
|
||||
except ValidationError as e:
|
||||
raise ParameterValidationError(str(e))
|
||||
|
||||
@staticmethod
|
||||
def format_results(results: List[Dict[str, Any]], operation_type: str = "元数据查询") -> str:
|
||||
@ -91,13 +93,22 @@ class MetadataToolBase:
|
||||
return await func(*args, **kwargs)
|
||||
except ParameterValidationError as e:
|
||||
logger.error(f"参数验证错误: {str(e)}")
|
||||
return json.dumps({"error": f"参数错误: {str(e)}"})
|
||||
return json.dumps({
|
||||
"error": f"参数错误: {str(e)}",
|
||||
"error_type": "ParameterValidationError"
|
||||
})
|
||||
except QueryExecutionError as e:
|
||||
logger.error(f"查询执行错误: {str(e)}")
|
||||
return json.dumps({"error": f"查询执行失败: {str(e)}"})
|
||||
return json.dumps({
|
||||
"error": f"查询执行失败: {str(e)}",
|
||||
"error_type": "QueryExecutionError"
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"未预期的错误: {str(e)}")
|
||||
return json.dumps({"error": f"操作失败: {str(e)}"})
|
||||
return json.dumps({
|
||||
"error": f"操作失败: {str(e)}",
|
||||
"error_type": "UnexpectedError"
|
||||
})
|
||||
return wrapper
|
||||
|
||||
@staticmethod
|
||||
@ -115,9 +126,9 @@ class MetadataToolBase:
|
||||
查询结果的JSON字符串
|
||||
"""
|
||||
try:
|
||||
with get_db_connection() as connection:
|
||||
async with get_db_connection() as connection:
|
||||
results = await execute_query(connection, query, params)
|
||||
return MetadataToolBase.format_results(results, operation_type)
|
||||
except Exception as e:
|
||||
logger.error(f"元数据查询执行失败: {str(e)}")
|
||||
raise QueryExecutionError(str(e))
|
||||
raise QueryExecutionError(str(e)) from e # 保留原始异常链
|
||||
@ -13,6 +13,7 @@ from mcp.server.fastmcp import FastMCP
|
||||
from .metadata_base_tool import MetadataToolBase, ParameterValidationError, QueryExecutionError
|
||||
from src.security.sql_analyzer import EnvironmentType
|
||||
from src.db.mysql_operations import get_db_connection, execute_query
|
||||
from src.validators import SQLValidators
|
||||
|
||||
logger = logging.getLogger("mysql_server")
|
||||
|
||||
@ -46,6 +47,10 @@ SENSITIVE_VARIABLE_PREFIXES = [
|
||||
"authentication", "secure", "credential", "token"
|
||||
]
|
||||
|
||||
# 变量名和值字段映射
|
||||
VARIABLE_NAME_FIELDS = ['Variable_name', 'variable_name', 'name', 'Name', 'key', 'Key', 'Setting']
|
||||
VALUE_FIELDS = ['Value', 'value', 'variable_value', 'val', 'setting', 'Setting_Value']
|
||||
|
||||
def check_environment_permission(env_type: EnvironmentType, query_type: str) -> bool:
|
||||
"""
|
||||
检查当前环境是否允许执行特定类型的查询
|
||||
@ -92,19 +97,27 @@ def filter_sensitive_info(results: List[Dict[str, Any]], filter_patterns: List[s
|
||||
# 复制一份,避免修改原始数据
|
||||
filtered_item = item.copy()
|
||||
|
||||
# 检查常见的变量名字段
|
||||
for field in ['Variable_name', 'variable_name', 'name']:
|
||||
# 确定哪个字段包含变量名
|
||||
name_field = None
|
||||
for field in VARIABLE_NAME_FIELDS:
|
||||
if field in filtered_item:
|
||||
var_name = filtered_item[field].lower()
|
||||
# 检查是否匹配敏感模式
|
||||
is_sensitive = any(re.search(pattern, var_name, re.IGNORECASE) for pattern in filter_patterns)
|
||||
name_field = field
|
||||
break
|
||||
|
||||
if is_sensitive:
|
||||
# 敏感信息,隐藏具体的值
|
||||
for value_field in ['Value', 'value', 'variable_value']:
|
||||
if value_field in filtered_item:
|
||||
filtered_item[value_field] = '*** HIDDEN ***'
|
||||
|
||||
# 如果找到变量名字段,检查是否敏感
|
||||
if name_field:
|
||||
var_name = str(filtered_item[name_field]).lower()
|
||||
# 检查是否匹配敏感模式
|
||||
is_sensitive = any(re.search(pattern, var_name, re.IGNORECASE) for pattern in filter_patterns)
|
||||
|
||||
if is_sensitive:
|
||||
# 找出所有可能的值字段
|
||||
for value_field in VALUE_FIELDS:
|
||||
if value_field in filtered_item:
|
||||
# 敏感信息,隐藏具体的值
|
||||
filtered_item[value_field] = '*** HIDDEN ***'
|
||||
logger.debug(f"已隐藏敏感变量 '{var_name}' 的值")
|
||||
|
||||
filtered_results.append(filtered_item)
|
||||
|
||||
return filtered_results
|
||||
@ -136,21 +149,21 @@ def register_info_tools(mcp: FastMCP):
|
||||
if pattern:
|
||||
MetadataToolBase.validate_parameter(
|
||||
"pattern", pattern,
|
||||
lambda x: re.match(r'^[a-zA-Z0-9_%]+$', x),
|
||||
SQLValidators.validate_like_pattern,
|
||||
"模式只能包含字母、数字、下划线和通配符(%_)"
|
||||
)
|
||||
|
||||
MetadataToolBase.validate_parameter(
|
||||
"limit", limit,
|
||||
lambda x: isinstance(x, int) and x >= 0,
|
||||
lambda x: SQLValidators.validate_integer(x, min_value=0),
|
||||
"返回结果的最大数量必须是非负整数"
|
||||
)
|
||||
|
||||
# 构建基础查询
|
||||
query = "SHOW DATABASES"
|
||||
|
||||
# 执行查询
|
||||
with get_db_connection() as connection:
|
||||
# 执行查询 - 使用异步上下文管理器,不要求预先指定数据库
|
||||
async with get_db_connection(require_database=False) as connection:
|
||||
# 先获取所有数据库
|
||||
results = await execute_query(connection, query)
|
||||
|
||||
@ -227,7 +240,7 @@ def register_info_tools(mcp: FastMCP):
|
||||
if pattern:
|
||||
MetadataToolBase.validate_parameter(
|
||||
"pattern", pattern,
|
||||
lambda x: re.match(r'^[a-zA-Z0-9_%]+$', x),
|
||||
SQLValidators.validate_like_pattern,
|
||||
"变量模式只能包含字母、数字、下划线和通配符(%_)"
|
||||
)
|
||||
|
||||
@ -239,7 +252,7 @@ def register_info_tools(mcp: FastMCP):
|
||||
|
||||
logger.debug(f"执行查询: {query}")
|
||||
|
||||
with get_db_connection() as connection:
|
||||
async with get_db_connection() as connection:
|
||||
results = await execute_query(connection, query)
|
||||
|
||||
# 生产环境中过滤敏感信息
|
||||
@ -273,7 +286,7 @@ def register_info_tools(mcp: FastMCP):
|
||||
if pattern:
|
||||
MetadataToolBase.validate_parameter(
|
||||
"pattern", pattern,
|
||||
lambda x: re.match(r'^[a-zA-Z0-9_%]+$', x),
|
||||
SQLValidators.validate_like_pattern,
|
||||
"状态模式只能包含字母、数字、下划线和通配符(%_)"
|
||||
)
|
||||
|
||||
@ -285,48 +298,11 @@ def register_info_tools(mcp: FastMCP):
|
||||
|
||||
logger.debug(f"执行查询: {query}")
|
||||
|
||||
with get_db_connection() as connection:
|
||||
async with get_db_connection() as connection:
|
||||
results = await execute_query(connection, query)
|
||||
|
||||
# 生产环境中过滤敏感信息
|
||||
if env_type == EnvironmentType.PRODUCTION:
|
||||
results = filter_sensitive_info(results)
|
||||
|
||||
return MetadataToolBase.format_results(results, operation_type="服务器状态查询")
|
||||
|
||||
# 工具函数: 用于参数验证
|
||||
def validate_pattern(pattern: str) -> bool:
|
||||
"""
|
||||
验证模式字符串是否安全 (防止SQL注入)
|
||||
|
||||
Args:
|
||||
pattern: 要验证的模式字符串
|
||||
|
||||
Returns:
|
||||
如果模式安全返回True,否则抛出ValueError
|
||||
|
||||
Raises:
|
||||
ValueError: 当模式包含不安全字符时
|
||||
"""
|
||||
# 仅允许字母、数字、下划线和通配符(% 和 _)
|
||||
if not re.match(r'^[a-zA-Z0-9_%]+$', pattern):
|
||||
raise ValueError("模式只能包含字母、数字、下划线和通配符(%_)")
|
||||
return True
|
||||
|
||||
def validate_engine_name(name: str) -> bool:
|
||||
"""
|
||||
验证存储引擎名称是否合法安全
|
||||
|
||||
Args:
|
||||
name: 要验证的引擎名称
|
||||
|
||||
Returns:
|
||||
如果引擎名称安全返回True,否则抛出ValueError
|
||||
|
||||
Raises:
|
||||
ValueError: 当引擎名称包含不安全字符时
|
||||
"""
|
||||
# 仅允许字母、数字和下划线
|
||||
if not re.match(r'^[a-zA-Z0-9_]+$', name):
|
||||
raise ValueError(f"无效的引擎名称: {name}, 引擎名称只能包含字母、数字和下划线")
|
||||
return True
|
||||
return MetadataToolBase.format_results(results, operation_type="服务器状态查询")
|
||||
@ -11,64 +11,10 @@ from mcp.server.fastmcp import FastMCP
|
||||
|
||||
from .metadata_base_tool import MetadataToolBase, ParameterValidationError, QueryExecutionError
|
||||
from src.db.mysql_operations import get_db_connection, execute_query
|
||||
from src.validators import SQLValidators
|
||||
|
||||
logger = logging.getLogger("mysql_server")
|
||||
|
||||
# 工具函数: 用于参数验证
|
||||
def validate_pattern(pattern: str) -> bool:
|
||||
"""
|
||||
验证模式字符串是否安全 (防止SQL注入)
|
||||
|
||||
Args:
|
||||
pattern: 要验证的模式字符串
|
||||
|
||||
Returns:
|
||||
如果模式安全返回True,否则抛出ValueError
|
||||
|
||||
Raises:
|
||||
ValueError: 当模式包含不安全字符时
|
||||
"""
|
||||
# 仅允许字母、数字、下划线和通配符(% 和 _)
|
||||
if not re.match(r'^[a-zA-Z0-9_%]+$', pattern):
|
||||
raise ValueError("模式只能包含字母、数字、下划线和通配符(%_)")
|
||||
return True
|
||||
|
||||
def validate_table_name(name: str) -> bool:
|
||||
"""
|
||||
验证表名是否合法安全
|
||||
|
||||
Args:
|
||||
name: 要验证的表名
|
||||
|
||||
Returns:
|
||||
如果表名安全返回True,否则抛出ValueError
|
||||
|
||||
Raises:
|
||||
ValueError: 当表名包含不安全字符时
|
||||
"""
|
||||
# 仅允许字母、数字和下划线
|
||||
if not re.match(r'^[a-zA-Z0-9_]+$', name):
|
||||
raise ValueError(f"无效的表名: {name}, 表名只能包含字母、数字和下划线")
|
||||
return True
|
||||
|
||||
def validate_database_name(name: str) -> bool:
|
||||
"""
|
||||
验证数据库名是否合法安全
|
||||
|
||||
Args:
|
||||
name: 要验证的数据库名
|
||||
|
||||
Returns:
|
||||
如果数据库名安全返回True,否则抛出ValueError
|
||||
|
||||
Raises:
|
||||
ValueError: 当数据库名包含不安全字符时
|
||||
"""
|
||||
# 仅允许字母、数字和下划线
|
||||
if not re.match(r'^[a-zA-Z0-9_]+$', name):
|
||||
raise ValueError(f"无效的数据库名: {name}, 数据库名只能包含字母、数字和下划线")
|
||||
return True
|
||||
|
||||
def register_metadata_tools(mcp: FastMCP):
|
||||
"""
|
||||
注册MySQL元数据查询工具到MCP服务器
|
||||
@ -98,20 +44,20 @@ def register_metadata_tools(mcp: FastMCP):
|
||||
if database:
|
||||
MetadataToolBase.validate_parameter(
|
||||
"database", database,
|
||||
lambda x: re.match(r'^[a-zA-Z0-9_]+$', x),
|
||||
SQLValidators.validate_database_name,
|
||||
"数据库名称只能包含字母、数字和下划线"
|
||||
)
|
||||
|
||||
if pattern:
|
||||
MetadataToolBase.validate_parameter(
|
||||
"pattern", pattern,
|
||||
lambda x: re.match(r'^[a-zA-Z0-9_%]+$', x),
|
||||
SQLValidators.validate_like_pattern,
|
||||
"模式只能包含字母、数字、下划线和通配符(%_)"
|
||||
)
|
||||
|
||||
MetadataToolBase.validate_parameter(
|
||||
"limit", limit,
|
||||
lambda x: isinstance(x, int) and x >= 0,
|
||||
lambda x: SQLValidators.validate_integer(x, min_value=0),
|
||||
"返回结果的最大数量必须是非负整数"
|
||||
)
|
||||
|
||||
@ -124,8 +70,8 @@ def register_metadata_tools(mcp: FastMCP):
|
||||
|
||||
logger.debug(f"执行查询: {base_query}")
|
||||
|
||||
# 执行查询
|
||||
with get_db_connection() as connection:
|
||||
# 执行查询 - 使用异步上下文管理器
|
||||
async with get_db_connection() as connection:
|
||||
results = await execute_query(connection, base_query)
|
||||
|
||||
# 如果需要排除视图,且使用的是SHOW FULL TABLES
|
||||
@ -133,17 +79,46 @@ def register_metadata_tools(mcp: FastMCP):
|
||||
filtered_results = []
|
||||
|
||||
# 查找表名和表类型字段
|
||||
fields = list(results[0].keys()) if results else []
|
||||
table_field = fields[0] if fields else None
|
||||
table_type_field = fields[1] if len(fields) > 1 else None
|
||||
if results:
|
||||
# 确定表名和表类型字段名
|
||||
field_names = list(results[0].keys())
|
||||
table_type_field = None
|
||||
table_field = None
|
||||
|
||||
# 查找表类型字段 - 这通常是'Table_type',但也检查其他可能的名称
|
||||
possible_type_fields = ['Table_type', 'table_type', 'type']
|
||||
for field in possible_type_fields:
|
||||
if field in field_names:
|
||||
table_type_field = field
|
||||
break
|
||||
|
||||
# 查找表名字段 - 这可能是结果中的第一个字段
|
||||
for field in field_names:
|
||||
if field != table_type_field: # 表名不会是类型字段
|
||||
if field.lower() in ['table', 'name', 'table_name', 'tables_in_']:
|
||||
table_field = field
|
||||
break
|
||||
|
||||
# 如果没找到明确的表名字段,使用第一个非类型字段
|
||||
if not table_field and len(field_names) > 0:
|
||||
for field in field_names:
|
||||
if field != table_type_field:
|
||||
table_field = field
|
||||
break
|
||||
|
||||
if table_field and table_type_field:
|
||||
# 基表类型通常是"BASE TABLE"
|
||||
for item in results:
|
||||
if item[table_type_field] == 'BASE TABLE':
|
||||
filtered_results.append(item)
|
||||
# 只有当我们能确定表名和类型字段时才进行过滤
|
||||
if table_field and table_type_field:
|
||||
logger.debug(f"表名字段: {table_field}, 表类型字段: {table_type_field}")
|
||||
# 只保留基表 (BASE TABLE),排除视图和其他对象
|
||||
for item in results:
|
||||
if item[table_type_field] == 'BASE TABLE':
|
||||
filtered_results.append(item)
|
||||
else:
|
||||
# 如果无法确定字段,保留所有结果并记录警告
|
||||
logger.warning("无法确定表类型字段,无法排除视图")
|
||||
filtered_results = results
|
||||
else:
|
||||
filtered_results = results
|
||||
filtered_results = []
|
||||
else:
|
||||
filtered_results = results
|
||||
|
||||
@ -189,14 +164,14 @@ def register_metadata_tools(mcp: FastMCP):
|
||||
# 参数验证
|
||||
MetadataToolBase.validate_parameter(
|
||||
"table", table,
|
||||
lambda x: re.match(r'^[a-zA-Z0-9_]+$', x),
|
||||
SQLValidators.validate_table_name,
|
||||
"表名只能包含字母、数字和下划线"
|
||||
)
|
||||
|
||||
if database:
|
||||
MetadataToolBase.validate_parameter(
|
||||
"database", database,
|
||||
lambda x: re.match(r'^[a-zA-Z0-9_]+$', x),
|
||||
SQLValidators.validate_database_name,
|
||||
"数据库名称只能包含字母、数字和下划线"
|
||||
)
|
||||
|
||||
@ -221,22 +196,21 @@ def register_metadata_tools(mcp: FastMCP):
|
||||
# 参数验证
|
||||
MetadataToolBase.validate_parameter(
|
||||
"table", table,
|
||||
lambda x: re.match(r'^[a-zA-Z0-9_]+$', x),
|
||||
SQLValidators.validate_table_name,
|
||||
"表名只能包含字母、数字和下划线"
|
||||
)
|
||||
|
||||
if database:
|
||||
MetadataToolBase.validate_parameter(
|
||||
"database", database,
|
||||
lambda x: re.match(r'^[a-zA-Z0-9_]+$', x),
|
||||
SQLValidators.validate_database_name,
|
||||
"数据库名称只能包含字母、数字和下划线"
|
||||
)
|
||||
|
||||
# DESCRIBE 语句与 SHOW COLUMNS 功能类似,但结果格式可能略有不同
|
||||
query = f"DESCRIBE `{table}`" if not database else f"DESCRIBE `{database}`.`{table}`"
|
||||
logger.debug(f"执行查询: {query}")
|
||||
|
||||
return await MetadataToolBase.execute_metadata_query(query, operation_type="表结构描述")
|
||||
return await MetadataToolBase.execute_metadata_query(query, operation_type="表结构描述查询")
|
||||
|
||||
@mcp.tool()
|
||||
@MetadataToolBase.handle_query_error
|
||||
@ -254,14 +228,14 @@ def register_metadata_tools(mcp: FastMCP):
|
||||
# 参数验证
|
||||
MetadataToolBase.validate_parameter(
|
||||
"table", table,
|
||||
lambda x: re.match(r'^[a-zA-Z0-9_]+$', x),
|
||||
SQLValidators.validate_table_name,
|
||||
"表名只能包含字母、数字和下划线"
|
||||
)
|
||||
|
||||
if database:
|
||||
MetadataToolBase.validate_parameter(
|
||||
"database", database,
|
||||
lambda x: re.match(r'^[a-zA-Z0-9_]+$', x),
|
||||
SQLValidators.validate_database_name,
|
||||
"数据库名称只能包含字母、数字和下划线"
|
||||
)
|
||||
|
||||
|
||||
@ -6,74 +6,21 @@ MySQL表结构高级查询工具
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
|
||||
from .metadata_base_tool import MetadataToolBase, ParameterValidationError, QueryExecutionError
|
||||
from src.db.mysql_operations import get_db_connection, execute_query
|
||||
from src.validators import SQLValidators
|
||||
|
||||
logger = logging.getLogger("mysql_server")
|
||||
|
||||
# 参数验证函数
|
||||
def validate_table_name(name: str) -> bool:
|
||||
"""
|
||||
验证表名是否合法安全
|
||||
|
||||
Args:
|
||||
name: 要验证的表名
|
||||
|
||||
Returns:
|
||||
如果表名安全返回True,否则抛出ValueError
|
||||
|
||||
Raises:
|
||||
ValueError: 当表名包含不安全字符时
|
||||
"""
|
||||
# 仅允许字母、数字和下划线
|
||||
if not re.match(r'^[a-zA-Z0-9_]+$', name):
|
||||
raise ValueError(f"无效的表名: {name}, 表名只能包含字母、数字和下划线")
|
||||
return True
|
||||
|
||||
def validate_database_name(name: str) -> bool:
|
||||
"""
|
||||
验证数据库名是否合法安全
|
||||
|
||||
Args:
|
||||
name: 要验证的数据库名
|
||||
|
||||
Returns:
|
||||
如果数据库名安全返回True,否则抛出ValueError
|
||||
|
||||
Raises:
|
||||
ValueError: 当数据库名包含不安全字符时
|
||||
"""
|
||||
# 仅允许字母、数字和下划线
|
||||
if not re.match(r'^[a-zA-Z0-9_]+$', name):
|
||||
raise ValueError(f"无效的数据库名: {name}, 数据库名只能包含字母、数字和下划线")
|
||||
return True
|
||||
|
||||
def validate_column_name(name: str) -> bool:
|
||||
"""
|
||||
验证列名是否合法安全
|
||||
|
||||
Args:
|
||||
name: 要验证的列名
|
||||
|
||||
Returns:
|
||||
如果列名安全返回True,否则抛出ValueError
|
||||
|
||||
Raises:
|
||||
ValueError: 当列名包含不安全字符时
|
||||
"""
|
||||
# 仅允许字母、数字和下划线
|
||||
if not re.match(r'^[a-zA-Z0-9_]+$', name):
|
||||
raise ValueError(f"无效的列名: {name}, 列名只能包含字母、数字和下划线")
|
||||
return True
|
||||
|
||||
async def execute_schema_query(
|
||||
query: str,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
operation_type: str = "元数据查询"
|
||||
operation_type: str = "元数据查询",
|
||||
stream_results: bool = False,
|
||||
batch_size: int = 1000
|
||||
) -> str:
|
||||
"""
|
||||
执行表结构查询
|
||||
@ -82,12 +29,20 @@ async def execute_schema_query(
|
||||
query: SQL查询语句
|
||||
params: 查询参数 (可选)
|
||||
operation_type: 操作类型描述
|
||||
stream_results: 是否使用流式处理获取大型结果集
|
||||
batch_size: 批处理大小,分批获取结果时的每批记录数量
|
||||
|
||||
Returns:
|
||||
查询结果的JSON字符串
|
||||
"""
|
||||
with get_db_connection() as connection:
|
||||
results = await execute_query(connection, query, params)
|
||||
async with get_db_connection() as connection:
|
||||
results = await execute_query(
|
||||
connection,
|
||||
query,
|
||||
params,
|
||||
batch_size=batch_size,
|
||||
stream_results=stream_results
|
||||
)
|
||||
return MetadataToolBase.format_results(results, operation_type)
|
||||
|
||||
def register_schema_tools(mcp: FastMCP):
|
||||
@ -113,10 +68,18 @@ def register_schema_tools(mcp: FastMCP):
|
||||
表索引信息的JSON字符串
|
||||
"""
|
||||
# 参数验证
|
||||
validate_table_name(table)
|
||||
MetadataToolBase.validate_parameter(
|
||||
"table", table,
|
||||
SQLValidators.validate_table_name,
|
||||
"表名只能包含字母、数字和下划线"
|
||||
)
|
||||
|
||||
if database:
|
||||
validate_database_name(database)
|
||||
MetadataToolBase.validate_parameter(
|
||||
"database", database,
|
||||
SQLValidators.validate_database_name,
|
||||
"数据库名称只能包含字母、数字和下划线"
|
||||
)
|
||||
|
||||
# 构建查询
|
||||
table_ref = f"`{table}`" if not database else f"`{database}`.`{table}`"
|
||||
@ -141,10 +104,18 @@ def register_schema_tools(mcp: FastMCP):
|
||||
"""
|
||||
# 参数验证
|
||||
if database:
|
||||
validate_database_name(database)
|
||||
MetadataToolBase.validate_parameter(
|
||||
"database", database,
|
||||
SQLValidators.validate_database_name,
|
||||
"数据库名称只能包含字母、数字和下划线"
|
||||
)
|
||||
|
||||
if like_pattern:
|
||||
validate_column_name(like_pattern)
|
||||
MetadataToolBase.validate_parameter(
|
||||
"like_pattern", like_pattern,
|
||||
SQLValidators.validate_like_pattern,
|
||||
"模式只能包含字母、数字、下划线和通配符(%_)"
|
||||
)
|
||||
|
||||
# 构建查询
|
||||
if database:
|
||||
@ -174,16 +145,24 @@ def register_schema_tools(mcp: FastMCP):
|
||||
表外键约束信息的JSON字符串
|
||||
"""
|
||||
# 参数验证
|
||||
validate_table_name(table)
|
||||
MetadataToolBase.validate_parameter(
|
||||
"table", table,
|
||||
SQLValidators.validate_table_name,
|
||||
"表名只能包含字母、数字和下划线"
|
||||
)
|
||||
|
||||
if database:
|
||||
validate_database_name(database)
|
||||
MetadataToolBase.validate_parameter(
|
||||
"database", database,
|
||||
SQLValidators.validate_database_name,
|
||||
"数据库名称只能包含字母、数字和下划线"
|
||||
)
|
||||
|
||||
# 确定数据库名
|
||||
db_name = database
|
||||
if not db_name:
|
||||
# 获取当前数据库
|
||||
with get_db_connection() as connection:
|
||||
# 获取当前数据库 - 使用异步上下文管理器
|
||||
async with get_db_connection() as connection:
|
||||
current_db_results = await execute_query(connection, "SELECT DATABASE() as db")
|
||||
if current_db_results and 'db' in current_db_results[0]:
|
||||
db_name = current_db_results[0]['db']
|
||||
@ -191,7 +170,7 @@ def register_schema_tools(mcp: FastMCP):
|
||||
if not db_name:
|
||||
raise ValueError("无法确定数据库名称,请明确指定database参数")
|
||||
|
||||
# 使用INFORMATION_SCHEMA查询外键
|
||||
# 使用INFORMATION_SCHEMA查询外键 - 修改为使用命名参数
|
||||
query = """
|
||||
SELECT
|
||||
CONSTRAINT_NAME,
|
||||
@ -208,16 +187,19 @@ def register_schema_tools(mcp: FastMCP):
|
||||
ON
|
||||
kcu.CONSTRAINT_NAME = rc.CONSTRAINT_NAME
|
||||
WHERE
|
||||
kcu.TABLE_SCHEMA = %s
|
||||
AND kcu.TABLE_NAME = %s
|
||||
kcu.TABLE_SCHEMA = %(table_schema)s
|
||||
AND kcu.TABLE_NAME = %(table_name)s
|
||||
AND kcu.REFERENCED_TABLE_NAME IS NOT NULL
|
||||
"""
|
||||
params = {'TABLE_SCHEMA': db_name, 'TABLE_NAME': table}
|
||||
|
||||
logger.debug(f"执行查询: 获取表 {db_name}.{table} 的外键约束")
|
||||
# 使用命名参数,键名与SQL中的占位符对应
|
||||
params = {"table_schema": db_name, "table_name": table}
|
||||
|
||||
logger.debug(f"执行外键查询: {query}")
|
||||
logger.debug(f"参数: {params}")
|
||||
|
||||
# 执行查询
|
||||
return await execute_schema_query(query, params, operation_type="外键约束查询")
|
||||
return await execute_schema_query(query, params, operation_type="表外键查询")
|
||||
|
||||
@mcp.tool()
|
||||
@MetadataToolBase.handle_query_error
|
||||
@ -236,47 +218,62 @@ def register_schema_tools(mcp: FastMCP):
|
||||
# 参数验证
|
||||
MetadataToolBase.validate_parameter(
|
||||
"page", page,
|
||||
lambda x: isinstance(x, int) and x > 0,
|
||||
lambda x: SQLValidators.validate_integer(x, min_value=1),
|
||||
"页码必须是正整数"
|
||||
)
|
||||
|
||||
MetadataToolBase.validate_parameter(
|
||||
"page_size", page_size,
|
||||
lambda x: isinstance(x, int) and 1 <= x <= 1000,
|
||||
"每页记录数必须在1-1000之间"
|
||||
lambda x: SQLValidators.validate_integer(x, min_value=1, max_value=1000),
|
||||
"每页记录数必须是正整数且不超过1000"
|
||||
)
|
||||
|
||||
# 检查查询语法
|
||||
if not query.strip().upper().startswith('SELECT'):
|
||||
raise ValueError("只支持SELECT查询的分页")
|
||||
|
||||
# 计算LIMIT和OFFSET
|
||||
# 计算偏移量
|
||||
offset = (page - 1) * page_size
|
||||
|
||||
# 在查询末尾添加LIMIT子句
|
||||
paginated_query = query.strip()
|
||||
if 'LIMIT' in paginated_query.upper():
|
||||
raise ValueError("查询已包含LIMIT子句,请移除后重试")
|
||||
# 分离基础查询和LIMIT/OFFSET部分
|
||||
base_query = query.strip()
|
||||
if re.search(r'\bLIMIT\b', base_query, re.IGNORECASE):
|
||||
raise ValueError("查询语句已包含LIMIT子句,不能与分页功能一起使用")
|
||||
|
||||
paginated_query += f" LIMIT {page_size} OFFSET {offset}"
|
||||
# 添加LIMIT和OFFSET
|
||||
paginated_query = f"{base_query} LIMIT {page_size} OFFSET {offset}"
|
||||
|
||||
logger.debug(f"执行分页查询: 页码={page}, 每页记录数={page_size}")
|
||||
logger.debug(f"执行分页查询: {paginated_query}")
|
||||
logger.debug(f"页码: {page}, 每页记录数: {page_size}, 偏移量: {offset}")
|
||||
|
||||
# 获取总记录数(用于计算总页数)
|
||||
count_query = f"SELECT COUNT(*) as total FROM ({query}) as temp_count_table"
|
||||
|
||||
with get_db_connection() as connection:
|
||||
# 执行分页查询
|
||||
# 执行查询 - 使用异步上下文管理器
|
||||
async with get_db_connection() as connection:
|
||||
# 首先检查并验证查询
|
||||
# 确认查询安全性 - 限制查询类型,只允许SELECT查询
|
||||
if not base_query.strip().upper().startswith('SELECT'):
|
||||
raise ValueError("只支持SELECT查询进行分页")
|
||||
|
||||
# 使用普通查询获取当前页结果(不需要流式处理,因为已经有LIMIT限制)
|
||||
results = await execute_query(connection, paginated_query)
|
||||
|
||||
# 获取总记录数
|
||||
count_results = await execute_query(connection, count_query)
|
||||
total_records = count_results[0]['total'] if count_results else 0
|
||||
|
||||
# 计算总页数
|
||||
total_pages = (total_records + page_size - 1) // page_size
|
||||
|
||||
# 构建分页元数据
|
||||
# 尝试获取总记录数 - 对于大型结果集使用流式处理
|
||||
try:
|
||||
# 由于无法参数化子查询,我们改为构建一个只返回计数的查询
|
||||
# 这仍有SQL注入风险,但我们已经验证查询只能是SELECT
|
||||
count_query = f"SELECT COUNT(*) as total FROM ({base_query}) as subquery"
|
||||
# 计数查询通常只返回一行,不需要流式处理
|
||||
count_results = await execute_query(connection, count_query)
|
||||
total = count_results[0]['total'] if count_results else 0
|
||||
|
||||
# 根据总记录数计算是否是大型结果集
|
||||
is_large_resultset = total > 1000
|
||||
|
||||
# 提示用户结果集大小
|
||||
if is_large_resultset:
|
||||
logger.info(f"检测到大型结果集,共 {total} 条记录,建议使用较小的 page_size 值")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"无法执行总数查询: {str(e)}")
|
||||
total = None
|
||||
is_large_resultset = False
|
||||
|
||||
# 构造分页元数据
|
||||
pagination_info = {
|
||||
"metadata_info": {
|
||||
"operation_type": "分页查询",
|
||||
@ -284,8 +281,11 @@ def register_schema_tools(mcp: FastMCP):
|
||||
"pagination": {
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"total_records": total_records,
|
||||
"total_pages": total_pages
|
||||
"total_records": total,
|
||||
"total_pages": (total + page_size - 1) // page_size if total else None,
|
||||
"has_next": (page * page_size < total) if total is not None else len(results) == page_size,
|
||||
"has_previous": page > 1,
|
||||
"is_large_resultset": is_large_resultset if total is not None else None
|
||||
}
|
||||
},
|
||||
"results": results
|
||||
|
||||
@ -3,16 +3,14 @@ import logging
|
||||
from typing import Any, Dict, Optional
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
from src.db.mysql_operations import get_db_connection, execute_query
|
||||
import mysql.connector
|
||||
import aiomysql
|
||||
|
||||
from .metadata_base_tool import MetadataToolBase
|
||||
|
||||
logger = logging.getLogger("mysql_server")
|
||||
|
||||
# 尝试导入MySQL连接器
|
||||
try:
|
||||
mysql.connector
|
||||
mysql_available = True
|
||||
except ImportError:
|
||||
mysql_available = False
|
||||
# MySQL可用性检查变量,默认认为aiomysql已可用
|
||||
mysql_available = True
|
||||
|
||||
def register_mysql_tool(mcp: FastMCP):
|
||||
"""
|
||||
@ -24,6 +22,7 @@ def register_mysql_tool(mcp: FastMCP):
|
||||
logger.debug("注册MySQL查询工具...")
|
||||
|
||||
@mcp.tool()
|
||||
@MetadataToolBase.handle_query_error
|
||||
async def mysql_query(query: str, params: Optional[Dict[str, Any]] = None) -> str:
|
||||
"""
|
||||
执行MySQL查询并返回结果
|
||||
@ -37,18 +36,22 @@ def register_mysql_tool(mcp: FastMCP):
|
||||
"""
|
||||
logger.debug(f"执行MySQL查询: {query}, 参数: {params}")
|
||||
|
||||
try:
|
||||
with get_db_connection() as connection:
|
||||
results = await execute_query(connection, query, params)
|
||||
async with get_db_connection() as connection:
|
||||
results = await execute_query(connection, query, params)
|
||||
|
||||
# 检查是否是修改操作返回的影响行数
|
||||
operation = query.strip().split()[0].upper()
|
||||
if operation in {'UPDATE', 'DELETE', 'INSERT'} and results and 'affected_rows' in results[0]:
|
||||
affected_rows = results[0]['affected_rows']
|
||||
logger.info(f"{operation}操作影响了{affected_rows}行数据")
|
||||
|
||||
# 检查是否是修改操作返回的影响行数
|
||||
operation = query.strip().split()[0].upper()
|
||||
if operation in {'UPDATE', 'DELETE', 'INSERT'} and results and 'affected_rows' in results[0]:
|
||||
affected_rows = results[0]['affected_rows']
|
||||
logger.info(f"{operation}操作影响了{affected_rows}行数据")
|
||||
|
||||
return json.dumps(results, default=str)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"执行查询时发生异常: {str(e)}")
|
||||
return json.dumps({"error": str(e)})
|
||||
# 添加元数据信息
|
||||
metadata_info = {
|
||||
"metadata_info": {
|
||||
"operation_type": operation,
|
||||
"result_count": len(results)
|
||||
},
|
||||
"results": results
|
||||
}
|
||||
|
||||
return json.dumps(metadata_info, default=str)
|
||||
127
src/validators.py
Normal file
127
src/validators.py
Normal file
@ -0,0 +1,127 @@
|
||||
import re
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
class ValidationError(Exception):
|
||||
"""验证错误异常"""
|
||||
pass
|
||||
|
||||
class SQLValidators:
|
||||
"""SQL相关验证器集合"""
|
||||
|
||||
# 正则表达式常量
|
||||
IDENTIFIER_PATTERN = r'^[a-zA-Z0-9_]+$'
|
||||
PATTERN_PATTERN = r'^[a-zA-Z0-9_%]+$'
|
||||
|
||||
@staticmethod
|
||||
def validate_identifier(name: str, entity_type: str = "标识符") -> bool:
|
||||
"""
|
||||
验证SQL标识符是否合法安全(表名、数据库名、列名等)
|
||||
|
||||
Args:
|
||||
name: 要验证的标识符
|
||||
entity_type: 实体类型名称,用于错误信息
|
||||
|
||||
Returns:
|
||||
如果标识符安全返回True
|
||||
|
||||
Raises:
|
||||
ValidationError: 当标识符包含不安全字符时
|
||||
"""
|
||||
if not name:
|
||||
raise ValidationError(f"{entity_type}不能为空")
|
||||
|
||||
if not re.match(SQLValidators.IDENTIFIER_PATTERN, name):
|
||||
raise ValidationError(f"无效的{entity_type}: {name}, {entity_type}只能包含字母、数字和下划线")
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def validate_table_name(name: str) -> bool:
|
||||
"""验证表名是否合法安全"""
|
||||
return SQLValidators.validate_identifier(name, "表名")
|
||||
|
||||
@staticmethod
|
||||
def validate_database_name(name: str) -> bool:
|
||||
"""验证数据库名是否合法安全"""
|
||||
return SQLValidators.validate_identifier(name, "数据库名")
|
||||
|
||||
@staticmethod
|
||||
def validate_column_name(name: str) -> bool:
|
||||
"""验证列名是否合法安全"""
|
||||
return SQLValidators.validate_identifier(name, "列名")
|
||||
|
||||
@staticmethod
|
||||
def validate_like_pattern(pattern: str) -> bool:
|
||||
"""
|
||||
验证LIKE查询模式是否安全
|
||||
|
||||
Args:
|
||||
pattern: 要验证的模式字符串
|
||||
|
||||
Returns:
|
||||
如果模式安全返回True
|
||||
|
||||
Raises:
|
||||
ValidationError: 当模式包含不安全字符时
|
||||
"""
|
||||
if not pattern:
|
||||
raise ValidationError("模式不能为空")
|
||||
|
||||
if not re.match(SQLValidators.PATTERN_PATTERN, pattern):
|
||||
raise ValidationError(f"无效的模式: {pattern}, 模式只能包含字母、数字、下划线和通配符(%_)")
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def validate_integer(value: int, min_value: Optional[int] = None, max_value: Optional[int] = None) -> bool:
|
||||
"""
|
||||
验证整数值是否在允许范围内
|
||||
|
||||
Args:
|
||||
value: 要验证的整数值
|
||||
min_value: 最小允许值(可选)
|
||||
max_value: 最大允许值(可选)
|
||||
|
||||
Returns:
|
||||
如果值合法返回True
|
||||
|
||||
Raises:
|
||||
ValidationError: 当值不合法时
|
||||
"""
|
||||
if not isinstance(value, int):
|
||||
raise ValidationError(f"值必须是整数,当前类型: {type(value).__name__}")
|
||||
|
||||
if min_value is not None and value < min_value:
|
||||
raise ValidationError(f"值必须大于或等于 {min_value}")
|
||||
|
||||
if max_value is not None and value > max_value:
|
||||
raise ValidationError(f"值必须小于或等于 {max_value}")
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def validate_parameter(param_name: str, param_value: Any, validator: Callable, error_prefix: str = "") -> bool:
|
||||
"""
|
||||
通用参数验证函数
|
||||
|
||||
Args:
|
||||
param_name: 参数名称
|
||||
param_value: 参数值
|
||||
validator: 验证函数
|
||||
error_prefix: 错误信息前缀
|
||||
|
||||
Returns:
|
||||
如果验证通过返回True
|
||||
|
||||
Raises:
|
||||
ValidationError: 当验证失败时
|
||||
"""
|
||||
if param_value is None:
|
||||
return True # 允许None值
|
||||
|
||||
try:
|
||||
return validator(param_value)
|
||||
except ValidationError as e:
|
||||
prefix = f"{error_prefix}: " if error_prefix else ""
|
||||
raise ValidationError(f"{prefix}{param_name} - {str(e)}")
|
||||
except Exception as e:
|
||||
prefix = f"{error_prefix}: " if error_prefix else ""
|
||||
raise ValidationError(f"{prefix}{param_name} - 验证失败: {str(e)}")
|
||||
Reference in New Issue
Block a user