diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index 5d92d01e7..428d8542d 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -193,21 +193,30 @@ class Base(ABC): return ans + LENGTH_NOTIFICATION_CN return ans + LENGTH_NOTIFICATION_EN - def _exceptions(self, e, attempt): + @property + def _retryable_errors(self) -> set[str]: + return { + LLMErrorCode.ERROR_RATE_LIMIT, + LLMErrorCode.ERROR_SERVER, + } + + def _should_retry(self, error_code: str) -> bool: + return error_code in self._retryable_errors + + def _exceptions(self, e, attempt) -> str | None: logging.exception("OpenAI chat_with_tools") # Classify the error error_code = self._classify_error(e) if attempt == self.max_retries: error_code = LLMErrorCode.ERROR_MAX_RETRIES - # Check if it's a rate limit error or server error and not the last attempt - should_retry = error_code == LLMErrorCode.ERROR_RATE_LIMIT or error_code == LLMErrorCode.ERROR_SERVER - if not should_retry: - return f"{ERROR_PREFIX}: {error_code} - {str(e)}" + if self._should_retry(error_code): + delay = self._get_delay() + logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})") + time.sleep(delay) + return None - delay = self._get_delay() - logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})") - time.sleep(delay) + return f"{ERROR_PREFIX}: {error_code} - {str(e)}" def _verbose_tool_use(self, name, args, res): return "" + json.dumps({"name": name, "args": args, "result": res}, ensure_ascii=False, indent=2) + "" @@ -536,6 +545,14 @@ class AzureChat(Base): self.client = AzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version) self.model_name = model_name + @property + def _retryable_errors(self) -> set[str]: + return { + LLMErrorCode.ERROR_RATE_LIMIT, + LLMErrorCode.ERROR_SERVER, + LLMErrorCode.ERROR_QUOTA, + } + class BaiChuanChat(Base): _FACTORY_NAME = "BaiChuan" @@ -1424,21 +1441,30 @@ class LiteLLMBase(ABC): return ans + LENGTH_NOTIFICATION_CN return ans + LENGTH_NOTIFICATION_EN - def _exceptions(self, e, attempt): + @property + def _retryable_errors(self) -> set[str]: + return { + LLMErrorCode.ERROR_RATE_LIMIT, + LLMErrorCode.ERROR_SERVER, + } + + def _should_retry(self, error_code: str) -> bool: + return error_code in self._retryable_errors + + def _exceptions(self, e, attempt) -> str | None: logging.exception("OpenAI chat_with_tools") # Classify the error error_code = self._classify_error(e) if attempt == self.max_retries: error_code = LLMErrorCode.ERROR_MAX_RETRIES - # Check if it's a rate limit error or server error and not the last attempt - should_retry = error_code == LLMErrorCode.ERROR_RATE_LIMIT or error_code == LLMErrorCode.ERROR_SERVER - if not should_retry: - return f"{ERROR_PREFIX}: {error_code} - {str(e)}" + if self._should_retry(error_code): + delay = self._get_delay() + logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})") + time.sleep(delay) + return None - delay = self._get_delay() - logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})") - time.sleep(delay) + return f"{ERROR_PREFIX}: {error_code} - {str(e)}" def _verbose_tool_use(self, name, args, res): return "" + json.dumps({"name": name, "args": args, "result": res}, ensure_ascii=False, indent=2) + ""