mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 12:32:30 +08:00
Fixed the issue where database connections were interrupted under high concurrency (#10126)
### What problem does this PR solve? Fixed the issue where database connections were interrupted under high concurrency ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --------- Co-authored-by: lemsn <lemsn@126.com> Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
This commit is contained in:
@ -26,7 +26,7 @@ from functools import wraps
|
||||
|
||||
from flask_login import UserMixin
|
||||
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
||||
from peewee import BigIntegerField, BooleanField, CharField, CompositeKey, DateTimeField, Field, FloatField, IntegerField, Metadata, Model, TextField
|
||||
from peewee import InterfaceError, OperationalError, BigIntegerField, BooleanField, CharField, CompositeKey, DateTimeField, Field, FloatField, IntegerField, Metadata, Model, TextField
|
||||
from playhouse.migrate import MySQLMigrator, PostgresqlMigrator, migrate
|
||||
from playhouse.pool import PooledMySQLDatabase, PooledPostgresqlDatabase
|
||||
|
||||
@ -250,36 +250,63 @@ class RetryingPooledMySQLDatabase(PooledMySQLDatabase):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def execute_sql(self, sql, params=None, commit=True):
|
||||
from peewee import OperationalError
|
||||
|
||||
for attempt in range(self.max_retries + 1):
|
||||
try:
|
||||
return super().execute_sql(sql, params, commit)
|
||||
except OperationalError as e:
|
||||
if e.args[0] in (2013, 2006) and attempt < self.max_retries:
|
||||
logging.warning(f"Lost connection (attempt {attempt + 1}/{self.max_retries}): {e}")
|
||||
except (OperationalError, InterfaceError) as e:
|
||||
error_codes = [2013, 2006]
|
||||
error_messages = ['', 'Lost connection']
|
||||
should_retry = (
|
||||
(hasattr(e, 'args') and e.args and e.args[0] in error_codes) or
|
||||
(str(e) in error_messages) or
|
||||
(hasattr(e, '__class__') and e.__class__.__name__ == 'InterfaceError')
|
||||
)
|
||||
|
||||
if should_retry and attempt < self.max_retries:
|
||||
logging.warning(
|
||||
f"Database connection issue (attempt {attempt+1}/{self.max_retries}): {e}"
|
||||
)
|
||||
self._handle_connection_loss()
|
||||
time.sleep(self.retry_delay * (2**attempt))
|
||||
time.sleep(self.retry_delay * (2 ** attempt))
|
||||
else:
|
||||
logging.error(f"DB execution failure: {e}")
|
||||
raise
|
||||
return None
|
||||
|
||||
def _handle_connection_loss(self):
|
||||
self.close_all()
|
||||
self.connect()
|
||||
# self.close_all()
|
||||
# self.connect()
|
||||
try:
|
||||
self.close()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
self.connect()
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to reconnect: {e}")
|
||||
time.sleep(0.1)
|
||||
self.connect()
|
||||
|
||||
def begin(self):
|
||||
from peewee import OperationalError
|
||||
|
||||
for attempt in range(self.max_retries + 1):
|
||||
try:
|
||||
return super().begin()
|
||||
except OperationalError as e:
|
||||
if e.args[0] in (2013, 2006) and attempt < self.max_retries:
|
||||
logging.warning(f"Lost connection during transaction (attempt {attempt + 1}/{self.max_retries})")
|
||||
except (OperationalError, InterfaceError) as e:
|
||||
error_codes = [2013, 2006]
|
||||
error_messages = ['', 'Lost connection']
|
||||
|
||||
should_retry = (
|
||||
(hasattr(e, 'args') and e.args and e.args[0] in error_codes) or
|
||||
(str(e) in error_messages) or
|
||||
(hasattr(e, '__class__') and e.__class__.__name__ == 'InterfaceError')
|
||||
)
|
||||
|
||||
if should_retry and attempt < self.max_retries:
|
||||
logging.warning(
|
||||
f"Lost connection during transaction (attempt {attempt+1}/{self.max_retries})"
|
||||
)
|
||||
self._handle_connection_loss()
|
||||
time.sleep(self.retry_delay * (2**attempt))
|
||||
time.sleep(self.retry_delay * (2 ** attempt))
|
||||
else:
|
||||
raise
|
||||
|
||||
@ -299,7 +326,16 @@ class BaseDataBase:
|
||||
def __init__(self):
|
||||
database_config = settings.DATABASE.copy()
|
||||
db_name = database_config.pop("name")
|
||||
self.database_connection = PooledDatabase[settings.DATABASE_TYPE.upper()].value(db_name, **database_config)
|
||||
|
||||
pool_config = {
|
||||
'max_retries': 5,
|
||||
'retry_delay': 1,
|
||||
}
|
||||
database_config.update(pool_config)
|
||||
self.database_connection = PooledDatabase[settings.DATABASE_TYPE.upper()].value(
|
||||
db_name, **database_config
|
||||
)
|
||||
# self.database_connection = PooledDatabase[settings.DATABASE_TYPE.upper()].value(db_name, **database_config)
|
||||
logging.info("init database on cluster mode successfully")
|
||||
|
||||
|
||||
|
||||
@ -14,12 +14,24 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
from datetime import datetime
|
||||
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
|
||||
import peewee
|
||||
from peewee import InterfaceError, OperationalError
|
||||
|
||||
from api.db.db_models import DB
|
||||
from api.utils import current_timestamp, datetime_format, get_uuid
|
||||
|
||||
def retry_db_operation(func):
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=1, max=5),
|
||||
retry=retry_if_exception_type((InterfaceError, OperationalError)),
|
||||
before_sleep=lambda retry_state: print(f"RETRY {retry_state.attempt_number} TIMES"),
|
||||
reraise=True,
|
||||
)
|
||||
def wrapper(*args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
class CommonService:
|
||||
"""Base service class that provides common database operations.
|
||||
@ -202,6 +214,7 @@ class CommonService:
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
@retry_db_operation
|
||||
def update_by_id(cls, pid, data):
|
||||
# Update a single record by ID
|
||||
# Args:
|
||||
|
||||
Reference in New Issue
Block a user