mirror of
https://github.com/mangooer/mysql-mcp-server-sse.git
synced 2025-12-08 09:42:27 +08:00
<feat> 添加数据库隔离功能,支持跨数据库访问限制,更新配置和文档说明,增强安全性
This commit is contained in:
64
README.md
64
README.md
@ -18,6 +18,7 @@ This project is a MySQL query server based on the MCP framework, supporting real
|
|||||||
- 丰富的MySQL元数据与结构查询API
|
- 丰富的MySQL元数据与结构查询API
|
||||||
- 自动事务管理与回滚
|
- 自动事务管理与回滚
|
||||||
- 多级SQL风险控制与注入防护
|
- 多级SQL风险控制与注入防护
|
||||||
|
- **数据库隔离安全**:防止跨数据库访问,支持三级访问控制
|
||||||
- 敏感信息自动隐藏与自定义
|
- 敏感信息自动隐藏与自定义
|
||||||
- 灵活的环境变量配置
|
- 灵活的环境变量配置
|
||||||
- 完善的日志与错误处理
|
- 完善的日志与错误处理
|
||||||
@ -29,6 +30,7 @@ This project is a MySQL query server based on the MCP framework, supporting real
|
|||||||
- Rich MySQL metadata & schema query APIs
|
- Rich MySQL metadata & schema query APIs
|
||||||
- Automatic transaction management & rollback
|
- Automatic transaction management & rollback
|
||||||
- Multi-level SQL risk control & injection protection
|
- Multi-level SQL risk control & injection protection
|
||||||
|
- **Database Isolation Security**: Prevents cross-database access with 3-level access control
|
||||||
- Automatic and customizable sensitive info masking
|
- Automatic and customizable sensitive info masking
|
||||||
- Flexible environment variable configuration
|
- Flexible environment variable configuration
|
||||||
- Robust logging & error handling
|
- Robust logging & error handling
|
||||||
@ -146,6 +148,8 @@ Default endpoint: http://127.0.0.1:3000/sse
|
|||||||
| MAX_SQL_LENGTH | 最大SQL语句长度 / Max SQL length | 5000 |
|
| MAX_SQL_LENGTH | 最大SQL语句长度 / Max SQL length | 5000 |
|
||||||
| BLOCKED_PATTERNS | 阻止的SQL模式(逗号分隔) / Blocked SQL patterns | (空/empty) |
|
| BLOCKED_PATTERNS | 阻止的SQL模式(逗号分隔) / Blocked SQL patterns | (空/empty) |
|
||||||
| ENABLE_QUERY_CHECK | 启用查询安全检查 / Enable query check (true/false) | true |
|
| ENABLE_QUERY_CHECK | 启用查询安全检查 / Enable query check (true/false) | true |
|
||||||
|
| **ENABLE_DATABASE_ISOLATION** | **启用数据库隔离 / Enable database isolation (true/false)** | **false** |
|
||||||
|
| **DATABASE_ACCESS_LEVEL** | **数据库访问级别 / Database access level (strict/restricted/permissive)** | **permissive** |
|
||||||
| LOG_LEVEL | 日志级别(DEBUG/INFO/...) / Log level | DEBUG |
|
| LOG_LEVEL | 日志级别(DEBUG/INFO/...) / Log level | DEBUG |
|
||||||
|
|
||||||
> 注/Note: 部分云MySQL需指定`DB_AUTH_PLUGIN`为`mysql_native_password`。
|
> 注/Note: 部分云MySQL需指定`DB_AUTH_PLUGIN`为`mysql_native_password`。
|
||||||
@ -185,9 +189,47 @@ When using `caching_sha2_password`, the `cryptography` package is required (alre
|
|||||||
pip install cryptography
|
pip install cryptography
|
||||||
```
|
```
|
||||||
|
|
||||||
详细配置指南请参考:[MySQL 8.0 认证插件支持指南](docs/mysql8_authentication.md)
|
|
||||||
|
|
||||||
For detailed configuration guide, see: [MySQL 8.0 Authentication Plugin Support Guide](docs/mysql8_authentication.md)
|
### 数据库隔离安全 / Database Isolation Security
|
||||||
|
|
||||||
|
本系统提供强大的数据库隔离功能,防止跨数据库访问,确保数据安全。
|
||||||
|
|
||||||
|
This system provides robust database isolation features to prevent cross-database access and ensure data security.
|
||||||
|
|
||||||
|
#### 访问级别 / Access Levels
|
||||||
|
|
||||||
|
| 级别 / Level | 允许访问 / Allowed Access | 适用场景 / Use Case |
|
||||||
|
|-------------|---------------------------|-------------------|
|
||||||
|
| **strict** | 仅指定数据库 / Only specified database | 生产环境 / Production |
|
||||||
|
| **restricted** | 指定数据库 + 系统库 / Specified + system databases | 开发环境 / Development |
|
||||||
|
| **permissive** | 所有数据库 / All databases | 测试环境 / Testing |
|
||||||
|
|
||||||
|
#### 启用数据库隔离 / Enable Database Isolation
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Docker 启用严格模式 / Docker with strict mode
|
||||||
|
docker run -d \
|
||||||
|
-e MYSQL_DATABASE=your_database \
|
||||||
|
-e ENABLE_DATABASE_ISOLATION=true \
|
||||||
|
-e DATABASE_ACCESS_LEVEL=strict \
|
||||||
|
mangooer/mysql-mcp-server-sse:latest
|
||||||
|
|
||||||
|
# 生产环境自动启用 / Auto-enable in production
|
||||||
|
docker run -d \
|
||||||
|
-e ENV_TYPE=production \
|
||||||
|
-e MYSQL_DATABASE=your_database \
|
||||||
|
mangooer/mysql-mcp-server-sse:latest
|
||||||
|
```
|
||||||
|
|
||||||
|
**安全效果 / Security Effects**:
|
||||||
|
- ✅ 阻止 `SHOW DATABASES` / Blocks `SHOW DATABASES`
|
||||||
|
- ✅ 阻止 `SELECT * FROM mysql.user` / Blocks `SELECT * FROM mysql.user`
|
||||||
|
- ✅ 阻止 `SHOW TABLES FROM other_db` / Blocks `SHOW TABLES FROM other_db`
|
||||||
|
- ✅ 允许当前数据库操作 / Allows current database operations
|
||||||
|
|
||||||
|
> 🔒 **重要**:生产环境(`ENV_TYPE=production`)会自动启用数据库隔离,使用 `restricted` 模式。
|
||||||
|
>
|
||||||
|
> 🔒 **Important**: Production environment (`ENV_TYPE=production`) automatically enables database isolation with `restricted` mode.
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
@ -222,14 +264,20 @@ For detailed configuration guide, see: [MySQL 8.0 Authentication Plugin Support
|
|||||||
- 多级SQL风险等级(LOW/MEDIUM/HIGH/CRITICAL)
|
- 多级SQL风险等级(LOW/MEDIUM/HIGH/CRITICAL)
|
||||||
- SQL注入与危险操作拦截
|
- SQL注入与危险操作拦截
|
||||||
- WHERE子句强制检查
|
- WHERE子句强制检查
|
||||||
|
- **数据库隔离安全**:三级访问控制(strict/restricted/permissive)
|
||||||
|
- **跨数据库访问防护**:阻止未授权的数据库访问
|
||||||
- 敏感信息自动隐藏(支持自定义字段)
|
- 敏感信息自动隐藏(支持自定义字段)
|
||||||
- 生产环境默认只允许低风险操作
|
- 生产环境默认只允许低风险操作
|
||||||
|
- **生产环境自动启用数据库隔离**
|
||||||
|
|
||||||
- Multi-level SQL risk levels (LOW/MEDIUM/HIGH/CRITICAL)
|
- Multi-level SQL risk levels (LOW/MEDIUM/HIGH/CRITICAL)
|
||||||
- SQL injection & dangerous operation interception
|
- SQL injection & dangerous operation interception
|
||||||
- Mandatory WHERE clause check
|
- Mandatory WHERE clause check
|
||||||
|
- **Database Isolation Security**: 3-level access control (strict/restricted/permissive)
|
||||||
|
- **Cross-database Access Protection**: Blocks unauthorized database access
|
||||||
- Automatic sensitive info masking (customizable fields)
|
- Automatic sensitive info masking (customizable fields)
|
||||||
- Production allows only low-risk operations by default
|
- Production allows only low-risk operations by default
|
||||||
|
- **Auto-enable database isolation in production**
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
@ -261,6 +309,18 @@ A: 设置SENSITIVE_INFO_FIELDS,如SENSITIVE_INFO_FIELDS=password,token
|
|||||||
Q: How to customize sensitive fields?
|
Q: How to customize sensitive fields?
|
||||||
A: Set SENSITIVE_INFO_FIELDS, e.g. SENSITIVE_INFO_FIELDS=password,token
|
A: Set SENSITIVE_INFO_FIELDS, e.g. SENSITIVE_INFO_FIELDS=password,token
|
||||||
|
|
||||||
|
### Q: 如何启用数据库隔离?
|
||||||
|
A: 设置ENABLE_DATABASE_ISOLATION=true和DATABASE_ACCESS_LEVEL=strict,或使用ENV_TYPE=production自动启用。
|
||||||
|
|
||||||
|
Q: How to enable database isolation?
|
||||||
|
A: Set ENABLE_DATABASE_ISOLATION=true and DATABASE_ACCESS_LEVEL=strict, or use ENV_TYPE=production for auto-enable.
|
||||||
|
|
||||||
|
### Q: 数据库隔离后无法查询系统表?
|
||||||
|
A: strict模式禁止系统表访问,可改为restricted模式,或检查是否确实需要系统表访问权限。
|
||||||
|
|
||||||
|
Q: Cannot query system tables after enabling database isolation?
|
||||||
|
A: strict mode blocks system table access. Use restricted mode or verify if system table access is actually needed.
|
||||||
|
|
||||||
### Q: limit参数报错?
|
### Q: limit参数报错?
|
||||||
A: limit必须为非负整数。
|
A: limit必须为非负整数。
|
||||||
|
|
||||||
|
|||||||
10
example.env
10
example.env
@ -46,6 +46,16 @@ BLOCKED_PATTERNS=
|
|||||||
# 是否启用查询安全检查
|
# 是否启用查询安全检查
|
||||||
ENABLE_QUERY_CHECK=true
|
ENABLE_QUERY_CHECK=true
|
||||||
|
|
||||||
|
# 数据库隔离配置
|
||||||
|
# 是否启用数据库隔离(防止跨数据库访问)
|
||||||
|
ENABLE_DATABASE_ISOLATION=false
|
||||||
|
# 数据库访问级别: strict(严格), restricted(限制), permissive(宽松)
|
||||||
|
# - strict: 只能访问指定的数据库
|
||||||
|
# - restricted: 可以访问指定数据库和系统库(information_schema, mysql等)
|
||||||
|
# - permissive: 可以访问所有数据库(默认)
|
||||||
|
# 注意:生产环境(ENV_TYPE=production)会自动启用数据库隔离并设为restricted模式
|
||||||
|
DATABASE_ACCESS_LEVEL=permissive
|
||||||
|
|
||||||
# 日志配置
|
# 日志配置
|
||||||
# DEBUG, INFO, WARNING, ERROR, CRITICAL
|
# DEBUG, INFO, WARNING, ERROR, CRITICAL
|
||||||
LOG_LEVEL=DEBUG
|
LOG_LEVEL=DEBUG
|
||||||
@ -109,6 +109,17 @@ class SecurityConfig:
|
|||||||
|
|
||||||
# 查询检查
|
# 查询检查
|
||||||
ENABLE_QUERY_CHECK = os.getenv('ENABLE_QUERY_CHECK', 'true').lower() in ('true', 'yes', '1')
|
ENABLE_QUERY_CHECK = os.getenv('ENABLE_QUERY_CHECK', 'true').lower() in ('true', 'yes', '1')
|
||||||
|
|
||||||
|
# 数据库隔离配置
|
||||||
|
ENABLE_DATABASE_ISOLATION = os.getenv('ENABLE_DATABASE_ISOLATION', 'false').lower() in ('true', 'yes', '1')
|
||||||
|
DATABASE_ACCESS_LEVEL = os.getenv('DATABASE_ACCESS_LEVEL', 'permissive').lower()
|
||||||
|
|
||||||
|
# 生产环境强制数据库隔离
|
||||||
|
if ENV_TYPE == EnvironmentType.PRODUCTION and not os.getenv('DATABASE_ACCESS_LEVEL'):
|
||||||
|
DATABASE_ACCESS_LEVEL = 'restricted' # 生产环境默认使用限制模式
|
||||||
|
ENABLE_DATABASE_ISOLATION = True
|
||||||
|
logger = __import__('logging').getLogger("mysql_server")
|
||||||
|
logger.info("生产环境自动启用数据库隔离,访问级别设为 restricted")
|
||||||
|
|
||||||
# SQL操作配置
|
# SQL操作配置
|
||||||
class SQLConfig:
|
class SQLConfig:
|
||||||
|
|||||||
207
src/security/database_scope_checker.py
Normal file
207
src/security/database_scope_checker.py
Normal file
@ -0,0 +1,207 @@
|
|||||||
|
"""
|
||||||
|
数据库范围检查器
|
||||||
|
用于检测和限制SQL查询中的跨数据库访问
|
||||||
|
"""
|
||||||
|
import re
|
||||||
|
import logging
|
||||||
|
from typing import Set, Optional, List, Tuple
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
logger = logging.getLogger("mysql_server")
|
||||||
|
|
||||||
|
class DatabaseAccessLevel(Enum):
|
||||||
|
"""数据库访问级别"""
|
||||||
|
STRICT = "strict" # 严格模式:只能访问指定数据库
|
||||||
|
RESTRICTED = "restricted" # 限制模式:允许访问指定数据库和系统库
|
||||||
|
PERMISSIVE = "permissive" # 宽松模式:允许访问所有数据库(默认)
|
||||||
|
|
||||||
|
class DatabaseScopeViolation(Exception):
|
||||||
|
"""数据库范围违规异常"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
class DatabaseScopeChecker:
|
||||||
|
"""数据库范围检查器"""
|
||||||
|
|
||||||
|
# 系统数据库列表
|
||||||
|
SYSTEM_DATABASES = {
|
||||||
|
'information_schema',
|
||||||
|
'mysql',
|
||||||
|
'performance_schema',
|
||||||
|
'sys'
|
||||||
|
}
|
||||||
|
|
||||||
|
# 跨数据库查询模式
|
||||||
|
CROSS_DB_PATTERNS = [
|
||||||
|
# database.table 格式
|
||||||
|
r'\b([a-zA-Z_][a-zA-Z0-9_]*)\s*\.\s*([a-zA-Z_][a-zA-Z0-9_]*)\b',
|
||||||
|
# SHOW TABLES FROM database
|
||||||
|
r'\bSHOW\s+(?:FULL\s+)?TABLES\s+FROM\s+([a-zA-Z_][a-zA-Z0-9_]*)\b',
|
||||||
|
# USE database
|
||||||
|
r'\bUSE\s+([a-zA-Z_][a-zA-Z0-9_]*)\b',
|
||||||
|
# SELECT ... FROM database.table
|
||||||
|
r'\bFROM\s+([a-zA-Z_][a-zA-Z0-9_]*)\s*\.\s*([a-zA-Z_][a-zA-Z0-9_]*)\b',
|
||||||
|
# JOIN database.table
|
||||||
|
r'\bJOIN\s+([a-zA-Z_][a-zA-Z0-9_]*)\s*\.\s*([a-zA-Z_][a-zA-Z0-9_]*)\b',
|
||||||
|
# INSERT INTO database.table
|
||||||
|
r'\bINTO\s+([a-zA-Z_][a-zA-Z0-9_]*)\s*\.\s*([a-zA-Z_][a-zA-Z0-9_]*)\b',
|
||||||
|
# UPDATE database.table
|
||||||
|
r'\bUPDATE\s+([a-zA-Z_][a-zA-Z0-9_]*)\s*\.\s*([a-zA-Z_][a-zA-Z0-9_]*)\b',
|
||||||
|
# DELETE FROM database.table
|
||||||
|
r'\bDELETE\s+FROM\s+([a-zA-Z_][a-zA-Z0-9_]*)\s*\.\s*([a-zA-Z_][a-zA-Z0-9_]*)\b',
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(self, allowed_database: Optional[str] = None,
|
||||||
|
access_level: DatabaseAccessLevel = DatabaseAccessLevel.PERMISSIVE):
|
||||||
|
"""
|
||||||
|
初始化数据库范围检查器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
allowed_database: 允许访问的数据库名称
|
||||||
|
access_level: 访问级别
|
||||||
|
"""
|
||||||
|
self.allowed_database = allowed_database
|
||||||
|
self.access_level = access_level
|
||||||
|
self.is_enabled = allowed_database is not None and access_level != DatabaseAccessLevel.PERMISSIVE
|
||||||
|
|
||||||
|
logger.debug(f"数据库范围检查器初始化: 允许数据库={allowed_database}, 访问级别={access_level.value}, 启用={self.is_enabled}")
|
||||||
|
|
||||||
|
def check_query(self, sql_query: str) -> Tuple[bool, List[str]]:
|
||||||
|
"""
|
||||||
|
检查SQL查询是否违反数据库范围限制
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sql_query: SQL查询语句
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(是否允许, 违规详情列表)
|
||||||
|
"""
|
||||||
|
if not self.is_enabled:
|
||||||
|
return True, []
|
||||||
|
|
||||||
|
violations = []
|
||||||
|
|
||||||
|
# 提取查询中涉及的数据库
|
||||||
|
referenced_databases = self._extract_databases(sql_query)
|
||||||
|
|
||||||
|
for db_name in referenced_databases:
|
||||||
|
if not self._is_database_allowed(db_name):
|
||||||
|
violations.append(f"不允许访问数据库: {db_name}")
|
||||||
|
|
||||||
|
# 检查特殊查询类型
|
||||||
|
special_violations = self._check_special_queries(sql_query)
|
||||||
|
violations.extend(special_violations)
|
||||||
|
|
||||||
|
is_allowed = len(violations) == 0
|
||||||
|
|
||||||
|
if violations:
|
||||||
|
logger.warning(f"数据库范围检查失败: {violations}")
|
||||||
|
|
||||||
|
return is_allowed, violations
|
||||||
|
|
||||||
|
def _extract_databases(self, sql_query: str) -> Set[str]:
|
||||||
|
"""提取SQL查询中涉及的数据库名称"""
|
||||||
|
databases = set()
|
||||||
|
|
||||||
|
# 标准化SQL(转换为大写,去除多余空格)
|
||||||
|
normalized_sql = re.sub(r'\s+', ' ', sql_query.upper().strip())
|
||||||
|
|
||||||
|
for pattern in self.CROSS_DB_PATTERNS:
|
||||||
|
matches = re.finditer(pattern, normalized_sql, re.IGNORECASE)
|
||||||
|
for match in matches:
|
||||||
|
# 第一个捕获组通常是数据库名
|
||||||
|
if match.groups():
|
||||||
|
db_name = match.group(1).lower()
|
||||||
|
# 过滤掉非数据库名的匹配(如函数名等)
|
||||||
|
if self._is_valid_database_name(db_name):
|
||||||
|
databases.add(db_name)
|
||||||
|
|
||||||
|
return databases
|
||||||
|
|
||||||
|
def _is_valid_database_name(self, name: str) -> bool:
|
||||||
|
"""检查是否是有效的数据库名称"""
|
||||||
|
# 数据库名称规则:字母、数字、下划线,不能以数字开头
|
||||||
|
return bool(re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', name))
|
||||||
|
|
||||||
|
def _is_database_allowed(self, db_name: str) -> bool:
|
||||||
|
"""检查数据库是否被允许访问"""
|
||||||
|
db_name_lower = db_name.lower()
|
||||||
|
|
||||||
|
# 检查是否是允许的主数据库
|
||||||
|
if self.allowed_database and db_name_lower == self.allowed_database.lower():
|
||||||
|
return True
|
||||||
|
|
||||||
|
# 根据访问级别决定是否允许系统数据库
|
||||||
|
if self.access_level == DatabaseAccessLevel.RESTRICTED:
|
||||||
|
if db_name_lower in self.SYSTEM_DATABASES:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _check_special_queries(self, sql_query: str) -> List[str]:
|
||||||
|
"""检查特殊类型的查询"""
|
||||||
|
violations = []
|
||||||
|
normalized_sql = sql_query.upper().strip()
|
||||||
|
|
||||||
|
# 检查SHOW DATABASES查询
|
||||||
|
if re.search(r'\bSHOW\s+DATABASES\b', normalized_sql):
|
||||||
|
if self.access_level == DatabaseAccessLevel.STRICT:
|
||||||
|
violations.append("严格模式下不允许执行 SHOW DATABASES")
|
||||||
|
|
||||||
|
# 检查USE语句
|
||||||
|
if re.search(r'\bUSE\s+', normalized_sql):
|
||||||
|
violations.append("不允许使用 USE 语句切换数据库")
|
||||||
|
|
||||||
|
# 检查系统表访问
|
||||||
|
system_table_patterns = [
|
||||||
|
r'\bmysql\.user\b',
|
||||||
|
r'\bmysql\.db\b',
|
||||||
|
r'\binformation_schema\.',
|
||||||
|
r'\bperformance_schema\.',
|
||||||
|
r'\bsys\.'
|
||||||
|
]
|
||||||
|
|
||||||
|
for pattern in system_table_patterns:
|
||||||
|
if re.search(pattern, normalized_sql, re.IGNORECASE):
|
||||||
|
if self.access_level == DatabaseAccessLevel.STRICT:
|
||||||
|
violations.append(f"严格模式下不允许访问系统表")
|
||||||
|
break
|
||||||
|
|
||||||
|
return violations
|
||||||
|
|
||||||
|
def get_allowed_databases(self) -> Set[str]:
|
||||||
|
"""获取允许访问的数据库列表"""
|
||||||
|
allowed = set()
|
||||||
|
|
||||||
|
if self.allowed_database:
|
||||||
|
allowed.add(self.allowed_database.lower())
|
||||||
|
|
||||||
|
if self.access_level == DatabaseAccessLevel.RESTRICTED:
|
||||||
|
allowed.update(self.SYSTEM_DATABASES)
|
||||||
|
|
||||||
|
return allowed
|
||||||
|
|
||||||
|
def is_cross_database_query(self, sql_query: str) -> bool:
|
||||||
|
"""检查是否是跨数据库查询"""
|
||||||
|
referenced_dbs = self._extract_databases(sql_query)
|
||||||
|
return len(referenced_dbs) > 0
|
||||||
|
|
||||||
|
# 便捷函数
|
||||||
|
def create_database_checker(allowed_database: Optional[str] = None,
|
||||||
|
access_level: str = "permissive") -> DatabaseScopeChecker:
|
||||||
|
"""
|
||||||
|
创建数据库范围检查器的便捷函数
|
||||||
|
|
||||||
|
Args:
|
||||||
|
allowed_database: 允许访问的数据库名称
|
||||||
|
access_level: 访问级别字符串 (strict/restricted/permissive)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DatabaseScopeChecker实例
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
level = DatabaseAccessLevel(access_level.lower())
|
||||||
|
except ValueError:
|
||||||
|
logger.warning(f"无效的访问级别: {access_level},使用默认的 permissive")
|
||||||
|
level = DatabaseAccessLevel.PERMISSIVE
|
||||||
|
|
||||||
|
return DatabaseScopeChecker(allowed_database, level)
|
||||||
@ -1,9 +1,10 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import List, Dict
|
from typing import List, Dict
|
||||||
|
|
||||||
from ..config import SecurityConfig, SQLConfig
|
from ..config import SecurityConfig, SQLConfig, DatabaseConfig
|
||||||
from .sql_analyzer import SQLOperationType, SQLRiskLevel
|
from .sql_analyzer import SQLOperationType, SQLRiskLevel
|
||||||
from .sql_parser import SQLParser
|
from .sql_parser import SQLParser
|
||||||
|
from .database_scope_checker import DatabaseScopeChecker, DatabaseScopeViolation, create_database_checker
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -18,6 +19,15 @@ class SQLInterceptor:
|
|||||||
self.analyzer = analyzer
|
self.analyzer = analyzer
|
||||||
# 设置最大SQL长度限制
|
# 设置最大SQL长度限制
|
||||||
self.max_sql_length = SecurityConfig.MAX_SQL_LENGTH
|
self.max_sql_length = SecurityConfig.MAX_SQL_LENGTH
|
||||||
|
|
||||||
|
# 初始化数据库范围检查器
|
||||||
|
self.database_checker = None
|
||||||
|
if SecurityConfig.ENABLE_DATABASE_ISOLATION and DatabaseConfig.DATABASE:
|
||||||
|
self.database_checker = create_database_checker(
|
||||||
|
allowed_database=DatabaseConfig.DATABASE,
|
||||||
|
access_level=SecurityConfig.DATABASE_ACCESS_LEVEL
|
||||||
|
)
|
||||||
|
logger.info(f"数据库隔离已启用: 允许数据库={DatabaseConfig.DATABASE}, 访问级别={SecurityConfig.DATABASE_ACCESS_LEVEL}")
|
||||||
|
|
||||||
async def check_operation(self, sql_query: str) -> bool:
|
async def check_operation(self, sql_query: str) -> bool:
|
||||||
"""
|
"""
|
||||||
@ -55,6 +65,13 @@ class SQLInterceptor:
|
|||||||
if operation not in supported_operations:
|
if operation not in supported_operations:
|
||||||
raise SecurityException(f"不支持的SQL操作: {operation}")
|
raise SecurityException(f"不支持的SQL操作: {operation}")
|
||||||
|
|
||||||
|
# 检查数据库范围限制
|
||||||
|
if self.database_checker:
|
||||||
|
is_allowed, violations = self.database_checker.check_query(sql_query)
|
||||||
|
if not is_allowed:
|
||||||
|
violation_details = "; ".join(violations)
|
||||||
|
raise SecurityException(f"数据库访问违规: {violation_details}")
|
||||||
|
|
||||||
# 分析SQL风险
|
# 分析SQL风险
|
||||||
risk_analysis = self.analyzer.analyze_risk(sql_query)
|
risk_analysis = self.analyzer.analyze_risk(sql_query)
|
||||||
|
|
||||||
|
|||||||
@ -12,8 +12,10 @@ from mcp.server.fastmcp import FastMCP
|
|||||||
|
|
||||||
from .metadata_base_tool import MetadataToolBase, ParameterValidationError, QueryExecutionError
|
from .metadata_base_tool import MetadataToolBase, ParameterValidationError, QueryExecutionError
|
||||||
from src.security.sql_analyzer import EnvironmentType
|
from src.security.sql_analyzer import EnvironmentType
|
||||||
|
from src.security.database_scope_checker import create_database_checker
|
||||||
from src.db.mysql_operations import get_db_connection, execute_query
|
from src.db.mysql_operations import get_db_connection, execute_query
|
||||||
from src.validators import SQLValidators
|
from src.validators import SQLValidators
|
||||||
|
from src.config import SecurityConfig, DatabaseConfig
|
||||||
|
|
||||||
logger = logging.getLogger("mysql_server")
|
logger = logging.getLogger("mysql_server")
|
||||||
|
|
||||||
@ -131,6 +133,14 @@ def register_info_tools(mcp: FastMCP):
|
|||||||
"""
|
"""
|
||||||
logger.debug("注册MySQL数据库信息查询工具...")
|
logger.debug("注册MySQL数据库信息查询工具...")
|
||||||
|
|
||||||
|
# 创建数据库范围检查器
|
||||||
|
database_checker = None
|
||||||
|
if SecurityConfig.ENABLE_DATABASE_ISOLATION and DatabaseConfig.DATABASE:
|
||||||
|
database_checker = create_database_checker(
|
||||||
|
allowed_database=DatabaseConfig.DATABASE,
|
||||||
|
access_level=SecurityConfig.DATABASE_ACCESS_LEVEL
|
||||||
|
)
|
||||||
|
|
||||||
@mcp.tool()
|
@mcp.tool()
|
||||||
@MetadataToolBase.handle_query_error
|
@MetadataToolBase.handle_query_error
|
||||||
async def mysql_show_databases(pattern: Optional[str] = None, limit: int = 100, exclude_system: bool = True) -> str:
|
async def mysql_show_databases(pattern: Optional[str] = None, limit: int = 100, exclude_system: bool = True) -> str:
|
||||||
@ -159,6 +169,14 @@ def register_info_tools(mcp: FastMCP):
|
|||||||
"返回结果的最大数量必须是非负整数"
|
"返回结果的最大数量必须是非负整数"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 检查数据库隔离限制
|
||||||
|
if database_checker:
|
||||||
|
query_to_check = "SHOW DATABASES"
|
||||||
|
is_allowed, violations = database_checker.check_query(query_to_check)
|
||||||
|
if not is_allowed:
|
||||||
|
violation_details = "; ".join(violations)
|
||||||
|
raise SecurityError(f"数据库隔离限制: {violation_details}")
|
||||||
|
|
||||||
# 构建基础查询
|
# 构建基础查询
|
||||||
query = "SHOW DATABASES"
|
query = "SHOW DATABASES"
|
||||||
|
|
||||||
|
|||||||
@ -3,6 +3,8 @@ import logging
|
|||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
from mcp.server.fastmcp import FastMCP
|
from mcp.server.fastmcp import FastMCP
|
||||||
from src.db.mysql_operations import get_db_connection, execute_query
|
from src.db.mysql_operations import get_db_connection, execute_query
|
||||||
|
from src.security.database_scope_checker import create_database_checker
|
||||||
|
from src.config import SecurityConfig, DatabaseConfig
|
||||||
import aiomysql
|
import aiomysql
|
||||||
|
|
||||||
from .metadata_base_tool import MetadataToolBase
|
from .metadata_base_tool import MetadataToolBase
|
||||||
@ -21,6 +23,14 @@ def register_mysql_tool(mcp: FastMCP):
|
|||||||
"""
|
"""
|
||||||
logger.debug("注册MySQL查询工具...")
|
logger.debug("注册MySQL查询工具...")
|
||||||
|
|
||||||
|
# 创建数据库范围检查器
|
||||||
|
database_checker = None
|
||||||
|
if SecurityConfig.ENABLE_DATABASE_ISOLATION and DatabaseConfig.DATABASE:
|
||||||
|
database_checker = create_database_checker(
|
||||||
|
allowed_database=DatabaseConfig.DATABASE,
|
||||||
|
access_level=SecurityConfig.DATABASE_ACCESS_LEVEL
|
||||||
|
)
|
||||||
|
|
||||||
@mcp.tool()
|
@mcp.tool()
|
||||||
@MetadataToolBase.handle_query_error
|
@MetadataToolBase.handle_query_error
|
||||||
async def mysql_query(query: str, params: Optional[Dict[str, Any]] = None) -> str:
|
async def mysql_query(query: str, params: Optional[Dict[str, Any]] = None) -> str:
|
||||||
@ -36,6 +46,13 @@ def register_mysql_tool(mcp: FastMCP):
|
|||||||
"""
|
"""
|
||||||
logger.debug(f"执行MySQL查询: {query}, 参数: {params}")
|
logger.debug(f"执行MySQL查询: {query}, 参数: {params}")
|
||||||
|
|
||||||
|
# 检查数据库隔离限制
|
||||||
|
if database_checker:
|
||||||
|
is_allowed, violations = database_checker.check_query(query)
|
||||||
|
if not is_allowed:
|
||||||
|
violation_details = "; ".join(violations)
|
||||||
|
raise ValueError(f"数据库隔离限制: {violation_details}")
|
||||||
|
|
||||||
async with get_db_connection() as connection:
|
async with get_db_connection() as connection:
|
||||||
results = await execute_query(connection, query, params)
|
results = await execute_query(connection, query, params)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user