mirror of
https://github.com/mangooer/mysql-mcp-server-sse.git
synced 2025-12-20 04:38:56 +08:00
<feat> 安全限制
This commit is contained in:
3
src/db/__init__.py
Normal file
3
src/db/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
"""
|
||||
MySQL数据库操作包
|
||||
"""
|
||||
@ -5,8 +5,17 @@ from mysql.connector import Error
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from ..security.sql_analyzer import SQLOperationType
|
||||
from ..security.query_limiter import QueryLimiter
|
||||
from ..security.interceptor import SQLInterceptor, SecurityException
|
||||
|
||||
logger = logging.getLogger("mysql_server")
|
||||
|
||||
# 初始化安全组件
|
||||
sql_analyzer = SQLOperationType()
|
||||
query_limiter = QueryLimiter()
|
||||
sql_interceptor = SQLInterceptor(sql_analyzer)
|
||||
|
||||
def get_db_config():
|
||||
"""动态获取数据库配置"""
|
||||
return {
|
||||
@ -15,7 +24,8 @@ def get_db_config():
|
||||
'password': os.getenv('MYSQL_PASSWORD', ''),
|
||||
'database': os.getenv('MYSQL_DATABASE', ''),
|
||||
'port': int(os.getenv('MYSQL_PORT', '3306')),
|
||||
'connection_timeout': 5
|
||||
'connection_timeout': 5,
|
||||
'auth_plugin': 'mysql_native_password' # 指定认证插件
|
||||
}
|
||||
|
||||
@contextmanager
|
||||
@ -54,7 +64,7 @@ def get_db_connection():
|
||||
connection.close()
|
||||
logger.debug("数据库连接已关闭")
|
||||
|
||||
def execute_query(connection, query: str, params: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]:
|
||||
async def execute_query(connection, query: str, params: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
在给定的数据库连接上执行查询
|
||||
|
||||
@ -64,10 +74,18 @@ def execute_query(connection, query: str, params: Optional[Dict[str, Any]] = Non
|
||||
params: 查询参数 (可选)
|
||||
|
||||
Returns:
|
||||
查询结果列表
|
||||
查询结果列表,如果是修改操作则返回影响的行数
|
||||
|
||||
Raises:
|
||||
SecurityException: 当操作被安全机制拒绝时
|
||||
ValueError: 当查询执行失败时
|
||||
"""
|
||||
cursor = None
|
||||
try:
|
||||
# 安全检查
|
||||
if not await sql_interceptor.check_operation(query):
|
||||
raise SecurityException("操作被安全机制拒绝")
|
||||
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
|
||||
# 执行查询
|
||||
@ -76,12 +94,33 @@ def execute_query(connection, query: str, params: Optional[Dict[str, Any]] = Non
|
||||
else:
|
||||
cursor.execute(query)
|
||||
|
||||
# 获取结果
|
||||
# 获取操作类型
|
||||
operation = query.strip().split()[0].upper()
|
||||
|
||||
# 对于修改操作,提交事务并返回影响的行数
|
||||
if operation in {'UPDATE', 'DELETE', 'INSERT'}:
|
||||
affected_rows = cursor.rowcount
|
||||
# 提交事务,确保更改被保存
|
||||
connection.commit()
|
||||
logger.debug(f"修改操作 {operation} 影响了 {affected_rows} 行数据")
|
||||
return [{'affected_rows': affected_rows}]
|
||||
|
||||
# 对于查询操作,返回结果集
|
||||
results = cursor.fetchall()
|
||||
logger.debug(f"查询返回 {len(results)} 条结果")
|
||||
return results
|
||||
|
||||
except SecurityException as security_err:
|
||||
logger.error(f"安全检查失败: {str(security_err)}")
|
||||
raise
|
||||
except mysql.connector.Error as query_err:
|
||||
# 如果发生错误,进行回滚
|
||||
if operation in {'UPDATE', 'DELETE', 'INSERT'}:
|
||||
try:
|
||||
connection.rollback()
|
||||
logger.debug("事务已回滚")
|
||||
except:
|
||||
pass
|
||||
logger.error(f"查询执行失败: {str(query_err)}")
|
||||
raise ValueError(f"查询执行失败: {str(query_err)}")
|
||||
finally:
|
||||
|
||||
Reference in New Issue
Block a user