mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Feat: Redesign and refactor agent module (#9113)
### What problem does this PR solve? #9082 #6365 <u> **WARNING: it's not compatible with the older version of `Agent` module, which means that `Agent` from older versions can not work anymore.**</u> ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -18,6 +18,7 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
from abc import ABC
|
||||
from copy import deepcopy
|
||||
@ -31,25 +32,33 @@ from dashscope import Generation
|
||||
from ollama import Client
|
||||
from openai import OpenAI
|
||||
from openai.lib.azure import AzureOpenAI
|
||||
from strenum import StrEnum
|
||||
from zhipuai import ZhipuAI
|
||||
|
||||
from rag.nlp import is_chinese, is_english
|
||||
from rag.utils import num_tokens_from_string
|
||||
|
||||
# Error message constants
|
||||
ERROR_PREFIX = "**ERROR**"
|
||||
ERROR_RATE_LIMIT = "RATE_LIMIT_EXCEEDED"
|
||||
ERROR_AUTHENTICATION = "AUTH_ERROR"
|
||||
ERROR_INVALID_REQUEST = "INVALID_REQUEST"
|
||||
ERROR_SERVER = "SERVER_ERROR"
|
||||
ERROR_TIMEOUT = "TIMEOUT"
|
||||
ERROR_CONNECTION = "CONNECTION_ERROR"
|
||||
ERROR_MODEL = "MODEL_ERROR"
|
||||
ERROR_CONTENT_FILTER = "CONTENT_FILTERED"
|
||||
ERROR_QUOTA = "QUOTA_EXCEEDED"
|
||||
ERROR_MAX_RETRIES = "MAX_RETRIES_EXCEEDED"
|
||||
ERROR_GENERIC = "GENERIC_ERROR"
|
||||
class LLMErrorCode(StrEnum):
|
||||
ERROR_RATE_LIMIT = "RATE_LIMIT_EXCEEDED"
|
||||
ERROR_AUTHENTICATION = "AUTH_ERROR"
|
||||
ERROR_INVALID_REQUEST = "INVALID_REQUEST"
|
||||
ERROR_SERVER = "SERVER_ERROR"
|
||||
ERROR_TIMEOUT = "TIMEOUT"
|
||||
ERROR_CONNECTION = "CONNECTION_ERROR"
|
||||
ERROR_MODEL = "MODEL_ERROR"
|
||||
ERROR_MAX_ROUNDS = "ERROR_MAX_ROUNDS"
|
||||
ERROR_CONTENT_FILTER = "CONTENT_FILTERED"
|
||||
ERROR_QUOTA = "QUOTA_EXCEEDED"
|
||||
ERROR_MAX_RETRIES = "MAX_RETRIES_EXCEEDED"
|
||||
ERROR_GENERIC = "GENERIC_ERROR"
|
||||
|
||||
|
||||
class ReActMode(StrEnum):
|
||||
FUNCTION_CALL = "function_call"
|
||||
REACT = "react"
|
||||
|
||||
ERROR_PREFIX = "**ERROR**"
|
||||
LENGTH_NOTIFICATION_CN = "······\n由于大模型的上下文窗口大小限制,回答已经被大模型截断。"
|
||||
LENGTH_NOTIFICATION_EN = "...\nThe answer is truncated by your chosen LLM due to its limitation on context length."
|
||||
|
||||
@ -73,51 +82,78 @@ class Base(ABC):
|
||||
|
||||
def _get_delay(self):
|
||||
"""Calculate retry delay time"""
|
||||
return self.base_delay + random.uniform(60, 150)
|
||||
return self.base_delay * random.uniform(10, 150)
|
||||
|
||||
def _classify_error(self, error):
|
||||
"""Classify error based on error message content"""
|
||||
error_str = str(error).lower()
|
||||
|
||||
if "rate limit" in error_str or "429" in error_str or "tpm limit" in error_str or "too many requests" in error_str or "requests per minute" in error_str:
|
||||
return ERROR_RATE_LIMIT
|
||||
elif "auth" in error_str or "key" in error_str or "apikey" in error_str or "401" in error_str or "forbidden" in error_str or "permission" in error_str:
|
||||
return ERROR_AUTHENTICATION
|
||||
elif "invalid" in error_str or "bad request" in error_str or "400" in error_str or "format" in error_str or "malformed" in error_str or "parameter" in error_str:
|
||||
return ERROR_INVALID_REQUEST
|
||||
elif "server" in error_str or "502" in error_str or "503" in error_str or "504" in error_str or "500" in error_str or "unavailable" in error_str:
|
||||
return ERROR_SERVER
|
||||
elif "timeout" in error_str or "timed out" in error_str:
|
||||
return ERROR_TIMEOUT
|
||||
elif "connect" in error_str or "network" in error_str or "unreachable" in error_str or "dns" in error_str:
|
||||
return ERROR_CONNECTION
|
||||
elif "quota" in error_str or "capacity" in error_str or "credit" in error_str or "billing" in error_str or "limit" in error_str and "rate" not in error_str:
|
||||
return ERROR_QUOTA
|
||||
elif "filter" in error_str or "content" in error_str or "policy" in error_str or "blocked" in error_str or "safety" in error_str or "inappropriate" in error_str:
|
||||
return ERROR_CONTENT_FILTER
|
||||
elif "model" in error_str or "not found" in error_str or "does not exist" in error_str or "not available" in error_str:
|
||||
return ERROR_MODEL
|
||||
else:
|
||||
return ERROR_GENERIC
|
||||
keywords_mapping = [
|
||||
(["quota", "capacity", "credit", "billing", "balance", "欠费"], LLMErrorCode.ERROR_QUOTA),
|
||||
(["rate limit", "429", "tpm limit", "too many requests", "requests per minute"], LLMErrorCode.ERROR_RATE_LIMIT),
|
||||
(["auth", "key", "apikey", "401", "forbidden", "permission"], LLMErrorCode.ERROR_AUTHENTICATION),
|
||||
(["invalid", "bad request", "400", "format", "malformed", "parameter"], LLMErrorCode.ERROR_INVALID_REQUEST),
|
||||
(["server", "503", "502", "504", "500", "unavailable"], LLMErrorCode.ERROR_SERVER),
|
||||
(["timeout", "timed out"], LLMErrorCode.ERROR_TIMEOUT),
|
||||
(["connect", "network", "unreachable", "dns"], LLMErrorCode.ERROR_CONNECTION),
|
||||
(["filter", "content", "policy", "blocked", "safety", "inappropriate"], LLMErrorCode.ERROR_CONTENT_FILTER),
|
||||
(["model", "not found", "does not exist", "not available"], LLMErrorCode.ERROR_MODEL),
|
||||
(["max rounds"], LLMErrorCode.ERROR_MODEL),
|
||||
]
|
||||
for words, code in keywords_mapping:
|
||||
if re.search("({})".format("|".join(words)), error_str):
|
||||
return code
|
||||
|
||||
return LLMErrorCode.ERROR_GENERIC
|
||||
|
||||
def _clean_conf(self, gen_conf):
|
||||
if "max_tokens" in gen_conf:
|
||||
del gen_conf["max_tokens"]
|
||||
return gen_conf
|
||||
|
||||
def _chat(self, history, gen_conf):
|
||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, **gen_conf)
|
||||
def _chat(self, history, gen_conf, **kwargs):
|
||||
logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2))
|
||||
if self.model_name.lower().find("qwen3") >=0:
|
||||
kwargs["extra_body"] = {"enable_thinking": False}
|
||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, **gen_conf, **kwargs)
|
||||
|
||||
if any([not response.choices, not response.choices[0].message, not response.choices[0].message.content]):
|
||||
return "", 0
|
||||
ans = response.choices[0].message.content.strip()
|
||||
if response.choices[0].finish_reason == "length":
|
||||
if is_chinese(ans):
|
||||
ans += LENGTH_NOTIFICATION_CN
|
||||
else:
|
||||
ans += LENGTH_NOTIFICATION_EN
|
||||
ans = self._length_stop(ans)
|
||||
return ans, self.total_token_count(response)
|
||||
|
||||
def _chat_streamly(self, history, gen_conf, **kwargs):
|
||||
logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4))
|
||||
reasoning_start = False
|
||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf, stop=kwargs.get("stop"))
|
||||
for resp in response:
|
||||
if not resp.choices:
|
||||
continue
|
||||
if not resp.choices[0].delta.content:
|
||||
resp.choices[0].delta.content = ""
|
||||
if kwargs.get("with_reasoning", True) and hasattr(resp.choices[0].delta, "reasoning_content") and resp.choices[0].delta.reasoning_content:
|
||||
ans = ""
|
||||
if not reasoning_start:
|
||||
reasoning_start = True
|
||||
ans = "<think>"
|
||||
ans += resp.choices[0].delta.reasoning_content + "</think>"
|
||||
else:
|
||||
reasoning_start = False
|
||||
ans = resp.choices[0].delta.content
|
||||
|
||||
tol = self.total_token_count(resp)
|
||||
if not tol:
|
||||
tol = num_tokens_from_string(resp.choices[0].delta.content)
|
||||
|
||||
if resp.choices[0].finish_reason == "length":
|
||||
if is_chinese(ans):
|
||||
ans += LENGTH_NOTIFICATION_CN
|
||||
else:
|
||||
ans += LENGTH_NOTIFICATION_EN
|
||||
yield ans, tol
|
||||
|
||||
def _length_stop(self, ans):
|
||||
if is_chinese([ans]):
|
||||
return ans + LENGTH_NOTIFICATION_CN
|
||||
@ -127,22 +163,24 @@ class Base(ABC):
|
||||
logging.exception("OpenAI chat_with_tools")
|
||||
# Classify the error
|
||||
error_code = self._classify_error(e)
|
||||
if attempt == self.max_retries:
|
||||
error_code = LLMErrorCode.ERROR_MAX_RETRIES
|
||||
|
||||
# Check if it's a rate limit error or server error and not the last attempt
|
||||
should_retry = (error_code == ERROR_RATE_LIMIT or error_code == ERROR_SERVER) and attempt < self.max_retries
|
||||
|
||||
if should_retry:
|
||||
delay = self._get_delay()
|
||||
logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})")
|
||||
time.sleep(delay)
|
||||
else:
|
||||
# For non-rate limit errors or the last attempt, return an error message
|
||||
if attempt == self.max_retries:
|
||||
error_code = ERROR_MAX_RETRIES
|
||||
should_retry = (error_code == LLMErrorCode.ERROR_RATE_LIMIT or error_code == LLMErrorCode.ERROR_SERVER)
|
||||
if not should_retry:
|
||||
return f"{ERROR_PREFIX}: {error_code} - {str(e)}"
|
||||
|
||||
delay = self._get_delay()
|
||||
logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})")
|
||||
time.sleep(delay)
|
||||
|
||||
def _verbose_tool_use(self, name, args, res):
|
||||
return "<tool_call>" + json.dumps({"name": name, "args": args, "result": res}, ensure_ascii=False, indent=2) + "</tool_call>"
|
||||
return "<tool_call>" + json.dumps({
|
||||
"name": name,
|
||||
"args": args,
|
||||
"result": res
|
||||
}, ensure_ascii=False, indent=2) + "</tool_call>"
|
||||
|
||||
def _append_history(self, hist, tool_call, tool_res):
|
||||
hist.append(
|
||||
@ -172,17 +210,14 @@ class Base(ABC):
|
||||
if not (toolcall_session and tools):
|
||||
return
|
||||
self.is_tools = True
|
||||
self.toolcall_session = toolcall_session
|
||||
self.tools = tools
|
||||
|
||||
for tool in tools:
|
||||
self.toolcall_sessions[tool["function"]["name"]] = toolcall_session
|
||||
self.tools.append(tool)
|
||||
|
||||
def chat_with_tools(self, system: str, history: list, gen_conf: dict):
|
||||
def chat_with_tools(self, system: str, history: list, gen_conf: dict={}):
|
||||
gen_conf = self._clean_conf(gen_conf)
|
||||
if system:
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
|
||||
gen_conf = self._clean_conf(gen_conf)
|
||||
ans = ""
|
||||
tk_count = 0
|
||||
hist = deepcopy(history)
|
||||
@ -190,8 +225,9 @@ class Base(ABC):
|
||||
for attempt in range(self.max_retries + 1):
|
||||
history = hist
|
||||
try:
|
||||
for _ in range(self.max_rounds * 2):
|
||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=self.tools, **gen_conf)
|
||||
for _ in range(self.max_rounds+1):
|
||||
logging.info(f"{self.tools=}")
|
||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=self.tools, tool_choice="auto", **gen_conf)
|
||||
tk_count += self.total_token_count(response)
|
||||
if any([not response.choices, not response.choices[0].message]):
|
||||
raise Exception(f"500 response structure error. Response: {response}")
|
||||
@ -207,10 +243,11 @@ class Base(ABC):
|
||||
return ans, tk_count
|
||||
|
||||
for tool_call in response.choices[0].message.tool_calls:
|
||||
logging.info(f"Response {tool_call=}")
|
||||
name = tool_call.function.name
|
||||
try:
|
||||
args = json_repair.loads(tool_call.function.arguments)
|
||||
tool_response = self.toolcall_sessions[name].tool_call(name, args)
|
||||
tool_response = self.toolcall_session.tool_call(name, args)
|
||||
history = self._append_history(history, tool_call, tool_response)
|
||||
ans += self._verbose_tool_use(name, args, tool_response)
|
||||
except Exception as e:
|
||||
@ -218,13 +255,20 @@ class Base(ABC):
|
||||
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
|
||||
ans += self._verbose_tool_use(name, {}, str(e))
|
||||
|
||||
logging.warning( f"Exceed max rounds: {self.max_rounds}")
|
||||
history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
|
||||
response, token_count = self._chat(history, gen_conf)
|
||||
ans += response
|
||||
tk_count += token_count
|
||||
return ans, tk_count
|
||||
except Exception as e:
|
||||
e = self._exceptions(e, attempt)
|
||||
if e:
|
||||
return e, tk_count
|
||||
|
||||
assert False, "Shouldn't be here."
|
||||
|
||||
def chat(self, system, history, gen_conf):
|
||||
def chat(self, system, history, gen_conf={}, **kwargs):
|
||||
if system:
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
gen_conf = self._clean_conf(gen_conf)
|
||||
@ -232,7 +276,7 @@ class Base(ABC):
|
||||
# Implement exponential backoff retry strategy
|
||||
for attempt in range(self.max_retries + 1):
|
||||
try:
|
||||
return self._chat(history, gen_conf)
|
||||
return self._chat(history, gen_conf, **kwargs)
|
||||
except Exception as e:
|
||||
e = self._exceptions(e, attempt)
|
||||
if e:
|
||||
@ -253,7 +297,7 @@ class Base(ABC):
|
||||
|
||||
return final_tool_calls
|
||||
|
||||
def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict):
|
||||
def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict={}):
|
||||
gen_conf = self._clean_conf(gen_conf)
|
||||
tools = self.tools
|
||||
if system:
|
||||
@ -265,9 +309,10 @@ class Base(ABC):
|
||||
for attempt in range(self.max_retries + 1):
|
||||
history = hist
|
||||
try:
|
||||
for _ in range(self.max_rounds * 2):
|
||||
for _ in range(self.max_rounds+1):
|
||||
reasoning_start = False
|
||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, **gen_conf)
|
||||
logging.info(f"{tools=}")
|
||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, tool_choice="auto", **gen_conf)
|
||||
final_tool_calls = {}
|
||||
answer = ""
|
||||
for resp in response:
|
||||
@ -319,7 +364,8 @@ class Base(ABC):
|
||||
name = tool_call.function.name
|
||||
try:
|
||||
args = json_repair.loads(tool_call.function.arguments)
|
||||
tool_response = self.toolcall_session[name].tool_call(name, args)
|
||||
yield self._verbose_tool_use(name, args, "Begin to call...")
|
||||
tool_response = self.toolcall_session.tool_call(name, args)
|
||||
history = self._append_history(history, tool_call, tool_response)
|
||||
yield self._verbose_tool_use(name, args, tool_response)
|
||||
except Exception as e:
|
||||
@ -327,51 +373,45 @@ class Base(ABC):
|
||||
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
|
||||
yield self._verbose_tool_use(name, {}, str(e))
|
||||
|
||||
logging.warning( f"Exceed max rounds: {self.max_rounds}")
|
||||
history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
|
||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf)
|
||||
for resp in response:
|
||||
if any([not resp.choices, not resp.choices[0].delta, not hasattr(resp.choices[0].delta, "content")]):
|
||||
raise Exception("500 response structure error.")
|
||||
if not resp.choices[0].delta.content:
|
||||
resp.choices[0].delta.content = ""
|
||||
continue
|
||||
tol = self.total_token_count(resp)
|
||||
if not tol:
|
||||
total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
|
||||
else:
|
||||
total_tokens += tol
|
||||
answer += resp.choices[0].delta.content
|
||||
yield resp.choices[0].delta.content
|
||||
|
||||
yield total_tokens
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
e = self._exceptions(e, attempt)
|
||||
if e:
|
||||
yield e
|
||||
yield total_tokens
|
||||
return
|
||||
|
||||
yield total_tokens
|
||||
assert False, "Shouldn't be here."
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf):
|
||||
def chat_streamly(self, system, history, gen_conf: dict={}, **kwargs):
|
||||
if system:
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
gen_conf = self._clean_conf(gen_conf)
|
||||
ans = ""
|
||||
total_tokens = 0
|
||||
reasoning_start = False
|
||||
try:
|
||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf)
|
||||
for resp in response:
|
||||
if not resp.choices:
|
||||
continue
|
||||
if not resp.choices[0].delta.content:
|
||||
resp.choices[0].delta.content = ""
|
||||
if hasattr(resp.choices[0].delta, "reasoning_content") and resp.choices[0].delta.reasoning_content:
|
||||
ans = ""
|
||||
if not reasoning_start:
|
||||
reasoning_start = True
|
||||
ans = "<think>"
|
||||
ans += resp.choices[0].delta.reasoning_content + "</think>"
|
||||
else:
|
||||
reasoning_start = False
|
||||
ans = resp.choices[0].delta.content
|
||||
|
||||
tol = self.total_token_count(resp)
|
||||
if not tol:
|
||||
total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
|
||||
else:
|
||||
total_tokens += tol
|
||||
|
||||
if resp.choices[0].finish_reason == "length":
|
||||
if is_chinese(ans):
|
||||
ans += LENGTH_NOTIFICATION_CN
|
||||
else:
|
||||
ans += LENGTH_NOTIFICATION_EN
|
||||
yield ans
|
||||
|
||||
for delta_ans, tol in self._chat_streamly(history, gen_conf, **kwargs):
|
||||
yield delta_ans
|
||||
total_tokens += tol
|
||||
except openai.APIError as e:
|
||||
yield ans + "\n**ERROR**: " + str(e)
|
||||
|
||||
@ -514,7 +554,7 @@ class BaiChuanChat(Base):
|
||||
"top_p": gen_conf.get("top_p", 0.85),
|
||||
}
|
||||
|
||||
def _chat(self, history, gen_conf):
|
||||
def _chat(self, history, gen_conf={}, **kwargs):
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=history,
|
||||
@ -529,7 +569,7 @@ class BaiChuanChat(Base):
|
||||
ans += LENGTH_NOTIFICATION_EN
|
||||
return ans, self.total_token_count(response)
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf):
|
||||
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
|
||||
if system:
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
if "max_tokens" in gen_conf:
|
||||
@ -614,7 +654,7 @@ class ZhipuChat(Base):
|
||||
|
||||
return super().chat_with_tools(system, history, gen_conf)
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf):
|
||||
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
|
||||
if system:
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
if "max_tokens" in gen_conf:
|
||||
@ -626,6 +666,7 @@ class ZhipuChat(Base):
|
||||
ans = ""
|
||||
tk_count = 0
|
||||
try:
|
||||
logging.info(json.dumps(history, ensure_ascii=False, indent=2))
|
||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf)
|
||||
for resp in response:
|
||||
if not resp.choices[0].delta.content:
|
||||
@ -675,7 +716,7 @@ class OllamaChat(Base):
|
||||
options[k] = gen_conf[k]
|
||||
return options
|
||||
|
||||
def _chat(self, history, gen_conf):
|
||||
def _chat(self, history, gen_conf={}, **kwargs):
|
||||
# Calculate context size
|
||||
ctx_size = self._calculate_dynamic_ctx(history)
|
||||
|
||||
@ -685,7 +726,7 @@ class OllamaChat(Base):
|
||||
token_count = response.get("eval_count", 0) + response.get("prompt_eval_count", 0)
|
||||
return ans, token_count
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf):
|
||||
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
|
||||
if system:
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
if "max_tokens" in gen_conf:
|
||||
@ -766,7 +807,7 @@ class LocalLLM(Base):
|
||||
yield answer + "\n**ERROR**: " + str(e)
|
||||
yield num_tokens_from_string(answer)
|
||||
|
||||
def chat(self, system, history, gen_conf):
|
||||
def chat(self, system, history, gen_conf={}, **kwargs):
|
||||
if "max_tokens" in gen_conf:
|
||||
del gen_conf["max_tokens"]
|
||||
prompt = self._prepare_prompt(system, history, gen_conf)
|
||||
@ -775,7 +816,7 @@ class LocalLLM(Base):
|
||||
total_tokens = next(chat_gen)
|
||||
return ans, total_tokens
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf):
|
||||
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
|
||||
if "max_tokens" in gen_conf:
|
||||
del gen_conf["max_tokens"]
|
||||
prompt = self._prepare_prompt(system, history, gen_conf)
|
||||
@ -894,7 +935,7 @@ class MistralChat(Base):
|
||||
del gen_conf[k]
|
||||
return gen_conf
|
||||
|
||||
def _chat(self, history, gen_conf):
|
||||
def _chat(self, history, gen_conf={}, **kwargs):
|
||||
response = self.client.chat(model=self.model_name, messages=history, **gen_conf)
|
||||
ans = response.choices[0].message.content
|
||||
if response.choices[0].finish_reason == "length":
|
||||
@ -904,7 +945,7 @@ class MistralChat(Base):
|
||||
ans += LENGTH_NOTIFICATION_EN
|
||||
return ans, self.total_token_count(response)
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf):
|
||||
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
|
||||
if system:
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
for k in list(gen_conf.keys()):
|
||||
@ -913,7 +954,7 @@ class MistralChat(Base):
|
||||
ans = ""
|
||||
total_tokens = 0
|
||||
try:
|
||||
response = self.client.chat_stream(model=self.model_name, messages=history, **gen_conf)
|
||||
response = self.client.chat_stream(model=self.model_name, messages=history, **gen_conf, **kwargs)
|
||||
for resp in response:
|
||||
if not resp.choices or not resp.choices[0].delta.content:
|
||||
continue
|
||||
@ -957,7 +998,7 @@ class BedrockChat(Base):
|
||||
del gen_conf[k]
|
||||
return gen_conf
|
||||
|
||||
def _chat(self, history, gen_conf):
|
||||
def _chat(self, history, gen_conf={}, **kwargs):
|
||||
system = history[0]["content"] if history and history[0]["role"] == "system" else ""
|
||||
hist = []
|
||||
for item in history:
|
||||
@ -978,7 +1019,7 @@ class BedrockChat(Base):
|
||||
ans = response["output"]["message"]["content"][0]["text"]
|
||||
return ans, num_tokens_from_string(ans)
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf):
|
||||
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
for k in list(gen_conf.keys()):
|
||||
@ -1036,7 +1077,7 @@ class GeminiChat(Base):
|
||||
del gen_conf[k]
|
||||
return gen_conf
|
||||
|
||||
def _chat(self, history, gen_conf):
|
||||
def _chat(self, history, gen_conf={}, **kwargs):
|
||||
from google.generativeai.types import content_types
|
||||
|
||||
system = history[0]["content"] if history and history[0]["role"] == "system" else ""
|
||||
@ -1059,7 +1100,7 @@ class GeminiChat(Base):
|
||||
ans = response.text
|
||||
return ans, response.usage_metadata.total_token_count
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf):
|
||||
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
|
||||
from google.generativeai.types import content_types
|
||||
|
||||
gen_conf = self._clean_conf(gen_conf)
|
||||
@ -1101,7 +1142,7 @@ class GroqChat(Base):
|
||||
del gen_conf[k]
|
||||
return gen_conf
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf):
|
||||
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
|
||||
if system:
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
for k in list(gen_conf.keys()):
|
||||
@ -1229,7 +1270,7 @@ class CoHereChat(Base):
|
||||
response.meta.tokens.input_tokens + response.meta.tokens.output_tokens,
|
||||
)
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf):
|
||||
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
|
||||
if system:
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
if "max_tokens" in gen_conf:
|
||||
@ -1348,7 +1389,7 @@ class ReplicateChat(Base):
|
||||
self.model_name = model_name
|
||||
self.client = Client(api_token=key)
|
||||
|
||||
def _chat(self, history, gen_conf):
|
||||
def _chat(self, history, gen_conf={}, **kwargs):
|
||||
system = history[0]["content"] if history and history[0]["role"] == "system" else ""
|
||||
prompt = "\n".join([item["role"] + ":" + item["content"] for item in history[-5:] if item["role"] != "system"])
|
||||
response = self.client.run(
|
||||
@ -1358,7 +1399,7 @@ class ReplicateChat(Base):
|
||||
ans = "".join(response)
|
||||
return ans, num_tokens_from_string(ans)
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf):
|
||||
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
|
||||
if "max_tokens" in gen_conf:
|
||||
del gen_conf["max_tokens"]
|
||||
prompt = "\n".join([item["role"] + ":" + item["content"] for item in history[-5:]])
|
||||
@ -1402,7 +1443,7 @@ class HunyuanChat(Base):
|
||||
_gen_conf["TopP"] = gen_conf["top_p"]
|
||||
return _gen_conf
|
||||
|
||||
def _chat(self, history, gen_conf):
|
||||
def _chat(self, history, gen_conf={}, **kwargs):
|
||||
from tencentcloud.hunyuan.v20230901 import models
|
||||
|
||||
hist = [{k.capitalize(): v for k, v in item.items()} for item in history]
|
||||
@ -1413,7 +1454,7 @@ class HunyuanChat(Base):
|
||||
ans = response.Choices[0].Message.Content
|
||||
return ans, response.Usage.TotalTokens
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf):
|
||||
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
|
||||
from tencentcloud.common.exception.tencent_cloud_sdk_exception import (
|
||||
TencentCloudSDKException,
|
||||
)
|
||||
@ -1504,7 +1545,7 @@ class BaiduYiyanChat(Base):
|
||||
ans = response["result"]
|
||||
return ans, self.total_token_count(response)
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf):
|
||||
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
|
||||
gen_conf["penalty_score"] = ((gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty", 0)) / 2) + 1
|
||||
if "max_tokens" in gen_conf:
|
||||
del gen_conf["max_tokens"]
|
||||
@ -1588,7 +1629,7 @@ class GoogleChat(Base):
|
||||
del gen_conf[k]
|
||||
return gen_conf
|
||||
|
||||
def _chat(self, history, gen_conf):
|
||||
def _chat(self, history, gen_conf={}, **kwargs):
|
||||
system = history[0]["content"] if history and history[0]["role"] == "system" else ""
|
||||
if "claude" in self.model_name:
|
||||
response = self.client.messages.create(
|
||||
@ -1626,7 +1667,7 @@ class GoogleChat(Base):
|
||||
ans = response.text
|
||||
return ans, response.usage_metadata.total_token_count
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf):
|
||||
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
|
||||
if "claude" in self.model_name:
|
||||
if "max_tokens" in gen_conf:
|
||||
del gen_conf["max_tokens"]
|
||||
|
||||
1153
rag/llm/cv_model.py
1153
rag/llm/cv_model.py
File diff suppressed because it is too large
Load Diff
@ -89,7 +89,7 @@ class DefaultRerank(Base):
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
except Exception as e:
|
||||
print(f"Error emptying cache: {e}")
|
||||
log_exception(e)
|
||||
|
||||
def _process_batch(self, pairs, max_batch_size=None):
|
||||
"""template method for subclass call"""
|
||||
|
||||
Reference in New Issue
Block a user