mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Fix: Add RetryingPooledPostgresqlDatabase to handle max_retries param (#10524)
## What problem does this PR solve? Fixes the PostgreSQL connection error that prevents RAGFlow from starting: peewee.ProgrammingError: invalid dsn: invalid connection option "max_retries" ## Problem Analysis The `BaseDataBase` class in `api/db/db_models.py` adds `max_retries` and `retry_delay` to the database configuration dict before passing it to the database connection constructor. - **MySQL**: Has `RetryingPooledMySQLDatabase` class that properly extracts these custom parameters using `kwargs.pop()` before calling the parent constructor - **PostgreSQL**: Was using the base `PooledPostgresqlDatabase` class which passes all parameters directly to `psycopg2.connect()`, which doesn't recognize `max_retries` as a valid connection option ## Solution Created `RetryingPooledPostgresqlDatabase` class that: - Extracts `max_retries` and `retry_delay` parameters before initialization - Implements retry logic with exponential backoff for connection failures - Handles PostgreSQL-specific connection errors (connection refused, server closed, etc.) - Mirrors the existing `RetryingPooledMySQLDatabase` implementation Updated the `PooledDatabase` enum to use the new retrying class for PostgreSQL. ## Benefits ✅ Prevents invalid connection parameters from being passed to psycopg2 ✅ Adds automatic retry logic for PostgreSQL connection failures ✅ Provides better error logging for PostgreSQL-specific issues ✅ Maintains consistency between MySQL and PostgreSQL database handling ## Type of change - [x] Bug Fix (non-breaking change which fixes an issue) ## Testing Tested with PostgreSQL database configuration and verified: - Server starts without the "invalid dsn" error - Database connections are established successfully - Retry logic works correctly on connection failures Co-authored-by: Andrea Bugeja <andrea.bugeja@gig.com>
This commit is contained in:
@ -313,9 +313,75 @@ class RetryingPooledMySQLDatabase(PooledMySQLDatabase):
|
||||
raise
|
||||
|
||||
|
||||
class RetryingPooledPostgresqlDatabase(PooledPostgresqlDatabase):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.max_retries = kwargs.pop("max_retries", 5)
|
||||
self.retry_delay = kwargs.pop("retry_delay", 1)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def execute_sql(self, sql, params=None, commit=True):
|
||||
for attempt in range(self.max_retries + 1):
|
||||
try:
|
||||
return super().execute_sql(sql, params, commit)
|
||||
except (OperationalError, InterfaceError) as e:
|
||||
# PostgreSQL specific error codes
|
||||
# 57P01: admin_shutdown
|
||||
# 57P02: crash_shutdown
|
||||
# 57P03: cannot_connect_now
|
||||
# 08006: connection_failure
|
||||
# 08003: connection_does_not_exist
|
||||
# 08000: connection_exception
|
||||
error_messages = ['connection', 'server closed', 'connection refused',
|
||||
'no connection to the server', 'terminating connection']
|
||||
|
||||
should_retry = any(msg in str(e).lower() for msg in error_messages)
|
||||
|
||||
if should_retry and attempt < self.max_retries:
|
||||
logging.warning(
|
||||
f"PostgreSQL connection issue (attempt {attempt+1}/{self.max_retries}): {e}"
|
||||
)
|
||||
self._handle_connection_loss()
|
||||
time.sleep(self.retry_delay * (2 ** attempt))
|
||||
else:
|
||||
logging.error(f"PostgreSQL execution failure: {e}")
|
||||
raise
|
||||
return None
|
||||
|
||||
def _handle_connection_loss(self):
|
||||
try:
|
||||
self.close()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
self.connect()
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to reconnect to PostgreSQL: {e}")
|
||||
time.sleep(0.1)
|
||||
self.connect()
|
||||
|
||||
def begin(self):
|
||||
for attempt in range(self.max_retries + 1):
|
||||
try:
|
||||
return super().begin()
|
||||
except (OperationalError, InterfaceError) as e:
|
||||
error_messages = ['connection', 'server closed', 'connection refused',
|
||||
'no connection to the server', 'terminating connection']
|
||||
|
||||
should_retry = any(msg in str(e).lower() for msg in error_messages)
|
||||
|
||||
if should_retry and attempt < self.max_retries:
|
||||
logging.warning(
|
||||
f"PostgreSQL connection lost during transaction (attempt {attempt+1}/{self.max_retries})"
|
||||
)
|
||||
self._handle_connection_loss()
|
||||
time.sleep(self.retry_delay * (2 ** attempt))
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
class PooledDatabase(Enum):
|
||||
MYSQL = RetryingPooledMySQLDatabase
|
||||
POSTGRES = PooledPostgresqlDatabase
|
||||
POSTGRES = RetryingPooledPostgresqlDatabase
|
||||
|
||||
|
||||
class DatabaseMigrator(Enum):
|
||||
|
||||
Reference in New Issue
Block a user