diff --git a/api/db/db_models.py b/api/db/db_models.py index b2681ac3e..3342e1f18 100644 --- a/api/db/db_models.py +++ b/api/db/db_models.py @@ -243,8 +243,51 @@ class JsonSerializedField(SerializedField): super(JsonSerializedField, self).__init__(serialized_type=SerializedType.JSON, object_hook=object_hook, object_pairs_hook=object_pairs_hook, **kwargs) +class RetryingPooledMySQLDatabase(PooledMySQLDatabase): + def __init__(self, *args, **kwargs): + self.max_retries = kwargs.pop('max_retries', 5) + self.retry_delay = kwargs.pop('retry_delay', 1) + super().__init__(*args, **kwargs) + + def execute_sql(self, sql, params=None, commit=True): + from peewee import OperationalError + for attempt in range(self.max_retries + 1): + try: + return super().execute_sql(sql, params, commit) + except OperationalError as e: + if e.args[0] in (2013, 2006) and attempt < self.max_retries: + logging.warning( + f"Lost connection (attempt {attempt+1}/{self.max_retries}): {e}" + ) + self._handle_connection_loss() + time.sleep(self.retry_delay * (2 ** attempt)) + else: + logging.error(f"DB execution failure: {e}") + raise + return None + + def _handle_connection_loss(self): + self.close_all() + self.connect() + + def begin(self): + from peewee import OperationalError + for attempt in range(self.max_retries + 1): + try: + return super().begin() + except OperationalError as e: + if e.args[0] in (2013, 2006) and attempt < self.max_retries: + logging.warning( + f"Lost connection during transaction (attempt {attempt+1}/{self.max_retries})" + ) + self._handle_connection_loss() + time.sleep(self.retry_delay * (2 ** attempt)) + else: + raise + + class PooledDatabase(Enum): - MYSQL = PooledMySQLDatabase + MYSQL = RetryingPooledMySQLDatabase POSTGRES = PooledPostgresqlDatabase diff --git a/graphrag/general/extractor.py b/graphrag/general/extractor.py index 23134425f..da68c811a 100644 --- a/graphrag/general/extractor.py +++ b/graphrag/general/extractor.py @@ -47,7 +47,7 @@ class Extractor: self._language = language self._entity_types = entity_types or DEFAULT_ENTITY_TYPES - @timeout(60) + @timeout(60*3) def _chat(self, system, history, gen_conf): hist = deepcopy(history) conf = deepcopy(gen_conf)