mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Fixed the issue where database connections were interrupted under high concurrency (#10126)
### What problem does this PR solve? Fixed the issue where database connections were interrupted under high concurrency ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --------- Co-authored-by: lemsn <lemsn@126.com> Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
This commit is contained in:
@ -26,7 +26,7 @@ from functools import wraps
|
|||||||
|
|
||||||
from flask_login import UserMixin
|
from flask_login import UserMixin
|
||||||
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
||||||
from peewee import BigIntegerField, BooleanField, CharField, CompositeKey, DateTimeField, Field, FloatField, IntegerField, Metadata, Model, TextField
|
from peewee import InterfaceError, OperationalError, BigIntegerField, BooleanField, CharField, CompositeKey, DateTimeField, Field, FloatField, IntegerField, Metadata, Model, TextField
|
||||||
from playhouse.migrate import MySQLMigrator, PostgresqlMigrator, migrate
|
from playhouse.migrate import MySQLMigrator, PostgresqlMigrator, migrate
|
||||||
from playhouse.pool import PooledMySQLDatabase, PooledPostgresqlDatabase
|
from playhouse.pool import PooledMySQLDatabase, PooledPostgresqlDatabase
|
||||||
|
|
||||||
@ -250,36 +250,63 @@ class RetryingPooledMySQLDatabase(PooledMySQLDatabase):
|
|||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
def execute_sql(self, sql, params=None, commit=True):
|
def execute_sql(self, sql, params=None, commit=True):
|
||||||
from peewee import OperationalError
|
|
||||||
|
|
||||||
for attempt in range(self.max_retries + 1):
|
for attempt in range(self.max_retries + 1):
|
||||||
try:
|
try:
|
||||||
return super().execute_sql(sql, params, commit)
|
return super().execute_sql(sql, params, commit)
|
||||||
except OperationalError as e:
|
except (OperationalError, InterfaceError) as e:
|
||||||
if e.args[0] in (2013, 2006) and attempt < self.max_retries:
|
error_codes = [2013, 2006]
|
||||||
logging.warning(f"Lost connection (attempt {attempt + 1}/{self.max_retries}): {e}")
|
error_messages = ['', 'Lost connection']
|
||||||
|
should_retry = (
|
||||||
|
(hasattr(e, 'args') and e.args and e.args[0] in error_codes) or
|
||||||
|
(str(e) in error_messages) or
|
||||||
|
(hasattr(e, '__class__') and e.__class__.__name__ == 'InterfaceError')
|
||||||
|
)
|
||||||
|
|
||||||
|
if should_retry and attempt < self.max_retries:
|
||||||
|
logging.warning(
|
||||||
|
f"Database connection issue (attempt {attempt+1}/{self.max_retries}): {e}"
|
||||||
|
)
|
||||||
self._handle_connection_loss()
|
self._handle_connection_loss()
|
||||||
time.sleep(self.retry_delay * (2**attempt))
|
time.sleep(self.retry_delay * (2 ** attempt))
|
||||||
else:
|
else:
|
||||||
logging.error(f"DB execution failure: {e}")
|
logging.error(f"DB execution failure: {e}")
|
||||||
raise
|
raise
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _handle_connection_loss(self):
|
def _handle_connection_loss(self):
|
||||||
self.close_all()
|
# self.close_all()
|
||||||
self.connect()
|
# self.connect()
|
||||||
|
try:
|
||||||
|
self.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
self.connect()
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Failed to reconnect: {e}")
|
||||||
|
time.sleep(0.1)
|
||||||
|
self.connect()
|
||||||
|
|
||||||
def begin(self):
|
def begin(self):
|
||||||
from peewee import OperationalError
|
|
||||||
|
|
||||||
for attempt in range(self.max_retries + 1):
|
for attempt in range(self.max_retries + 1):
|
||||||
try:
|
try:
|
||||||
return super().begin()
|
return super().begin()
|
||||||
except OperationalError as e:
|
except (OperationalError, InterfaceError) as e:
|
||||||
if e.args[0] in (2013, 2006) and attempt < self.max_retries:
|
error_codes = [2013, 2006]
|
||||||
logging.warning(f"Lost connection during transaction (attempt {attempt + 1}/{self.max_retries})")
|
error_messages = ['', 'Lost connection']
|
||||||
|
|
||||||
|
should_retry = (
|
||||||
|
(hasattr(e, 'args') and e.args and e.args[0] in error_codes) or
|
||||||
|
(str(e) in error_messages) or
|
||||||
|
(hasattr(e, '__class__') and e.__class__.__name__ == 'InterfaceError')
|
||||||
|
)
|
||||||
|
|
||||||
|
if should_retry and attempt < self.max_retries:
|
||||||
|
logging.warning(
|
||||||
|
f"Lost connection during transaction (attempt {attempt+1}/{self.max_retries})"
|
||||||
|
)
|
||||||
self._handle_connection_loss()
|
self._handle_connection_loss()
|
||||||
time.sleep(self.retry_delay * (2**attempt))
|
time.sleep(self.retry_delay * (2 ** attempt))
|
||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@ -299,7 +326,16 @@ class BaseDataBase:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
database_config = settings.DATABASE.copy()
|
database_config = settings.DATABASE.copy()
|
||||||
db_name = database_config.pop("name")
|
db_name = database_config.pop("name")
|
||||||
self.database_connection = PooledDatabase[settings.DATABASE_TYPE.upper()].value(db_name, **database_config)
|
|
||||||
|
pool_config = {
|
||||||
|
'max_retries': 5,
|
||||||
|
'retry_delay': 1,
|
||||||
|
}
|
||||||
|
database_config.update(pool_config)
|
||||||
|
self.database_connection = PooledDatabase[settings.DATABASE_TYPE.upper()].value(
|
||||||
|
db_name, **database_config
|
||||||
|
)
|
||||||
|
# self.database_connection = PooledDatabase[settings.DATABASE_TYPE.upper()].value(db_name, **database_config)
|
||||||
logging.info("init database on cluster mode successfully")
|
logging.info("init database on cluster mode successfully")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -14,12 +14,24 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
|
||||||
import peewee
|
import peewee
|
||||||
|
from peewee import InterfaceError, OperationalError
|
||||||
|
|
||||||
from api.db.db_models import DB
|
from api.db.db_models import DB
|
||||||
from api.utils import current_timestamp, datetime_format, get_uuid
|
from api.utils import current_timestamp, datetime_format, get_uuid
|
||||||
|
|
||||||
|
def retry_db_operation(func):
|
||||||
|
@retry(
|
||||||
|
stop=stop_after_attempt(3),
|
||||||
|
wait=wait_exponential(multiplier=1, min=1, max=5),
|
||||||
|
retry=retry_if_exception_type((InterfaceError, OperationalError)),
|
||||||
|
before_sleep=lambda retry_state: print(f"RETRY {retry_state.attempt_number} TIMES"),
|
||||||
|
reraise=True,
|
||||||
|
)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
return wrapper
|
||||||
|
|
||||||
class CommonService:
|
class CommonService:
|
||||||
"""Base service class that provides common database operations.
|
"""Base service class that provides common database operations.
|
||||||
@ -202,6 +214,7 @@ class CommonService:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
|
@retry_db_operation
|
||||||
def update_by_id(cls, pid, data):
|
def update_by_id(cls, pid, data):
|
||||||
# Update a single record by ID
|
# Update a single record by ID
|
||||||
# Args:
|
# Args:
|
||||||
|
|||||||
Reference in New Issue
Block a user