<feat> 添加数据库隔离功能,支持跨数据库访问限制,更新配置和文档说明,增强安全性

This commit is contained in:
tangyi
2025-06-19 11:48:39 +08:00
parent 70bdd20333
commit adf5c73dcc
7 changed files with 343 additions and 3 deletions

View File

@ -18,6 +18,7 @@ This project is a MySQL query server based on the MCP framework, supporting real
- 丰富的MySQL元数据与结构查询API - 丰富的MySQL元数据与结构查询API
- 自动事务管理与回滚 - 自动事务管理与回滚
- 多级SQL风险控制与注入防护 - 多级SQL风险控制与注入防护
- **数据库隔离安全**:防止跨数据库访问,支持三级访问控制
- 敏感信息自动隐藏与自定义 - 敏感信息自动隐藏与自定义
- 灵活的环境变量配置 - 灵活的环境变量配置
- 完善的日志与错误处理 - 完善的日志与错误处理
@ -29,6 +30,7 @@ This project is a MySQL query server based on the MCP framework, supporting real
- Rich MySQL metadata & schema query APIs - Rich MySQL metadata & schema query APIs
- Automatic transaction management & rollback - Automatic transaction management & rollback
- Multi-level SQL risk control & injection protection - Multi-level SQL risk control & injection protection
- **Database Isolation Security**: Prevents cross-database access with 3-level access control
- Automatic and customizable sensitive info masking - Automatic and customizable sensitive info masking
- Flexible environment variable configuration - Flexible environment variable configuration
- Robust logging & error handling - Robust logging & error handling
@ -146,6 +148,8 @@ Default endpoint: http://127.0.0.1:3000/sse
| MAX_SQL_LENGTH | 最大SQL语句长度 / Max SQL length | 5000 | | MAX_SQL_LENGTH | 最大SQL语句长度 / Max SQL length | 5000 |
| BLOCKED_PATTERNS | 阻止的SQL模式(逗号分隔) / Blocked SQL patterns | (空/empty) | | BLOCKED_PATTERNS | 阻止的SQL模式(逗号分隔) / Blocked SQL patterns | (空/empty) |
| ENABLE_QUERY_CHECK | 启用查询安全检查 / Enable query check (true/false) | true | | ENABLE_QUERY_CHECK | 启用查询安全检查 / Enable query check (true/false) | true |
| **ENABLE_DATABASE_ISOLATION** | **启用数据库隔离 / Enable database isolation (true/false)** | **false** |
| **DATABASE_ACCESS_LEVEL** | **数据库访问级别 / Database access level (strict/restricted/permissive)** | **permissive** |
| LOG_LEVEL | 日志级别(DEBUG/INFO/...) / Log level | DEBUG | | LOG_LEVEL | 日志级别(DEBUG/INFO/...) / Log level | DEBUG |
> 注/Note: 部分云MySQL需指定`DB_AUTH_PLUGIN`为`mysql_native_password`。 > 注/Note: 部分云MySQL需指定`DB_AUTH_PLUGIN`为`mysql_native_password`。
@ -185,9 +189,47 @@ When using `caching_sha2_password`, the `cryptography` package is required (alre
pip install cryptography pip install cryptography
``` ```
详细配置指南请参考:[MySQL 8.0 认证插件支持指南](docs/mysql8_authentication.md)
For detailed configuration guide, see: [MySQL 8.0 Authentication Plugin Support Guide](docs/mysql8_authentication.md) ### 数据库隔离安全 / Database Isolation Security
本系统提供强大的数据库隔离功能,防止跨数据库访问,确保数据安全。
This system provides robust database isolation features to prevent cross-database access and ensure data security.
#### 访问级别 / Access Levels
| 级别 / Level | 允许访问 / Allowed Access | 适用场景 / Use Case |
|-------------|---------------------------|-------------------|
| **strict** | 仅指定数据库 / Only specified database | 生产环境 / Production |
| **restricted** | 指定数据库 + 系统库 / Specified + system databases | 开发环境 / Development |
| **permissive** | 所有数据库 / All databases | 测试环境 / Testing |
#### 启用数据库隔离 / Enable Database Isolation
```bash
# Docker 启用严格模式 / Docker with strict mode
docker run -d \
-e MYSQL_DATABASE=your_database \
-e ENABLE_DATABASE_ISOLATION=true \
-e DATABASE_ACCESS_LEVEL=strict \
mangooer/mysql-mcp-server-sse:latest
# 生产环境自动启用 / Auto-enable in production
docker run -d \
-e ENV_TYPE=production \
-e MYSQL_DATABASE=your_database \
mangooer/mysql-mcp-server-sse:latest
```
**安全效果 / Security Effects**
- ✅ 阻止 `SHOW DATABASES` / Blocks `SHOW DATABASES`
- ✅ 阻止 `SELECT * FROM mysql.user` / Blocks `SELECT * FROM mysql.user`
- ✅ 阻止 `SHOW TABLES FROM other_db` / Blocks `SHOW TABLES FROM other_db`
- ✅ 允许当前数据库操作 / Allows current database operations
> 🔒 **重要**:生产环境(`ENV_TYPE=production`)会自动启用数据库隔离,使用 `restricted` 模式。
>
> 🔒 **Important**: Production environment (`ENV_TYPE=production`) automatically enables database isolation with `restricted` mode.
--- ---
@ -222,14 +264,20 @@ For detailed configuration guide, see: [MySQL 8.0 Authentication Plugin Support
- 多级SQL风险等级LOW/MEDIUM/HIGH/CRITICAL - 多级SQL风险等级LOW/MEDIUM/HIGH/CRITICAL
- SQL注入与危险操作拦截 - SQL注入与危险操作拦截
- WHERE子句强制检查 - WHERE子句强制检查
- **数据库隔离安全**三级访问控制strict/restricted/permissive
- **跨数据库访问防护**:阻止未授权的数据库访问
- 敏感信息自动隐藏(支持自定义字段) - 敏感信息自动隐藏(支持自定义字段)
- 生产环境默认只允许低风险操作 - 生产环境默认只允许低风险操作
- **生产环境自动启用数据库隔离**
- Multi-level SQL risk levels (LOW/MEDIUM/HIGH/CRITICAL) - Multi-level SQL risk levels (LOW/MEDIUM/HIGH/CRITICAL)
- SQL injection & dangerous operation interception - SQL injection & dangerous operation interception
- Mandatory WHERE clause check - Mandatory WHERE clause check
- **Database Isolation Security**: 3-level access control (strict/restricted/permissive)
- **Cross-database Access Protection**: Blocks unauthorized database access
- Automatic sensitive info masking (customizable fields) - Automatic sensitive info masking (customizable fields)
- Production allows only low-risk operations by default - Production allows only low-risk operations by default
- **Auto-enable database isolation in production**
--- ---
@ -261,6 +309,18 @@ A: 设置SENSITIVE_INFO_FIELDS如SENSITIVE_INFO_FIELDS=password,token
Q: How to customize sensitive fields? Q: How to customize sensitive fields?
A: Set SENSITIVE_INFO_FIELDS, e.g. SENSITIVE_INFO_FIELDS=password,token A: Set SENSITIVE_INFO_FIELDS, e.g. SENSITIVE_INFO_FIELDS=password,token
### Q: 如何启用数据库隔离?
A: 设置ENABLE_DATABASE_ISOLATION=true和DATABASE_ACCESS_LEVEL=strict或使用ENV_TYPE=production自动启用。
Q: How to enable database isolation?
A: Set ENABLE_DATABASE_ISOLATION=true and DATABASE_ACCESS_LEVEL=strict, or use ENV_TYPE=production for auto-enable.
### Q: 数据库隔离后无法查询系统表?
A: strict模式禁止系统表访问可改为restricted模式或检查是否确实需要系统表访问权限。
Q: Cannot query system tables after enabling database isolation?
A: strict mode blocks system table access. Use restricted mode or verify if system table access is actually needed.
### Q: limit参数报错 ### Q: limit参数报错
A: limit必须为非负整数。 A: limit必须为非负整数。

View File

@ -46,6 +46,16 @@ BLOCKED_PATTERNS=
# 是否启用查询安全检查 # 是否启用查询安全检查
ENABLE_QUERY_CHECK=true ENABLE_QUERY_CHECK=true
# 数据库隔离配置
# 是否启用数据库隔离(防止跨数据库访问)
ENABLE_DATABASE_ISOLATION=false
# 数据库访问级别: strict(严格), restricted(限制), permissive(宽松)
# - strict: 只能访问指定的数据库
# - restricted: 可以访问指定数据库和系统库(information_schema, mysql等)
# - permissive: 可以访问所有数据库(默认)
# 注意:生产环境(ENV_TYPE=production)会自动启用数据库隔离并设为restricted模式
DATABASE_ACCESS_LEVEL=permissive
# 日志配置 # 日志配置
# DEBUG, INFO, WARNING, ERROR, CRITICAL # DEBUG, INFO, WARNING, ERROR, CRITICAL
LOG_LEVEL=DEBUG LOG_LEVEL=DEBUG

View File

@ -109,6 +109,17 @@ class SecurityConfig:
# 查询检查 # 查询检查
ENABLE_QUERY_CHECK = os.getenv('ENABLE_QUERY_CHECK', 'true').lower() in ('true', 'yes', '1') ENABLE_QUERY_CHECK = os.getenv('ENABLE_QUERY_CHECK', 'true').lower() in ('true', 'yes', '1')
# 数据库隔离配置
ENABLE_DATABASE_ISOLATION = os.getenv('ENABLE_DATABASE_ISOLATION', 'false').lower() in ('true', 'yes', '1')
DATABASE_ACCESS_LEVEL = os.getenv('DATABASE_ACCESS_LEVEL', 'permissive').lower()
# 生产环境强制数据库隔离
if ENV_TYPE == EnvironmentType.PRODUCTION and not os.getenv('DATABASE_ACCESS_LEVEL'):
DATABASE_ACCESS_LEVEL = 'restricted' # 生产环境默认使用限制模式
ENABLE_DATABASE_ISOLATION = True
logger = __import__('logging').getLogger("mysql_server")
logger.info("生产环境自动启用数据库隔离,访问级别设为 restricted")
# SQL操作配置 # SQL操作配置
class SQLConfig: class SQLConfig:

View File

@ -0,0 +1,207 @@
"""
数据库范围检查器
用于检测和限制SQL查询中的跨数据库访问
"""
import re
import logging
from typing import Set, Optional, List, Tuple
from enum import Enum
logger = logging.getLogger("mysql_server")
class DatabaseAccessLevel(Enum):
"""数据库访问级别"""
STRICT = "strict" # 严格模式:只能访问指定数据库
RESTRICTED = "restricted" # 限制模式:允许访问指定数据库和系统库
PERMISSIVE = "permissive" # 宽松模式:允许访问所有数据库(默认)
class DatabaseScopeViolation(Exception):
"""数据库范围违规异常"""
pass
class DatabaseScopeChecker:
"""数据库范围检查器"""
# 系统数据库列表
SYSTEM_DATABASES = {
'information_schema',
'mysql',
'performance_schema',
'sys'
}
# 跨数据库查询模式
CROSS_DB_PATTERNS = [
# database.table 格式
r'\b([a-zA-Z_][a-zA-Z0-9_]*)\s*\.\s*([a-zA-Z_][a-zA-Z0-9_]*)\b',
# SHOW TABLES FROM database
r'\bSHOW\s+(?:FULL\s+)?TABLES\s+FROM\s+([a-zA-Z_][a-zA-Z0-9_]*)\b',
# USE database
r'\bUSE\s+([a-zA-Z_][a-zA-Z0-9_]*)\b',
# SELECT ... FROM database.table
r'\bFROM\s+([a-zA-Z_][a-zA-Z0-9_]*)\s*\.\s*([a-zA-Z_][a-zA-Z0-9_]*)\b',
# JOIN database.table
r'\bJOIN\s+([a-zA-Z_][a-zA-Z0-9_]*)\s*\.\s*([a-zA-Z_][a-zA-Z0-9_]*)\b',
# INSERT INTO database.table
r'\bINTO\s+([a-zA-Z_][a-zA-Z0-9_]*)\s*\.\s*([a-zA-Z_][a-zA-Z0-9_]*)\b',
# UPDATE database.table
r'\bUPDATE\s+([a-zA-Z_][a-zA-Z0-9_]*)\s*\.\s*([a-zA-Z_][a-zA-Z0-9_]*)\b',
# DELETE FROM database.table
r'\bDELETE\s+FROM\s+([a-zA-Z_][a-zA-Z0-9_]*)\s*\.\s*([a-zA-Z_][a-zA-Z0-9_]*)\b',
]
def __init__(self, allowed_database: Optional[str] = None,
access_level: DatabaseAccessLevel = DatabaseAccessLevel.PERMISSIVE):
"""
初始化数据库范围检查器
Args:
allowed_database: 允许访问的数据库名称
access_level: 访问级别
"""
self.allowed_database = allowed_database
self.access_level = access_level
self.is_enabled = allowed_database is not None and access_level != DatabaseAccessLevel.PERMISSIVE
logger.debug(f"数据库范围检查器初始化: 允许数据库={allowed_database}, 访问级别={access_level.value}, 启用={self.is_enabled}")
def check_query(self, sql_query: str) -> Tuple[bool, List[str]]:
"""
检查SQL查询是否违反数据库范围限制
Args:
sql_query: SQL查询语句
Returns:
(是否允许, 违规详情列表)
"""
if not self.is_enabled:
return True, []
violations = []
# 提取查询中涉及的数据库
referenced_databases = self._extract_databases(sql_query)
for db_name in referenced_databases:
if not self._is_database_allowed(db_name):
violations.append(f"不允许访问数据库: {db_name}")
# 检查特殊查询类型
special_violations = self._check_special_queries(sql_query)
violations.extend(special_violations)
is_allowed = len(violations) == 0
if violations:
logger.warning(f"数据库范围检查失败: {violations}")
return is_allowed, violations
def _extract_databases(self, sql_query: str) -> Set[str]:
"""提取SQL查询中涉及的数据库名称"""
databases = set()
# 标准化SQL转换为大写去除多余空格
normalized_sql = re.sub(r'\s+', ' ', sql_query.upper().strip())
for pattern in self.CROSS_DB_PATTERNS:
matches = re.finditer(pattern, normalized_sql, re.IGNORECASE)
for match in matches:
# 第一个捕获组通常是数据库名
if match.groups():
db_name = match.group(1).lower()
# 过滤掉非数据库名的匹配(如函数名等)
if self._is_valid_database_name(db_name):
databases.add(db_name)
return databases
def _is_valid_database_name(self, name: str) -> bool:
"""检查是否是有效的数据库名称"""
# 数据库名称规则:字母、数字、下划线,不能以数字开头
return bool(re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', name))
def _is_database_allowed(self, db_name: str) -> bool:
"""检查数据库是否被允许访问"""
db_name_lower = db_name.lower()
# 检查是否是允许的主数据库
if self.allowed_database and db_name_lower == self.allowed_database.lower():
return True
# 根据访问级别决定是否允许系统数据库
if self.access_level == DatabaseAccessLevel.RESTRICTED:
if db_name_lower in self.SYSTEM_DATABASES:
return True
return False
def _check_special_queries(self, sql_query: str) -> List[str]:
"""检查特殊类型的查询"""
violations = []
normalized_sql = sql_query.upper().strip()
# 检查SHOW DATABASES查询
if re.search(r'\bSHOW\s+DATABASES\b', normalized_sql):
if self.access_level == DatabaseAccessLevel.STRICT:
violations.append("严格模式下不允许执行 SHOW DATABASES")
# 检查USE语句
if re.search(r'\bUSE\s+', normalized_sql):
violations.append("不允许使用 USE 语句切换数据库")
# 检查系统表访问
system_table_patterns = [
r'\bmysql\.user\b',
r'\bmysql\.db\b',
r'\binformation_schema\.',
r'\bperformance_schema\.',
r'\bsys\.'
]
for pattern in system_table_patterns:
if re.search(pattern, normalized_sql, re.IGNORECASE):
if self.access_level == DatabaseAccessLevel.STRICT:
violations.append(f"严格模式下不允许访问系统表")
break
return violations
def get_allowed_databases(self) -> Set[str]:
"""获取允许访问的数据库列表"""
allowed = set()
if self.allowed_database:
allowed.add(self.allowed_database.lower())
if self.access_level == DatabaseAccessLevel.RESTRICTED:
allowed.update(self.SYSTEM_DATABASES)
return allowed
def is_cross_database_query(self, sql_query: str) -> bool:
"""检查是否是跨数据库查询"""
referenced_dbs = self._extract_databases(sql_query)
return len(referenced_dbs) > 0
# 便捷函数
def create_database_checker(allowed_database: Optional[str] = None,
access_level: str = "permissive") -> DatabaseScopeChecker:
"""
创建数据库范围检查器的便捷函数
Args:
allowed_database: 允许访问的数据库名称
access_level: 访问级别字符串 (strict/restricted/permissive)
Returns:
DatabaseScopeChecker实例
"""
try:
level = DatabaseAccessLevel(access_level.lower())
except ValueError:
logger.warning(f"无效的访问级别: {access_level},使用默认的 permissive")
level = DatabaseAccessLevel.PERMISSIVE
return DatabaseScopeChecker(allowed_database, level)

View File

@ -1,9 +1,10 @@
import logging import logging
from typing import List, Dict from typing import List, Dict
from ..config import SecurityConfig, SQLConfig from ..config import SecurityConfig, SQLConfig, DatabaseConfig
from .sql_analyzer import SQLOperationType, SQLRiskLevel from .sql_analyzer import SQLOperationType, SQLRiskLevel
from .sql_parser import SQLParser from .sql_parser import SQLParser
from .database_scope_checker import DatabaseScopeChecker, DatabaseScopeViolation, create_database_checker
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -18,6 +19,15 @@ class SQLInterceptor:
self.analyzer = analyzer self.analyzer = analyzer
# 设置最大SQL长度限制 # 设置最大SQL长度限制
self.max_sql_length = SecurityConfig.MAX_SQL_LENGTH self.max_sql_length = SecurityConfig.MAX_SQL_LENGTH
# 初始化数据库范围检查器
self.database_checker = None
if SecurityConfig.ENABLE_DATABASE_ISOLATION and DatabaseConfig.DATABASE:
self.database_checker = create_database_checker(
allowed_database=DatabaseConfig.DATABASE,
access_level=SecurityConfig.DATABASE_ACCESS_LEVEL
)
logger.info(f"数据库隔离已启用: 允许数据库={DatabaseConfig.DATABASE}, 访问级别={SecurityConfig.DATABASE_ACCESS_LEVEL}")
async def check_operation(self, sql_query: str) -> bool: async def check_operation(self, sql_query: str) -> bool:
""" """
@ -55,6 +65,13 @@ class SQLInterceptor:
if operation not in supported_operations: if operation not in supported_operations:
raise SecurityException(f"不支持的SQL操作: {operation}") raise SecurityException(f"不支持的SQL操作: {operation}")
# 检查数据库范围限制
if self.database_checker:
is_allowed, violations = self.database_checker.check_query(sql_query)
if not is_allowed:
violation_details = "; ".join(violations)
raise SecurityException(f"数据库访问违规: {violation_details}")
# 分析SQL风险 # 分析SQL风险
risk_analysis = self.analyzer.analyze_risk(sql_query) risk_analysis = self.analyzer.analyze_risk(sql_query)

View File

@ -12,8 +12,10 @@ from mcp.server.fastmcp import FastMCP
from .metadata_base_tool import MetadataToolBase, ParameterValidationError, QueryExecutionError from .metadata_base_tool import MetadataToolBase, ParameterValidationError, QueryExecutionError
from src.security.sql_analyzer import EnvironmentType from src.security.sql_analyzer import EnvironmentType
from src.security.database_scope_checker import create_database_checker
from src.db.mysql_operations import get_db_connection, execute_query from src.db.mysql_operations import get_db_connection, execute_query
from src.validators import SQLValidators from src.validators import SQLValidators
from src.config import SecurityConfig, DatabaseConfig
logger = logging.getLogger("mysql_server") logger = logging.getLogger("mysql_server")
@ -131,6 +133,14 @@ def register_info_tools(mcp: FastMCP):
""" """
logger.debug("注册MySQL数据库信息查询工具...") logger.debug("注册MySQL数据库信息查询工具...")
# 创建数据库范围检查器
database_checker = None
if SecurityConfig.ENABLE_DATABASE_ISOLATION and DatabaseConfig.DATABASE:
database_checker = create_database_checker(
allowed_database=DatabaseConfig.DATABASE,
access_level=SecurityConfig.DATABASE_ACCESS_LEVEL
)
@mcp.tool() @mcp.tool()
@MetadataToolBase.handle_query_error @MetadataToolBase.handle_query_error
async def mysql_show_databases(pattern: Optional[str] = None, limit: int = 100, exclude_system: bool = True) -> str: async def mysql_show_databases(pattern: Optional[str] = None, limit: int = 100, exclude_system: bool = True) -> str:
@ -159,6 +169,14 @@ def register_info_tools(mcp: FastMCP):
"返回结果的最大数量必须是非负整数" "返回结果的最大数量必须是非负整数"
) )
# 检查数据库隔离限制
if database_checker:
query_to_check = "SHOW DATABASES"
is_allowed, violations = database_checker.check_query(query_to_check)
if not is_allowed:
violation_details = "; ".join(violations)
raise SecurityError(f"数据库隔离限制: {violation_details}")
# 构建基础查询 # 构建基础查询
query = "SHOW DATABASES" query = "SHOW DATABASES"

View File

@ -3,6 +3,8 @@ import logging
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from mcp.server.fastmcp import FastMCP from mcp.server.fastmcp import FastMCP
from src.db.mysql_operations import get_db_connection, execute_query from src.db.mysql_operations import get_db_connection, execute_query
from src.security.database_scope_checker import create_database_checker
from src.config import SecurityConfig, DatabaseConfig
import aiomysql import aiomysql
from .metadata_base_tool import MetadataToolBase from .metadata_base_tool import MetadataToolBase
@ -21,6 +23,14 @@ def register_mysql_tool(mcp: FastMCP):
""" """
logger.debug("注册MySQL查询工具...") logger.debug("注册MySQL查询工具...")
# 创建数据库范围检查器
database_checker = None
if SecurityConfig.ENABLE_DATABASE_ISOLATION and DatabaseConfig.DATABASE:
database_checker = create_database_checker(
allowed_database=DatabaseConfig.DATABASE,
access_level=SecurityConfig.DATABASE_ACCESS_LEVEL
)
@mcp.tool() @mcp.tool()
@MetadataToolBase.handle_query_error @MetadataToolBase.handle_query_error
async def mysql_query(query: str, params: Optional[Dict[str, Any]] = None) -> str: async def mysql_query(query: str, params: Optional[Dict[str, Any]] = None) -> str:
@ -36,6 +46,13 @@ def register_mysql_tool(mcp: FastMCP):
""" """
logger.debug(f"执行MySQL查询: {query}, 参数: {params}") logger.debug(f"执行MySQL查询: {query}, 参数: {params}")
# 检查数据库隔离限制
if database_checker:
is_allowed, violations = database_checker.check_query(query)
if not is_allowed:
violation_details = "; ".join(violations)
raise ValueError(f"数据库隔离限制: {violation_details}")
async with get_db_connection() as connection: async with get_db_connection() as connection:
results = await execute_query(connection, query, params) results = await execute_query(connection, query, params)