diff --git a/api/db/db_models.py b/api/db/db_models.py index c1a5fd5ed..6f2529e18 100644 --- a/api/db/db_models.py +++ b/api/db/db_models.py @@ -313,9 +313,75 @@ class RetryingPooledMySQLDatabase(PooledMySQLDatabase): raise +class RetryingPooledPostgresqlDatabase(PooledPostgresqlDatabase): + def __init__(self, *args, **kwargs): + self.max_retries = kwargs.pop("max_retries", 5) + self.retry_delay = kwargs.pop("retry_delay", 1) + super().__init__(*args, **kwargs) + + def execute_sql(self, sql, params=None, commit=True): + for attempt in range(self.max_retries + 1): + try: + return super().execute_sql(sql, params, commit) + except (OperationalError, InterfaceError) as e: + # PostgreSQL specific error codes + # 57P01: admin_shutdown + # 57P02: crash_shutdown + # 57P03: cannot_connect_now + # 08006: connection_failure + # 08003: connection_does_not_exist + # 08000: connection_exception + error_messages = ['connection', 'server closed', 'connection refused', + 'no connection to the server', 'terminating connection'] + + should_retry = any(msg in str(e).lower() for msg in error_messages) + + if should_retry and attempt < self.max_retries: + logging.warning( + f"PostgreSQL connection issue (attempt {attempt+1}/{self.max_retries}): {e}" + ) + self._handle_connection_loss() + time.sleep(self.retry_delay * (2 ** attempt)) + else: + logging.error(f"PostgreSQL execution failure: {e}") + raise + return None + + def _handle_connection_loss(self): + try: + self.close() + except Exception: + pass + try: + self.connect() + except Exception as e: + logging.error(f"Failed to reconnect to PostgreSQL: {e}") + time.sleep(0.1) + self.connect() + + def begin(self): + for attempt in range(self.max_retries + 1): + try: + return super().begin() + except (OperationalError, InterfaceError) as e: + error_messages = ['connection', 'server closed', 'connection refused', + 'no connection to the server', 'terminating connection'] + + should_retry = any(msg in str(e).lower() for msg in error_messages) + + if should_retry and attempt < self.max_retries: + logging.warning( + f"PostgreSQL connection lost during transaction (attempt {attempt+1}/{self.max_retries})" + ) + self._handle_connection_loss() + time.sleep(self.retry_delay * (2 ** attempt)) + else: + raise + + class PooledDatabase(Enum): MYSQL = RetryingPooledMySQLDatabase - POSTGRES = PooledPostgresqlDatabase + POSTGRES = RetryingPooledPostgresqlDatabase class DatabaseMigrator(Enum):