diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py
index 638f021f4..9ae96a702 100644
--- a/rag/llm/chat_model.py
+++ b/rag/llm/chat_model.py
@@ -18,11 +18,9 @@ import json
import logging
import os
import random
-import re
import time
from abc import ABC
from copy import deepcopy
-from http import HTTPStatus
from typing import Any, Protocol
from urllib.parse import urljoin
@@ -61,9 +59,6 @@ class ToolCallSession(Protocol):
class Base(ABC):
- tools: list[Any]
- toolcall_sessions: dict[str, ToolCallSession]
-
def __init__(self, key, model_name, base_url, **kwargs):
timeout = int(os.environ.get("LM_TIMEOUT_SECONDS", 600))
self.client = OpenAI(api_key=key, base_url=base_url, timeout=timeout)
@@ -146,6 +141,37 @@ class Base(ABC):
error_code = ERROR_MAX_RETRIES
return f"{ERROR_PREFIX}: {error_code} - {str(e)}"
+ def _verbose_tool_use(self, name, args, res):
+ return "" + json.dumps({
+ "name": name,
+ "args": args,
+ "result": res
+ }, ensure_ascii=False, indent=2) + ""
+
+ def _append_history(self, hist, tool_call, tool_res):
+ hist.append(
+ {
+ "role": "assistant",
+ "tool_calls": [
+ {
+ "index": tool_call.index,
+ "id": tool_call.id,
+ "function": {
+ "name": tool_call.function.name,
+ "arguments": tool_call.function.arguments,
+ },
+ "type": "function",
+ },
+ ],
+ }
+ )
+ try:
+ if isinstance(tool_res, dict):
+ tool_res = json.dumps(tool_res, ensure_ascii=False)
+ finally:
+ hist.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_res)})
+ return hist
+
def bind_tools(self, toolcall_session, tools):
if not (toolcall_session and tools):
return
@@ -160,18 +186,19 @@ class Base(ABC):
if system:
history.insert(0, {"role": "system", "content": system})
+ gen_conf = self._clean_conf(gen_conf)
ans = ""
tk_count = 0
hist = deepcopy(history)
# Implement exponential backoff retry strategy
- for attempt in range(self.max_retries + 1):
+ for attempt in range(self.max_retries+1):
history = hist
- for _ in range(self.max_rounds * 2):
- try:
+ try:
+ for _ in range(self.max_rounds*2):
response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=self.tools, **gen_conf)
tk_count += self.total_token_count(response)
- if any([not response.choices, not response.choices[0].message, not response.choices[0].message.content]):
- raise Exception("500 response structure error.")
+ if any([not response.choices, not response.choices[0].message]):
+ raise Exception(f"500 response structure error. Response: {response}")
if not hasattr(response.choices[0].message, "tool_calls") or not response.choices[0].message.tool_calls:
if hasattr(response.choices[0].message, "reasoning_content") and response.choices[0].message.reasoning_content:
@@ -188,14 +215,17 @@ class Base(ABC):
try:
args = json_repair.loads(tool_call.function.arguments)
tool_response = self.toolcall_sessions[name].tool_call(name, args)
- history.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_response)})
+ history = self._append_history(history, tool_call, tool_response)
+ ans += self._verbose_tool_use(name, args, tool_response)
except Exception as e:
+ logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}")
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
+ ans += self._verbose_tool_use(name, {}, str(e))
- except Exception as e:
- e = self._exceptions(e, attempt)
- if e:
- return e, tk_count
+ except Exception as e:
+ e = self._exceptions(e, attempt)
+ if e:
+ return e, tk_count
assert False, "Shouldn't be here."
def chat(self, system, history, gen_conf):
@@ -228,9 +258,7 @@ class Base(ABC):
return final_tool_calls
def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict):
- if "max_tokens" in gen_conf:
- del gen_conf["max_tokens"]
-
+ gen_conf = self._clean_conf(gen_conf)
tools = self.tools
if system:
history.insert(0, {"role": "system", "content": system})
@@ -240,9 +268,9 @@ class Base(ABC):
# Implement exponential backoff retry strategy
for attempt in range(self.max_retries + 1):
history = hist
- for _ in range(self.max_rounds * 2):
- reasoning_start = False
- try:
+ try:
+ for _ in range(self.max_rounds*2):
+ reasoning_start = False
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, **gen_conf)
final_tool_calls = {}
answer = ""
@@ -252,9 +280,11 @@ class Base(ABC):
index = tool_call.index
if index not in final_tool_calls:
+ if not tool_call.function.arguments:
+ tool_call.function.arguments = ""
final_tool_calls[index] = tool_call
else:
- final_tool_calls[index].function.arguments += tool_call.function.arguments
+ final_tool_calls[index].function.arguments += tool_call.function.arguments if tool_call.function.arguments else ""
continue
if any([not resp.choices, not resp.choices[0].delta, not hasattr(resp.choices[0].delta, "content")]):
@@ -293,40 +323,26 @@ class Base(ABC):
name = tool_call.function.name
try:
args = json_repair.loads(tool_call.function.arguments)
- tool_response = self.toolcall_sessions[name].tool_call(name, args)
- history.append(
- {
- "role": "assistant",
- "tool_calls": [
- {
- "index": tool_call.index,
- "id": tool_call.id,
- "function": {
- "name": tool_call.function.name,
- "arguments": tool_call.function.arguments,
- },
- "type": "function",
- },
- ],
- }
- )
- history.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_response)})
+ tool_response = self.toolcall_session[name].tool_call(name, args)
+ history = self._append_history(history, tool_call, tool_response)
+ yield self._verbose_tool_use(name, args, tool_response)
except Exception as e:
logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}")
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
- except Exception as e:
- e = self._exceptions(e, attempt)
- if e:
- yield total_tokens
- return
+ yield self._verbose_tool_use(name, {}, str(e))
- assert False, "Shouldn't be here."
+ except Exception as e:
+ e = self._exceptions(e, attempt)
+ if e:
+ yield total_tokens
+ return
+
+ yield total_tokens
def chat_streamly(self, system, history, gen_conf):
if system:
history.insert(0, {"role": "system", "content": system})
- if "max_tokens" in gen_conf:
- del gen_conf["max_tokens"]
+ gen_conf = self._clean_conf(gen_conf)
ans = ""
total_tokens = 0
reasoning_start = False
@@ -542,252 +558,8 @@ class BaiChuanChat(Base):
class QWenChat(Base):
def __init__(self, key, model_name=Generation.Models.qwen_turbo, base_url=None, **kwargs):
- super().__init__(key, model_name, base_url=base_url, **kwargs)
-
- import dashscope
-
- dashscope.api_key = key
- self.model_name = model_name
- if self.is_reasoning_model(self.model_name) or self.model_name in ["qwen-vl-plus", "qwen-vl-plus-latest", "qwen-vl-max", "qwen-vl-max-latest"]:
- super().__init__(key, model_name, "https://dashscope.aliyuncs.com/compatible-mode/v1", **kwargs)
-
- def chat_with_tools(self, system: str, history: list, gen_conf: dict) -> tuple[str, int]:
- if "max_tokens" in gen_conf:
- del gen_conf["max_tokens"]
- # if self.is_reasoning_model(self.model_name):
- # return super().chat(system, history, gen_conf)
-
- stream_flag = str(os.environ.get("QWEN_CHAT_BY_STREAM", "true")).lower() == "true"
- if not stream_flag:
- from http import HTTPStatus
-
- tools = self.tools
-
- if system:
- history.insert(0, {"role": "system", "content": system})
-
- response = Generation.call(self.model_name, messages=history, result_format="message", tools=tools, **gen_conf)
- ans = ""
- tk_count = 0
- if response.status_code == HTTPStatus.OK:
- assistant_output = response.output.choices[0].message
- if not ans and "tool_calls" not in assistant_output and "reasoning_content" in assistant_output:
- ans += "" + ans + ""
- ans += response.output.choices[0].message.content
-
- if "tool_calls" not in assistant_output:
- tk_count += self.total_token_count(response)
- if response.output.choices[0].get("finish_reason", "") == "length":
- if is_chinese([ans]):
- ans += LENGTH_NOTIFICATION_CN
- else:
- ans += LENGTH_NOTIFICATION_EN
- return ans, tk_count
-
- tk_count += self.total_token_count(response)
- history.append(assistant_output)
-
- while "tool_calls" in assistant_output:
- tool_info = {"content": "", "role": "tool", "tool_call_id": assistant_output.tool_calls[0]["id"]}
- tool_name = assistant_output.tool_calls[0]["function"]["name"]
- if tool_name:
- arguments = json.loads(assistant_output.tool_calls[0]["function"]["arguments"])
- tool_info["content"] = self.toolcall_sessions[tool_name].tool_call(name=tool_name, arguments=arguments)
- history.append(tool_info)
-
- response = Generation.call(self.model_name, messages=history, result_format="message", tools=self.tools, **gen_conf)
- if response.output.choices[0].get("finish_reason", "") == "length":
- tk_count += self.total_token_count(response)
- if is_chinese([ans]):
- ans += LENGTH_NOTIFICATION_CN
- else:
- ans += LENGTH_NOTIFICATION_EN
- return ans, tk_count
-
- tk_count += self.total_token_count(response)
- assistant_output = response.output.choices[0].message
- if assistant_output.content is None:
- assistant_output.content = ""
- history.append(response)
- ans += assistant_output["content"]
- return ans, tk_count
- else:
- return "**ERROR**: " + response.message, tk_count
- else:
- result_list = []
- for result in self._chat_streamly_with_tools(system, history, gen_conf, incremental_output=True):
- result_list.append(result)
- error_msg_list = [result for result in result_list if str(result).find("**ERROR**") >= 0]
- if len(error_msg_list) > 0:
- return "**ERROR**: " + "".join(error_msg_list), 0
- else:
- return "".join(result_list[:-1]), result_list[-1]
-
- def _chat(self, history, gen_conf):
- if self.is_reasoning_model(self.model_name) or self.model_name in ["qwen-vl-plus", "qwen-vl-plus-latest", "qwen-vl-max", "qwen-vl-max-latest"]:
- return super()._chat(history, gen_conf)
- response = Generation.call(self.model_name, messages=history, result_format="message", **gen_conf)
- ans = ""
- tk_count = 0
- if response.status_code == HTTPStatus.OK:
- ans += response.output.choices[0]["message"]["content"]
- tk_count += self.total_token_count(response)
- if response.output.choices[0].get("finish_reason", "") == "length":
- if is_chinese([ans]):
- ans += LENGTH_NOTIFICATION_CN
- else:
- ans += LENGTH_NOTIFICATION_EN
- return ans, tk_count
- return "**ERROR**: " + response.message, tk_count
-
- def _wrap_toolcall_message(self, old_message, message):
- if not old_message:
- return message
- tool_call_id = message["tool_calls"][0].get("id")
- if tool_call_id:
- old_message.tool_calls[0]["id"] = tool_call_id
- function = message.tool_calls[0]["function"]
- if function:
- if function.get("name"):
- old_message.tool_calls[0]["function"]["name"] = function["name"]
- if function.get("arguments"):
- old_message.tool_calls[0]["function"]["arguments"] += function["arguments"]
- return old_message
-
- def _chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict, incremental_output=True):
- from http import HTTPStatus
-
- if system:
- history.insert(0, {"role": "system", "content": system})
- if "max_tokens" in gen_conf:
- del gen_conf["max_tokens"]
- ans = ""
- tk_count = 0
- try:
- response = Generation.call(self.model_name, messages=history, result_format="message", tools=self.tools, stream=True, incremental_output=incremental_output, **gen_conf)
- tool_info = {"content": "", "role": "tool"}
- toolcall_message = None
- tool_name = ""
- tool_arguments = ""
- finish_completion = False
- reasoning_start = False
- while not finish_completion:
- for resp in response:
- if resp.status_code == HTTPStatus.OK:
- assistant_output = resp.output.choices[0].message
- ans = resp.output.choices[0].message.content
- if not ans and "tool_calls" not in assistant_output and "reasoning_content" in assistant_output:
- ans = resp.output.choices[0].message.reasoning_content
- if not reasoning_start:
- reasoning_start = True
- ans = "" + ans
- else:
- ans = ans + ""
-
- if "tool_calls" not in assistant_output:
- reasoning_start = False
- tk_count += self.total_token_count(resp)
- if resp.output.choices[0].get("finish_reason", "") == "length":
- if is_chinese([ans]):
- ans += LENGTH_NOTIFICATION_CN
- else:
- ans += LENGTH_NOTIFICATION_EN
- finish_reason = resp.output.choices[0]["finish_reason"]
- if finish_reason == "stop":
- finish_completion = True
- yield ans
- break
- yield ans
- continue
-
- tk_count += self.total_token_count(resp)
- toolcall_message = self._wrap_toolcall_message(toolcall_message, assistant_output)
- if "tool_calls" in assistant_output:
- tool_call_finish_reason = resp.output.choices[0]["finish_reason"]
- if tool_call_finish_reason == "tool_calls":
- try:
- tool_arguments = json.loads(toolcall_message.tool_calls[0]["function"]["arguments"])
- except Exception as e:
- logging.exception(msg="_chat_streamly_with_tool tool call error")
- yield ans + "\n**ERROR**: " + str(e)
- finish_completion = True
- break
-
- tool_name = toolcall_message.tool_calls[0]["function"]["name"]
- history.append(toolcall_message)
- tool_info["content"] = self.toolcall_sessions[tool_name].tool_call(name=tool_name, arguments=tool_arguments)
- history.append(tool_info)
- tool_info = {"content": "", "role": "tool"}
- tool_name = ""
- tool_arguments = ""
- toolcall_message = None
- response = Generation.call(self.model_name, messages=history, result_format="message", tools=self.tools, stream=True, incremental_output=incremental_output, **gen_conf)
- else:
- yield (
- ans + "\n**ERROR**: " + resp.output.choices[0].message
- if not re.search(r" (key|quota)", str(resp.message).lower())
- else "Out of credit. Please set the API key in **settings > Model providers.**"
- )
- except Exception as e:
- logging.exception(msg="_chat_streamly_with_tool")
- yield ans + "\n**ERROR**: " + str(e)
- yield tk_count
-
- def _chat_streamly(self, system, history, gen_conf, incremental_output=True):
- from http import HTTPStatus
-
- if system:
- history.insert(0, {"role": "system", "content": system})
- if "max_tokens" in gen_conf:
- del gen_conf["max_tokens"]
- ans = ""
- tk_count = 0
- try:
- response = Generation.call(self.model_name, messages=history, result_format="message", stream=True, incremental_output=incremental_output, **gen_conf)
- for resp in response:
- if resp.status_code == HTTPStatus.OK:
- ans = resp.output.choices[0]["message"]["content"]
- tk_count = self.total_token_count(resp)
- if resp.output.choices[0].get("finish_reason", "") == "length":
- if is_chinese(ans):
- ans += LENGTH_NOTIFICATION_CN
- else:
- ans += LENGTH_NOTIFICATION_EN
- yield ans
- else:
- yield (
- ans + "\n**ERROR**: " + resp.message
- if not re.search(r" (key|quota)", str(resp.message).lower())
- else "Out of credit. Please set the API key in **settings > Model providers.**"
- )
- except Exception as e:
- yield ans + "\n**ERROR**: " + str(e)
-
- yield tk_count
-
- def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict, incremental_output=True):
- if "max_tokens" in gen_conf:
- del gen_conf["max_tokens"]
-
- for txt in self._chat_streamly_with_tools(system, history, gen_conf, incremental_output=incremental_output):
- yield txt
-
- def chat_streamly(self, system, history, gen_conf):
- if "max_tokens" in gen_conf:
- del gen_conf["max_tokens"]
- if self.is_reasoning_model(self.model_name) or self.model_name in ["qwen-vl-plus", "qwen-vl-plus-latest", "qwen-vl-max", "qwen-vl-max-latest"]:
- return super().chat_streamly(system, history, gen_conf)
-
- return self._chat_streamly(system, history, gen_conf)
-
- @staticmethod
- def is_reasoning_model(model_name: str) -> bool:
- return any(
- [
- model_name.lower().find("deepseek") >= 0,
- model_name.lower().find("qwq") >= 0 and model_name.lower() != "qwq-32b-preview",
- ]
- )
+ super().__init__(key, model_name, base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", **kwargs)
+ return
class ZhipuChat(Base):
@@ -1877,4 +1649,4 @@ class GPUStackChat(Base):
if not base_url:
raise ValueError("Local llm url cannot be None")
base_url = urljoin(base_url, "v1")
- super().__init__(key, model_name, base_url, **kwargs)
+ super().__init__(key, model_name, base_url, **kwargs)
\ No newline at end of file
diff --git a/rag/raptor.py b/rag/raptor.py
index db2c82d2e..a8d912f32 100644
--- a/rag/raptor.py
+++ b/rag/raptor.py
@@ -105,14 +105,14 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
],
{"temperature": 0.3, "max_tokens": self._max_token},
)
- cnt = re.sub(
- "(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)",
- "",
- cnt,
- )
- logging.debug(f"SUM: {cnt}")
- embds = await self._embedding_encode(cnt)
- chunks.append((cnt, embds))
+ cnt = re.sub(
+ "(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)",
+ "",
+ cnt,
+ )
+ logging.debug(f"SUM: {cnt}")
+ embds = await self._embedding_encode(cnt)
+ chunks.append((cnt, embds))
labels = []
while end - start > 1: