mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Fix: Add RetryingPooledPostgresqlDatabase to handle max_retries param (#10524)
## What problem does this PR solve? Fixes the PostgreSQL connection error that prevents RAGFlow from starting: peewee.ProgrammingError: invalid dsn: invalid connection option "max_retries" ## Problem Analysis The `BaseDataBase` class in `api/db/db_models.py` adds `max_retries` and `retry_delay` to the database configuration dict before passing it to the database connection constructor. - **MySQL**: Has `RetryingPooledMySQLDatabase` class that properly extracts these custom parameters using `kwargs.pop()` before calling the parent constructor - **PostgreSQL**: Was using the base `PooledPostgresqlDatabase` class which passes all parameters directly to `psycopg2.connect()`, which doesn't recognize `max_retries` as a valid connection option ## Solution Created `RetryingPooledPostgresqlDatabase` class that: - Extracts `max_retries` and `retry_delay` parameters before initialization - Implements retry logic with exponential backoff for connection failures - Handles PostgreSQL-specific connection errors (connection refused, server closed, etc.) - Mirrors the existing `RetryingPooledMySQLDatabase` implementation Updated the `PooledDatabase` enum to use the new retrying class for PostgreSQL. ## Benefits ✅ Prevents invalid connection parameters from being passed to psycopg2 ✅ Adds automatic retry logic for PostgreSQL connection failures ✅ Provides better error logging for PostgreSQL-specific issues ✅ Maintains consistency between MySQL and PostgreSQL database handling ## Type of change - [x] Bug Fix (non-breaking change which fixes an issue) ## Testing Tested with PostgreSQL database configuration and verified: - Server starts without the "invalid dsn" error - Database connections are established successfully - Retry logic works correctly on connection failures Co-authored-by: Andrea Bugeja <andrea.bugeja@gig.com>
This commit is contained in:
@ -313,9 +313,75 @@ class RetryingPooledMySQLDatabase(PooledMySQLDatabase):
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
class RetryingPooledPostgresqlDatabase(PooledPostgresqlDatabase):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
self.max_retries = kwargs.pop("max_retries", 5)
|
||||||
|
self.retry_delay = kwargs.pop("retry_delay", 1)
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def execute_sql(self, sql, params=None, commit=True):
|
||||||
|
for attempt in range(self.max_retries + 1):
|
||||||
|
try:
|
||||||
|
return super().execute_sql(sql, params, commit)
|
||||||
|
except (OperationalError, InterfaceError) as e:
|
||||||
|
# PostgreSQL specific error codes
|
||||||
|
# 57P01: admin_shutdown
|
||||||
|
# 57P02: crash_shutdown
|
||||||
|
# 57P03: cannot_connect_now
|
||||||
|
# 08006: connection_failure
|
||||||
|
# 08003: connection_does_not_exist
|
||||||
|
# 08000: connection_exception
|
||||||
|
error_messages = ['connection', 'server closed', 'connection refused',
|
||||||
|
'no connection to the server', 'terminating connection']
|
||||||
|
|
||||||
|
should_retry = any(msg in str(e).lower() for msg in error_messages)
|
||||||
|
|
||||||
|
if should_retry and attempt < self.max_retries:
|
||||||
|
logging.warning(
|
||||||
|
f"PostgreSQL connection issue (attempt {attempt+1}/{self.max_retries}): {e}"
|
||||||
|
)
|
||||||
|
self._handle_connection_loss()
|
||||||
|
time.sleep(self.retry_delay * (2 ** attempt))
|
||||||
|
else:
|
||||||
|
logging.error(f"PostgreSQL execution failure: {e}")
|
||||||
|
raise
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _handle_connection_loss(self):
|
||||||
|
try:
|
||||||
|
self.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
self.connect()
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Failed to reconnect to PostgreSQL: {e}")
|
||||||
|
time.sleep(0.1)
|
||||||
|
self.connect()
|
||||||
|
|
||||||
|
def begin(self):
|
||||||
|
for attempt in range(self.max_retries + 1):
|
||||||
|
try:
|
||||||
|
return super().begin()
|
||||||
|
except (OperationalError, InterfaceError) as e:
|
||||||
|
error_messages = ['connection', 'server closed', 'connection refused',
|
||||||
|
'no connection to the server', 'terminating connection']
|
||||||
|
|
||||||
|
should_retry = any(msg in str(e).lower() for msg in error_messages)
|
||||||
|
|
||||||
|
if should_retry and attempt < self.max_retries:
|
||||||
|
logging.warning(
|
||||||
|
f"PostgreSQL connection lost during transaction (attempt {attempt+1}/{self.max_retries})"
|
||||||
|
)
|
||||||
|
self._handle_connection_loss()
|
||||||
|
time.sleep(self.retry_delay * (2 ** attempt))
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
class PooledDatabase(Enum):
|
class PooledDatabase(Enum):
|
||||||
MYSQL = RetryingPooledMySQLDatabase
|
MYSQL = RetryingPooledMySQLDatabase
|
||||||
POSTGRES = PooledPostgresqlDatabase
|
POSTGRES = RetryingPooledPostgresqlDatabase
|
||||||
|
|
||||||
|
|
||||||
class DatabaseMigrator(Enum):
|
class DatabaseMigrator(Enum):
|
||||||
|
|||||||
Reference in New Issue
Block a user