Refa: limit embedding concurrency and fix chat_with_tool (#8543)

### What problem does this PR solve?

#8538

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
- [x] Refactoring
This commit is contained in:
Kevin Hu
2025-06-27 19:28:41 +08:00
committed by GitHub
parent 8e1f8a0c48
commit e441c17c2c
2 changed files with 75 additions and 303 deletions

View File

@ -18,11 +18,9 @@ import json
import logging
import os
import random
import re
import time
from abc import ABC
from copy import deepcopy
from http import HTTPStatus
from typing import Any, Protocol
from urllib.parse import urljoin
@ -61,9 +59,6 @@ class ToolCallSession(Protocol):
class Base(ABC):
tools: list[Any]
toolcall_sessions: dict[str, ToolCallSession]
def __init__(self, key, model_name, base_url, **kwargs):
timeout = int(os.environ.get("LM_TIMEOUT_SECONDS", 600))
self.client = OpenAI(api_key=key, base_url=base_url, timeout=timeout)
@ -146,6 +141,37 @@ class Base(ABC):
error_code = ERROR_MAX_RETRIES
return f"{ERROR_PREFIX}: {error_code} - {str(e)}"
def _verbose_tool_use(self, name, args, res):
return "<tool_call>" + json.dumps({
"name": name,
"args": args,
"result": res
}, ensure_ascii=False, indent=2) + "</tool_call>"
def _append_history(self, hist, tool_call, tool_res):
hist.append(
{
"role": "assistant",
"tool_calls": [
{
"index": tool_call.index,
"id": tool_call.id,
"function": {
"name": tool_call.function.name,
"arguments": tool_call.function.arguments,
},
"type": "function",
},
],
}
)
try:
if isinstance(tool_res, dict):
tool_res = json.dumps(tool_res, ensure_ascii=False)
finally:
hist.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_res)})
return hist
def bind_tools(self, toolcall_session, tools):
if not (toolcall_session and tools):
return
@ -160,18 +186,19 @@ class Base(ABC):
if system:
history.insert(0, {"role": "system", "content": system})
gen_conf = self._clean_conf(gen_conf)
ans = ""
tk_count = 0
hist = deepcopy(history)
# Implement exponential backoff retry strategy
for attempt in range(self.max_retries + 1):
for attempt in range(self.max_retries+1):
history = hist
for _ in range(self.max_rounds * 2):
try:
try:
for _ in range(self.max_rounds*2):
response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=self.tools, **gen_conf)
tk_count += self.total_token_count(response)
if any([not response.choices, not response.choices[0].message, not response.choices[0].message.content]):
raise Exception("500 response structure error.")
if any([not response.choices, not response.choices[0].message]):
raise Exception(f"500 response structure error. Response: {response}")
if not hasattr(response.choices[0].message, "tool_calls") or not response.choices[0].message.tool_calls:
if hasattr(response.choices[0].message, "reasoning_content") and response.choices[0].message.reasoning_content:
@ -188,14 +215,17 @@ class Base(ABC):
try:
args = json_repair.loads(tool_call.function.arguments)
tool_response = self.toolcall_sessions[name].tool_call(name, args)
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_response)})
history = self._append_history(history, tool_call, tool_response)
ans += self._verbose_tool_use(name, args, tool_response)
except Exception as e:
logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}")
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
ans += self._verbose_tool_use(name, {}, str(e))
except Exception as e:
e = self._exceptions(e, attempt)
if e:
return e, tk_count
except Exception as e:
e = self._exceptions(e, attempt)
if e:
return e, tk_count
assert False, "Shouldn't be here."
def chat(self, system, history, gen_conf):
@ -228,9 +258,7 @@ class Base(ABC):
return final_tool_calls
def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict):
if "max_tokens" in gen_conf:
del gen_conf["max_tokens"]
gen_conf = self._clean_conf(gen_conf)
tools = self.tools
if system:
history.insert(0, {"role": "system", "content": system})
@ -240,9 +268,9 @@ class Base(ABC):
# Implement exponential backoff retry strategy
for attempt in range(self.max_retries + 1):
history = hist
for _ in range(self.max_rounds * 2):
reasoning_start = False
try:
try:
for _ in range(self.max_rounds*2):
reasoning_start = False
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, **gen_conf)
final_tool_calls = {}
answer = ""
@ -252,9 +280,11 @@ class Base(ABC):
index = tool_call.index
if index not in final_tool_calls:
if not tool_call.function.arguments:
tool_call.function.arguments = ""
final_tool_calls[index] = tool_call
else:
final_tool_calls[index].function.arguments += tool_call.function.arguments
final_tool_calls[index].function.arguments += tool_call.function.arguments if tool_call.function.arguments else ""
continue
if any([not resp.choices, not resp.choices[0].delta, not hasattr(resp.choices[0].delta, "content")]):
@ -293,40 +323,26 @@ class Base(ABC):
name = tool_call.function.name
try:
args = json_repair.loads(tool_call.function.arguments)
tool_response = self.toolcall_sessions[name].tool_call(name, args)
history.append(
{
"role": "assistant",
"tool_calls": [
{
"index": tool_call.index,
"id": tool_call.id,
"function": {
"name": tool_call.function.name,
"arguments": tool_call.function.arguments,
},
"type": "function",
},
],
}
)
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_response)})
tool_response = self.toolcall_session[name].tool_call(name, args)
history = self._append_history(history, tool_call, tool_response)
yield self._verbose_tool_use(name, args, tool_response)
except Exception as e:
logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}")
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
except Exception as e:
e = self._exceptions(e, attempt)
if e:
yield total_tokens
return
yield self._verbose_tool_use(name, {}, str(e))
assert False, "Shouldn't be here."
except Exception as e:
e = self._exceptions(e, attempt)
if e:
yield total_tokens
return
yield total_tokens
def chat_streamly(self, system, history, gen_conf):
if system:
history.insert(0, {"role": "system", "content": system})
if "max_tokens" in gen_conf:
del gen_conf["max_tokens"]
gen_conf = self._clean_conf(gen_conf)
ans = ""
total_tokens = 0
reasoning_start = False
@ -542,252 +558,8 @@ class BaiChuanChat(Base):
class QWenChat(Base):
def __init__(self, key, model_name=Generation.Models.qwen_turbo, base_url=None, **kwargs):
super().__init__(key, model_name, base_url=base_url, **kwargs)
import dashscope
dashscope.api_key = key
self.model_name = model_name
if self.is_reasoning_model(self.model_name) or self.model_name in ["qwen-vl-plus", "qwen-vl-plus-latest", "qwen-vl-max", "qwen-vl-max-latest"]:
super().__init__(key, model_name, "https://dashscope.aliyuncs.com/compatible-mode/v1", **kwargs)
def chat_with_tools(self, system: str, history: list, gen_conf: dict) -> tuple[str, int]:
if "max_tokens" in gen_conf:
del gen_conf["max_tokens"]
# if self.is_reasoning_model(self.model_name):
# return super().chat(system, history, gen_conf)
stream_flag = str(os.environ.get("QWEN_CHAT_BY_STREAM", "true")).lower() == "true"
if not stream_flag:
from http import HTTPStatus
tools = self.tools
if system:
history.insert(0, {"role": "system", "content": system})
response = Generation.call(self.model_name, messages=history, result_format="message", tools=tools, **gen_conf)
ans = ""
tk_count = 0
if response.status_code == HTTPStatus.OK:
assistant_output = response.output.choices[0].message
if not ans and "tool_calls" not in assistant_output and "reasoning_content" in assistant_output:
ans += "<think>" + ans + "</think>"
ans += response.output.choices[0].message.content
if "tool_calls" not in assistant_output:
tk_count += self.total_token_count(response)
if response.output.choices[0].get("finish_reason", "") == "length":
if is_chinese([ans]):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
return ans, tk_count
tk_count += self.total_token_count(response)
history.append(assistant_output)
while "tool_calls" in assistant_output:
tool_info = {"content": "", "role": "tool", "tool_call_id": assistant_output.tool_calls[0]["id"]}
tool_name = assistant_output.tool_calls[0]["function"]["name"]
if tool_name:
arguments = json.loads(assistant_output.tool_calls[0]["function"]["arguments"])
tool_info["content"] = self.toolcall_sessions[tool_name].tool_call(name=tool_name, arguments=arguments)
history.append(tool_info)
response = Generation.call(self.model_name, messages=history, result_format="message", tools=self.tools, **gen_conf)
if response.output.choices[0].get("finish_reason", "") == "length":
tk_count += self.total_token_count(response)
if is_chinese([ans]):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
return ans, tk_count
tk_count += self.total_token_count(response)
assistant_output = response.output.choices[0].message
if assistant_output.content is None:
assistant_output.content = ""
history.append(response)
ans += assistant_output["content"]
return ans, tk_count
else:
return "**ERROR**: " + response.message, tk_count
else:
result_list = []
for result in self._chat_streamly_with_tools(system, history, gen_conf, incremental_output=True):
result_list.append(result)
error_msg_list = [result for result in result_list if str(result).find("**ERROR**") >= 0]
if len(error_msg_list) > 0:
return "**ERROR**: " + "".join(error_msg_list), 0
else:
return "".join(result_list[:-1]), result_list[-1]
def _chat(self, history, gen_conf):
if self.is_reasoning_model(self.model_name) or self.model_name in ["qwen-vl-plus", "qwen-vl-plus-latest", "qwen-vl-max", "qwen-vl-max-latest"]:
return super()._chat(history, gen_conf)
response = Generation.call(self.model_name, messages=history, result_format="message", **gen_conf)
ans = ""
tk_count = 0
if response.status_code == HTTPStatus.OK:
ans += response.output.choices[0]["message"]["content"]
tk_count += self.total_token_count(response)
if response.output.choices[0].get("finish_reason", "") == "length":
if is_chinese([ans]):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
return ans, tk_count
return "**ERROR**: " + response.message, tk_count
def _wrap_toolcall_message(self, old_message, message):
if not old_message:
return message
tool_call_id = message["tool_calls"][0].get("id")
if tool_call_id:
old_message.tool_calls[0]["id"] = tool_call_id
function = message.tool_calls[0]["function"]
if function:
if function.get("name"):
old_message.tool_calls[0]["function"]["name"] = function["name"]
if function.get("arguments"):
old_message.tool_calls[0]["function"]["arguments"] += function["arguments"]
return old_message
def _chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict, incremental_output=True):
from http import HTTPStatus
if system:
history.insert(0, {"role": "system", "content": system})
if "max_tokens" in gen_conf:
del gen_conf["max_tokens"]
ans = ""
tk_count = 0
try:
response = Generation.call(self.model_name, messages=history, result_format="message", tools=self.tools, stream=True, incremental_output=incremental_output, **gen_conf)
tool_info = {"content": "", "role": "tool"}
toolcall_message = None
tool_name = ""
tool_arguments = ""
finish_completion = False
reasoning_start = False
while not finish_completion:
for resp in response:
if resp.status_code == HTTPStatus.OK:
assistant_output = resp.output.choices[0].message
ans = resp.output.choices[0].message.content
if not ans and "tool_calls" not in assistant_output and "reasoning_content" in assistant_output:
ans = resp.output.choices[0].message.reasoning_content
if not reasoning_start:
reasoning_start = True
ans = "<think>" + ans
else:
ans = ans + "</think>"
if "tool_calls" not in assistant_output:
reasoning_start = False
tk_count += self.total_token_count(resp)
if resp.output.choices[0].get("finish_reason", "") == "length":
if is_chinese([ans]):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
finish_reason = resp.output.choices[0]["finish_reason"]
if finish_reason == "stop":
finish_completion = True
yield ans
break
yield ans
continue
tk_count += self.total_token_count(resp)
toolcall_message = self._wrap_toolcall_message(toolcall_message, assistant_output)
if "tool_calls" in assistant_output:
tool_call_finish_reason = resp.output.choices[0]["finish_reason"]
if tool_call_finish_reason == "tool_calls":
try:
tool_arguments = json.loads(toolcall_message.tool_calls[0]["function"]["arguments"])
except Exception as e:
logging.exception(msg="_chat_streamly_with_tool tool call error")
yield ans + "\n**ERROR**: " + str(e)
finish_completion = True
break
tool_name = toolcall_message.tool_calls[0]["function"]["name"]
history.append(toolcall_message)
tool_info["content"] = self.toolcall_sessions[tool_name].tool_call(name=tool_name, arguments=tool_arguments)
history.append(tool_info)
tool_info = {"content": "", "role": "tool"}
tool_name = ""
tool_arguments = ""
toolcall_message = None
response = Generation.call(self.model_name, messages=history, result_format="message", tools=self.tools, stream=True, incremental_output=incremental_output, **gen_conf)
else:
yield (
ans + "\n**ERROR**: " + resp.output.choices[0].message
if not re.search(r" (key|quota)", str(resp.message).lower())
else "Out of credit. Please set the API key in **settings > Model providers.**"
)
except Exception as e:
logging.exception(msg="_chat_streamly_with_tool")
yield ans + "\n**ERROR**: " + str(e)
yield tk_count
def _chat_streamly(self, system, history, gen_conf, incremental_output=True):
from http import HTTPStatus
if system:
history.insert(0, {"role": "system", "content": system})
if "max_tokens" in gen_conf:
del gen_conf["max_tokens"]
ans = ""
tk_count = 0
try:
response = Generation.call(self.model_name, messages=history, result_format="message", stream=True, incremental_output=incremental_output, **gen_conf)
for resp in response:
if resp.status_code == HTTPStatus.OK:
ans = resp.output.choices[0]["message"]["content"]
tk_count = self.total_token_count(resp)
if resp.output.choices[0].get("finish_reason", "") == "length":
if is_chinese(ans):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
yield ans
else:
yield (
ans + "\n**ERROR**: " + resp.message
if not re.search(r" (key|quota)", str(resp.message).lower())
else "Out of credit. Please set the API key in **settings > Model providers.**"
)
except Exception as e:
yield ans + "\n**ERROR**: " + str(e)
yield tk_count
def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict, incremental_output=True):
if "max_tokens" in gen_conf:
del gen_conf["max_tokens"]
for txt in self._chat_streamly_with_tools(system, history, gen_conf, incremental_output=incremental_output):
yield txt
def chat_streamly(self, system, history, gen_conf):
if "max_tokens" in gen_conf:
del gen_conf["max_tokens"]
if self.is_reasoning_model(self.model_name) or self.model_name in ["qwen-vl-plus", "qwen-vl-plus-latest", "qwen-vl-max", "qwen-vl-max-latest"]:
return super().chat_streamly(system, history, gen_conf)
return self._chat_streamly(system, history, gen_conf)
@staticmethod
def is_reasoning_model(model_name: str) -> bool:
return any(
[
model_name.lower().find("deepseek") >= 0,
model_name.lower().find("qwq") >= 0 and model_name.lower() != "qwq-32b-preview",
]
)
super().__init__(key, model_name, base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", **kwargs)
return
class ZhipuChat(Base):
@ -1877,4 +1649,4 @@ class GPUStackChat(Base):
if not base_url:
raise ValueError("Local llm url cannot be None")
base_url = urljoin(base_url, "v1")
super().__init__(key, model_name, base_url, **kwargs)
super().__init__(key, model_name, base_url, **kwargs)