Refa: chat with tools. (#8210)

### What problem does this PR solve?


### Type of change
- [x] Refactoring
This commit is contained in:
Kevin Hu
2025-06-12 12:31:10 +08:00
committed by GitHub
parent 44287fb05f
commit 56ee69e9d9
2 changed files with 130 additions and 171 deletions

View File

@ -26,6 +26,7 @@ from http import HTTPStatus
from typing import Any, Protocol from typing import Any, Protocol
from urllib.parse import urljoin from urllib.parse import urljoin
import json_repair
import openai import openai
import requests import requests
from dashscope import Generation from dashscope import Generation
@ -67,11 +68,12 @@ class Base(ABC):
# Configure retry parameters # Configure retry parameters
self.max_retries = kwargs.get("max_retries", int(os.environ.get("LLM_MAX_RETRIES", 5))) self.max_retries = kwargs.get("max_retries", int(os.environ.get("LLM_MAX_RETRIES", 5)))
self.base_delay = kwargs.get("retry_interval", float(os.environ.get("LLM_BASE_DELAY", 2.0))) self.base_delay = kwargs.get("retry_interval", float(os.environ.get("LLM_BASE_DELAY", 2.0)))
self.max_rounds = kwargs.get("max_rounds", 5)
self.is_tools = False self.is_tools = False
def _get_delay(self, attempt): def _get_delay(self):
"""Calculate retry delay time""" """Calculate retry delay time"""
return self.base_delay * (2**attempt) + random.uniform(0, 0.5) return self.base_delay + random.uniform(0, 0.5)
def _classify_error(self, error): def _classify_error(self, error):
"""Classify error based on error message content""" """Classify error based on error message content"""
@ -116,6 +118,29 @@ class Base(ABC):
ans += LENGTH_NOTIFICATION_EN ans += LENGTH_NOTIFICATION_EN
return ans, self.total_token_count(response) return ans, self.total_token_count(response)
def _length_stop(self, ans):
if is_chinese([ans]):
return ans + LENGTH_NOTIFICATION_CN
return ans + LENGTH_NOTIFICATION_EN
def _exceptions(self, e, attempt):
logging.exception("OpenAI cat_with_tools")
# Classify the error
error_code = self._classify_error(e)
# Check if it's a rate limit error or server error and not the last attempt
should_retry = (error_code == ERROR_RATE_LIMIT or error_code == ERROR_SERVER) and attempt < self.max_retries
if should_retry:
delay = self._get_delay()
logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})")
time.sleep(delay)
else:
# For non-rate limit errors or the last attempt, return an error message
if attempt == self.max_retries:
error_code = ERROR_MAX_RETRIES
return f"{ERROR_PREFIX}: {error_code} - {str(e)}"
def bind_tools(self, toolcall_session, tools): def bind_tools(self, toolcall_session, tools):
if not (toolcall_session and tools): if not (toolcall_session and tools):
return return
@ -124,76 +149,48 @@ class Base(ABC):
self.tools = tools self.tools = tools
def chat_with_tools(self, system: str, history: list, gen_conf: dict): def chat_with_tools(self, system: str, history: list, gen_conf: dict):
if "max_tokens" in gen_conf: gen_conf = self._clean_conf()
del gen_conf["max_tokens"]
tools = self.tools
if system: if system:
history.insert(0, {"role": "system", "content": system}) history.insert(0, {"role": "system", "content": system})
ans = "" ans = ""
tk_count = 0 tk_count = 0
hist = deepcopy(history)
# Implement exponential backoff retry strategy # Implement exponential backoff retry strategy
for attempt in range(self.max_retries): for attempt in range(self.max_retries+1):
try: history = hist
response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=tools, **gen_conf) for _ in range(self.max_rounds*2):
try:
assistant_output = response.choices[0].message response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=self.tools, **gen_conf)
if not ans and "tool_calls" not in assistant_output and "reasoning_content" in assistant_output:
ans += "<think>" + ans + "</think>"
ans += response.choices[0].message.content
if not response.choices[0].message.tool_calls:
tk_count += self.total_token_count(response) tk_count += self.total_token_count(response)
if response.choices[0].finish_reason == "length": if any([not response.choices, not response.choices[0].message, not response.choices[0].message.content]):
if is_chinese([ans]): raise Exception("500 response structure error.")
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
return ans, tk_count
tk_count += self.total_token_count(response) if not hasattr(response.choices[0].message, "tool_calls") or not response.choices[0].message.tool_calls:
history.append(assistant_output) if hasattr(response.choices[0].message, "reasoning_content") and response.choices[0].message.reasoning_content:
ans += "<think>" + response.choices[0].message.reasoning_content + "</think>"
for tool_call in response.choices[0].message.tool_calls: ans += response.choices[0].message.content
name = tool_call.function.name if response.choices[0].finish_reason == "length":
args = json.loads(tool_call.function.arguments) ans = self._length_stop(ans)
tool_response = self.toolcall_session.tool_call(name, args) return ans, tk_count
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_response)})
final_response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=tools, **gen_conf) for tool_call in response.choices[0].message.tool_calls:
assistant_output = final_response.choices[0].message name = tool_call.function.name
if "tool_calls" not in assistant_output and "reasoning_content" in assistant_output: try:
ans += "<think>" + ans + "</think>" args = json_repair.loads(tool_call.function.arguments)
ans += final_response.choices[0].message.content tool_response = self.toolcall_session.tool_call(name, args)
if final_response.choices[0].finish_reason == "length": history.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_response)})
tk_count += self.total_token_count(response) except Exception as e:
if is_chinese([ans]): history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
return ans, tk_count
return ans, tk_count
except Exception as e:
logging.exception("OpenAI cat_with_tools")
# Classify the error
error_code = self._classify_error(e)
# Check if it's a rate limit error or server error and not the last attempt except Exception as e:
should_retry = (error_code == ERROR_RATE_LIMIT or error_code == ERROR_SERVER) and attempt < self.max_retries - 1 e = self._exceptions(e, attempt)
if e:
if should_retry: return e, tk_count
delay = self._get_delay(attempt) assert False, "Shouldn't be here."
logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})")
time.sleep(delay)
else:
# For non-rate limit errors or the last attempt, return an error message
if attempt == self.max_retries - 1:
error_code = ERROR_MAX_RETRIES
return f"{ERROR_PREFIX}: {error_code} - {str(e)}", 0
def chat(self, system, history, gen_conf): def chat(self, system, history, gen_conf):
if system: if system:
@ -201,26 +198,14 @@ class Base(ABC):
gen_conf = self._clean_conf(gen_conf) gen_conf = self._clean_conf(gen_conf)
# Implement exponential backoff retry strategy # Implement exponential backoff retry strategy
for attempt in range(self.max_retries): for attempt in range(self.max_retries+1):
try: try:
return self._chat(history, gen_conf) return self._chat(history, gen_conf)
except Exception as e: except Exception as e:
logging.exception("chat_model.Base.chat got exception") e = self._exceptions(e, attempt)
# Classify the error if e:
error_code = self._classify_error(e) return e, 0
assert False, "Shouldn't be here."
# Check if it's a rate limit error or server error and not the last attempt
should_retry = (error_code == ERROR_RATE_LIMIT or error_code == ERROR_SERVER) and attempt < self.max_retries - 1
if should_retry:
delay = self._get_delay(attempt)
logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})")
time.sleep(delay)
else:
# For non-rate limit errors or the last attempt, return an error message
if attempt == self.max_retries - 1:
error_code = ERROR_MAX_RETRIES
return f"{ERROR_PREFIX}: {error_code} - {str(e)}", 0
def _wrap_toolcall_message(self, stream): def _wrap_toolcall_message(self, stream):
final_tool_calls = {} final_tool_calls = {}
@ -241,41 +226,48 @@ class Base(ABC):
del gen_conf["max_tokens"] del gen_conf["max_tokens"]
tools = self.tools tools = self.tools
if system: if system:
history.insert(0, {"role": "system", "content": system}) history.insert(0, {"role": "system", "content": system})
ans = ""
total_tokens = 0 total_tokens = 0
reasoning_start = False hist = deepcopy(history)
finish_completion = False # Implement exponential backoff retry strategy
final_tool_calls = {} for attempt in range(self.max_retries+1):
try: history = hist
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, **gen_conf) for _ in range(self.max_rounds*2):
while not finish_completion: reasoning_start = False
for resp in response: try:
if resp.choices[0].delta.tool_calls: response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, **gen_conf)
for tool_call in resp.choices[0].delta.tool_calls or []: final_tool_calls = {}
index = tool_call.index answer = ""
for resp in response:
if resp.choices[0].delta.tool_calls:
for tool_call in resp.choices[0].delta.tool_calls or []:
index = tool_call.index
if index not in final_tool_calls: if index not in final_tool_calls:
final_tool_calls[index] = tool_call final_tool_calls[index] = tool_call
else: else:
final_tool_calls[index].function.arguments += tool_call.function.arguments final_tool_calls[index].function.arguments += tool_call.function.arguments
else:
if not resp.choices:
continue continue
if any([not resp.choices, not resp.choices[0].delta, not hasattr(resp.choices[0].delta, "content")]):
raise Exception("500 response structure error.")
if not resp.choices[0].delta.content: if not resp.choices[0].delta.content:
resp.choices[0].delta.content = "" resp.choices[0].delta.content = ""
if hasattr(resp.choices[0].delta, "reasoning_content") and resp.choices[0].delta.reasoning_content: if hasattr(resp.choices[0].delta, "reasoning_content") and resp.choices[0].delta.reasoning_content:
ans = "" ans = ""
if not reasoning_start: if not reasoning_start:
reasoning_start = True reasoning_start = True
ans = "<think>" ans = "<think>"
ans += resp.choices[0].delta.reasoning_content + "</think>" ans += resp.choices[0].delta.reasoning_content + "</think>"
yield ans
else: else:
reasoning_start = False reasoning_start = False
ans = resp.choices[0].delta.content answer += resp.choices[0].delta.content
yield resp.choices[0].delta.content
tol = self.total_token_count(resp) tol = self.total_token_count(resp)
if not tol: if not tol:
@ -283,18 +275,18 @@ class Base(ABC):
else: else:
total_tokens += tol total_tokens += tol
finish_reason = resp.choices[0].finish_reason finish_reason = resp.choices[0].finish_reason if hasattr(resp.choices[0], "finish_reason") else ""
if finish_reason == "tool_calls" and final_tool_calls: if finish_reason == "length":
for tool_call in final_tool_calls.values(): yield self._length_stop("")
name = tool_call.function.name
try:
args = json.loads(tool_call.function.arguments)
except Exception as e:
logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}")
yield ans + "\n**ERROR**: " + str(e)
finish_completion = True
break
if answer:
yield total_tokens
return
for tool_call in final_tool_calls.values():
name = tool_call.function.name
try:
args = json_repair.loads(tool_call.function.arguments)
tool_response = self.toolcall_session.tool_call(name, args) tool_response = self.toolcall_session.tool_call(name, args)
history.append( history.append(
{ {
@ -313,26 +305,16 @@ class Base(ABC):
} }
) )
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_response)}) history.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_response)})
final_tool_calls = {} except Exception as e:
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, **gen_conf) logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}")
continue history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
if finish_reason == "length": except Exception as e:
if is_chinese(ans): e = self._exceptions(e, attempt)
ans += LENGTH_NOTIFICATION_CN if e:
else: yield total_tokens
ans += LENGTH_NOTIFICATION_EN return
return ans, total_tokens
if finish_reason == "stop":
finish_completion = True
yield ans
break
yield ans
continue
except openai.APIError as e: assert False, "Shouldn't be here."
yield ans + "\n**ERROR**: " + str(e)
yield total_tokens
def chat_streamly(self, system, history, gen_conf): def chat_streamly(self, system, history, gen_conf):
if system: if system:
@ -636,49 +618,21 @@ class QWenChat(Base):
return "".join(result_list[:-1]), result_list[-1] return "".join(result_list[:-1]), result_list[-1]
def _chat(self, history, gen_conf): def _chat(self, history, gen_conf):
tk_count = 0
if self.is_reasoning_model(self.model_name) or self.model_name in ["qwen-vl-plus", "qwen-vl-plus-latest", "qwen-vl-max", "qwen-vl-max-latest"]: if self.is_reasoning_model(self.model_name) or self.model_name in ["qwen-vl-plus", "qwen-vl-plus-latest", "qwen-vl-max", "qwen-vl-max-latest"]:
try: return super()._chat(history, gen_conf)
response = super()._chat(history, gen_conf) response = Generation.call(self.model_name, messages=history, result_format="message", **gen_conf)
return response ans = ""
except Exception as e: tk_count = 0
error_msg = str(e).lower() if response.status_code == HTTPStatus.OK:
if "invalid_parameter_error" in error_msg and "only support stream mode" in error_msg: ans += response.output.choices[0]["message"]["content"]
return self._simulate_one_shot_from_stream(history, gen_conf) tk_count += self.total_token_count(response)
if response.output.choices[0].get("finish_reason", "") == "length":
if is_chinese([ans]):
ans += LENGTH_NOTIFICATION_CN
else: else:
return "**ERROR**: " + str(e), tk_count ans += LENGTH_NOTIFICATION_EN
return ans, tk_count
try: return "**ERROR**: " + response.message, tk_count
ans = ""
response = Generation.call(self.model_name, messages=history, result_format="message", **gen_conf)
if response.status_code == HTTPStatus.OK:
ans += response.output.choices[0]["message"]["content"]
tk_count += self.total_token_count(response)
if response.output.choices[0].get("finish_reason", "") == "length":
if is_chinese([ans]):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
return ans, tk_count
return "**ERROR**: " + response.message, tk_count
except Exception as e:
error_msg = str(e).lower()
if "invalid_parameter_error" in error_msg and "only support stream mode" in error_msg:
return self._simulate_one_shot_from_stream(history, gen_conf)
else:
return "**ERROR**: " + str(e), tk_count
def _simulate_one_shot_from_stream(self, history, gen_conf):
"""
Handles models that require streaming output but need one-shot response.
"""
g = self._chat_streamly("", history, gen_conf, incremental_output=True)
result_list = list(g)
error_msg_list = [item for item in result_list if str(item).find("**ERROR**") >= 0]
if len(error_msg_list) > 0:
return "**ERROR**: " + "".join(error_msg_list), 0
else:
return "".join(result_list[:-1]), result_list[-1]
def _wrap_toolcall_message(self, old_message, message): def _wrap_toolcall_message(self, old_message, message):
if not old_message: if not old_message:
@ -971,10 +925,10 @@ class LocalAIChat(Base):
class LocalLLM(Base): class LocalLLM(Base):
def __init__(self, key, model_name, base_url=None, **kwargs): def __init__(self, key, model_name, base_url=None, **kwargs):
super().__init__(key, model_name, base_url=base_url, **kwargs) super().__init__(key, model_name, base_url=base_url, **kwargs)
from jina import Client from jina import Client
self.client = Client(port=12345, protocol="grpc", asyncio=True) self.client = Client(port=12345, protocol="grpc", asyncio=True)
def _prepare_prompt(self, system, history, gen_conf): def _prepare_prompt(self, system, history, gen_conf):
@ -1031,7 +985,13 @@ class VolcEngineChat(Base):
class MiniMaxChat(Base): class MiniMaxChat(Base):
def __init__(self, key, model_name, base_url="https://api.minimax.chat/v1/text/chatcompletion_v2", **kwargs): def __init__(
self,
key,
model_name,
base_url="https://api.minimax.chat/v1/text/chatcompletion_v2",
**kwargs
):
super().__init__(key, model_name, base_url=base_url, **kwargs) super().__init__(key, model_name, base_url=base_url, **kwargs)
if not base_url: if not base_url:
@ -1263,7 +1223,6 @@ class GeminiChat(Base):
def _chat(self, history, gen_conf): def _chat(self, history, gen_conf):
from google.generativeai.types import content_types from google.generativeai.types import content_types
system = history[0]["content"] if history and history[0]["role"] == "system" else "" system = history[0]["content"] if history and history[0]["role"] == "system" else ""
hist = [] hist = []
for item in history: for item in history:
@ -1921,4 +1880,4 @@ class GPUStackChat(Base):
if not base_url: if not base_url:
raise ValueError("Local llm url cannot be None") raise ValueError("Local llm url cannot be None")
base_url = urljoin(base_url, "v1") base_url = urljoin(base_url, "v1")
super().__init__(key, model_name, base_url, **kwargs) super().__init__(key, model_name, base_url, **kwargs)

View File

@ -119,7 +119,7 @@ def kb_prompt(kbinfos, max_tokens):
doc2chunks = defaultdict(lambda: {"chunks": [], "meta": []}) doc2chunks = defaultdict(lambda: {"chunks": [], "meta": []})
for i, ck in enumerate(kbinfos["chunks"][:chunks_num]): for i, ck in enumerate(kbinfos["chunks"][:chunks_num]):
cnt = f"---\nID: {i}\n" + (f"URL: {ck['url']}\n" if "url" in ck else "") cnt = f"---\nID: {i}\n" + (f"URL: {ck['url']}\n" if "url" in ck else "")
cnt += ck["content_with_weight"] cnt += re.sub(r"( style=\"[^\"]+\"|</?(html|body|head|title)>|<!DOCTYPE html>)", " ", ck["content_with_weight"], flags=re.DOTALL|re.IGNORECASE)
doc2chunks[ck["docnm_kwd"]]["chunks"].append(cnt) doc2chunks[ck["docnm_kwd"]]["chunks"].append(cnt)
doc2chunks[ck["docnm_kwd"]]["meta"] = docs.get(ck["doc_id"], {}) doc2chunks[ck["docnm_kwd"]]["meta"] = docs.get(ck["doc_id"], {})