diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py
index c045d2611..820dbec53 100644
--- a/rag/llm/chat_model.py
+++ b/rag/llm/chat_model.py
@@ -26,6 +26,7 @@ from http import HTTPStatus
from typing import Any, Protocol
from urllib.parse import urljoin
+import json_repair
import openai
import requests
from dashscope import Generation
@@ -67,11 +68,12 @@ class Base(ABC):
# Configure retry parameters
self.max_retries = kwargs.get("max_retries", int(os.environ.get("LLM_MAX_RETRIES", 5)))
self.base_delay = kwargs.get("retry_interval", float(os.environ.get("LLM_BASE_DELAY", 2.0)))
+ self.max_rounds = kwargs.get("max_rounds", 5)
self.is_tools = False
- def _get_delay(self, attempt):
+ def _get_delay(self):
"""Calculate retry delay time"""
- return self.base_delay * (2**attempt) + random.uniform(0, 0.5)
+ return self.base_delay + random.uniform(0, 0.5)
def _classify_error(self, error):
"""Classify error based on error message content"""
@@ -116,6 +118,29 @@ class Base(ABC):
ans += LENGTH_NOTIFICATION_EN
return ans, self.total_token_count(response)
+ def _length_stop(self, ans):
+ if is_chinese([ans]):
+ return ans + LENGTH_NOTIFICATION_CN
+ return ans + LENGTH_NOTIFICATION_EN
+
+ def _exceptions(self, e, attempt):
+ logging.exception("OpenAI cat_with_tools")
+ # Classify the error
+ error_code = self._classify_error(e)
+
+ # Check if it's a rate limit error or server error and not the last attempt
+ should_retry = (error_code == ERROR_RATE_LIMIT or error_code == ERROR_SERVER) and attempt < self.max_retries
+
+ if should_retry:
+ delay = self._get_delay()
+ logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})")
+ time.sleep(delay)
+ else:
+ # For non-rate limit errors or the last attempt, return an error message
+ if attempt == self.max_retries:
+ error_code = ERROR_MAX_RETRIES
+ return f"{ERROR_PREFIX}: {error_code} - {str(e)}"
+
def bind_tools(self, toolcall_session, tools):
if not (toolcall_session and tools):
return
@@ -124,76 +149,48 @@ class Base(ABC):
self.tools = tools
def chat_with_tools(self, system: str, history: list, gen_conf: dict):
- if "max_tokens" in gen_conf:
- del gen_conf["max_tokens"]
-
- tools = self.tools
-
+ gen_conf = self._clean_conf()
if system:
history.insert(0, {"role": "system", "content": system})
ans = ""
tk_count = 0
+ hist = deepcopy(history)
# Implement exponential backoff retry strategy
- for attempt in range(self.max_retries):
- try:
- response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=tools, **gen_conf)
-
- assistant_output = response.choices[0].message
- if not ans and "tool_calls" not in assistant_output and "reasoning_content" in assistant_output:
- ans += "" + ans + ""
- ans += response.choices[0].message.content
-
- if not response.choices[0].message.tool_calls:
+ for attempt in range(self.max_retries+1):
+ history = hist
+ for _ in range(self.max_rounds*2):
+ try:
+ response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=self.tools, **gen_conf)
tk_count += self.total_token_count(response)
- if response.choices[0].finish_reason == "length":
- if is_chinese([ans]):
- ans += LENGTH_NOTIFICATION_CN
- else:
- ans += LENGTH_NOTIFICATION_EN
- return ans, tk_count
+ if any([not response.choices, not response.choices[0].message, not response.choices[0].message.content]):
+ raise Exception("500 response structure error.")
- tk_count += self.total_token_count(response)
- history.append(assistant_output)
+ if not hasattr(response.choices[0].message, "tool_calls") or not response.choices[0].message.tool_calls:
+ if hasattr(response.choices[0].message, "reasoning_content") and response.choices[0].message.reasoning_content:
+ ans += "" + response.choices[0].message.reasoning_content + ""
- for tool_call in response.choices[0].message.tool_calls:
- name = tool_call.function.name
- args = json.loads(tool_call.function.arguments)
+ ans += response.choices[0].message.content
+ if response.choices[0].finish_reason == "length":
+ ans = self._length_stop(ans)
- tool_response = self.toolcall_session.tool_call(name, args)
- history.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_response)})
+ return ans, tk_count
- final_response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=tools, **gen_conf)
- assistant_output = final_response.choices[0].message
- if "tool_calls" not in assistant_output and "reasoning_content" in assistant_output:
- ans += "" + ans + ""
- ans += final_response.choices[0].message.content
- if final_response.choices[0].finish_reason == "length":
- tk_count += self.total_token_count(response)
- if is_chinese([ans]):
- ans += LENGTH_NOTIFICATION_CN
- else:
- ans += LENGTH_NOTIFICATION_EN
- return ans, tk_count
- return ans, tk_count
+ for tool_call in response.choices[0].message.tool_calls:
+ name = tool_call.function.name
+ try:
+ args = json_repair.loads(tool_call.function.arguments)
+ tool_response = self.toolcall_session.tool_call(name, args)
+ history.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_response)})
+ except Exception as e:
+ history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
- except Exception as e:
- logging.exception("OpenAI cat_with_tools")
- # Classify the error
- error_code = self._classify_error(e)
- # Check if it's a rate limit error or server error and not the last attempt
- should_retry = (error_code == ERROR_RATE_LIMIT or error_code == ERROR_SERVER) and attempt < self.max_retries - 1
-
- if should_retry:
- delay = self._get_delay(attempt)
- logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})")
- time.sleep(delay)
- else:
- # For non-rate limit errors or the last attempt, return an error message
- if attempt == self.max_retries - 1:
- error_code = ERROR_MAX_RETRIES
- return f"{ERROR_PREFIX}: {error_code} - {str(e)}", 0
+ except Exception as e:
+ e = self._exceptions(e, attempt)
+ if e:
+ return e, tk_count
+ assert False, "Shouldn't be here."
def chat(self, system, history, gen_conf):
if system:
@@ -201,26 +198,14 @@ class Base(ABC):
gen_conf = self._clean_conf(gen_conf)
# Implement exponential backoff retry strategy
- for attempt in range(self.max_retries):
+ for attempt in range(self.max_retries+1):
try:
return self._chat(history, gen_conf)
except Exception as e:
- logging.exception("chat_model.Base.chat got exception")
- # Classify the error
- error_code = self._classify_error(e)
-
- # Check if it's a rate limit error or server error and not the last attempt
- should_retry = (error_code == ERROR_RATE_LIMIT or error_code == ERROR_SERVER) and attempt < self.max_retries - 1
-
- if should_retry:
- delay = self._get_delay(attempt)
- logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})")
- time.sleep(delay)
- else:
- # For non-rate limit errors or the last attempt, return an error message
- if attempt == self.max_retries - 1:
- error_code = ERROR_MAX_RETRIES
- return f"{ERROR_PREFIX}: {error_code} - {str(e)}", 0
+ e = self._exceptions(e, attempt)
+ if e:
+ return e, 0
+ assert False, "Shouldn't be here."
def _wrap_toolcall_message(self, stream):
final_tool_calls = {}
@@ -241,41 +226,48 @@ class Base(ABC):
del gen_conf["max_tokens"]
tools = self.tools
-
if system:
history.insert(0, {"role": "system", "content": system})
- ans = ""
total_tokens = 0
- reasoning_start = False
- finish_completion = False
- final_tool_calls = {}
- try:
- response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, **gen_conf)
- while not finish_completion:
- for resp in response:
- if resp.choices[0].delta.tool_calls:
- for tool_call in resp.choices[0].delta.tool_calls or []:
- index = tool_call.index
+ hist = deepcopy(history)
+ # Implement exponential backoff retry strategy
+ for attempt in range(self.max_retries+1):
+ history = hist
+ for _ in range(self.max_rounds*2):
+ reasoning_start = False
+ try:
+ response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, **gen_conf)
+ final_tool_calls = {}
+ answer = ""
+ for resp in response:
+ if resp.choices[0].delta.tool_calls:
+ for tool_call in resp.choices[0].delta.tool_calls or []:
+ index = tool_call.index
- if index not in final_tool_calls:
- final_tool_calls[index] = tool_call
- else:
- final_tool_calls[index].function.arguments += tool_call.function.arguments
- else:
- if not resp.choices:
+ if index not in final_tool_calls:
+ final_tool_calls[index] = tool_call
+ else:
+ final_tool_calls[index].function.arguments += tool_call.function.arguments
continue
+
+ if any([not resp.choices, not resp.choices[0].delta, not hasattr(resp.choices[0].delta, "content")]):
+ raise Exception("500 response structure error.")
+
if not resp.choices[0].delta.content:
resp.choices[0].delta.content = ""
+
if hasattr(resp.choices[0].delta, "reasoning_content") and resp.choices[0].delta.reasoning_content:
ans = ""
if not reasoning_start:
reasoning_start = True
ans = ""
ans += resp.choices[0].delta.reasoning_content + ""
+ yield ans
else:
reasoning_start = False
- ans = resp.choices[0].delta.content
+ answer += resp.choices[0].delta.content
+ yield resp.choices[0].delta.content
tol = self.total_token_count(resp)
if not tol:
@@ -283,18 +275,18 @@ class Base(ABC):
else:
total_tokens += tol
- finish_reason = resp.choices[0].finish_reason
- if finish_reason == "tool_calls" and final_tool_calls:
- for tool_call in final_tool_calls.values():
- name = tool_call.function.name
- try:
- args = json.loads(tool_call.function.arguments)
- except Exception as e:
- logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}")
- yield ans + "\n**ERROR**: " + str(e)
- finish_completion = True
- break
+ finish_reason = resp.choices[0].finish_reason if hasattr(resp.choices[0], "finish_reason") else ""
+ if finish_reason == "length":
+ yield self._length_stop("")
+ if answer:
+ yield total_tokens
+ return
+
+ for tool_call in final_tool_calls.values():
+ name = tool_call.function.name
+ try:
+ args = json_repair.loads(tool_call.function.arguments)
tool_response = self.toolcall_session.tool_call(name, args)
history.append(
{
@@ -313,26 +305,16 @@ class Base(ABC):
}
)
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_response)})
- final_tool_calls = {}
- response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, **gen_conf)
- continue
- if finish_reason == "length":
- if is_chinese(ans):
- ans += LENGTH_NOTIFICATION_CN
- else:
- ans += LENGTH_NOTIFICATION_EN
- return ans, total_tokens
- if finish_reason == "stop":
- finish_completion = True
- yield ans
- break
- yield ans
- continue
+ except Exception as e:
+ logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}")
+ history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
+ except Exception as e:
+ e = self._exceptions(e, attempt)
+ if e:
+ yield total_tokens
+ return
- except openai.APIError as e:
- yield ans + "\n**ERROR**: " + str(e)
-
- yield total_tokens
+ assert False, "Shouldn't be here."
def chat_streamly(self, system, history, gen_conf):
if system:
@@ -636,49 +618,21 @@ class QWenChat(Base):
return "".join(result_list[:-1]), result_list[-1]
def _chat(self, history, gen_conf):
- tk_count = 0
if self.is_reasoning_model(self.model_name) or self.model_name in ["qwen-vl-plus", "qwen-vl-plus-latest", "qwen-vl-max", "qwen-vl-max-latest"]:
- try:
- response = super()._chat(history, gen_conf)
- return response
- except Exception as e:
- error_msg = str(e).lower()
- if "invalid_parameter_error" in error_msg and "only support stream mode" in error_msg:
- return self._simulate_one_shot_from_stream(history, gen_conf)
+ return super()._chat(history, gen_conf)
+ response = Generation.call(self.model_name, messages=history, result_format="message", **gen_conf)
+ ans = ""
+ tk_count = 0
+ if response.status_code == HTTPStatus.OK:
+ ans += response.output.choices[0]["message"]["content"]
+ tk_count += self.total_token_count(response)
+ if response.output.choices[0].get("finish_reason", "") == "length":
+ if is_chinese([ans]):
+ ans += LENGTH_NOTIFICATION_CN
else:
- return "**ERROR**: " + str(e), tk_count
-
- try:
- ans = ""
- response = Generation.call(self.model_name, messages=history, result_format="message", **gen_conf)
- if response.status_code == HTTPStatus.OK:
- ans += response.output.choices[0]["message"]["content"]
- tk_count += self.total_token_count(response)
- if response.output.choices[0].get("finish_reason", "") == "length":
- if is_chinese([ans]):
- ans += LENGTH_NOTIFICATION_CN
- else:
- ans += LENGTH_NOTIFICATION_EN
- return ans, tk_count
- return "**ERROR**: " + response.message, tk_count
- except Exception as e:
- error_msg = str(e).lower()
- if "invalid_parameter_error" in error_msg and "only support stream mode" in error_msg:
- return self._simulate_one_shot_from_stream(history, gen_conf)
- else:
- return "**ERROR**: " + str(e), tk_count
-
- def _simulate_one_shot_from_stream(self, history, gen_conf):
- """
- Handles models that require streaming output but need one-shot response.
- """
- g = self._chat_streamly("", history, gen_conf, incremental_output=True)
- result_list = list(g)
- error_msg_list = [item for item in result_list if str(item).find("**ERROR**") >= 0]
- if len(error_msg_list) > 0:
- return "**ERROR**: " + "".join(error_msg_list), 0
- else:
- return "".join(result_list[:-1]), result_list[-1]
+ ans += LENGTH_NOTIFICATION_EN
+ return ans, tk_count
+ return "**ERROR**: " + response.message, tk_count
def _wrap_toolcall_message(self, old_message, message):
if not old_message:
@@ -971,10 +925,10 @@ class LocalAIChat(Base):
class LocalLLM(Base):
+
def __init__(self, key, model_name, base_url=None, **kwargs):
super().__init__(key, model_name, base_url=base_url, **kwargs)
from jina import Client
-
self.client = Client(port=12345, protocol="grpc", asyncio=True)
def _prepare_prompt(self, system, history, gen_conf):
@@ -1031,7 +985,13 @@ class VolcEngineChat(Base):
class MiniMaxChat(Base):
- def __init__(self, key, model_name, base_url="https://api.minimax.chat/v1/text/chatcompletion_v2", **kwargs):
+ def __init__(
+ self,
+ key,
+ model_name,
+ base_url="https://api.minimax.chat/v1/text/chatcompletion_v2",
+ **kwargs
+ ):
super().__init__(key, model_name, base_url=base_url, **kwargs)
if not base_url:
@@ -1263,7 +1223,6 @@ class GeminiChat(Base):
def _chat(self, history, gen_conf):
from google.generativeai.types import content_types
-
system = history[0]["content"] if history and history[0]["role"] == "system" else ""
hist = []
for item in history:
@@ -1921,4 +1880,4 @@ class GPUStackChat(Base):
if not base_url:
raise ValueError("Local llm url cannot be None")
base_url = urljoin(base_url, "v1")
- super().__init__(key, model_name, base_url, **kwargs)
+ super().__init__(key, model_name, base_url, **kwargs)
\ No newline at end of file
diff --git a/rag/prompts.py b/rag/prompts.py
index 551ed99f9..389d6a66d 100644
--- a/rag/prompts.py
+++ b/rag/prompts.py
@@ -119,7 +119,7 @@ def kb_prompt(kbinfos, max_tokens):
doc2chunks = defaultdict(lambda: {"chunks": [], "meta": []})
for i, ck in enumerate(kbinfos["chunks"][:chunks_num]):
cnt = f"---\nID: {i}\n" + (f"URL: {ck['url']}\n" if "url" in ck else "")
- cnt += ck["content_with_weight"]
+ cnt += re.sub(r"( style=\"[^\"]+\"|?(html|body|head|title)>|)", " ", ck["content_with_weight"], flags=re.DOTALL|re.IGNORECASE)
doc2chunks[ck["docnm_kwd"]]["chunks"].append(cnt)
doc2chunks[ck["docnm_kwd"]]["meta"] = docs.get(ck["doc_id"], {})