|
|
|
|
@ -144,9 +144,9 @@ class Base(ABC):
|
|
|
|
|
if self.model_name.lower().find("qwen3") >= 0:
|
|
|
|
|
kwargs["extra_body"] = {"enable_thinking": False}
|
|
|
|
|
|
|
|
|
|
response = self.client.responses.create(model=self.model_name, messages=history, **gen_conf, **kwargs)
|
|
|
|
|
response = self.client.chat.completions.create(model=self.model_name, messages=history, **gen_conf, **kwargs)
|
|
|
|
|
|
|
|
|
|
if any([not response.choices, not response.choices[0].message, not response.choices[0].message.content]):
|
|
|
|
|
if (not response.choices or not response.choices[0].message or not response.choices[0].message.content):
|
|
|
|
|
return "", 0
|
|
|
|
|
ans = response.choices[0].message.content.strip()
|
|
|
|
|
if response.choices[0].finish_reason == "length":
|
|
|
|
|
@ -158,9 +158,9 @@ class Base(ABC):
|
|
|
|
|
reasoning_start = False
|
|
|
|
|
|
|
|
|
|
if kwargs.get("stop") or "stop" in gen_conf:
|
|
|
|
|
response = self.client.responses.create(model=self.model_name, messages=history, stream=True, **gen_conf, stop=kwargs.get("stop"))
|
|
|
|
|
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf, stop=kwargs.get("stop"))
|
|
|
|
|
else:
|
|
|
|
|
response = self.client.responses.create(model=self.model_name, messages=history, stream=True, **gen_conf)
|
|
|
|
|
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf)
|
|
|
|
|
|
|
|
|
|
for resp in response:
|
|
|
|
|
if not resp.choices:
|
|
|
|
|
@ -193,21 +193,30 @@ class Base(ABC):
|
|
|
|
|
return ans + LENGTH_NOTIFICATION_CN
|
|
|
|
|
return ans + LENGTH_NOTIFICATION_EN
|
|
|
|
|
|
|
|
|
|
def _exceptions(self, e, attempt):
|
|
|
|
|
@property
|
|
|
|
|
def _retryable_errors(self) -> set[str]:
|
|
|
|
|
return {
|
|
|
|
|
LLMErrorCode.ERROR_RATE_LIMIT,
|
|
|
|
|
LLMErrorCode.ERROR_SERVER,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
def _should_retry(self, error_code: str) -> bool:
|
|
|
|
|
return error_code in self._retryable_errors
|
|
|
|
|
|
|
|
|
|
def _exceptions(self, e, attempt) -> str | None:
|
|
|
|
|
logging.exception("OpenAI chat_with_tools")
|
|
|
|
|
# Classify the error
|
|
|
|
|
error_code = self._classify_error(e)
|
|
|
|
|
if attempt == self.max_retries:
|
|
|
|
|
error_code = LLMErrorCode.ERROR_MAX_RETRIES
|
|
|
|
|
|
|
|
|
|
# Check if it's a rate limit error or server error and not the last attempt
|
|
|
|
|
should_retry = error_code == LLMErrorCode.ERROR_RATE_LIMIT or error_code == LLMErrorCode.ERROR_SERVER
|
|
|
|
|
if not should_retry:
|
|
|
|
|
return f"{ERROR_PREFIX}: {error_code} - {str(e)}"
|
|
|
|
|
if self._should_retry(error_code):
|
|
|
|
|
delay = self._get_delay()
|
|
|
|
|
logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})")
|
|
|
|
|
time.sleep(delay)
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
delay = self._get_delay()
|
|
|
|
|
logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})")
|
|
|
|
|
time.sleep(delay)
|
|
|
|
|
return f"{ERROR_PREFIX}: {error_code} - {str(e)}"
|
|
|
|
|
|
|
|
|
|
def _verbose_tool_use(self, name, args, res):
|
|
|
|
|
return "<tool_call>" + json.dumps({"name": name, "args": args, "result": res}, ensure_ascii=False, indent=2) + "</tool_call>"
|
|
|
|
|
@ -257,7 +266,7 @@ class Base(ABC):
|
|
|
|
|
try:
|
|
|
|
|
for _ in range(self.max_rounds + 1):
|
|
|
|
|
logging.info(f"{self.tools=}")
|
|
|
|
|
response = self.client.responses.create(model=self.model_name, messages=history, tools=self.tools, tool_choice="auto", **gen_conf)
|
|
|
|
|
response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=self.tools, tool_choice="auto", **gen_conf)
|
|
|
|
|
tk_count += self.total_token_count(response)
|
|
|
|
|
if any([not response.choices, not response.choices[0].message]):
|
|
|
|
|
raise Exception(f"500 response structure error. Response: {response}")
|
|
|
|
|
@ -342,7 +351,7 @@ class Base(ABC):
|
|
|
|
|
for _ in range(self.max_rounds + 1):
|
|
|
|
|
reasoning_start = False
|
|
|
|
|
logging.info(f"{tools=}")
|
|
|
|
|
response = self.client.responses.create(model=self.model_name, messages=history, stream=True, tools=tools, tool_choice="auto", **gen_conf)
|
|
|
|
|
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, tool_choice="auto", **gen_conf)
|
|
|
|
|
final_tool_calls = {}
|
|
|
|
|
answer = ""
|
|
|
|
|
for resp in response:
|
|
|
|
|
@ -405,7 +414,7 @@ class Base(ABC):
|
|
|
|
|
|
|
|
|
|
logging.warning(f"Exceed max rounds: {self.max_rounds}")
|
|
|
|
|
history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
|
|
|
|
|
response = self.client.responses.create(model=self.model_name, messages=history, stream=True, **gen_conf)
|
|
|
|
|
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf)
|
|
|
|
|
for resp in response:
|
|
|
|
|
if any([not resp.choices, not resp.choices[0].delta, not hasattr(resp.choices[0].delta, "content")]):
|
|
|
|
|
raise Exception("500 response structure error.")
|
|
|
|
|
@ -536,6 +545,14 @@ class AzureChat(Base):
|
|
|
|
|
self.client = AzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version)
|
|
|
|
|
self.model_name = model_name
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def _retryable_errors(self) -> set[str]:
|
|
|
|
|
return {
|
|
|
|
|
LLMErrorCode.ERROR_RATE_LIMIT,
|
|
|
|
|
LLMErrorCode.ERROR_SERVER,
|
|
|
|
|
LLMErrorCode.ERROR_QUOTA,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BaiChuanChat(Base):
|
|
|
|
|
_FACTORY_NAME = "BaiChuan"
|
|
|
|
|
@ -559,7 +576,7 @@ class BaiChuanChat(Base):
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
def _chat(self, history, gen_conf={}, **kwargs):
|
|
|
|
|
response = self.client.responses.create(
|
|
|
|
|
response = self.client.chat.completions.create(
|
|
|
|
|
model=self.model_name,
|
|
|
|
|
messages=history,
|
|
|
|
|
extra_body={"tools": [{"type": "web_search", "web_search": {"enable": True, "search_mode": "performance_first"}}]},
|
|
|
|
|
@ -581,7 +598,7 @@ class BaiChuanChat(Base):
|
|
|
|
|
ans = ""
|
|
|
|
|
total_tokens = 0
|
|
|
|
|
try:
|
|
|
|
|
response = self.client.responses.create(
|
|
|
|
|
response = self.client.chat.completions.create(
|
|
|
|
|
model=self.model_name,
|
|
|
|
|
messages=history,
|
|
|
|
|
extra_body={"tools": [{"type": "web_search", "web_search": {"enable": True, "search_mode": "performance_first"}}]},
|
|
|
|
|
@ -651,7 +668,7 @@ class ZhipuChat(Base):
|
|
|
|
|
tk_count = 0
|
|
|
|
|
try:
|
|
|
|
|
logging.info(json.dumps(history, ensure_ascii=False, indent=2))
|
|
|
|
|
response = self.client.responses.create(model=self.model_name, messages=history, stream=True, **gen_conf)
|
|
|
|
|
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf)
|
|
|
|
|
for resp in response:
|
|
|
|
|
if not resp.choices[0].delta.content:
|
|
|
|
|
continue
|
|
|
|
|
@ -1364,7 +1381,7 @@ class LiteLLMBase(ABC):
|
|
|
|
|
drop_params=True,
|
|
|
|
|
timeout=self.timeout,
|
|
|
|
|
)
|
|
|
|
|
# response = self.client.responses.create(model=self.model_name, messages=history, **gen_conf, **kwargs)
|
|
|
|
|
# response = self.client.chat.completions.create(model=self.model_name, messages=history, **gen_conf, **kwargs)
|
|
|
|
|
|
|
|
|
|
if any([not response.choices, not response.choices[0].message, not response.choices[0].message.content]):
|
|
|
|
|
return "", 0
|
|
|
|
|
@ -1424,21 +1441,30 @@ class LiteLLMBase(ABC):
|
|
|
|
|
return ans + LENGTH_NOTIFICATION_CN
|
|
|
|
|
return ans + LENGTH_NOTIFICATION_EN
|
|
|
|
|
|
|
|
|
|
def _exceptions(self, e, attempt):
|
|
|
|
|
@property
|
|
|
|
|
def _retryable_errors(self) -> set[str]:
|
|
|
|
|
return {
|
|
|
|
|
LLMErrorCode.ERROR_RATE_LIMIT,
|
|
|
|
|
LLMErrorCode.ERROR_SERVER,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
def _should_retry(self, error_code: str) -> bool:
|
|
|
|
|
return error_code in self._retryable_errors
|
|
|
|
|
|
|
|
|
|
def _exceptions(self, e, attempt) -> str | None:
|
|
|
|
|
logging.exception("OpenAI chat_with_tools")
|
|
|
|
|
# Classify the error
|
|
|
|
|
error_code = self._classify_error(e)
|
|
|
|
|
if attempt == self.max_retries:
|
|
|
|
|
error_code = LLMErrorCode.ERROR_MAX_RETRIES
|
|
|
|
|
|
|
|
|
|
# Check if it's a rate limit error or server error and not the last attempt
|
|
|
|
|
should_retry = error_code == LLMErrorCode.ERROR_RATE_LIMIT or error_code == LLMErrorCode.ERROR_SERVER
|
|
|
|
|
if not should_retry:
|
|
|
|
|
return f"{ERROR_PREFIX}: {error_code} - {str(e)}"
|
|
|
|
|
if self._should_retry(error_code):
|
|
|
|
|
delay = self._get_delay()
|
|
|
|
|
logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})")
|
|
|
|
|
time.sleep(delay)
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
delay = self._get_delay()
|
|
|
|
|
logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})")
|
|
|
|
|
time.sleep(delay)
|
|
|
|
|
return f"{ERROR_PREFIX}: {error_code} - {str(e)}"
|
|
|
|
|
|
|
|
|
|
def _verbose_tool_use(self, name, args, res):
|
|
|
|
|
return "<tool_call>" + json.dumps({"name": name, "args": args, "result": res}, ensure_ascii=False, indent=2) + "</tool_call>"
|
|
|
|
|
|