diff --git a/api/db/db_models.py b/api/db/db_models.py index cda279f22..c541587c3 100644 --- a/api/db/db_models.py +++ b/api/db/db_models.py @@ -26,7 +26,7 @@ from functools import wraps from flask_login import UserMixin from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer -from peewee import BigIntegerField, BooleanField, CharField, CompositeKey, DateTimeField, Field, FloatField, IntegerField, Metadata, Model, TextField +from peewee import InterfaceError, OperationalError, BigIntegerField, BooleanField, CharField, CompositeKey, DateTimeField, Field, FloatField, IntegerField, Metadata, Model, TextField from playhouse.migrate import MySQLMigrator, PostgresqlMigrator, migrate from playhouse.pool import PooledMySQLDatabase, PooledPostgresqlDatabase @@ -250,36 +250,63 @@ class RetryingPooledMySQLDatabase(PooledMySQLDatabase): super().__init__(*args, **kwargs) def execute_sql(self, sql, params=None, commit=True): - from peewee import OperationalError - for attempt in range(self.max_retries + 1): try: return super().execute_sql(sql, params, commit) - except OperationalError as e: - if e.args[0] in (2013, 2006) and attempt < self.max_retries: - logging.warning(f"Lost connection (attempt {attempt + 1}/{self.max_retries}): {e}") + except (OperationalError, InterfaceError) as e: + error_codes = [2013, 2006] + error_messages = ['', 'Lost connection'] + should_retry = ( + (hasattr(e, 'args') and e.args and e.args[0] in error_codes) or + (str(e) in error_messages) or + (hasattr(e, '__class__') and e.__class__.__name__ == 'InterfaceError') + ) + + if should_retry and attempt < self.max_retries: + logging.warning( + f"Database connection issue (attempt {attempt+1}/{self.max_retries}): {e}" + ) self._handle_connection_loss() - time.sleep(self.retry_delay * (2**attempt)) + time.sleep(self.retry_delay * (2 ** attempt)) else: logging.error(f"DB execution failure: {e}") raise return None def _handle_connection_loss(self): - self.close_all() - self.connect() + # self.close_all() + # self.connect() + try: + self.close() + except Exception: + pass + try: + self.connect() + except Exception as e: + logging.error(f"Failed to reconnect: {e}") + time.sleep(0.1) + self.connect() def begin(self): - from peewee import OperationalError - for attempt in range(self.max_retries + 1): try: return super().begin() - except OperationalError as e: - if e.args[0] in (2013, 2006) and attempt < self.max_retries: - logging.warning(f"Lost connection during transaction (attempt {attempt + 1}/{self.max_retries})") + except (OperationalError, InterfaceError) as e: + error_codes = [2013, 2006] + error_messages = ['', 'Lost connection'] + + should_retry = ( + (hasattr(e, 'args') and e.args and e.args[0] in error_codes) or + (str(e) in error_messages) or + (hasattr(e, '__class__') and e.__class__.__name__ == 'InterfaceError') + ) + + if should_retry and attempt < self.max_retries: + logging.warning( + f"Lost connection during transaction (attempt {attempt+1}/{self.max_retries})" + ) self._handle_connection_loss() - time.sleep(self.retry_delay * (2**attempt)) + time.sleep(self.retry_delay * (2 ** attempt)) else: raise @@ -299,7 +326,16 @@ class BaseDataBase: def __init__(self): database_config = settings.DATABASE.copy() db_name = database_config.pop("name") - self.database_connection = PooledDatabase[settings.DATABASE_TYPE.upper()].value(db_name, **database_config) + + pool_config = { + 'max_retries': 5, + 'retry_delay': 1, + } + database_config.update(pool_config) + self.database_connection = PooledDatabase[settings.DATABASE_TYPE.upper()].value( + db_name, **database_config + ) + # self.database_connection = PooledDatabase[settings.DATABASE_TYPE.upper()].value(db_name, **database_config) logging.info("init database on cluster mode successfully") diff --git a/api/db/services/common_service.py b/api/db/services/common_service.py index 7645b43d4..a5c871426 100644 --- a/api/db/services/common_service.py +++ b/api/db/services/common_service.py @@ -14,12 +14,24 @@ # limitations under the License. # from datetime import datetime - +from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type import peewee +from peewee import InterfaceError, OperationalError from api.db.db_models import DB from api.utils import current_timestamp, datetime_format, get_uuid +def retry_db_operation(func): + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=1, max=5), + retry=retry_if_exception_type((InterfaceError, OperationalError)), + before_sleep=lambda retry_state: print(f"RETRY {retry_state.attempt_number} TIMES"), + reraise=True, + ) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + return wrapper class CommonService: """Base service class that provides common database operations. @@ -202,6 +214,7 @@ class CommonService: @classmethod @DB.connection_context() + @retry_db_operation def update_by_id(cls, pid, data): # Update a single record by ID # Args: