mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Refa: limit embedding concurrency and fix chat_with_tool (#8543)
### What problem does this PR solve? #8538 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) - [x] Refactoring
This commit is contained in:
@ -18,11 +18,9 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import re
|
|
||||||
import time
|
import time
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from http import HTTPStatus
|
|
||||||
from typing import Any, Protocol
|
from typing import Any, Protocol
|
||||||
from urllib.parse import urljoin
|
from urllib.parse import urljoin
|
||||||
|
|
||||||
@ -61,9 +59,6 @@ class ToolCallSession(Protocol):
|
|||||||
|
|
||||||
|
|
||||||
class Base(ABC):
|
class Base(ABC):
|
||||||
tools: list[Any]
|
|
||||||
toolcall_sessions: dict[str, ToolCallSession]
|
|
||||||
|
|
||||||
def __init__(self, key, model_name, base_url, **kwargs):
|
def __init__(self, key, model_name, base_url, **kwargs):
|
||||||
timeout = int(os.environ.get("LM_TIMEOUT_SECONDS", 600))
|
timeout = int(os.environ.get("LM_TIMEOUT_SECONDS", 600))
|
||||||
self.client = OpenAI(api_key=key, base_url=base_url, timeout=timeout)
|
self.client = OpenAI(api_key=key, base_url=base_url, timeout=timeout)
|
||||||
@ -146,6 +141,37 @@ class Base(ABC):
|
|||||||
error_code = ERROR_MAX_RETRIES
|
error_code = ERROR_MAX_RETRIES
|
||||||
return f"{ERROR_PREFIX}: {error_code} - {str(e)}"
|
return f"{ERROR_PREFIX}: {error_code} - {str(e)}"
|
||||||
|
|
||||||
|
def _verbose_tool_use(self, name, args, res):
|
||||||
|
return "<tool_call>" + json.dumps({
|
||||||
|
"name": name,
|
||||||
|
"args": args,
|
||||||
|
"result": res
|
||||||
|
}, ensure_ascii=False, indent=2) + "</tool_call>"
|
||||||
|
|
||||||
|
def _append_history(self, hist, tool_call, tool_res):
|
||||||
|
hist.append(
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"index": tool_call.index,
|
||||||
|
"id": tool_call.id,
|
||||||
|
"function": {
|
||||||
|
"name": tool_call.function.name,
|
||||||
|
"arguments": tool_call.function.arguments,
|
||||||
|
},
|
||||||
|
"type": "function",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
if isinstance(tool_res, dict):
|
||||||
|
tool_res = json.dumps(tool_res, ensure_ascii=False)
|
||||||
|
finally:
|
||||||
|
hist.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_res)})
|
||||||
|
return hist
|
||||||
|
|
||||||
def bind_tools(self, toolcall_session, tools):
|
def bind_tools(self, toolcall_session, tools):
|
||||||
if not (toolcall_session and tools):
|
if not (toolcall_session and tools):
|
||||||
return
|
return
|
||||||
@ -160,18 +186,19 @@ class Base(ABC):
|
|||||||
if system:
|
if system:
|
||||||
history.insert(0, {"role": "system", "content": system})
|
history.insert(0, {"role": "system", "content": system})
|
||||||
|
|
||||||
|
gen_conf = self._clean_conf(gen_conf)
|
||||||
ans = ""
|
ans = ""
|
||||||
tk_count = 0
|
tk_count = 0
|
||||||
hist = deepcopy(history)
|
hist = deepcopy(history)
|
||||||
# Implement exponential backoff retry strategy
|
# Implement exponential backoff retry strategy
|
||||||
for attempt in range(self.max_retries+1):
|
for attempt in range(self.max_retries+1):
|
||||||
history = hist
|
history = hist
|
||||||
for _ in range(self.max_rounds * 2):
|
|
||||||
try:
|
try:
|
||||||
|
for _ in range(self.max_rounds*2):
|
||||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=self.tools, **gen_conf)
|
response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=self.tools, **gen_conf)
|
||||||
tk_count += self.total_token_count(response)
|
tk_count += self.total_token_count(response)
|
||||||
if any([not response.choices, not response.choices[0].message, not response.choices[0].message.content]):
|
if any([not response.choices, not response.choices[0].message]):
|
||||||
raise Exception("500 response structure error.")
|
raise Exception(f"500 response structure error. Response: {response}")
|
||||||
|
|
||||||
if not hasattr(response.choices[0].message, "tool_calls") or not response.choices[0].message.tool_calls:
|
if not hasattr(response.choices[0].message, "tool_calls") or not response.choices[0].message.tool_calls:
|
||||||
if hasattr(response.choices[0].message, "reasoning_content") and response.choices[0].message.reasoning_content:
|
if hasattr(response.choices[0].message, "reasoning_content") and response.choices[0].message.reasoning_content:
|
||||||
@ -188,9 +215,12 @@ class Base(ABC):
|
|||||||
try:
|
try:
|
||||||
args = json_repair.loads(tool_call.function.arguments)
|
args = json_repair.loads(tool_call.function.arguments)
|
||||||
tool_response = self.toolcall_sessions[name].tool_call(name, args)
|
tool_response = self.toolcall_sessions[name].tool_call(name, args)
|
||||||
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_response)})
|
history = self._append_history(history, tool_call, tool_response)
|
||||||
|
ans += self._verbose_tool_use(name, args, tool_response)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}")
|
||||||
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
|
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
|
||||||
|
ans += self._verbose_tool_use(name, {}, str(e))
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
e = self._exceptions(e, attempt)
|
e = self._exceptions(e, attempt)
|
||||||
@ -228,9 +258,7 @@ class Base(ABC):
|
|||||||
return final_tool_calls
|
return final_tool_calls
|
||||||
|
|
||||||
def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict):
|
def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict):
|
||||||
if "max_tokens" in gen_conf:
|
gen_conf = self._clean_conf(gen_conf)
|
||||||
del gen_conf["max_tokens"]
|
|
||||||
|
|
||||||
tools = self.tools
|
tools = self.tools
|
||||||
if system:
|
if system:
|
||||||
history.insert(0, {"role": "system", "content": system})
|
history.insert(0, {"role": "system", "content": system})
|
||||||
@ -240,9 +268,9 @@ class Base(ABC):
|
|||||||
# Implement exponential backoff retry strategy
|
# Implement exponential backoff retry strategy
|
||||||
for attempt in range(self.max_retries + 1):
|
for attempt in range(self.max_retries + 1):
|
||||||
history = hist
|
history = hist
|
||||||
|
try:
|
||||||
for _ in range(self.max_rounds*2):
|
for _ in range(self.max_rounds*2):
|
||||||
reasoning_start = False
|
reasoning_start = False
|
||||||
try:
|
|
||||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, **gen_conf)
|
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, **gen_conf)
|
||||||
final_tool_calls = {}
|
final_tool_calls = {}
|
||||||
answer = ""
|
answer = ""
|
||||||
@ -252,9 +280,11 @@ class Base(ABC):
|
|||||||
index = tool_call.index
|
index = tool_call.index
|
||||||
|
|
||||||
if index not in final_tool_calls:
|
if index not in final_tool_calls:
|
||||||
|
if not tool_call.function.arguments:
|
||||||
|
tool_call.function.arguments = ""
|
||||||
final_tool_calls[index] = tool_call
|
final_tool_calls[index] = tool_call
|
||||||
else:
|
else:
|
||||||
final_tool_calls[index].function.arguments += tool_call.function.arguments
|
final_tool_calls[index].function.arguments += tool_call.function.arguments if tool_call.function.arguments else ""
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if any([not resp.choices, not resp.choices[0].delta, not hasattr(resp.choices[0].delta, "content")]):
|
if any([not resp.choices, not resp.choices[0].delta, not hasattr(resp.choices[0].delta, "content")]):
|
||||||
@ -293,40 +323,26 @@ class Base(ABC):
|
|||||||
name = tool_call.function.name
|
name = tool_call.function.name
|
||||||
try:
|
try:
|
||||||
args = json_repair.loads(tool_call.function.arguments)
|
args = json_repair.loads(tool_call.function.arguments)
|
||||||
tool_response = self.toolcall_sessions[name].tool_call(name, args)
|
tool_response = self.toolcall_session[name].tool_call(name, args)
|
||||||
history.append(
|
history = self._append_history(history, tool_call, tool_response)
|
||||||
{
|
yield self._verbose_tool_use(name, args, tool_response)
|
||||||
"role": "assistant",
|
|
||||||
"tool_calls": [
|
|
||||||
{
|
|
||||||
"index": tool_call.index,
|
|
||||||
"id": tool_call.id,
|
|
||||||
"function": {
|
|
||||||
"name": tool_call.function.name,
|
|
||||||
"arguments": tool_call.function.arguments,
|
|
||||||
},
|
|
||||||
"type": "function",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_response)})
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}")
|
logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}")
|
||||||
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
|
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
|
||||||
|
yield self._verbose_tool_use(name, {}, str(e))
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
e = self._exceptions(e, attempt)
|
e = self._exceptions(e, attempt)
|
||||||
if e:
|
if e:
|
||||||
yield total_tokens
|
yield total_tokens
|
||||||
return
|
return
|
||||||
|
|
||||||
assert False, "Shouldn't be here."
|
yield total_tokens
|
||||||
|
|
||||||
def chat_streamly(self, system, history, gen_conf):
|
def chat_streamly(self, system, history, gen_conf):
|
||||||
if system:
|
if system:
|
||||||
history.insert(0, {"role": "system", "content": system})
|
history.insert(0, {"role": "system", "content": system})
|
||||||
if "max_tokens" in gen_conf:
|
gen_conf = self._clean_conf(gen_conf)
|
||||||
del gen_conf["max_tokens"]
|
|
||||||
ans = ""
|
ans = ""
|
||||||
total_tokens = 0
|
total_tokens = 0
|
||||||
reasoning_start = False
|
reasoning_start = False
|
||||||
@ -542,252 +558,8 @@ class BaiChuanChat(Base):
|
|||||||
|
|
||||||
class QWenChat(Base):
|
class QWenChat(Base):
|
||||||
def __init__(self, key, model_name=Generation.Models.qwen_turbo, base_url=None, **kwargs):
|
def __init__(self, key, model_name=Generation.Models.qwen_turbo, base_url=None, **kwargs):
|
||||||
super().__init__(key, model_name, base_url=base_url, **kwargs)
|
super().__init__(key, model_name, base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", **kwargs)
|
||||||
|
return
|
||||||
import dashscope
|
|
||||||
|
|
||||||
dashscope.api_key = key
|
|
||||||
self.model_name = model_name
|
|
||||||
if self.is_reasoning_model(self.model_name) or self.model_name in ["qwen-vl-plus", "qwen-vl-plus-latest", "qwen-vl-max", "qwen-vl-max-latest"]:
|
|
||||||
super().__init__(key, model_name, "https://dashscope.aliyuncs.com/compatible-mode/v1", **kwargs)
|
|
||||||
|
|
||||||
def chat_with_tools(self, system: str, history: list, gen_conf: dict) -> tuple[str, int]:
|
|
||||||
if "max_tokens" in gen_conf:
|
|
||||||
del gen_conf["max_tokens"]
|
|
||||||
# if self.is_reasoning_model(self.model_name):
|
|
||||||
# return super().chat(system, history, gen_conf)
|
|
||||||
|
|
||||||
stream_flag = str(os.environ.get("QWEN_CHAT_BY_STREAM", "true")).lower() == "true"
|
|
||||||
if not stream_flag:
|
|
||||||
from http import HTTPStatus
|
|
||||||
|
|
||||||
tools = self.tools
|
|
||||||
|
|
||||||
if system:
|
|
||||||
history.insert(0, {"role": "system", "content": system})
|
|
||||||
|
|
||||||
response = Generation.call(self.model_name, messages=history, result_format="message", tools=tools, **gen_conf)
|
|
||||||
ans = ""
|
|
||||||
tk_count = 0
|
|
||||||
if response.status_code == HTTPStatus.OK:
|
|
||||||
assistant_output = response.output.choices[0].message
|
|
||||||
if not ans and "tool_calls" not in assistant_output and "reasoning_content" in assistant_output:
|
|
||||||
ans += "<think>" + ans + "</think>"
|
|
||||||
ans += response.output.choices[0].message.content
|
|
||||||
|
|
||||||
if "tool_calls" not in assistant_output:
|
|
||||||
tk_count += self.total_token_count(response)
|
|
||||||
if response.output.choices[0].get("finish_reason", "") == "length":
|
|
||||||
if is_chinese([ans]):
|
|
||||||
ans += LENGTH_NOTIFICATION_CN
|
|
||||||
else:
|
|
||||||
ans += LENGTH_NOTIFICATION_EN
|
|
||||||
return ans, tk_count
|
|
||||||
|
|
||||||
tk_count += self.total_token_count(response)
|
|
||||||
history.append(assistant_output)
|
|
||||||
|
|
||||||
while "tool_calls" in assistant_output:
|
|
||||||
tool_info = {"content": "", "role": "tool", "tool_call_id": assistant_output.tool_calls[0]["id"]}
|
|
||||||
tool_name = assistant_output.tool_calls[0]["function"]["name"]
|
|
||||||
if tool_name:
|
|
||||||
arguments = json.loads(assistant_output.tool_calls[0]["function"]["arguments"])
|
|
||||||
tool_info["content"] = self.toolcall_sessions[tool_name].tool_call(name=tool_name, arguments=arguments)
|
|
||||||
history.append(tool_info)
|
|
||||||
|
|
||||||
response = Generation.call(self.model_name, messages=history, result_format="message", tools=self.tools, **gen_conf)
|
|
||||||
if response.output.choices[0].get("finish_reason", "") == "length":
|
|
||||||
tk_count += self.total_token_count(response)
|
|
||||||
if is_chinese([ans]):
|
|
||||||
ans += LENGTH_NOTIFICATION_CN
|
|
||||||
else:
|
|
||||||
ans += LENGTH_NOTIFICATION_EN
|
|
||||||
return ans, tk_count
|
|
||||||
|
|
||||||
tk_count += self.total_token_count(response)
|
|
||||||
assistant_output = response.output.choices[0].message
|
|
||||||
if assistant_output.content is None:
|
|
||||||
assistant_output.content = ""
|
|
||||||
history.append(response)
|
|
||||||
ans += assistant_output["content"]
|
|
||||||
return ans, tk_count
|
|
||||||
else:
|
|
||||||
return "**ERROR**: " + response.message, tk_count
|
|
||||||
else:
|
|
||||||
result_list = []
|
|
||||||
for result in self._chat_streamly_with_tools(system, history, gen_conf, incremental_output=True):
|
|
||||||
result_list.append(result)
|
|
||||||
error_msg_list = [result for result in result_list if str(result).find("**ERROR**") >= 0]
|
|
||||||
if len(error_msg_list) > 0:
|
|
||||||
return "**ERROR**: " + "".join(error_msg_list), 0
|
|
||||||
else:
|
|
||||||
return "".join(result_list[:-1]), result_list[-1]
|
|
||||||
|
|
||||||
def _chat(self, history, gen_conf):
|
|
||||||
if self.is_reasoning_model(self.model_name) or self.model_name in ["qwen-vl-plus", "qwen-vl-plus-latest", "qwen-vl-max", "qwen-vl-max-latest"]:
|
|
||||||
return super()._chat(history, gen_conf)
|
|
||||||
response = Generation.call(self.model_name, messages=history, result_format="message", **gen_conf)
|
|
||||||
ans = ""
|
|
||||||
tk_count = 0
|
|
||||||
if response.status_code == HTTPStatus.OK:
|
|
||||||
ans += response.output.choices[0]["message"]["content"]
|
|
||||||
tk_count += self.total_token_count(response)
|
|
||||||
if response.output.choices[0].get("finish_reason", "") == "length":
|
|
||||||
if is_chinese([ans]):
|
|
||||||
ans += LENGTH_NOTIFICATION_CN
|
|
||||||
else:
|
|
||||||
ans += LENGTH_NOTIFICATION_EN
|
|
||||||
return ans, tk_count
|
|
||||||
return "**ERROR**: " + response.message, tk_count
|
|
||||||
|
|
||||||
def _wrap_toolcall_message(self, old_message, message):
|
|
||||||
if not old_message:
|
|
||||||
return message
|
|
||||||
tool_call_id = message["tool_calls"][0].get("id")
|
|
||||||
if tool_call_id:
|
|
||||||
old_message.tool_calls[0]["id"] = tool_call_id
|
|
||||||
function = message.tool_calls[0]["function"]
|
|
||||||
if function:
|
|
||||||
if function.get("name"):
|
|
||||||
old_message.tool_calls[0]["function"]["name"] = function["name"]
|
|
||||||
if function.get("arguments"):
|
|
||||||
old_message.tool_calls[0]["function"]["arguments"] += function["arguments"]
|
|
||||||
return old_message
|
|
||||||
|
|
||||||
def _chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict, incremental_output=True):
|
|
||||||
from http import HTTPStatus
|
|
||||||
|
|
||||||
if system:
|
|
||||||
history.insert(0, {"role": "system", "content": system})
|
|
||||||
if "max_tokens" in gen_conf:
|
|
||||||
del gen_conf["max_tokens"]
|
|
||||||
ans = ""
|
|
||||||
tk_count = 0
|
|
||||||
try:
|
|
||||||
response = Generation.call(self.model_name, messages=history, result_format="message", tools=self.tools, stream=True, incremental_output=incremental_output, **gen_conf)
|
|
||||||
tool_info = {"content": "", "role": "tool"}
|
|
||||||
toolcall_message = None
|
|
||||||
tool_name = ""
|
|
||||||
tool_arguments = ""
|
|
||||||
finish_completion = False
|
|
||||||
reasoning_start = False
|
|
||||||
while not finish_completion:
|
|
||||||
for resp in response:
|
|
||||||
if resp.status_code == HTTPStatus.OK:
|
|
||||||
assistant_output = resp.output.choices[0].message
|
|
||||||
ans = resp.output.choices[0].message.content
|
|
||||||
if not ans and "tool_calls" not in assistant_output and "reasoning_content" in assistant_output:
|
|
||||||
ans = resp.output.choices[0].message.reasoning_content
|
|
||||||
if not reasoning_start:
|
|
||||||
reasoning_start = True
|
|
||||||
ans = "<think>" + ans
|
|
||||||
else:
|
|
||||||
ans = ans + "</think>"
|
|
||||||
|
|
||||||
if "tool_calls" not in assistant_output:
|
|
||||||
reasoning_start = False
|
|
||||||
tk_count += self.total_token_count(resp)
|
|
||||||
if resp.output.choices[0].get("finish_reason", "") == "length":
|
|
||||||
if is_chinese([ans]):
|
|
||||||
ans += LENGTH_NOTIFICATION_CN
|
|
||||||
else:
|
|
||||||
ans += LENGTH_NOTIFICATION_EN
|
|
||||||
finish_reason = resp.output.choices[0]["finish_reason"]
|
|
||||||
if finish_reason == "stop":
|
|
||||||
finish_completion = True
|
|
||||||
yield ans
|
|
||||||
break
|
|
||||||
yield ans
|
|
||||||
continue
|
|
||||||
|
|
||||||
tk_count += self.total_token_count(resp)
|
|
||||||
toolcall_message = self._wrap_toolcall_message(toolcall_message, assistant_output)
|
|
||||||
if "tool_calls" in assistant_output:
|
|
||||||
tool_call_finish_reason = resp.output.choices[0]["finish_reason"]
|
|
||||||
if tool_call_finish_reason == "tool_calls":
|
|
||||||
try:
|
|
||||||
tool_arguments = json.loads(toolcall_message.tool_calls[0]["function"]["arguments"])
|
|
||||||
except Exception as e:
|
|
||||||
logging.exception(msg="_chat_streamly_with_tool tool call error")
|
|
||||||
yield ans + "\n**ERROR**: " + str(e)
|
|
||||||
finish_completion = True
|
|
||||||
break
|
|
||||||
|
|
||||||
tool_name = toolcall_message.tool_calls[0]["function"]["name"]
|
|
||||||
history.append(toolcall_message)
|
|
||||||
tool_info["content"] = self.toolcall_sessions[tool_name].tool_call(name=tool_name, arguments=tool_arguments)
|
|
||||||
history.append(tool_info)
|
|
||||||
tool_info = {"content": "", "role": "tool"}
|
|
||||||
tool_name = ""
|
|
||||||
tool_arguments = ""
|
|
||||||
toolcall_message = None
|
|
||||||
response = Generation.call(self.model_name, messages=history, result_format="message", tools=self.tools, stream=True, incremental_output=incremental_output, **gen_conf)
|
|
||||||
else:
|
|
||||||
yield (
|
|
||||||
ans + "\n**ERROR**: " + resp.output.choices[0].message
|
|
||||||
if not re.search(r" (key|quota)", str(resp.message).lower())
|
|
||||||
else "Out of credit. Please set the API key in **settings > Model providers.**"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logging.exception(msg="_chat_streamly_with_tool")
|
|
||||||
yield ans + "\n**ERROR**: " + str(e)
|
|
||||||
yield tk_count
|
|
||||||
|
|
||||||
def _chat_streamly(self, system, history, gen_conf, incremental_output=True):
|
|
||||||
from http import HTTPStatus
|
|
||||||
|
|
||||||
if system:
|
|
||||||
history.insert(0, {"role": "system", "content": system})
|
|
||||||
if "max_tokens" in gen_conf:
|
|
||||||
del gen_conf["max_tokens"]
|
|
||||||
ans = ""
|
|
||||||
tk_count = 0
|
|
||||||
try:
|
|
||||||
response = Generation.call(self.model_name, messages=history, result_format="message", stream=True, incremental_output=incremental_output, **gen_conf)
|
|
||||||
for resp in response:
|
|
||||||
if resp.status_code == HTTPStatus.OK:
|
|
||||||
ans = resp.output.choices[0]["message"]["content"]
|
|
||||||
tk_count = self.total_token_count(resp)
|
|
||||||
if resp.output.choices[0].get("finish_reason", "") == "length":
|
|
||||||
if is_chinese(ans):
|
|
||||||
ans += LENGTH_NOTIFICATION_CN
|
|
||||||
else:
|
|
||||||
ans += LENGTH_NOTIFICATION_EN
|
|
||||||
yield ans
|
|
||||||
else:
|
|
||||||
yield (
|
|
||||||
ans + "\n**ERROR**: " + resp.message
|
|
||||||
if not re.search(r" (key|quota)", str(resp.message).lower())
|
|
||||||
else "Out of credit. Please set the API key in **settings > Model providers.**"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
yield ans + "\n**ERROR**: " + str(e)
|
|
||||||
|
|
||||||
yield tk_count
|
|
||||||
|
|
||||||
def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict, incremental_output=True):
|
|
||||||
if "max_tokens" in gen_conf:
|
|
||||||
del gen_conf["max_tokens"]
|
|
||||||
|
|
||||||
for txt in self._chat_streamly_with_tools(system, history, gen_conf, incremental_output=incremental_output):
|
|
||||||
yield txt
|
|
||||||
|
|
||||||
def chat_streamly(self, system, history, gen_conf):
|
|
||||||
if "max_tokens" in gen_conf:
|
|
||||||
del gen_conf["max_tokens"]
|
|
||||||
if self.is_reasoning_model(self.model_name) or self.model_name in ["qwen-vl-plus", "qwen-vl-plus-latest", "qwen-vl-max", "qwen-vl-max-latest"]:
|
|
||||||
return super().chat_streamly(system, history, gen_conf)
|
|
||||||
|
|
||||||
return self._chat_streamly(system, history, gen_conf)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def is_reasoning_model(model_name: str) -> bool:
|
|
||||||
return any(
|
|
||||||
[
|
|
||||||
model_name.lower().find("deepseek") >= 0,
|
|
||||||
model_name.lower().find("qwq") >= 0 and model_name.lower() != "qwq-32b-preview",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ZhipuChat(Base):
|
class ZhipuChat(Base):
|
||||||
|
|||||||
Reference in New Issue
Block a user