Refa: limit embedding concurrency and fix chat_with_tool (#8543)

### What problem does this PR solve?

#8538

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
- [x] Refactoring
This commit is contained in:
Kevin Hu
2025-06-27 19:28:41 +08:00
committed by GitHub
parent 8e1f8a0c48
commit e441c17c2c
2 changed files with 75 additions and 303 deletions

View File

@ -18,11 +18,9 @@ import json
import logging import logging
import os import os
import random import random
import re
import time import time
from abc import ABC from abc import ABC
from copy import deepcopy from copy import deepcopy
from http import HTTPStatus
from typing import Any, Protocol from typing import Any, Protocol
from urllib.parse import urljoin from urllib.parse import urljoin
@ -61,9 +59,6 @@ class ToolCallSession(Protocol):
class Base(ABC): class Base(ABC):
tools: list[Any]
toolcall_sessions: dict[str, ToolCallSession]
def __init__(self, key, model_name, base_url, **kwargs): def __init__(self, key, model_name, base_url, **kwargs):
timeout = int(os.environ.get("LM_TIMEOUT_SECONDS", 600)) timeout = int(os.environ.get("LM_TIMEOUT_SECONDS", 600))
self.client = OpenAI(api_key=key, base_url=base_url, timeout=timeout) self.client = OpenAI(api_key=key, base_url=base_url, timeout=timeout)
@ -146,6 +141,37 @@ class Base(ABC):
error_code = ERROR_MAX_RETRIES error_code = ERROR_MAX_RETRIES
return f"{ERROR_PREFIX}: {error_code} - {str(e)}" return f"{ERROR_PREFIX}: {error_code} - {str(e)}"
def _verbose_tool_use(self, name, args, res):
return "<tool_call>" + json.dumps({
"name": name,
"args": args,
"result": res
}, ensure_ascii=False, indent=2) + "</tool_call>"
def _append_history(self, hist, tool_call, tool_res):
hist.append(
{
"role": "assistant",
"tool_calls": [
{
"index": tool_call.index,
"id": tool_call.id,
"function": {
"name": tool_call.function.name,
"arguments": tool_call.function.arguments,
},
"type": "function",
},
],
}
)
try:
if isinstance(tool_res, dict):
tool_res = json.dumps(tool_res, ensure_ascii=False)
finally:
hist.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_res)})
return hist
def bind_tools(self, toolcall_session, tools): def bind_tools(self, toolcall_session, tools):
if not (toolcall_session and tools): if not (toolcall_session and tools):
return return
@ -160,18 +186,19 @@ class Base(ABC):
if system: if system:
history.insert(0, {"role": "system", "content": system}) history.insert(0, {"role": "system", "content": system})
gen_conf = self._clean_conf(gen_conf)
ans = "" ans = ""
tk_count = 0 tk_count = 0
hist = deepcopy(history) hist = deepcopy(history)
# Implement exponential backoff retry strategy # Implement exponential backoff retry strategy
for attempt in range(self.max_retries+1): for attempt in range(self.max_retries+1):
history = hist history = hist
for _ in range(self.max_rounds * 2):
try: try:
for _ in range(self.max_rounds*2):
response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=self.tools, **gen_conf) response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=self.tools, **gen_conf)
tk_count += self.total_token_count(response) tk_count += self.total_token_count(response)
if any([not response.choices, not response.choices[0].message, not response.choices[0].message.content]): if any([not response.choices, not response.choices[0].message]):
raise Exception("500 response structure error.") raise Exception(f"500 response structure error. Response: {response}")
if not hasattr(response.choices[0].message, "tool_calls") or not response.choices[0].message.tool_calls: if not hasattr(response.choices[0].message, "tool_calls") or not response.choices[0].message.tool_calls:
if hasattr(response.choices[0].message, "reasoning_content") and response.choices[0].message.reasoning_content: if hasattr(response.choices[0].message, "reasoning_content") and response.choices[0].message.reasoning_content:
@ -188,9 +215,12 @@ class Base(ABC):
try: try:
args = json_repair.loads(tool_call.function.arguments) args = json_repair.loads(tool_call.function.arguments)
tool_response = self.toolcall_sessions[name].tool_call(name, args) tool_response = self.toolcall_sessions[name].tool_call(name, args)
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_response)}) history = self._append_history(history, tool_call, tool_response)
ans += self._verbose_tool_use(name, args, tool_response)
except Exception as e: except Exception as e:
logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}")
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)}) history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
ans += self._verbose_tool_use(name, {}, str(e))
except Exception as e: except Exception as e:
e = self._exceptions(e, attempt) e = self._exceptions(e, attempt)
@ -228,9 +258,7 @@ class Base(ABC):
return final_tool_calls return final_tool_calls
def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict): def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict):
if "max_tokens" in gen_conf: gen_conf = self._clean_conf(gen_conf)
del gen_conf["max_tokens"]
tools = self.tools tools = self.tools
if system: if system:
history.insert(0, {"role": "system", "content": system}) history.insert(0, {"role": "system", "content": system})
@ -240,9 +268,9 @@ class Base(ABC):
# Implement exponential backoff retry strategy # Implement exponential backoff retry strategy
for attempt in range(self.max_retries + 1): for attempt in range(self.max_retries + 1):
history = hist history = hist
try:
for _ in range(self.max_rounds*2): for _ in range(self.max_rounds*2):
reasoning_start = False reasoning_start = False
try:
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, **gen_conf) response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, **gen_conf)
final_tool_calls = {} final_tool_calls = {}
answer = "" answer = ""
@ -252,9 +280,11 @@ class Base(ABC):
index = tool_call.index index = tool_call.index
if index not in final_tool_calls: if index not in final_tool_calls:
if not tool_call.function.arguments:
tool_call.function.arguments = ""
final_tool_calls[index] = tool_call final_tool_calls[index] = tool_call
else: else:
final_tool_calls[index].function.arguments += tool_call.function.arguments final_tool_calls[index].function.arguments += tool_call.function.arguments if tool_call.function.arguments else ""
continue continue
if any([not resp.choices, not resp.choices[0].delta, not hasattr(resp.choices[0].delta, "content")]): if any([not resp.choices, not resp.choices[0].delta, not hasattr(resp.choices[0].delta, "content")]):
@ -293,40 +323,26 @@ class Base(ABC):
name = tool_call.function.name name = tool_call.function.name
try: try:
args = json_repair.loads(tool_call.function.arguments) args = json_repair.loads(tool_call.function.arguments)
tool_response = self.toolcall_sessions[name].tool_call(name, args) tool_response = self.toolcall_session[name].tool_call(name, args)
history.append( history = self._append_history(history, tool_call, tool_response)
{ yield self._verbose_tool_use(name, args, tool_response)
"role": "assistant",
"tool_calls": [
{
"index": tool_call.index,
"id": tool_call.id,
"function": {
"name": tool_call.function.name,
"arguments": tool_call.function.arguments,
},
"type": "function",
},
],
}
)
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_response)})
except Exception as e: except Exception as e:
logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}") logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}")
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)}) history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
yield self._verbose_tool_use(name, {}, str(e))
except Exception as e: except Exception as e:
e = self._exceptions(e, attempt) e = self._exceptions(e, attempt)
if e: if e:
yield total_tokens yield total_tokens
return return
assert False, "Shouldn't be here." yield total_tokens
def chat_streamly(self, system, history, gen_conf): def chat_streamly(self, system, history, gen_conf):
if system: if system:
history.insert(0, {"role": "system", "content": system}) history.insert(0, {"role": "system", "content": system})
if "max_tokens" in gen_conf: gen_conf = self._clean_conf(gen_conf)
del gen_conf["max_tokens"]
ans = "" ans = ""
total_tokens = 0 total_tokens = 0
reasoning_start = False reasoning_start = False
@ -542,252 +558,8 @@ class BaiChuanChat(Base):
class QWenChat(Base): class QWenChat(Base):
def __init__(self, key, model_name=Generation.Models.qwen_turbo, base_url=None, **kwargs): def __init__(self, key, model_name=Generation.Models.qwen_turbo, base_url=None, **kwargs):
super().__init__(key, model_name, base_url=base_url, **kwargs) super().__init__(key, model_name, base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", **kwargs)
return
import dashscope
dashscope.api_key = key
self.model_name = model_name
if self.is_reasoning_model(self.model_name) or self.model_name in ["qwen-vl-plus", "qwen-vl-plus-latest", "qwen-vl-max", "qwen-vl-max-latest"]:
super().__init__(key, model_name, "https://dashscope.aliyuncs.com/compatible-mode/v1", **kwargs)
def chat_with_tools(self, system: str, history: list, gen_conf: dict) -> tuple[str, int]:
if "max_tokens" in gen_conf:
del gen_conf["max_tokens"]
# if self.is_reasoning_model(self.model_name):
# return super().chat(system, history, gen_conf)
stream_flag = str(os.environ.get("QWEN_CHAT_BY_STREAM", "true")).lower() == "true"
if not stream_flag:
from http import HTTPStatus
tools = self.tools
if system:
history.insert(0, {"role": "system", "content": system})
response = Generation.call(self.model_name, messages=history, result_format="message", tools=tools, **gen_conf)
ans = ""
tk_count = 0
if response.status_code == HTTPStatus.OK:
assistant_output = response.output.choices[0].message
if not ans and "tool_calls" not in assistant_output and "reasoning_content" in assistant_output:
ans += "<think>" + ans + "</think>"
ans += response.output.choices[0].message.content
if "tool_calls" not in assistant_output:
tk_count += self.total_token_count(response)
if response.output.choices[0].get("finish_reason", "") == "length":
if is_chinese([ans]):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
return ans, tk_count
tk_count += self.total_token_count(response)
history.append(assistant_output)
while "tool_calls" in assistant_output:
tool_info = {"content": "", "role": "tool", "tool_call_id": assistant_output.tool_calls[0]["id"]}
tool_name = assistant_output.tool_calls[0]["function"]["name"]
if tool_name:
arguments = json.loads(assistant_output.tool_calls[0]["function"]["arguments"])
tool_info["content"] = self.toolcall_sessions[tool_name].tool_call(name=tool_name, arguments=arguments)
history.append(tool_info)
response = Generation.call(self.model_name, messages=history, result_format="message", tools=self.tools, **gen_conf)
if response.output.choices[0].get("finish_reason", "") == "length":
tk_count += self.total_token_count(response)
if is_chinese([ans]):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
return ans, tk_count
tk_count += self.total_token_count(response)
assistant_output = response.output.choices[0].message
if assistant_output.content is None:
assistant_output.content = ""
history.append(response)
ans += assistant_output["content"]
return ans, tk_count
else:
return "**ERROR**: " + response.message, tk_count
else:
result_list = []
for result in self._chat_streamly_with_tools(system, history, gen_conf, incremental_output=True):
result_list.append(result)
error_msg_list = [result for result in result_list if str(result).find("**ERROR**") >= 0]
if len(error_msg_list) > 0:
return "**ERROR**: " + "".join(error_msg_list), 0
else:
return "".join(result_list[:-1]), result_list[-1]
def _chat(self, history, gen_conf):
if self.is_reasoning_model(self.model_name) or self.model_name in ["qwen-vl-plus", "qwen-vl-plus-latest", "qwen-vl-max", "qwen-vl-max-latest"]:
return super()._chat(history, gen_conf)
response = Generation.call(self.model_name, messages=history, result_format="message", **gen_conf)
ans = ""
tk_count = 0
if response.status_code == HTTPStatus.OK:
ans += response.output.choices[0]["message"]["content"]
tk_count += self.total_token_count(response)
if response.output.choices[0].get("finish_reason", "") == "length":
if is_chinese([ans]):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
return ans, tk_count
return "**ERROR**: " + response.message, tk_count
def _wrap_toolcall_message(self, old_message, message):
if not old_message:
return message
tool_call_id = message["tool_calls"][0].get("id")
if tool_call_id:
old_message.tool_calls[0]["id"] = tool_call_id
function = message.tool_calls[0]["function"]
if function:
if function.get("name"):
old_message.tool_calls[0]["function"]["name"] = function["name"]
if function.get("arguments"):
old_message.tool_calls[0]["function"]["arguments"] += function["arguments"]
return old_message
def _chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict, incremental_output=True):
from http import HTTPStatus
if system:
history.insert(0, {"role": "system", "content": system})
if "max_tokens" in gen_conf:
del gen_conf["max_tokens"]
ans = ""
tk_count = 0
try:
response = Generation.call(self.model_name, messages=history, result_format="message", tools=self.tools, stream=True, incremental_output=incremental_output, **gen_conf)
tool_info = {"content": "", "role": "tool"}
toolcall_message = None
tool_name = ""
tool_arguments = ""
finish_completion = False
reasoning_start = False
while not finish_completion:
for resp in response:
if resp.status_code == HTTPStatus.OK:
assistant_output = resp.output.choices[0].message
ans = resp.output.choices[0].message.content
if not ans and "tool_calls" not in assistant_output and "reasoning_content" in assistant_output:
ans = resp.output.choices[0].message.reasoning_content
if not reasoning_start:
reasoning_start = True
ans = "<think>" + ans
else:
ans = ans + "</think>"
if "tool_calls" not in assistant_output:
reasoning_start = False
tk_count += self.total_token_count(resp)
if resp.output.choices[0].get("finish_reason", "") == "length":
if is_chinese([ans]):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
finish_reason = resp.output.choices[0]["finish_reason"]
if finish_reason == "stop":
finish_completion = True
yield ans
break
yield ans
continue
tk_count += self.total_token_count(resp)
toolcall_message = self._wrap_toolcall_message(toolcall_message, assistant_output)
if "tool_calls" in assistant_output:
tool_call_finish_reason = resp.output.choices[0]["finish_reason"]
if tool_call_finish_reason == "tool_calls":
try:
tool_arguments = json.loads(toolcall_message.tool_calls[0]["function"]["arguments"])
except Exception as e:
logging.exception(msg="_chat_streamly_with_tool tool call error")
yield ans + "\n**ERROR**: " + str(e)
finish_completion = True
break
tool_name = toolcall_message.tool_calls[0]["function"]["name"]
history.append(toolcall_message)
tool_info["content"] = self.toolcall_sessions[tool_name].tool_call(name=tool_name, arguments=tool_arguments)
history.append(tool_info)
tool_info = {"content": "", "role": "tool"}
tool_name = ""
tool_arguments = ""
toolcall_message = None
response = Generation.call(self.model_name, messages=history, result_format="message", tools=self.tools, stream=True, incremental_output=incremental_output, **gen_conf)
else:
yield (
ans + "\n**ERROR**: " + resp.output.choices[0].message
if not re.search(r" (key|quota)", str(resp.message).lower())
else "Out of credit. Please set the API key in **settings > Model providers.**"
)
except Exception as e:
logging.exception(msg="_chat_streamly_with_tool")
yield ans + "\n**ERROR**: " + str(e)
yield tk_count
def _chat_streamly(self, system, history, gen_conf, incremental_output=True):
from http import HTTPStatus
if system:
history.insert(0, {"role": "system", "content": system})
if "max_tokens" in gen_conf:
del gen_conf["max_tokens"]
ans = ""
tk_count = 0
try:
response = Generation.call(self.model_name, messages=history, result_format="message", stream=True, incremental_output=incremental_output, **gen_conf)
for resp in response:
if resp.status_code == HTTPStatus.OK:
ans = resp.output.choices[0]["message"]["content"]
tk_count = self.total_token_count(resp)
if resp.output.choices[0].get("finish_reason", "") == "length":
if is_chinese(ans):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
yield ans
else:
yield (
ans + "\n**ERROR**: " + resp.message
if not re.search(r" (key|quota)", str(resp.message).lower())
else "Out of credit. Please set the API key in **settings > Model providers.**"
)
except Exception as e:
yield ans + "\n**ERROR**: " + str(e)
yield tk_count
def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict, incremental_output=True):
if "max_tokens" in gen_conf:
del gen_conf["max_tokens"]
for txt in self._chat_streamly_with_tools(system, history, gen_conf, incremental_output=incremental_output):
yield txt
def chat_streamly(self, system, history, gen_conf):
if "max_tokens" in gen_conf:
del gen_conf["max_tokens"]
if self.is_reasoning_model(self.model_name) or self.model_name in ["qwen-vl-plus", "qwen-vl-plus-latest", "qwen-vl-max", "qwen-vl-max-latest"]:
return super().chat_streamly(system, history, gen_conf)
return self._chat_streamly(system, history, gen_conf)
@staticmethod
def is_reasoning_model(model_name: str) -> bool:
return any(
[
model_name.lower().find("deepseek") >= 0,
model_name.lower().find("qwq") >= 0 and model_name.lower() != "qwq-32b-preview",
]
)
class ZhipuChat(Base): class ZhipuChat(Base):