Feat: Redesign and refactor agent module (#9113)

### What problem does this PR solve?

#9082 #6365

<u> **WARNING: it's not compatible with the older version of `Agent`
module, which means that `Agent` from older versions can not work
anymore.**</u>

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Kevin Hu
2025-07-30 19:41:09 +08:00
committed by GitHub
parent 07e37560fc
commit d9fe279dde
124 changed files with 7744 additions and 18226 deletions

View File

@ -18,6 +18,7 @@ import json
import logging
import os
import random
import re
import time
from abc import ABC
from copy import deepcopy
@ -31,25 +32,33 @@ from dashscope import Generation
from ollama import Client
from openai import OpenAI
from openai.lib.azure import AzureOpenAI
from strenum import StrEnum
from zhipuai import ZhipuAI
from rag.nlp import is_chinese, is_english
from rag.utils import num_tokens_from_string
# Error message constants
ERROR_PREFIX = "**ERROR**"
ERROR_RATE_LIMIT = "RATE_LIMIT_EXCEEDED"
ERROR_AUTHENTICATION = "AUTH_ERROR"
ERROR_INVALID_REQUEST = "INVALID_REQUEST"
ERROR_SERVER = "SERVER_ERROR"
ERROR_TIMEOUT = "TIMEOUT"
ERROR_CONNECTION = "CONNECTION_ERROR"
ERROR_MODEL = "MODEL_ERROR"
ERROR_CONTENT_FILTER = "CONTENT_FILTERED"
ERROR_QUOTA = "QUOTA_EXCEEDED"
ERROR_MAX_RETRIES = "MAX_RETRIES_EXCEEDED"
ERROR_GENERIC = "GENERIC_ERROR"
class LLMErrorCode(StrEnum):
ERROR_RATE_LIMIT = "RATE_LIMIT_EXCEEDED"
ERROR_AUTHENTICATION = "AUTH_ERROR"
ERROR_INVALID_REQUEST = "INVALID_REQUEST"
ERROR_SERVER = "SERVER_ERROR"
ERROR_TIMEOUT = "TIMEOUT"
ERROR_CONNECTION = "CONNECTION_ERROR"
ERROR_MODEL = "MODEL_ERROR"
ERROR_MAX_ROUNDS = "ERROR_MAX_ROUNDS"
ERROR_CONTENT_FILTER = "CONTENT_FILTERED"
ERROR_QUOTA = "QUOTA_EXCEEDED"
ERROR_MAX_RETRIES = "MAX_RETRIES_EXCEEDED"
ERROR_GENERIC = "GENERIC_ERROR"
class ReActMode(StrEnum):
FUNCTION_CALL = "function_call"
REACT = "react"
ERROR_PREFIX = "**ERROR**"
LENGTH_NOTIFICATION_CN = "······\n由于大模型的上下文窗口大小限制,回答已经被大模型截断。"
LENGTH_NOTIFICATION_EN = "...\nThe answer is truncated by your chosen LLM due to its limitation on context length."
@ -73,51 +82,78 @@ class Base(ABC):
def _get_delay(self):
"""Calculate retry delay time"""
return self.base_delay + random.uniform(60, 150)
return self.base_delay * random.uniform(10, 150)
def _classify_error(self, error):
"""Classify error based on error message content"""
error_str = str(error).lower()
if "rate limit" in error_str or "429" in error_str or "tpm limit" in error_str or "too many requests" in error_str or "requests per minute" in error_str:
return ERROR_RATE_LIMIT
elif "auth" in error_str or "key" in error_str or "apikey" in error_str or "401" in error_str or "forbidden" in error_str or "permission" in error_str:
return ERROR_AUTHENTICATION
elif "invalid" in error_str or "bad request" in error_str or "400" in error_str or "format" in error_str or "malformed" in error_str or "parameter" in error_str:
return ERROR_INVALID_REQUEST
elif "server" in error_str or "502" in error_str or "503" in error_str or "504" in error_str or "500" in error_str or "unavailable" in error_str:
return ERROR_SERVER
elif "timeout" in error_str or "timed out" in error_str:
return ERROR_TIMEOUT
elif "connect" in error_str or "network" in error_str or "unreachable" in error_str or "dns" in error_str:
return ERROR_CONNECTION
elif "quota" in error_str or "capacity" in error_str or "credit" in error_str or "billing" in error_str or "limit" in error_str and "rate" not in error_str:
return ERROR_QUOTA
elif "filter" in error_str or "content" in error_str or "policy" in error_str or "blocked" in error_str or "safety" in error_str or "inappropriate" in error_str:
return ERROR_CONTENT_FILTER
elif "model" in error_str or "not found" in error_str or "does not exist" in error_str or "not available" in error_str:
return ERROR_MODEL
else:
return ERROR_GENERIC
keywords_mapping = [
(["quota", "capacity", "credit", "billing", "balance", "欠费"], LLMErrorCode.ERROR_QUOTA),
(["rate limit", "429", "tpm limit", "too many requests", "requests per minute"], LLMErrorCode.ERROR_RATE_LIMIT),
(["auth", "key", "apikey", "401", "forbidden", "permission"], LLMErrorCode.ERROR_AUTHENTICATION),
(["invalid", "bad request", "400", "format", "malformed", "parameter"], LLMErrorCode.ERROR_INVALID_REQUEST),
(["server", "503", "502", "504", "500", "unavailable"], LLMErrorCode.ERROR_SERVER),
(["timeout", "timed out"], LLMErrorCode.ERROR_TIMEOUT),
(["connect", "network", "unreachable", "dns"], LLMErrorCode.ERROR_CONNECTION),
(["filter", "content", "policy", "blocked", "safety", "inappropriate"], LLMErrorCode.ERROR_CONTENT_FILTER),
(["model", "not found", "does not exist", "not available"], LLMErrorCode.ERROR_MODEL),
(["max rounds"], LLMErrorCode.ERROR_MODEL),
]
for words, code in keywords_mapping:
if re.search("({})".format("|".join(words)), error_str):
return code
return LLMErrorCode.ERROR_GENERIC
def _clean_conf(self, gen_conf):
if "max_tokens" in gen_conf:
del gen_conf["max_tokens"]
return gen_conf
def _chat(self, history, gen_conf):
response = self.client.chat.completions.create(model=self.model_name, messages=history, **gen_conf)
def _chat(self, history, gen_conf, **kwargs):
logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2))
if self.model_name.lower().find("qwen3") >=0:
kwargs["extra_body"] = {"enable_thinking": False}
response = self.client.chat.completions.create(model=self.model_name, messages=history, **gen_conf, **kwargs)
if any([not response.choices, not response.choices[0].message, not response.choices[0].message.content]):
return "", 0
ans = response.choices[0].message.content.strip()
if response.choices[0].finish_reason == "length":
if is_chinese(ans):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
ans = self._length_stop(ans)
return ans, self.total_token_count(response)
def _chat_streamly(self, history, gen_conf, **kwargs):
logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4))
reasoning_start = False
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf, stop=kwargs.get("stop"))
for resp in response:
if not resp.choices:
continue
if not resp.choices[0].delta.content:
resp.choices[0].delta.content = ""
if kwargs.get("with_reasoning", True) and hasattr(resp.choices[0].delta, "reasoning_content") and resp.choices[0].delta.reasoning_content:
ans = ""
if not reasoning_start:
reasoning_start = True
ans = "<think>"
ans += resp.choices[0].delta.reasoning_content + "</think>"
else:
reasoning_start = False
ans = resp.choices[0].delta.content
tol = self.total_token_count(resp)
if not tol:
tol = num_tokens_from_string(resp.choices[0].delta.content)
if resp.choices[0].finish_reason == "length":
if is_chinese(ans):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
yield ans, tol
def _length_stop(self, ans):
if is_chinese([ans]):
return ans + LENGTH_NOTIFICATION_CN
@ -127,22 +163,24 @@ class Base(ABC):
logging.exception("OpenAI chat_with_tools")
# Classify the error
error_code = self._classify_error(e)
if attempt == self.max_retries:
error_code = LLMErrorCode.ERROR_MAX_RETRIES
# Check if it's a rate limit error or server error and not the last attempt
should_retry = (error_code == ERROR_RATE_LIMIT or error_code == ERROR_SERVER) and attempt < self.max_retries
if should_retry:
delay = self._get_delay()
logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})")
time.sleep(delay)
else:
# For non-rate limit errors or the last attempt, return an error message
if attempt == self.max_retries:
error_code = ERROR_MAX_RETRIES
should_retry = (error_code == LLMErrorCode.ERROR_RATE_LIMIT or error_code == LLMErrorCode.ERROR_SERVER)
if not should_retry:
return f"{ERROR_PREFIX}: {error_code} - {str(e)}"
delay = self._get_delay()
logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})")
time.sleep(delay)
def _verbose_tool_use(self, name, args, res):
return "<tool_call>" + json.dumps({"name": name, "args": args, "result": res}, ensure_ascii=False, indent=2) + "</tool_call>"
return "<tool_call>" + json.dumps({
"name": name,
"args": args,
"result": res
}, ensure_ascii=False, indent=2) + "</tool_call>"
def _append_history(self, hist, tool_call, tool_res):
hist.append(
@ -172,17 +210,14 @@ class Base(ABC):
if not (toolcall_session and tools):
return
self.is_tools = True
self.toolcall_session = toolcall_session
self.tools = tools
for tool in tools:
self.toolcall_sessions[tool["function"]["name"]] = toolcall_session
self.tools.append(tool)
def chat_with_tools(self, system: str, history: list, gen_conf: dict):
def chat_with_tools(self, system: str, history: list, gen_conf: dict={}):
gen_conf = self._clean_conf(gen_conf)
if system:
history.insert(0, {"role": "system", "content": system})
gen_conf = self._clean_conf(gen_conf)
ans = ""
tk_count = 0
hist = deepcopy(history)
@ -190,8 +225,9 @@ class Base(ABC):
for attempt in range(self.max_retries + 1):
history = hist
try:
for _ in range(self.max_rounds * 2):
response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=self.tools, **gen_conf)
for _ in range(self.max_rounds+1):
logging.info(f"{self.tools=}")
response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=self.tools, tool_choice="auto", **gen_conf)
tk_count += self.total_token_count(response)
if any([not response.choices, not response.choices[0].message]):
raise Exception(f"500 response structure error. Response: {response}")
@ -207,10 +243,11 @@ class Base(ABC):
return ans, tk_count
for tool_call in response.choices[0].message.tool_calls:
logging.info(f"Response {tool_call=}")
name = tool_call.function.name
try:
args = json_repair.loads(tool_call.function.arguments)
tool_response = self.toolcall_sessions[name].tool_call(name, args)
tool_response = self.toolcall_session.tool_call(name, args)
history = self._append_history(history, tool_call, tool_response)
ans += self._verbose_tool_use(name, args, tool_response)
except Exception as e:
@ -218,13 +255,20 @@ class Base(ABC):
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
ans += self._verbose_tool_use(name, {}, str(e))
logging.warning( f"Exceed max rounds: {self.max_rounds}")
history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
response, token_count = self._chat(history, gen_conf)
ans += response
tk_count += token_count
return ans, tk_count
except Exception as e:
e = self._exceptions(e, attempt)
if e:
return e, tk_count
assert False, "Shouldn't be here."
def chat(self, system, history, gen_conf):
def chat(self, system, history, gen_conf={}, **kwargs):
if system:
history.insert(0, {"role": "system", "content": system})
gen_conf = self._clean_conf(gen_conf)
@ -232,7 +276,7 @@ class Base(ABC):
# Implement exponential backoff retry strategy
for attempt in range(self.max_retries + 1):
try:
return self._chat(history, gen_conf)
return self._chat(history, gen_conf, **kwargs)
except Exception as e:
e = self._exceptions(e, attempt)
if e:
@ -253,7 +297,7 @@ class Base(ABC):
return final_tool_calls
def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict):
def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict={}):
gen_conf = self._clean_conf(gen_conf)
tools = self.tools
if system:
@ -265,9 +309,10 @@ class Base(ABC):
for attempt in range(self.max_retries + 1):
history = hist
try:
for _ in range(self.max_rounds * 2):
for _ in range(self.max_rounds+1):
reasoning_start = False
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, **gen_conf)
logging.info(f"{tools=}")
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, tool_choice="auto", **gen_conf)
final_tool_calls = {}
answer = ""
for resp in response:
@ -319,7 +364,8 @@ class Base(ABC):
name = tool_call.function.name
try:
args = json_repair.loads(tool_call.function.arguments)
tool_response = self.toolcall_session[name].tool_call(name, args)
yield self._verbose_tool_use(name, args, "Begin to call...")
tool_response = self.toolcall_session.tool_call(name, args)
history = self._append_history(history, tool_call, tool_response)
yield self._verbose_tool_use(name, args, tool_response)
except Exception as e:
@ -327,51 +373,45 @@ class Base(ABC):
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
yield self._verbose_tool_use(name, {}, str(e))
logging.warning( f"Exceed max rounds: {self.max_rounds}")
history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf)
for resp in response:
if any([not resp.choices, not resp.choices[0].delta, not hasattr(resp.choices[0].delta, "content")]):
raise Exception("500 response structure error.")
if not resp.choices[0].delta.content:
resp.choices[0].delta.content = ""
continue
tol = self.total_token_count(resp)
if not tol:
total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
else:
total_tokens += tol
answer += resp.choices[0].delta.content
yield resp.choices[0].delta.content
yield total_tokens
return
except Exception as e:
e = self._exceptions(e, attempt)
if e:
yield e
yield total_tokens
return
yield total_tokens
assert False, "Shouldn't be here."
def chat_streamly(self, system, history, gen_conf):
def chat_streamly(self, system, history, gen_conf: dict={}, **kwargs):
if system:
history.insert(0, {"role": "system", "content": system})
gen_conf = self._clean_conf(gen_conf)
ans = ""
total_tokens = 0
reasoning_start = False
try:
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf)
for resp in response:
if not resp.choices:
continue
if not resp.choices[0].delta.content:
resp.choices[0].delta.content = ""
if hasattr(resp.choices[0].delta, "reasoning_content") and resp.choices[0].delta.reasoning_content:
ans = ""
if not reasoning_start:
reasoning_start = True
ans = "<think>"
ans += resp.choices[0].delta.reasoning_content + "</think>"
else:
reasoning_start = False
ans = resp.choices[0].delta.content
tol = self.total_token_count(resp)
if not tol:
total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
else:
total_tokens += tol
if resp.choices[0].finish_reason == "length":
if is_chinese(ans):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
yield ans
for delta_ans, tol in self._chat_streamly(history, gen_conf, **kwargs):
yield delta_ans
total_tokens += tol
except openai.APIError as e:
yield ans + "\n**ERROR**: " + str(e)
@ -514,7 +554,7 @@ class BaiChuanChat(Base):
"top_p": gen_conf.get("top_p", 0.85),
}
def _chat(self, history, gen_conf):
def _chat(self, history, gen_conf={}, **kwargs):
response = self.client.chat.completions.create(
model=self.model_name,
messages=history,
@ -529,7 +569,7 @@ class BaiChuanChat(Base):
ans += LENGTH_NOTIFICATION_EN
return ans, self.total_token_count(response)
def chat_streamly(self, system, history, gen_conf):
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
if system:
history.insert(0, {"role": "system", "content": system})
if "max_tokens" in gen_conf:
@ -614,7 +654,7 @@ class ZhipuChat(Base):
return super().chat_with_tools(system, history, gen_conf)
def chat_streamly(self, system, history, gen_conf):
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
if system:
history.insert(0, {"role": "system", "content": system})
if "max_tokens" in gen_conf:
@ -626,6 +666,7 @@ class ZhipuChat(Base):
ans = ""
tk_count = 0
try:
logging.info(json.dumps(history, ensure_ascii=False, indent=2))
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf)
for resp in response:
if not resp.choices[0].delta.content:
@ -675,7 +716,7 @@ class OllamaChat(Base):
options[k] = gen_conf[k]
return options
def _chat(self, history, gen_conf):
def _chat(self, history, gen_conf={}, **kwargs):
# Calculate context size
ctx_size = self._calculate_dynamic_ctx(history)
@ -685,7 +726,7 @@ class OllamaChat(Base):
token_count = response.get("eval_count", 0) + response.get("prompt_eval_count", 0)
return ans, token_count
def chat_streamly(self, system, history, gen_conf):
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
if system:
history.insert(0, {"role": "system", "content": system})
if "max_tokens" in gen_conf:
@ -766,7 +807,7 @@ class LocalLLM(Base):
yield answer + "\n**ERROR**: " + str(e)
yield num_tokens_from_string(answer)
def chat(self, system, history, gen_conf):
def chat(self, system, history, gen_conf={}, **kwargs):
if "max_tokens" in gen_conf:
del gen_conf["max_tokens"]
prompt = self._prepare_prompt(system, history, gen_conf)
@ -775,7 +816,7 @@ class LocalLLM(Base):
total_tokens = next(chat_gen)
return ans, total_tokens
def chat_streamly(self, system, history, gen_conf):
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
if "max_tokens" in gen_conf:
del gen_conf["max_tokens"]
prompt = self._prepare_prompt(system, history, gen_conf)
@ -894,7 +935,7 @@ class MistralChat(Base):
del gen_conf[k]
return gen_conf
def _chat(self, history, gen_conf):
def _chat(self, history, gen_conf={}, **kwargs):
response = self.client.chat(model=self.model_name, messages=history, **gen_conf)
ans = response.choices[0].message.content
if response.choices[0].finish_reason == "length":
@ -904,7 +945,7 @@ class MistralChat(Base):
ans += LENGTH_NOTIFICATION_EN
return ans, self.total_token_count(response)
def chat_streamly(self, system, history, gen_conf):
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
if system:
history.insert(0, {"role": "system", "content": system})
for k in list(gen_conf.keys()):
@ -913,7 +954,7 @@ class MistralChat(Base):
ans = ""
total_tokens = 0
try:
response = self.client.chat_stream(model=self.model_name, messages=history, **gen_conf)
response = self.client.chat_stream(model=self.model_name, messages=history, **gen_conf, **kwargs)
for resp in response:
if not resp.choices or not resp.choices[0].delta.content:
continue
@ -957,7 +998,7 @@ class BedrockChat(Base):
del gen_conf[k]
return gen_conf
def _chat(self, history, gen_conf):
def _chat(self, history, gen_conf={}, **kwargs):
system = history[0]["content"] if history and history[0]["role"] == "system" else ""
hist = []
for item in history:
@ -978,7 +1019,7 @@ class BedrockChat(Base):
ans = response["output"]["message"]["content"][0]["text"]
return ans, num_tokens_from_string(ans)
def chat_streamly(self, system, history, gen_conf):
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
from botocore.exceptions import ClientError
for k in list(gen_conf.keys()):
@ -1036,7 +1077,7 @@ class GeminiChat(Base):
del gen_conf[k]
return gen_conf
def _chat(self, history, gen_conf):
def _chat(self, history, gen_conf={}, **kwargs):
from google.generativeai.types import content_types
system = history[0]["content"] if history and history[0]["role"] == "system" else ""
@ -1059,7 +1100,7 @@ class GeminiChat(Base):
ans = response.text
return ans, response.usage_metadata.total_token_count
def chat_streamly(self, system, history, gen_conf):
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
from google.generativeai.types import content_types
gen_conf = self._clean_conf(gen_conf)
@ -1101,7 +1142,7 @@ class GroqChat(Base):
del gen_conf[k]
return gen_conf
def chat_streamly(self, system, history, gen_conf):
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
if system:
history.insert(0, {"role": "system", "content": system})
for k in list(gen_conf.keys()):
@ -1229,7 +1270,7 @@ class CoHereChat(Base):
response.meta.tokens.input_tokens + response.meta.tokens.output_tokens,
)
def chat_streamly(self, system, history, gen_conf):
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
if system:
history.insert(0, {"role": "system", "content": system})
if "max_tokens" in gen_conf:
@ -1348,7 +1389,7 @@ class ReplicateChat(Base):
self.model_name = model_name
self.client = Client(api_token=key)
def _chat(self, history, gen_conf):
def _chat(self, history, gen_conf={}, **kwargs):
system = history[0]["content"] if history and history[0]["role"] == "system" else ""
prompt = "\n".join([item["role"] + ":" + item["content"] for item in history[-5:] if item["role"] != "system"])
response = self.client.run(
@ -1358,7 +1399,7 @@ class ReplicateChat(Base):
ans = "".join(response)
return ans, num_tokens_from_string(ans)
def chat_streamly(self, system, history, gen_conf):
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
if "max_tokens" in gen_conf:
del gen_conf["max_tokens"]
prompt = "\n".join([item["role"] + ":" + item["content"] for item in history[-5:]])
@ -1402,7 +1443,7 @@ class HunyuanChat(Base):
_gen_conf["TopP"] = gen_conf["top_p"]
return _gen_conf
def _chat(self, history, gen_conf):
def _chat(self, history, gen_conf={}, **kwargs):
from tencentcloud.hunyuan.v20230901 import models
hist = [{k.capitalize(): v for k, v in item.items()} for item in history]
@ -1413,7 +1454,7 @@ class HunyuanChat(Base):
ans = response.Choices[0].Message.Content
return ans, response.Usage.TotalTokens
def chat_streamly(self, system, history, gen_conf):
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
from tencentcloud.common.exception.tencent_cloud_sdk_exception import (
TencentCloudSDKException,
)
@ -1504,7 +1545,7 @@ class BaiduYiyanChat(Base):
ans = response["result"]
return ans, self.total_token_count(response)
def chat_streamly(self, system, history, gen_conf):
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
gen_conf["penalty_score"] = ((gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty", 0)) / 2) + 1
if "max_tokens" in gen_conf:
del gen_conf["max_tokens"]
@ -1588,7 +1629,7 @@ class GoogleChat(Base):
del gen_conf[k]
return gen_conf
def _chat(self, history, gen_conf):
def _chat(self, history, gen_conf={}, **kwargs):
system = history[0]["content"] if history and history[0]["role"] == "system" else ""
if "claude" in self.model_name:
response = self.client.messages.create(
@ -1626,7 +1667,7 @@ class GoogleChat(Base):
ans = response.text
return ans, response.usage_metadata.total_token_count
def chat_streamly(self, system, history, gen_conf):
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
if "claude" in self.model_name:
if "max_tokens" in gen_conf:
del gen_conf["max_tokens"]

File diff suppressed because it is too large Load Diff

View File

@ -89,7 +89,7 @@ class DefaultRerank(Base):
torch.cuda.empty_cache()
except Exception as e:
print(f"Error emptying cache: {e}")
log_exception(e)
def _process_batch(self, pairs, max_batch_size=None):
"""template method for subclass call"""