From b0463903f504cd90971311b3c81fe7e935f01cba Mon Sep 17 00:00:00 2001 From: tangyi Date: Thu, 27 Mar 2025 15:44:49 +0800 Subject: [PATCH] =?UTF-8?q?=20=E5=AE=89=E5=85=A8=E9=99=90=E5=88=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .env.example | 30 ++++- .gitignore | 8 +- README.md | 72 +++++++++++ src/__init__.py | 3 + src/db/__init__.py | 3 + src/db/mysql_operations.py | 47 +++++++- src/security/interceptor.py | 84 +++++++++++++ src/security/query_limiter.py | 68 +++++++++++ src/security/sql_analyzer.py | 221 ++++++++++++++++++++++++++++++++++ src/server.py | 13 +- src/tools/__init__.py | 3 + src/tools/mysql_tool.py | 11 +- 12 files changed, 551 insertions(+), 12 deletions(-) create mode 100644 src/__init__.py create mode 100644 src/db/__init__.py create mode 100644 src/security/interceptor.py create mode 100644 src/security/query_limiter.py create mode 100644 src/security/sql_analyzer.py create mode 100644 src/tools/__init__.py diff --git a/.env.example b/.env.example index 941483d..c432618 100644 --- a/.env.example +++ b/.env.example @@ -7,4 +7,32 @@ MYSQL_DATABASE=test # 服务器配置 PORT=3000 -HOST=127.0.0.1 \ No newline at end of file +HOST=127.0.0.1 + +# 环境配置 +ENV_TYPE=development # 环境类型:development/production + +# SQL风险控制配置 +ALLOWED_RISK_LEVELS=LOW,MEDIUM,HIGH # 允许的风险等级(LOW/MEDIUM/HIGH/CRITICAL) +# 如需执行无WHERE条件的DELETE操作,需要将CRITICAL添加到ALLOWED_RISK_LEVELS中 +# 例如:ALLOWED_RISK_LEVELS=LOW,MEDIUM,HIGH,CRITICAL + +# 禁止的SQL模式(正则表达式,用逗号分隔) +BLOCKED_PATTERNS=DROP\s+DATABASE,TRUNCATE\s+TABLE + +# SQL安全检查配置 +ENABLE_QUERY_CHECK=true # 是否启用SQL安全检查 + +# 配置示例说明 +# ------------------- +# 开发环境配置示例: +# ENV_TYPE=development +# ALLOWED_RISK_LEVELS=LOW,MEDIUM,HIGH # 允许除CRITICAL外的所有风险等级 +# BLOCKED_PATTERNS=DROP\s+DATABASE,TRUNCATE\s+TABLE # 仅禁止最危险的操作 +# ENABLE_QUERY_CHECK=true # 启用SQL安全检查 + +# 生产环境配置示例: +# ENV_TYPE=production +# ALLOWED_RISK_LEVELS=LOW # 只允许低风险操作(SELECT) +# BLOCKED_PATTERNS=DROP,TRUNCATE,DELETE,UPDATE # 禁止所有数据修改操作 +# ENABLE_QUERY_CHECK=true # 强制启用SQL安全检查 \ No newline at end of file diff --git a/.gitignore b/.gitignore index 5cf2ff6..802c731 100644 --- a/.gitignore +++ b/.gitignore @@ -8,4 +8,10 @@ vendor/ .idea .cursorrules .cursor -**/__pycache__/ \ No newline at end of file +**/__pycache__/ +tests/ +.env +.coverage +docs/ +.pytest_cache/ +requirements-dev.txt \ No newline at end of file diff --git a/README.md b/README.md index 8d7752f..4a1e3cf 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,14 @@ - 支持SSE(Server-Sent Events)实时数据传输 - 提供MySQL数据库查询接口 - 完整的日志记录系统 +- 自动事务管理(提交/回滚) - 环境变量配置支持 +- SQL安全检查机制 + - 风险等级控制 + - SQL注入防护 + - 危险操作拦截 + - WHERE子句强制检查 + - 自动返回修改操作影响的行数 ## 系统要求 @@ -36,6 +43,7 @@ pip install -r requirements.txt 在`.env`文件中配置以下参数: +### 基本配置 - `HOST`: 服务器监听地址(默认:127.0.0.1) - `PORT`: 服务器监听端口(默认:3000) - `MYSQL_HOST`: MySQL服务器地址 @@ -44,6 +52,45 @@ pip install -r requirements.txt - `MYSQL_PASSWORD`: MySQL密码 - `MYSQL_DATABASE`: MySQL数据库名 +### SQL安全配置 +- `ENV_TYPE`: 环境类型(development/production) +- `ALLOWED_RISK_LEVELS`: 允许的风险等级(LOW/MEDIUM/HIGH/CRITICAL) +- `BLOCKED_PATTERNS`: 禁止的SQL模式(正则表达式,用逗号分隔) +- `ENABLE_QUERY_CHECK`: 是否启用SQL安全检查(true/false) + +## SQL安全机制 + +### 风险等级控制 +- LOW: 查询操作(SELECT) +- MEDIUM: 基本数据修改(INSERT,有WHERE的UPDATE/DELETE) +- HIGH: 结构变更(CREATE/ALTER)和无WHERE的UPDATE +- CRITICAL: 危险操作(DROP/TRUNCATE)和无WHERE的DELETE操作 + +### 环境变量加载顺序 +项目使用python-dotenv加载环境变量,需要确保在导入其他模块前先加载环境变量,否则可能会导致配置未被正确应用。 + +### 安全检查机制 +- 强制要求UPDATE/DELETE操作包含WHERE子句 +- SQL语句语法检查 +- 危险操作模式检测 +- 自动识别受影响的表 +- 风险等级评估 + +### 环境特性 +- 开发环境:允许较高风险的操作,但仍需遵守基本安全规则 +- 生产环境:默认只允许LOW风险操作,严格限制数据修改 + +### 安全拦截 +- SQL语句长度限制 +- 危险操作模式检测 +- 自动识别受影响的表 +- SQL注入防护 + +### 事务管理 +- 对于修改操作(INSERT/UPDATE/DELETE)会自动提交事务 +- 执行错误时自动回滚事务 +- 返回操作影响的行数 + ## 启动服务器 运行以下命令启动服务器: @@ -61,11 +108,34 @@ python src/server.py ├── src/ # 源代码目录 │ ├── server.py # 主服务器文件 │ ├── db/ # 数据库相关代码 +│ ├── security/ # SQL安全相关代码 +│ │ ├── interceptor.py # SQL拦截器 +│ │ ├── query_limiter.py # SQL安全检查器 +│ │ └── sql_analyzer.py # SQL分析器 │ └── tools/ # 工具类代码 +├── tests/ # 测试代码目录 ├── .env.example # 环境变量示例文件 └── requirements.txt # 项目依赖文件 ``` +## 常见问题解决 + +### DELETE操作未执行成功 +- 检查DELETE操作是否包含WHERE条件 +- 无WHERE条件的DELETE操作被标记为CRITICAL风险级别 +- 确保环境变量ALLOWED_RISK_LEVELS中包含CRITICAL(如果需要执行该操作) +- 检查影响行数返回值,确认操作是否实际影响了数据库 + +### 环境变量未生效 +- 确保在server.py中的load_dotenv()调用发生在导入其他模块之前 +- 重启应用以确保环境变量被正确加载 +- 检查日志中"从环境变量读取到的风险等级设置"的输出 + +### 操作被安全机制拒绝 +- 检查操作的风险级别是否在允许的范围内 +- 如果需要执行高风险操作,相应地调整ALLOWED_RISK_LEVELS +- 对于不带WHERE条件的UPDATE或DELETE,可以添加条件(即使是WHERE 1=1)降低风险级别 + ## 日志系统 服务器包含完整的日志记录系统,可以在控制台和日志文件中查看运行状态和错误信息。日志级别可以在`server.py`中配置。 @@ -75,7 +145,9 @@ python src/server.py 服务器包含完善的错误处理机制: - MySQL连接器导入检查 - 数据库配置验证 +- SQL安全检查 - 运行时错误捕获和记录 +- 事务自动回滚 ## 贡献指南 diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..c8837f2 --- /dev/null +++ b/src/__init__.py @@ -0,0 +1,3 @@ +""" +MySQL查询服务器包 +""" \ No newline at end of file diff --git a/src/db/__init__.py b/src/db/__init__.py new file mode 100644 index 0000000..e9edceb --- /dev/null +++ b/src/db/__init__.py @@ -0,0 +1,3 @@ +""" +MySQL数据库操作包 +""" \ No newline at end of file diff --git a/src/db/mysql_operations.py b/src/db/mysql_operations.py index a3ea89f..93e6c05 100644 --- a/src/db/mysql_operations.py +++ b/src/db/mysql_operations.py @@ -5,8 +5,17 @@ from mysql.connector import Error from contextlib import contextmanager from typing import Any, Dict, List, Optional +from ..security.sql_analyzer import SQLOperationType +from ..security.query_limiter import QueryLimiter +from ..security.interceptor import SQLInterceptor, SecurityException + logger = logging.getLogger("mysql_server") +# 初始化安全组件 +sql_analyzer = SQLOperationType() +query_limiter = QueryLimiter() +sql_interceptor = SQLInterceptor(sql_analyzer) + def get_db_config(): """动态获取数据库配置""" return { @@ -15,7 +24,8 @@ def get_db_config(): 'password': os.getenv('MYSQL_PASSWORD', ''), 'database': os.getenv('MYSQL_DATABASE', ''), 'port': int(os.getenv('MYSQL_PORT', '3306')), - 'connection_timeout': 5 + 'connection_timeout': 5, + 'auth_plugin': 'mysql_native_password' # 指定认证插件 } @contextmanager @@ -54,7 +64,7 @@ def get_db_connection(): connection.close() logger.debug("数据库连接已关闭") -def execute_query(connection, query: str, params: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]: +async def execute_query(connection, query: str, params: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]: """ 在给定的数据库连接上执行查询 @@ -64,10 +74,18 @@ def execute_query(connection, query: str, params: Optional[Dict[str, Any]] = Non params: 查询参数 (可选) Returns: - 查询结果列表 + 查询结果列表,如果是修改操作则返回影响的行数 + + Raises: + SecurityException: 当操作被安全机制拒绝时 + ValueError: 当查询执行失败时 """ cursor = None try: + # 安全检查 + if not await sql_interceptor.check_operation(query): + raise SecurityException("操作被安全机制拒绝") + cursor = connection.cursor(dictionary=True) # 执行查询 @@ -76,12 +94,33 @@ def execute_query(connection, query: str, params: Optional[Dict[str, Any]] = Non else: cursor.execute(query) - # 获取结果 + # 获取操作类型 + operation = query.strip().split()[0].upper() + + # 对于修改操作,提交事务并返回影响的行数 + if operation in {'UPDATE', 'DELETE', 'INSERT'}: + affected_rows = cursor.rowcount + # 提交事务,确保更改被保存 + connection.commit() + logger.debug(f"修改操作 {operation} 影响了 {affected_rows} 行数据") + return [{'affected_rows': affected_rows}] + + # 对于查询操作,返回结果集 results = cursor.fetchall() logger.debug(f"查询返回 {len(results)} 条结果") return results + except SecurityException as security_err: + logger.error(f"安全检查失败: {str(security_err)}") + raise except mysql.connector.Error as query_err: + # 如果发生错误,进行回滚 + if operation in {'UPDATE', 'DELETE', 'INSERT'}: + try: + connection.rollback() + logger.debug("事务已回滚") + except: + pass logger.error(f"查询执行失败: {str(query_err)}") raise ValueError(f"查询执行失败: {str(query_err)}") finally: diff --git a/src/security/interceptor.py b/src/security/interceptor.py new file mode 100644 index 0000000..6ae70b0 --- /dev/null +++ b/src/security/interceptor.py @@ -0,0 +1,84 @@ +import logging +import os +from typing import List, Dict + +from .sql_analyzer import SQLOperationType, SQLRiskLevel + +logger = logging.getLogger(__name__) + +class SecurityException(Exception): + """安全相关异常""" + pass + +class SQLInterceptor: + """SQL操作拦截器""" + + def __init__(self, analyzer: SQLOperationType): + self.analyzer = analyzer + # 设置最大SQL长度限制(默认1000个字符) + self.max_sql_length = 1000 + + async def check_operation(self, sql_query: str) -> bool: + """ + 检查SQL操作是否允许执行 + + Args: + sql_query: SQL查询语句 + + Returns: + bool: 是否允许执行 + + Raises: + SecurityException: 当操作被拒绝时抛出 + """ + try: + # 检查SQL是否为空 + if not sql_query or not sql_query.strip(): + raise SecurityException("SQL语句不能为空") + + # 检查SQL长度 + if len(sql_query) > self.max_sql_length: + raise SecurityException(f"SQL语句长度({len(sql_query)})超出限制({self.max_sql_length})") + + # 检查SQL是否有效 + sql_parts = sql_query.strip().split() + if not sql_parts: + raise SecurityException("SQL语句格式无效") + + operation = sql_parts[0].upper() + if operation not in {'SELECT', 'INSERT', 'UPDATE', 'DELETE', 'CREATE', 'ALTER', 'DROP', 'TRUNCATE', 'MERGE'}: + raise SecurityException(f"不支持的SQL操作: {operation}") + + # 分析SQL风险 + risk_analysis = self.analyzer.analyze_risk(sql_query) + + # 检查是否是危险操作 + if risk_analysis['is_dangerous']: + raise SecurityException( + f"检测到危险操作: {risk_analysis['operation']}" + ) + + # 检查操作是否被允许 + if not risk_analysis['is_allowed']: + raise SecurityException( + f"当前操作风险等级({risk_analysis['risk_level'].name})不被允许执行," + f"允许的风险等级: {[level.name for level in self.analyzer.allowed_risk_levels]}" + ) + + # 记录详细日志 + logger.info( + f"SQL操作检查通过 - " + f"操作: {risk_analysis['operation']}, " + f"风险等级: {risk_analysis['risk_level'].name}, " + f"影响表: {', '.join(risk_analysis['affected_tables'])}" + ) + + return True + + except SecurityException as e: + logger.error(str(e)) + raise + except Exception as e: + error_msg = f"安全检查失败: {str(e)}" + logger.error(error_msg) + raise SecurityException(error_msg) \ No newline at end of file diff --git a/src/security/query_limiter.py b/src/security/query_limiter.py new file mode 100644 index 0000000..e8e255b --- /dev/null +++ b/src/security/query_limiter.py @@ -0,0 +1,68 @@ +import os +import logging +from typing import Tuple + +logger = logging.getLogger(__name__) + +class QueryLimiter: + """查询安全检查器""" + + def __init__(self): + # 解析启用状态(默认启用) + enable_check = os.getenv('ENABLE_QUERY_CHECK', 'true') + self.enable_check = str(enable_check).lower() not in {'false', '0', 'no', 'off'} + + def check_query(self, sql_query: str) -> Tuple[bool, str]: + """ + 检查SQL查询是否安全 + + Args: + sql_query: SQL查询语句 + + Returns: + Tuple[bool, str]: (是否允许执行, 错误信息) + """ + if not self.enable_check: + return True, "" + + sql_query = sql_query.strip().upper() + operation_type = self._get_operation_type(sql_query) + + # 检查是否为无 WHERE 子句的更新/删除操作 + if operation_type in {'UPDATE', 'DELETE'} and 'WHERE' not in sql_query: + error_msg = f"{operation_type}操作必须包含WHERE子句" + logger.warning(f"查询被限制: {error_msg}") + return False, error_msg + + return True, "" + + def _get_operation_type(self, sql_query: str) -> str: + """获取SQL操作类型""" + if not sql_query: + return "" + words = sql_query.split() + if not words: + return "" + return words[0].upper() + + def _parse_int_env(self, env_name: str, default: int) -> int: + """解析整数类型的环境变量""" + try: + return int(os.getenv(env_name, str(default))) + except (ValueError, TypeError): + return default + + def update_limits(self, new_limits: dict): + """ + 更新限制阈值 + + Args: + new_limits: 新的限制值字典 + """ + for operation, limit in new_limits.items(): + if operation in self.max_limits: + try: + self.max_limits[operation] = int(limit) + logger.info(f"更新{operation}操作的限制为: {limit}") + except (ValueError, TypeError): + logger.warning(f"无效的限制值: {operation}={limit}") \ No newline at end of file diff --git a/src/security/sql_analyzer.py b/src/security/sql_analyzer.py new file mode 100644 index 0000000..4f99aa8 --- /dev/null +++ b/src/security/sql_analyzer.py @@ -0,0 +1,221 @@ +import re +import os +from enum import IntEnum, Enum +import logging +from typing import Set, List + +logger = logging.getLogger(__name__) + +class SQLRiskLevel(IntEnum): + """SQL操作风险等级""" + LOW = 1 # 查询操作(SELECT) + MEDIUM = 2 # 基本数据修改(INSERT,有WHERE的UPDATE/DELETE) + HIGH = 3 # 结构变更(CREATE/ALTER)和无WHERE的数据修改 + CRITICAL = 4 # 危险操作(DROP/TRUNCATE等) + +class EnvironmentType(Enum): + """环境类型""" + DEVELOPMENT = 'development' + PRODUCTION = 'production' + +class SQLOperationType: + """SQL操作类型分析器""" + + def __init__(self): + # 环境类型处理 + env_type_str = os.getenv('ENV_TYPE', 'development').lower() + try: + self.env_type = EnvironmentType(env_type_str) + except ValueError: + logger.warning(f"无效的环境类型: {env_type_str},使用默认值: development") + self.env_type = EnvironmentType.DEVELOPMENT + + # 基础操作集合 + self.ddl_operations = { + 'CREATE', 'ALTER', 'DROP', 'TRUNCATE', 'RENAME' + } + self.dml_operations = { + 'SELECT', 'INSERT', 'UPDATE', 'DELETE', 'MERGE' + } + + # 风险等级配置 + self.allowed_risk_levels = self._parse_risk_levels() + self.blocked_patterns = self._parse_blocked_patterns('BLOCKED_PATTERNS') + + # 生产环境特殊处理:如果没有明确配置风险等级,则只允许LOW风险操作 + if self.env_type == EnvironmentType.PRODUCTION and not os.getenv('ALLOWED_RISK_LEVELS'): + self.allowed_risk_levels = {SQLRiskLevel.LOW} + + logger.info(f"SQL分析器初始化 - 环境: {self.env_type.value}") + logger.info(f"允许的风险等级: {[level.name for level in self.allowed_risk_levels]}") + + def _parse_risk_levels(self) -> Set[SQLRiskLevel]: + """解析允许的风险等级""" + allowed_levels_str = os.getenv('ALLOWED_RISK_LEVELS', 'LOW,MEDIUM') + allowed_levels = set() + + logger.info(f"从环境变量读取到的风险等级设置: '{allowed_levels_str}'") + + for level_str in allowed_levels_str.upper().split(','): + level_str = level_str.strip() + try: + allowed_levels.add(SQLRiskLevel[level_str]) + except KeyError: + logger.warning(f"未知的风险等级配置: {level_str}") + + return allowed_levels + + def _parse_blocked_patterns(self, env_var: str) -> List[str]: + """解析禁止的操作模式""" + patterns = os.getenv(env_var, '').split(',') + return [p.strip() for p in patterns if p.strip()] + + def analyze_risk(self, sql_query: str) -> dict: + """ + 分析SQL查询的风险级别和影响范围 + + Args: + sql_query: SQL查询语句 + + Returns: + dict: 包含风险分析结果的字典 + """ + sql_query = sql_query.strip() + + # 处理空SQL + if not sql_query: + return { + 'operation': '', + 'operation_type': 'UNKNOWN', + 'is_dangerous': True, + 'affected_tables': [], + 'estimated_impact': { + 'operation': '', + 'estimated_rows': 0, + 'needs_where': False, + 'has_where': False + }, + 'risk_level': SQLRiskLevel.HIGH, + 'is_allowed': False + } + + operation = sql_query.split()[0].upper() + + # 基本风险分析 + risk_analysis = { + 'operation': operation, + 'operation_type': 'DDL' if operation in self.ddl_operations else 'DML', + 'is_dangerous': self._check_dangerous_patterns(sql_query), + 'affected_tables': self._get_affected_tables(sql_query), + 'estimated_impact': self._estimate_impact(sql_query) + } + + # 计算风险等级 + risk_level = self._calculate_risk_level(sql_query, operation, risk_analysis['is_dangerous']) + risk_analysis['risk_level'] = risk_level + risk_analysis['is_allowed'] = risk_level in self.allowed_risk_levels + + return risk_analysis + + def _calculate_risk_level(self, sql_query: str, operation: str, is_dangerous: bool) -> SQLRiskLevel: + """ + 计算操作风险等级 + + 规则: + 1. 危险操作(匹配危险模式)=> CRITICAL + 2. DDL操作: + - CREATE/ALTER => HIGH + - DROP/TRUNCATE => CRITICAL + 3. DML操作: + - SELECT => LOW + - INSERT => MEDIUM + - UPDATE/DELETE(有WHERE)=> MEDIUM + - UPDATE(无WHERE)=> HIGH + - DELETE(无WHERE)=> CRITICAL + """ + # 危险操作 + if is_dangerous: + return SQLRiskLevel.CRITICAL + + # 生产环境中非SELECT操作 + if self.env_type == EnvironmentType.PRODUCTION and operation != 'SELECT': + return SQLRiskLevel.CRITICAL + + # DDL操作 + if operation in self.ddl_operations: + if operation in {'DROP', 'TRUNCATE'}: + return SQLRiskLevel.CRITICAL + return SQLRiskLevel.HIGH + + # DML操作 + if operation == 'SELECT': + return SQLRiskLevel.LOW + elif operation == 'INSERT': + return SQLRiskLevel.MEDIUM + elif operation == 'UPDATE': + return SQLRiskLevel.HIGH if 'WHERE' not in sql_query.upper() else SQLRiskLevel.MEDIUM + elif operation == 'DELETE': + # 无WHERE条件的DELETE操作视为CRITICAL风险 + return SQLRiskLevel.CRITICAL if 'WHERE' not in sql_query.upper() else SQLRiskLevel.MEDIUM + + return SQLRiskLevel.HIGH + + def _check_dangerous_patterns(self, sql_query: str) -> bool: + """检查是否匹配危险操作模式""" + sql_upper = sql_query.upper() + + # 生产环境额外的安全检查 + if self.env_type == EnvironmentType.PRODUCTION: + # 生产环境中禁止所有非SELECT操作 + if sql_upper.split()[0] != 'SELECT': + return True + + for pattern in self.blocked_patterns: + if re.search(pattern, sql_upper, re.IGNORECASE): + return True + + return False + + def _get_affected_tables(self, sql_query: str) -> list: + """获取受影响的表名列表""" + words = sql_query.upper().split() + tables = [] + + for i, word in enumerate(words): + if word in {'FROM', 'JOIN', 'UPDATE', 'INTO', 'TABLE'}: + if i + 1 < len(words): + table = words[i + 1].strip('`;') + if table not in {'SELECT', 'WHERE', 'SET'}: + tables.append(table) + + return list(set(tables)) + + def _estimate_impact(self, sql_query: str) -> dict: + """ + 估算查询影响范围 + + Returns: + dict: 包含预估影响的字典 + """ + operation = sql_query.split()[0].upper() + + impact = { + 'operation': operation, + 'estimated_rows': 0, + 'needs_where': operation in {'UPDATE', 'DELETE'}, + 'has_where': 'WHERE' in sql_query.upper() + } + + # 根据环境类型调整估算 + if self.env_type == EnvironmentType.PRODUCTION: + if operation == 'SELECT': + impact['estimated_rows'] = 100 + else: + impact['estimated_rows'] = float('inf') # 生产环境中非SELECT操作视为影响无限行 + else: + if operation == 'SELECT': + impact['estimated_rows'] = 100 + elif operation in {'UPDATE', 'DELETE'}: + impact['estimated_rows'] = 1000 if impact['has_where'] else float('inf') + + return impact \ No newline at end of file diff --git a/src/server.py b/src/server.py index 13df382..d521b27 100644 --- a/src/server.py +++ b/src/server.py @@ -2,6 +2,11 @@ from mcp.server.fastmcp import FastMCP import os import logging from dotenv import load_dotenv + +# 加载环境变量 - 移到最前面确保所有模块导入前环境变量已加载 +load_dotenv() + +# 导入自定义模块 - 确保在load_dotenv之后导入 from src.tools.mysql_tool import register_mysql_tool # 配置日志 @@ -11,6 +16,10 @@ logging.basicConfig( ) logger = logging.getLogger("mysql_server") +# 记录环境变量加载情况 +logger.debug("已加载环境变量") +logger.debug(f"当前允许的风险等级: {os.getenv('ALLOWED_RISK_LEVELS', '未设置')}") + # 尝试导入MySQL连接器 try: import mysql.connector @@ -21,10 +30,6 @@ except ImportError as e: logger.critical("请确保已安装mysql-connector-python包: pip install mysql-connector-python") mysql_available = False -# 加载环境变量 -load_dotenv() -logger.debug("已加载环境变量") - # 从环境变量获取服务器配置 host = os.getenv('HOST', '127.0.0.1') port = int(os.getenv('PORT', '3000')) diff --git a/src/tools/__init__.py b/src/tools/__init__.py new file mode 100644 index 0000000..ea4a702 --- /dev/null +++ b/src/tools/__init__.py @@ -0,0 +1,3 @@ +""" +MySQL工具包 +""" \ No newline at end of file diff --git a/src/tools/mysql_tool.py b/src/tools/mysql_tool.py index 2d9310e..04d3384 100644 --- a/src/tools/mysql_tool.py +++ b/src/tools/mysql_tool.py @@ -24,7 +24,7 @@ def register_mysql_tool(mcp: FastMCP): logger.debug("注册MySQL查询工具...") @mcp.tool() - def mysql_query(query: str, params: Optional[Dict[str, Any]] = None) -> str: + async def mysql_query(query: str, params: Optional[Dict[str, Any]] = None) -> str: """ 执行MySQL查询并返回结果 @@ -39,7 +39,14 @@ def register_mysql_tool(mcp: FastMCP): try: with get_db_connection() as connection: - results = execute_query(connection, query, params) + results = await execute_query(connection, query, params) + + # 检查是否是修改操作返回的影响行数 + operation = query.strip().split()[0].upper() + if operation in {'UPDATE', 'DELETE', 'INSERT'} and results and 'affected_rows' in results[0]: + affected_rows = results[0]['affected_rows'] + logger.info(f"{operation}操作影响了{affected_rows}行数据") + return json.dumps(results, default=str) except Exception as e: