diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index 638f021f4..9ae96a702 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -18,11 +18,9 @@ import json import logging import os import random -import re import time from abc import ABC from copy import deepcopy -from http import HTTPStatus from typing import Any, Protocol from urllib.parse import urljoin @@ -61,9 +59,6 @@ class ToolCallSession(Protocol): class Base(ABC): - tools: list[Any] - toolcall_sessions: dict[str, ToolCallSession] - def __init__(self, key, model_name, base_url, **kwargs): timeout = int(os.environ.get("LM_TIMEOUT_SECONDS", 600)) self.client = OpenAI(api_key=key, base_url=base_url, timeout=timeout) @@ -146,6 +141,37 @@ class Base(ABC): error_code = ERROR_MAX_RETRIES return f"{ERROR_PREFIX}: {error_code} - {str(e)}" + def _verbose_tool_use(self, name, args, res): + return "" + json.dumps({ + "name": name, + "args": args, + "result": res + }, ensure_ascii=False, indent=2) + "" + + def _append_history(self, hist, tool_call, tool_res): + hist.append( + { + "role": "assistant", + "tool_calls": [ + { + "index": tool_call.index, + "id": tool_call.id, + "function": { + "name": tool_call.function.name, + "arguments": tool_call.function.arguments, + }, + "type": "function", + }, + ], + } + ) + try: + if isinstance(tool_res, dict): + tool_res = json.dumps(tool_res, ensure_ascii=False) + finally: + hist.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_res)}) + return hist + def bind_tools(self, toolcall_session, tools): if not (toolcall_session and tools): return @@ -160,18 +186,19 @@ class Base(ABC): if system: history.insert(0, {"role": "system", "content": system}) + gen_conf = self._clean_conf(gen_conf) ans = "" tk_count = 0 hist = deepcopy(history) # Implement exponential backoff retry strategy - for attempt in range(self.max_retries + 1): + for attempt in range(self.max_retries+1): history = hist - for _ in range(self.max_rounds * 2): - try: + try: + for _ in range(self.max_rounds*2): response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=self.tools, **gen_conf) tk_count += self.total_token_count(response) - if any([not response.choices, not response.choices[0].message, not response.choices[0].message.content]): - raise Exception("500 response structure error.") + if any([not response.choices, not response.choices[0].message]): + raise Exception(f"500 response structure error. Response: {response}") if not hasattr(response.choices[0].message, "tool_calls") or not response.choices[0].message.tool_calls: if hasattr(response.choices[0].message, "reasoning_content") and response.choices[0].message.reasoning_content: @@ -188,14 +215,17 @@ class Base(ABC): try: args = json_repair.loads(tool_call.function.arguments) tool_response = self.toolcall_sessions[name].tool_call(name, args) - history.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_response)}) + history = self._append_history(history, tool_call, tool_response) + ans += self._verbose_tool_use(name, args, tool_response) except Exception as e: + logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}") history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)}) + ans += self._verbose_tool_use(name, {}, str(e)) - except Exception as e: - e = self._exceptions(e, attempt) - if e: - return e, tk_count + except Exception as e: + e = self._exceptions(e, attempt) + if e: + return e, tk_count assert False, "Shouldn't be here." def chat(self, system, history, gen_conf): @@ -228,9 +258,7 @@ class Base(ABC): return final_tool_calls def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict): - if "max_tokens" in gen_conf: - del gen_conf["max_tokens"] - + gen_conf = self._clean_conf(gen_conf) tools = self.tools if system: history.insert(0, {"role": "system", "content": system}) @@ -240,9 +268,9 @@ class Base(ABC): # Implement exponential backoff retry strategy for attempt in range(self.max_retries + 1): history = hist - for _ in range(self.max_rounds * 2): - reasoning_start = False - try: + try: + for _ in range(self.max_rounds*2): + reasoning_start = False response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, **gen_conf) final_tool_calls = {} answer = "" @@ -252,9 +280,11 @@ class Base(ABC): index = tool_call.index if index not in final_tool_calls: + if not tool_call.function.arguments: + tool_call.function.arguments = "" final_tool_calls[index] = tool_call else: - final_tool_calls[index].function.arguments += tool_call.function.arguments + final_tool_calls[index].function.arguments += tool_call.function.arguments if tool_call.function.arguments else "" continue if any([not resp.choices, not resp.choices[0].delta, not hasattr(resp.choices[0].delta, "content")]): @@ -293,40 +323,26 @@ class Base(ABC): name = tool_call.function.name try: args = json_repair.loads(tool_call.function.arguments) - tool_response = self.toolcall_sessions[name].tool_call(name, args) - history.append( - { - "role": "assistant", - "tool_calls": [ - { - "index": tool_call.index, - "id": tool_call.id, - "function": { - "name": tool_call.function.name, - "arguments": tool_call.function.arguments, - }, - "type": "function", - }, - ], - } - ) - history.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_response)}) + tool_response = self.toolcall_session[name].tool_call(name, args) + history = self._append_history(history, tool_call, tool_response) + yield self._verbose_tool_use(name, args, tool_response) except Exception as e: logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}") history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)}) - except Exception as e: - e = self._exceptions(e, attempt) - if e: - yield total_tokens - return + yield self._verbose_tool_use(name, {}, str(e)) - assert False, "Shouldn't be here." + except Exception as e: + e = self._exceptions(e, attempt) + if e: + yield total_tokens + return + + yield total_tokens def chat_streamly(self, system, history, gen_conf): if system: history.insert(0, {"role": "system", "content": system}) - if "max_tokens" in gen_conf: - del gen_conf["max_tokens"] + gen_conf = self._clean_conf(gen_conf) ans = "" total_tokens = 0 reasoning_start = False @@ -542,252 +558,8 @@ class BaiChuanChat(Base): class QWenChat(Base): def __init__(self, key, model_name=Generation.Models.qwen_turbo, base_url=None, **kwargs): - super().__init__(key, model_name, base_url=base_url, **kwargs) - - import dashscope - - dashscope.api_key = key - self.model_name = model_name - if self.is_reasoning_model(self.model_name) or self.model_name in ["qwen-vl-plus", "qwen-vl-plus-latest", "qwen-vl-max", "qwen-vl-max-latest"]: - super().__init__(key, model_name, "https://dashscope.aliyuncs.com/compatible-mode/v1", **kwargs) - - def chat_with_tools(self, system: str, history: list, gen_conf: dict) -> tuple[str, int]: - if "max_tokens" in gen_conf: - del gen_conf["max_tokens"] - # if self.is_reasoning_model(self.model_name): - # return super().chat(system, history, gen_conf) - - stream_flag = str(os.environ.get("QWEN_CHAT_BY_STREAM", "true")).lower() == "true" - if not stream_flag: - from http import HTTPStatus - - tools = self.tools - - if system: - history.insert(0, {"role": "system", "content": system}) - - response = Generation.call(self.model_name, messages=history, result_format="message", tools=tools, **gen_conf) - ans = "" - tk_count = 0 - if response.status_code == HTTPStatus.OK: - assistant_output = response.output.choices[0].message - if not ans and "tool_calls" not in assistant_output and "reasoning_content" in assistant_output: - ans += "" + ans + "" - ans += response.output.choices[0].message.content - - if "tool_calls" not in assistant_output: - tk_count += self.total_token_count(response) - if response.output.choices[0].get("finish_reason", "") == "length": - if is_chinese([ans]): - ans += LENGTH_NOTIFICATION_CN - else: - ans += LENGTH_NOTIFICATION_EN - return ans, tk_count - - tk_count += self.total_token_count(response) - history.append(assistant_output) - - while "tool_calls" in assistant_output: - tool_info = {"content": "", "role": "tool", "tool_call_id": assistant_output.tool_calls[0]["id"]} - tool_name = assistant_output.tool_calls[0]["function"]["name"] - if tool_name: - arguments = json.loads(assistant_output.tool_calls[0]["function"]["arguments"]) - tool_info["content"] = self.toolcall_sessions[tool_name].tool_call(name=tool_name, arguments=arguments) - history.append(tool_info) - - response = Generation.call(self.model_name, messages=history, result_format="message", tools=self.tools, **gen_conf) - if response.output.choices[0].get("finish_reason", "") == "length": - tk_count += self.total_token_count(response) - if is_chinese([ans]): - ans += LENGTH_NOTIFICATION_CN - else: - ans += LENGTH_NOTIFICATION_EN - return ans, tk_count - - tk_count += self.total_token_count(response) - assistant_output = response.output.choices[0].message - if assistant_output.content is None: - assistant_output.content = "" - history.append(response) - ans += assistant_output["content"] - return ans, tk_count - else: - return "**ERROR**: " + response.message, tk_count - else: - result_list = [] - for result in self._chat_streamly_with_tools(system, history, gen_conf, incremental_output=True): - result_list.append(result) - error_msg_list = [result for result in result_list if str(result).find("**ERROR**") >= 0] - if len(error_msg_list) > 0: - return "**ERROR**: " + "".join(error_msg_list), 0 - else: - return "".join(result_list[:-1]), result_list[-1] - - def _chat(self, history, gen_conf): - if self.is_reasoning_model(self.model_name) or self.model_name in ["qwen-vl-plus", "qwen-vl-plus-latest", "qwen-vl-max", "qwen-vl-max-latest"]: - return super()._chat(history, gen_conf) - response = Generation.call(self.model_name, messages=history, result_format="message", **gen_conf) - ans = "" - tk_count = 0 - if response.status_code == HTTPStatus.OK: - ans += response.output.choices[0]["message"]["content"] - tk_count += self.total_token_count(response) - if response.output.choices[0].get("finish_reason", "") == "length": - if is_chinese([ans]): - ans += LENGTH_NOTIFICATION_CN - else: - ans += LENGTH_NOTIFICATION_EN - return ans, tk_count - return "**ERROR**: " + response.message, tk_count - - def _wrap_toolcall_message(self, old_message, message): - if not old_message: - return message - tool_call_id = message["tool_calls"][0].get("id") - if tool_call_id: - old_message.tool_calls[0]["id"] = tool_call_id - function = message.tool_calls[0]["function"] - if function: - if function.get("name"): - old_message.tool_calls[0]["function"]["name"] = function["name"] - if function.get("arguments"): - old_message.tool_calls[0]["function"]["arguments"] += function["arguments"] - return old_message - - def _chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict, incremental_output=True): - from http import HTTPStatus - - if system: - history.insert(0, {"role": "system", "content": system}) - if "max_tokens" in gen_conf: - del gen_conf["max_tokens"] - ans = "" - tk_count = 0 - try: - response = Generation.call(self.model_name, messages=history, result_format="message", tools=self.tools, stream=True, incremental_output=incremental_output, **gen_conf) - tool_info = {"content": "", "role": "tool"} - toolcall_message = None - tool_name = "" - tool_arguments = "" - finish_completion = False - reasoning_start = False - while not finish_completion: - for resp in response: - if resp.status_code == HTTPStatus.OK: - assistant_output = resp.output.choices[0].message - ans = resp.output.choices[0].message.content - if not ans and "tool_calls" not in assistant_output and "reasoning_content" in assistant_output: - ans = resp.output.choices[0].message.reasoning_content - if not reasoning_start: - reasoning_start = True - ans = "" + ans - else: - ans = ans + "" - - if "tool_calls" not in assistant_output: - reasoning_start = False - tk_count += self.total_token_count(resp) - if resp.output.choices[0].get("finish_reason", "") == "length": - if is_chinese([ans]): - ans += LENGTH_NOTIFICATION_CN - else: - ans += LENGTH_NOTIFICATION_EN - finish_reason = resp.output.choices[0]["finish_reason"] - if finish_reason == "stop": - finish_completion = True - yield ans - break - yield ans - continue - - tk_count += self.total_token_count(resp) - toolcall_message = self._wrap_toolcall_message(toolcall_message, assistant_output) - if "tool_calls" in assistant_output: - tool_call_finish_reason = resp.output.choices[0]["finish_reason"] - if tool_call_finish_reason == "tool_calls": - try: - tool_arguments = json.loads(toolcall_message.tool_calls[0]["function"]["arguments"]) - except Exception as e: - logging.exception(msg="_chat_streamly_with_tool tool call error") - yield ans + "\n**ERROR**: " + str(e) - finish_completion = True - break - - tool_name = toolcall_message.tool_calls[0]["function"]["name"] - history.append(toolcall_message) - tool_info["content"] = self.toolcall_sessions[tool_name].tool_call(name=tool_name, arguments=tool_arguments) - history.append(tool_info) - tool_info = {"content": "", "role": "tool"} - tool_name = "" - tool_arguments = "" - toolcall_message = None - response = Generation.call(self.model_name, messages=history, result_format="message", tools=self.tools, stream=True, incremental_output=incremental_output, **gen_conf) - else: - yield ( - ans + "\n**ERROR**: " + resp.output.choices[0].message - if not re.search(r" (key|quota)", str(resp.message).lower()) - else "Out of credit. Please set the API key in **settings > Model providers.**" - ) - except Exception as e: - logging.exception(msg="_chat_streamly_with_tool") - yield ans + "\n**ERROR**: " + str(e) - yield tk_count - - def _chat_streamly(self, system, history, gen_conf, incremental_output=True): - from http import HTTPStatus - - if system: - history.insert(0, {"role": "system", "content": system}) - if "max_tokens" in gen_conf: - del gen_conf["max_tokens"] - ans = "" - tk_count = 0 - try: - response = Generation.call(self.model_name, messages=history, result_format="message", stream=True, incremental_output=incremental_output, **gen_conf) - for resp in response: - if resp.status_code == HTTPStatus.OK: - ans = resp.output.choices[0]["message"]["content"] - tk_count = self.total_token_count(resp) - if resp.output.choices[0].get("finish_reason", "") == "length": - if is_chinese(ans): - ans += LENGTH_NOTIFICATION_CN - else: - ans += LENGTH_NOTIFICATION_EN - yield ans - else: - yield ( - ans + "\n**ERROR**: " + resp.message - if not re.search(r" (key|quota)", str(resp.message).lower()) - else "Out of credit. Please set the API key in **settings > Model providers.**" - ) - except Exception as e: - yield ans + "\n**ERROR**: " + str(e) - - yield tk_count - - def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict, incremental_output=True): - if "max_tokens" in gen_conf: - del gen_conf["max_tokens"] - - for txt in self._chat_streamly_with_tools(system, history, gen_conf, incremental_output=incremental_output): - yield txt - - def chat_streamly(self, system, history, gen_conf): - if "max_tokens" in gen_conf: - del gen_conf["max_tokens"] - if self.is_reasoning_model(self.model_name) or self.model_name in ["qwen-vl-plus", "qwen-vl-plus-latest", "qwen-vl-max", "qwen-vl-max-latest"]: - return super().chat_streamly(system, history, gen_conf) - - return self._chat_streamly(system, history, gen_conf) - - @staticmethod - def is_reasoning_model(model_name: str) -> bool: - return any( - [ - model_name.lower().find("deepseek") >= 0, - model_name.lower().find("qwq") >= 0 and model_name.lower() != "qwq-32b-preview", - ] - ) + super().__init__(key, model_name, base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", **kwargs) + return class ZhipuChat(Base): @@ -1877,4 +1649,4 @@ class GPUStackChat(Base): if not base_url: raise ValueError("Local llm url cannot be None") base_url = urljoin(base_url, "v1") - super().__init__(key, model_name, base_url, **kwargs) + super().__init__(key, model_name, base_url, **kwargs) \ No newline at end of file diff --git a/rag/raptor.py b/rag/raptor.py index db2c82d2e..a8d912f32 100644 --- a/rag/raptor.py +++ b/rag/raptor.py @@ -105,14 +105,14 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: ], {"temperature": 0.3, "max_tokens": self._max_token}, ) - cnt = re.sub( - "(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)", - "", - cnt, - ) - logging.debug(f"SUM: {cnt}") - embds = await self._embedding_encode(cnt) - chunks.append((cnt, embds)) + cnt = re.sub( + "(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)", + "", + cnt, + ) + logging.debug(f"SUM: {cnt}") + embds = await self._embedding_encode(cnt) + chunks.append((cnt, embds)) labels = [] while end - start > 1: