mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 12:32:30 +08:00
Refa: chat with tools. (#8210)
### What problem does this PR solve? ### Type of change - [x] Refactoring
This commit is contained in:
@ -26,6 +26,7 @@ from http import HTTPStatus
|
||||
from typing import Any, Protocol
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import json_repair
|
||||
import openai
|
||||
import requests
|
||||
from dashscope import Generation
|
||||
@ -67,11 +68,12 @@ class Base(ABC):
|
||||
# Configure retry parameters
|
||||
self.max_retries = kwargs.get("max_retries", int(os.environ.get("LLM_MAX_RETRIES", 5)))
|
||||
self.base_delay = kwargs.get("retry_interval", float(os.environ.get("LLM_BASE_DELAY", 2.0)))
|
||||
self.max_rounds = kwargs.get("max_rounds", 5)
|
||||
self.is_tools = False
|
||||
|
||||
def _get_delay(self, attempt):
|
||||
def _get_delay(self):
|
||||
"""Calculate retry delay time"""
|
||||
return self.base_delay * (2**attempt) + random.uniform(0, 0.5)
|
||||
return self.base_delay + random.uniform(0, 0.5)
|
||||
|
||||
def _classify_error(self, error):
|
||||
"""Classify error based on error message content"""
|
||||
@ -116,6 +118,29 @@ class Base(ABC):
|
||||
ans += LENGTH_NOTIFICATION_EN
|
||||
return ans, self.total_token_count(response)
|
||||
|
||||
def _length_stop(self, ans):
|
||||
if is_chinese([ans]):
|
||||
return ans + LENGTH_NOTIFICATION_CN
|
||||
return ans + LENGTH_NOTIFICATION_EN
|
||||
|
||||
def _exceptions(self, e, attempt):
|
||||
logging.exception("OpenAI cat_with_tools")
|
||||
# Classify the error
|
||||
error_code = self._classify_error(e)
|
||||
|
||||
# Check if it's a rate limit error or server error and not the last attempt
|
||||
should_retry = (error_code == ERROR_RATE_LIMIT or error_code == ERROR_SERVER) and attempt < self.max_retries
|
||||
|
||||
if should_retry:
|
||||
delay = self._get_delay()
|
||||
logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})")
|
||||
time.sleep(delay)
|
||||
else:
|
||||
# For non-rate limit errors or the last attempt, return an error message
|
||||
if attempt == self.max_retries:
|
||||
error_code = ERROR_MAX_RETRIES
|
||||
return f"{ERROR_PREFIX}: {error_code} - {str(e)}"
|
||||
|
||||
def bind_tools(self, toolcall_session, tools):
|
||||
if not (toolcall_session and tools):
|
||||
return
|
||||
@ -124,76 +149,48 @@ class Base(ABC):
|
||||
self.tools = tools
|
||||
|
||||
def chat_with_tools(self, system: str, history: list, gen_conf: dict):
|
||||
if "max_tokens" in gen_conf:
|
||||
del gen_conf["max_tokens"]
|
||||
|
||||
tools = self.tools
|
||||
|
||||
gen_conf = self._clean_conf()
|
||||
if system:
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
|
||||
ans = ""
|
||||
tk_count = 0
|
||||
hist = deepcopy(history)
|
||||
# Implement exponential backoff retry strategy
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=tools, **gen_conf)
|
||||
|
||||
assistant_output = response.choices[0].message
|
||||
if not ans and "tool_calls" not in assistant_output and "reasoning_content" in assistant_output:
|
||||
ans += "<think>" + ans + "</think>"
|
||||
ans += response.choices[0].message.content
|
||||
|
||||
if not response.choices[0].message.tool_calls:
|
||||
for attempt in range(self.max_retries+1):
|
||||
history = hist
|
||||
for _ in range(self.max_rounds*2):
|
||||
try:
|
||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=self.tools, **gen_conf)
|
||||
tk_count += self.total_token_count(response)
|
||||
if response.choices[0].finish_reason == "length":
|
||||
if is_chinese([ans]):
|
||||
ans += LENGTH_NOTIFICATION_CN
|
||||
else:
|
||||
ans += LENGTH_NOTIFICATION_EN
|
||||
return ans, tk_count
|
||||
if any([not response.choices, not response.choices[0].message, not response.choices[0].message.content]):
|
||||
raise Exception("500 response structure error.")
|
||||
|
||||
tk_count += self.total_token_count(response)
|
||||
history.append(assistant_output)
|
||||
if not hasattr(response.choices[0].message, "tool_calls") or not response.choices[0].message.tool_calls:
|
||||
if hasattr(response.choices[0].message, "reasoning_content") and response.choices[0].message.reasoning_content:
|
||||
ans += "<think>" + response.choices[0].message.reasoning_content + "</think>"
|
||||
|
||||
for tool_call in response.choices[0].message.tool_calls:
|
||||
name = tool_call.function.name
|
||||
args = json.loads(tool_call.function.arguments)
|
||||
ans += response.choices[0].message.content
|
||||
if response.choices[0].finish_reason == "length":
|
||||
ans = self._length_stop(ans)
|
||||
|
||||
tool_response = self.toolcall_session.tool_call(name, args)
|
||||
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_response)})
|
||||
return ans, tk_count
|
||||
|
||||
final_response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=tools, **gen_conf)
|
||||
assistant_output = final_response.choices[0].message
|
||||
if "tool_calls" not in assistant_output and "reasoning_content" in assistant_output:
|
||||
ans += "<think>" + ans + "</think>"
|
||||
ans += final_response.choices[0].message.content
|
||||
if final_response.choices[0].finish_reason == "length":
|
||||
tk_count += self.total_token_count(response)
|
||||
if is_chinese([ans]):
|
||||
ans += LENGTH_NOTIFICATION_CN
|
||||
else:
|
||||
ans += LENGTH_NOTIFICATION_EN
|
||||
return ans, tk_count
|
||||
return ans, tk_count
|
||||
for tool_call in response.choices[0].message.tool_calls:
|
||||
name = tool_call.function.name
|
||||
try:
|
||||
args = json_repair.loads(tool_call.function.arguments)
|
||||
tool_response = self.toolcall_session.tool_call(name, args)
|
||||
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_response)})
|
||||
except Exception as e:
|
||||
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
|
||||
|
||||
except Exception as e:
|
||||
logging.exception("OpenAI cat_with_tools")
|
||||
# Classify the error
|
||||
error_code = self._classify_error(e)
|
||||
|
||||
# Check if it's a rate limit error or server error and not the last attempt
|
||||
should_retry = (error_code == ERROR_RATE_LIMIT or error_code == ERROR_SERVER) and attempt < self.max_retries - 1
|
||||
|
||||
if should_retry:
|
||||
delay = self._get_delay(attempt)
|
||||
logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})")
|
||||
time.sleep(delay)
|
||||
else:
|
||||
# For non-rate limit errors or the last attempt, return an error message
|
||||
if attempt == self.max_retries - 1:
|
||||
error_code = ERROR_MAX_RETRIES
|
||||
return f"{ERROR_PREFIX}: {error_code} - {str(e)}", 0
|
||||
except Exception as e:
|
||||
e = self._exceptions(e, attempt)
|
||||
if e:
|
||||
return e, tk_count
|
||||
assert False, "Shouldn't be here."
|
||||
|
||||
def chat(self, system, history, gen_conf):
|
||||
if system:
|
||||
@ -201,26 +198,14 @@ class Base(ABC):
|
||||
gen_conf = self._clean_conf(gen_conf)
|
||||
|
||||
# Implement exponential backoff retry strategy
|
||||
for attempt in range(self.max_retries):
|
||||
for attempt in range(self.max_retries+1):
|
||||
try:
|
||||
return self._chat(history, gen_conf)
|
||||
except Exception as e:
|
||||
logging.exception("chat_model.Base.chat got exception")
|
||||
# Classify the error
|
||||
error_code = self._classify_error(e)
|
||||
|
||||
# Check if it's a rate limit error or server error and not the last attempt
|
||||
should_retry = (error_code == ERROR_RATE_LIMIT or error_code == ERROR_SERVER) and attempt < self.max_retries - 1
|
||||
|
||||
if should_retry:
|
||||
delay = self._get_delay(attempt)
|
||||
logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})")
|
||||
time.sleep(delay)
|
||||
else:
|
||||
# For non-rate limit errors or the last attempt, return an error message
|
||||
if attempt == self.max_retries - 1:
|
||||
error_code = ERROR_MAX_RETRIES
|
||||
return f"{ERROR_PREFIX}: {error_code} - {str(e)}", 0
|
||||
e = self._exceptions(e, attempt)
|
||||
if e:
|
||||
return e, 0
|
||||
assert False, "Shouldn't be here."
|
||||
|
||||
def _wrap_toolcall_message(self, stream):
|
||||
final_tool_calls = {}
|
||||
@ -241,41 +226,48 @@ class Base(ABC):
|
||||
del gen_conf["max_tokens"]
|
||||
|
||||
tools = self.tools
|
||||
|
||||
if system:
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
|
||||
ans = ""
|
||||
total_tokens = 0
|
||||
reasoning_start = False
|
||||
finish_completion = False
|
||||
final_tool_calls = {}
|
||||
try:
|
||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, **gen_conf)
|
||||
while not finish_completion:
|
||||
for resp in response:
|
||||
if resp.choices[0].delta.tool_calls:
|
||||
for tool_call in resp.choices[0].delta.tool_calls or []:
|
||||
index = tool_call.index
|
||||
hist = deepcopy(history)
|
||||
# Implement exponential backoff retry strategy
|
||||
for attempt in range(self.max_retries+1):
|
||||
history = hist
|
||||
for _ in range(self.max_rounds*2):
|
||||
reasoning_start = False
|
||||
try:
|
||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, **gen_conf)
|
||||
final_tool_calls = {}
|
||||
answer = ""
|
||||
for resp in response:
|
||||
if resp.choices[0].delta.tool_calls:
|
||||
for tool_call in resp.choices[0].delta.tool_calls or []:
|
||||
index = tool_call.index
|
||||
|
||||
if index not in final_tool_calls:
|
||||
final_tool_calls[index] = tool_call
|
||||
else:
|
||||
final_tool_calls[index].function.arguments += tool_call.function.arguments
|
||||
else:
|
||||
if not resp.choices:
|
||||
if index not in final_tool_calls:
|
||||
final_tool_calls[index] = tool_call
|
||||
else:
|
||||
final_tool_calls[index].function.arguments += tool_call.function.arguments
|
||||
continue
|
||||
|
||||
if any([not resp.choices, not resp.choices[0].delta, not hasattr(resp.choices[0].delta, "content")]):
|
||||
raise Exception("500 response structure error.")
|
||||
|
||||
if not resp.choices[0].delta.content:
|
||||
resp.choices[0].delta.content = ""
|
||||
|
||||
if hasattr(resp.choices[0].delta, "reasoning_content") and resp.choices[0].delta.reasoning_content:
|
||||
ans = ""
|
||||
if not reasoning_start:
|
||||
reasoning_start = True
|
||||
ans = "<think>"
|
||||
ans += resp.choices[0].delta.reasoning_content + "</think>"
|
||||
yield ans
|
||||
else:
|
||||
reasoning_start = False
|
||||
ans = resp.choices[0].delta.content
|
||||
answer += resp.choices[0].delta.content
|
||||
yield resp.choices[0].delta.content
|
||||
|
||||
tol = self.total_token_count(resp)
|
||||
if not tol:
|
||||
@ -283,18 +275,18 @@ class Base(ABC):
|
||||
else:
|
||||
total_tokens += tol
|
||||
|
||||
finish_reason = resp.choices[0].finish_reason
|
||||
if finish_reason == "tool_calls" and final_tool_calls:
|
||||
for tool_call in final_tool_calls.values():
|
||||
name = tool_call.function.name
|
||||
try:
|
||||
args = json.loads(tool_call.function.arguments)
|
||||
except Exception as e:
|
||||
logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}")
|
||||
yield ans + "\n**ERROR**: " + str(e)
|
||||
finish_completion = True
|
||||
break
|
||||
finish_reason = resp.choices[0].finish_reason if hasattr(resp.choices[0], "finish_reason") else ""
|
||||
if finish_reason == "length":
|
||||
yield self._length_stop("")
|
||||
|
||||
if answer:
|
||||
yield total_tokens
|
||||
return
|
||||
|
||||
for tool_call in final_tool_calls.values():
|
||||
name = tool_call.function.name
|
||||
try:
|
||||
args = json_repair.loads(tool_call.function.arguments)
|
||||
tool_response = self.toolcall_session.tool_call(name, args)
|
||||
history.append(
|
||||
{
|
||||
@ -313,26 +305,16 @@ class Base(ABC):
|
||||
}
|
||||
)
|
||||
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_response)})
|
||||
final_tool_calls = {}
|
||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, **gen_conf)
|
||||
continue
|
||||
if finish_reason == "length":
|
||||
if is_chinese(ans):
|
||||
ans += LENGTH_NOTIFICATION_CN
|
||||
else:
|
||||
ans += LENGTH_NOTIFICATION_EN
|
||||
return ans, total_tokens
|
||||
if finish_reason == "stop":
|
||||
finish_completion = True
|
||||
yield ans
|
||||
break
|
||||
yield ans
|
||||
continue
|
||||
except Exception as e:
|
||||
logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}")
|
||||
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
|
||||
except Exception as e:
|
||||
e = self._exceptions(e, attempt)
|
||||
if e:
|
||||
yield total_tokens
|
||||
return
|
||||
|
||||
except openai.APIError as e:
|
||||
yield ans + "\n**ERROR**: " + str(e)
|
||||
|
||||
yield total_tokens
|
||||
assert False, "Shouldn't be here."
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf):
|
||||
if system:
|
||||
@ -636,49 +618,21 @@ class QWenChat(Base):
|
||||
return "".join(result_list[:-1]), result_list[-1]
|
||||
|
||||
def _chat(self, history, gen_conf):
|
||||
tk_count = 0
|
||||
if self.is_reasoning_model(self.model_name) or self.model_name in ["qwen-vl-plus", "qwen-vl-plus-latest", "qwen-vl-max", "qwen-vl-max-latest"]:
|
||||
try:
|
||||
response = super()._chat(history, gen_conf)
|
||||
return response
|
||||
except Exception as e:
|
||||
error_msg = str(e).lower()
|
||||
if "invalid_parameter_error" in error_msg and "only support stream mode" in error_msg:
|
||||
return self._simulate_one_shot_from_stream(history, gen_conf)
|
||||
return super()._chat(history, gen_conf)
|
||||
response = Generation.call(self.model_name, messages=history, result_format="message", **gen_conf)
|
||||
ans = ""
|
||||
tk_count = 0
|
||||
if response.status_code == HTTPStatus.OK:
|
||||
ans += response.output.choices[0]["message"]["content"]
|
||||
tk_count += self.total_token_count(response)
|
||||
if response.output.choices[0].get("finish_reason", "") == "length":
|
||||
if is_chinese([ans]):
|
||||
ans += LENGTH_NOTIFICATION_CN
|
||||
else:
|
||||
return "**ERROR**: " + str(e), tk_count
|
||||
|
||||
try:
|
||||
ans = ""
|
||||
response = Generation.call(self.model_name, messages=history, result_format="message", **gen_conf)
|
||||
if response.status_code == HTTPStatus.OK:
|
||||
ans += response.output.choices[0]["message"]["content"]
|
||||
tk_count += self.total_token_count(response)
|
||||
if response.output.choices[0].get("finish_reason", "") == "length":
|
||||
if is_chinese([ans]):
|
||||
ans += LENGTH_NOTIFICATION_CN
|
||||
else:
|
||||
ans += LENGTH_NOTIFICATION_EN
|
||||
return ans, tk_count
|
||||
return "**ERROR**: " + response.message, tk_count
|
||||
except Exception as e:
|
||||
error_msg = str(e).lower()
|
||||
if "invalid_parameter_error" in error_msg and "only support stream mode" in error_msg:
|
||||
return self._simulate_one_shot_from_stream(history, gen_conf)
|
||||
else:
|
||||
return "**ERROR**: " + str(e), tk_count
|
||||
|
||||
def _simulate_one_shot_from_stream(self, history, gen_conf):
|
||||
"""
|
||||
Handles models that require streaming output but need one-shot response.
|
||||
"""
|
||||
g = self._chat_streamly("", history, gen_conf, incremental_output=True)
|
||||
result_list = list(g)
|
||||
error_msg_list = [item for item in result_list if str(item).find("**ERROR**") >= 0]
|
||||
if len(error_msg_list) > 0:
|
||||
return "**ERROR**: " + "".join(error_msg_list), 0
|
||||
else:
|
||||
return "".join(result_list[:-1]), result_list[-1]
|
||||
ans += LENGTH_NOTIFICATION_EN
|
||||
return ans, tk_count
|
||||
return "**ERROR**: " + response.message, tk_count
|
||||
|
||||
def _wrap_toolcall_message(self, old_message, message):
|
||||
if not old_message:
|
||||
@ -971,10 +925,10 @@ class LocalAIChat(Base):
|
||||
|
||||
|
||||
class LocalLLM(Base):
|
||||
|
||||
def __init__(self, key, model_name, base_url=None, **kwargs):
|
||||
super().__init__(key, model_name, base_url=base_url, **kwargs)
|
||||
from jina import Client
|
||||
|
||||
self.client = Client(port=12345, protocol="grpc", asyncio=True)
|
||||
|
||||
def _prepare_prompt(self, system, history, gen_conf):
|
||||
@ -1031,7 +985,13 @@ class VolcEngineChat(Base):
|
||||
|
||||
|
||||
class MiniMaxChat(Base):
|
||||
def __init__(self, key, model_name, base_url="https://api.minimax.chat/v1/text/chatcompletion_v2", **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
key,
|
||||
model_name,
|
||||
base_url="https://api.minimax.chat/v1/text/chatcompletion_v2",
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(key, model_name, base_url=base_url, **kwargs)
|
||||
|
||||
if not base_url:
|
||||
@ -1263,7 +1223,6 @@ class GeminiChat(Base):
|
||||
|
||||
def _chat(self, history, gen_conf):
|
||||
from google.generativeai.types import content_types
|
||||
|
||||
system = history[0]["content"] if history and history[0]["role"] == "system" else ""
|
||||
hist = []
|
||||
for item in history:
|
||||
@ -1921,4 +1880,4 @@ class GPUStackChat(Base):
|
||||
if not base_url:
|
||||
raise ValueError("Local llm url cannot be None")
|
||||
base_url = urljoin(base_url, "v1")
|
||||
super().__init__(key, model_name, base_url, **kwargs)
|
||||
super().__init__(key, model_name, base_url, **kwargs)
|
||||
@ -119,7 +119,7 @@ def kb_prompt(kbinfos, max_tokens):
|
||||
doc2chunks = defaultdict(lambda: {"chunks": [], "meta": []})
|
||||
for i, ck in enumerate(kbinfos["chunks"][:chunks_num]):
|
||||
cnt = f"---\nID: {i}\n" + (f"URL: {ck['url']}\n" if "url" in ck else "")
|
||||
cnt += ck["content_with_weight"]
|
||||
cnt += re.sub(r"( style=\"[^\"]+\"|</?(html|body|head|title)>|<!DOCTYPE html>)", " ", ck["content_with_weight"], flags=re.DOTALL|re.IGNORECASE)
|
||||
doc2chunks[ck["docnm_kwd"]]["chunks"].append(cnt)
|
||||
doc2chunks[ck["docnm_kwd"]]["meta"] = docs.get(ck["doc_id"], {})
|
||||
|
||||
|
||||
Reference in New Issue
Block a user