diff --git a/rag/utils/opendal_conn.py b/rag/utils/opendal_conn.py index 936081384..b27e8d132 100644 --- a/rag/utils/opendal_conn.py +++ b/rag/utils/opendal_conn.py @@ -1,6 +1,7 @@ import opendal import logging import pymysql +import re from urllib.parse import quote_plus from common.config_utils import get_base_config @@ -110,7 +111,8 @@ class OpenDALStorage: ) cursor = conn.cursor() max_packet = self._kwargs.get('max_allowed_packet', 4194304) # Default to 4MB if not specified - cursor.execute(SET_MAX_ALLOWED_PACKET_SQL.format(max_packet)) + # Ensure max_packet is a valid integer to prevent SQL injection + cursor.execute(SET_MAX_ALLOWED_PACKET_SQL.format(int(max_packet))) conn.commit() cursor.close() conn.close() @@ -120,6 +122,11 @@ class OpenDALStorage: raise def init_opendal_mysql_table(self): + table_name = self._kwargs['table'] + # Validate table name to prevent SQL injection + if not re.match(r'^[a-zA-Z0-9_]+$', table_name): + raise ValueError(f"Invalid table name: {table_name}") + conn = pymysql.connect( host=self._kwargs['host'], port=int(self._kwargs['port']), @@ -128,8 +135,8 @@ class OpenDALStorage: database=self._kwargs['database'] ) cursor = conn.cursor() - cursor.execute(CREATE_TABLE_SQL.format(self._kwargs['table'])) + cursor.execute(CREATE_TABLE_SQL.format(table_name)) conn.commit() cursor.close() conn.close() - logging.info(f"Table `{self._kwargs['table']}` initialized.") + logging.info(f"Table `{table_name}` initialized.")