diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py
index 5d92d01e7..428d8542d 100644
--- a/rag/llm/chat_model.py
+++ b/rag/llm/chat_model.py
@@ -193,21 +193,30 @@ class Base(ABC):
return ans + LENGTH_NOTIFICATION_CN
return ans + LENGTH_NOTIFICATION_EN
- def _exceptions(self, e, attempt):
+ @property
+ def _retryable_errors(self) -> set[str]:
+ return {
+ LLMErrorCode.ERROR_RATE_LIMIT,
+ LLMErrorCode.ERROR_SERVER,
+ }
+
+ def _should_retry(self, error_code: str) -> bool:
+ return error_code in self._retryable_errors
+
+ def _exceptions(self, e, attempt) -> str | None:
logging.exception("OpenAI chat_with_tools")
# Classify the error
error_code = self._classify_error(e)
if attempt == self.max_retries:
error_code = LLMErrorCode.ERROR_MAX_RETRIES
- # Check if it's a rate limit error or server error and not the last attempt
- should_retry = error_code == LLMErrorCode.ERROR_RATE_LIMIT or error_code == LLMErrorCode.ERROR_SERVER
- if not should_retry:
- return f"{ERROR_PREFIX}: {error_code} - {str(e)}"
+ if self._should_retry(error_code):
+ delay = self._get_delay()
+ logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})")
+ time.sleep(delay)
+ return None
- delay = self._get_delay()
- logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})")
- time.sleep(delay)
+ return f"{ERROR_PREFIX}: {error_code} - {str(e)}"
def _verbose_tool_use(self, name, args, res):
return "" + json.dumps({"name": name, "args": args, "result": res}, ensure_ascii=False, indent=2) + ""
@@ -536,6 +545,14 @@ class AzureChat(Base):
self.client = AzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version)
self.model_name = model_name
+ @property
+ def _retryable_errors(self) -> set[str]:
+ return {
+ LLMErrorCode.ERROR_RATE_LIMIT,
+ LLMErrorCode.ERROR_SERVER,
+ LLMErrorCode.ERROR_QUOTA,
+ }
+
class BaiChuanChat(Base):
_FACTORY_NAME = "BaiChuan"
@@ -1424,21 +1441,30 @@ class LiteLLMBase(ABC):
return ans + LENGTH_NOTIFICATION_CN
return ans + LENGTH_NOTIFICATION_EN
- def _exceptions(self, e, attempt):
+ @property
+ def _retryable_errors(self) -> set[str]:
+ return {
+ LLMErrorCode.ERROR_RATE_LIMIT,
+ LLMErrorCode.ERROR_SERVER,
+ }
+
+ def _should_retry(self, error_code: str) -> bool:
+ return error_code in self._retryable_errors
+
+ def _exceptions(self, e, attempt) -> str | None:
logging.exception("OpenAI chat_with_tools")
# Classify the error
error_code = self._classify_error(e)
if attempt == self.max_retries:
error_code = LLMErrorCode.ERROR_MAX_RETRIES
- # Check if it's a rate limit error or server error and not the last attempt
- should_retry = error_code == LLMErrorCode.ERROR_RATE_LIMIT or error_code == LLMErrorCode.ERROR_SERVER
- if not should_retry:
- return f"{ERROR_PREFIX}: {error_code} - {str(e)}"
+ if self._should_retry(error_code):
+ delay = self._get_delay()
+ logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})")
+ time.sleep(delay)
+ return None
- delay = self._get_delay()
- logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})")
- time.sleep(delay)
+ return f"{ERROR_PREFIX}: {error_code} - {str(e)}"
def _verbose_tool_use(self, name, args, res):
return "" + json.dumps({"name": name, "args": args, "result": res}, ensure_ascii=False, indent=2) + ""