diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index c045d2611..820dbec53 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -26,6 +26,7 @@ from http import HTTPStatus from typing import Any, Protocol from urllib.parse import urljoin +import json_repair import openai import requests from dashscope import Generation @@ -67,11 +68,12 @@ class Base(ABC): # Configure retry parameters self.max_retries = kwargs.get("max_retries", int(os.environ.get("LLM_MAX_RETRIES", 5))) self.base_delay = kwargs.get("retry_interval", float(os.environ.get("LLM_BASE_DELAY", 2.0))) + self.max_rounds = kwargs.get("max_rounds", 5) self.is_tools = False - def _get_delay(self, attempt): + def _get_delay(self): """Calculate retry delay time""" - return self.base_delay * (2**attempt) + random.uniform(0, 0.5) + return self.base_delay + random.uniform(0, 0.5) def _classify_error(self, error): """Classify error based on error message content""" @@ -116,6 +118,29 @@ class Base(ABC): ans += LENGTH_NOTIFICATION_EN return ans, self.total_token_count(response) + def _length_stop(self, ans): + if is_chinese([ans]): + return ans + LENGTH_NOTIFICATION_CN + return ans + LENGTH_NOTIFICATION_EN + + def _exceptions(self, e, attempt): + logging.exception("OpenAI cat_with_tools") + # Classify the error + error_code = self._classify_error(e) + + # Check if it's a rate limit error or server error and not the last attempt + should_retry = (error_code == ERROR_RATE_LIMIT or error_code == ERROR_SERVER) and attempt < self.max_retries + + if should_retry: + delay = self._get_delay() + logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})") + time.sleep(delay) + else: + # For non-rate limit errors or the last attempt, return an error message + if attempt == self.max_retries: + error_code = ERROR_MAX_RETRIES + return f"{ERROR_PREFIX}: {error_code} - {str(e)}" + def bind_tools(self, toolcall_session, tools): if not (toolcall_session and tools): return @@ -124,76 +149,48 @@ class Base(ABC): self.tools = tools def chat_with_tools(self, system: str, history: list, gen_conf: dict): - if "max_tokens" in gen_conf: - del gen_conf["max_tokens"] - - tools = self.tools - + gen_conf = self._clean_conf() if system: history.insert(0, {"role": "system", "content": system}) ans = "" tk_count = 0 + hist = deepcopy(history) # Implement exponential backoff retry strategy - for attempt in range(self.max_retries): - try: - response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=tools, **gen_conf) - - assistant_output = response.choices[0].message - if not ans and "tool_calls" not in assistant_output and "reasoning_content" in assistant_output: - ans += "" + ans + "" - ans += response.choices[0].message.content - - if not response.choices[0].message.tool_calls: + for attempt in range(self.max_retries+1): + history = hist + for _ in range(self.max_rounds*2): + try: + response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=self.tools, **gen_conf) tk_count += self.total_token_count(response) - if response.choices[0].finish_reason == "length": - if is_chinese([ans]): - ans += LENGTH_NOTIFICATION_CN - else: - ans += LENGTH_NOTIFICATION_EN - return ans, tk_count + if any([not response.choices, not response.choices[0].message, not response.choices[0].message.content]): + raise Exception("500 response structure error.") - tk_count += self.total_token_count(response) - history.append(assistant_output) + if not hasattr(response.choices[0].message, "tool_calls") or not response.choices[0].message.tool_calls: + if hasattr(response.choices[0].message, "reasoning_content") and response.choices[0].message.reasoning_content: + ans += "" + response.choices[0].message.reasoning_content + "" - for tool_call in response.choices[0].message.tool_calls: - name = tool_call.function.name - args = json.loads(tool_call.function.arguments) + ans += response.choices[0].message.content + if response.choices[0].finish_reason == "length": + ans = self._length_stop(ans) - tool_response = self.toolcall_session.tool_call(name, args) - history.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_response)}) + return ans, tk_count - final_response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=tools, **gen_conf) - assistant_output = final_response.choices[0].message - if "tool_calls" not in assistant_output and "reasoning_content" in assistant_output: - ans += "" + ans + "" - ans += final_response.choices[0].message.content - if final_response.choices[0].finish_reason == "length": - tk_count += self.total_token_count(response) - if is_chinese([ans]): - ans += LENGTH_NOTIFICATION_CN - else: - ans += LENGTH_NOTIFICATION_EN - return ans, tk_count - return ans, tk_count + for tool_call in response.choices[0].message.tool_calls: + name = tool_call.function.name + try: + args = json_repair.loads(tool_call.function.arguments) + tool_response = self.toolcall_session.tool_call(name, args) + history.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_response)}) + except Exception as e: + history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)}) - except Exception as e: - logging.exception("OpenAI cat_with_tools") - # Classify the error - error_code = self._classify_error(e) - # Check if it's a rate limit error or server error and not the last attempt - should_retry = (error_code == ERROR_RATE_LIMIT or error_code == ERROR_SERVER) and attempt < self.max_retries - 1 - - if should_retry: - delay = self._get_delay(attempt) - logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})") - time.sleep(delay) - else: - # For non-rate limit errors or the last attempt, return an error message - if attempt == self.max_retries - 1: - error_code = ERROR_MAX_RETRIES - return f"{ERROR_PREFIX}: {error_code} - {str(e)}", 0 + except Exception as e: + e = self._exceptions(e, attempt) + if e: + return e, tk_count + assert False, "Shouldn't be here." def chat(self, system, history, gen_conf): if system: @@ -201,26 +198,14 @@ class Base(ABC): gen_conf = self._clean_conf(gen_conf) # Implement exponential backoff retry strategy - for attempt in range(self.max_retries): + for attempt in range(self.max_retries+1): try: return self._chat(history, gen_conf) except Exception as e: - logging.exception("chat_model.Base.chat got exception") - # Classify the error - error_code = self._classify_error(e) - - # Check if it's a rate limit error or server error and not the last attempt - should_retry = (error_code == ERROR_RATE_LIMIT or error_code == ERROR_SERVER) and attempt < self.max_retries - 1 - - if should_retry: - delay = self._get_delay(attempt) - logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})") - time.sleep(delay) - else: - # For non-rate limit errors or the last attempt, return an error message - if attempt == self.max_retries - 1: - error_code = ERROR_MAX_RETRIES - return f"{ERROR_PREFIX}: {error_code} - {str(e)}", 0 + e = self._exceptions(e, attempt) + if e: + return e, 0 + assert False, "Shouldn't be here." def _wrap_toolcall_message(self, stream): final_tool_calls = {} @@ -241,41 +226,48 @@ class Base(ABC): del gen_conf["max_tokens"] tools = self.tools - if system: history.insert(0, {"role": "system", "content": system}) - ans = "" total_tokens = 0 - reasoning_start = False - finish_completion = False - final_tool_calls = {} - try: - response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, **gen_conf) - while not finish_completion: - for resp in response: - if resp.choices[0].delta.tool_calls: - for tool_call in resp.choices[0].delta.tool_calls or []: - index = tool_call.index + hist = deepcopy(history) + # Implement exponential backoff retry strategy + for attempt in range(self.max_retries+1): + history = hist + for _ in range(self.max_rounds*2): + reasoning_start = False + try: + response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, **gen_conf) + final_tool_calls = {} + answer = "" + for resp in response: + if resp.choices[0].delta.tool_calls: + for tool_call in resp.choices[0].delta.tool_calls or []: + index = tool_call.index - if index not in final_tool_calls: - final_tool_calls[index] = tool_call - else: - final_tool_calls[index].function.arguments += tool_call.function.arguments - else: - if not resp.choices: + if index not in final_tool_calls: + final_tool_calls[index] = tool_call + else: + final_tool_calls[index].function.arguments += tool_call.function.arguments continue + + if any([not resp.choices, not resp.choices[0].delta, not hasattr(resp.choices[0].delta, "content")]): + raise Exception("500 response structure error.") + if not resp.choices[0].delta.content: resp.choices[0].delta.content = "" + if hasattr(resp.choices[0].delta, "reasoning_content") and resp.choices[0].delta.reasoning_content: ans = "" if not reasoning_start: reasoning_start = True ans = "" ans += resp.choices[0].delta.reasoning_content + "" + yield ans else: reasoning_start = False - ans = resp.choices[0].delta.content + answer += resp.choices[0].delta.content + yield resp.choices[0].delta.content tol = self.total_token_count(resp) if not tol: @@ -283,18 +275,18 @@ class Base(ABC): else: total_tokens += tol - finish_reason = resp.choices[0].finish_reason - if finish_reason == "tool_calls" and final_tool_calls: - for tool_call in final_tool_calls.values(): - name = tool_call.function.name - try: - args = json.loads(tool_call.function.arguments) - except Exception as e: - logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}") - yield ans + "\n**ERROR**: " + str(e) - finish_completion = True - break + finish_reason = resp.choices[0].finish_reason if hasattr(resp.choices[0], "finish_reason") else "" + if finish_reason == "length": + yield self._length_stop("") + if answer: + yield total_tokens + return + + for tool_call in final_tool_calls.values(): + name = tool_call.function.name + try: + args = json_repair.loads(tool_call.function.arguments) tool_response = self.toolcall_session.tool_call(name, args) history.append( { @@ -313,26 +305,16 @@ class Base(ABC): } ) history.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_response)}) - final_tool_calls = {} - response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, **gen_conf) - continue - if finish_reason == "length": - if is_chinese(ans): - ans += LENGTH_NOTIFICATION_CN - else: - ans += LENGTH_NOTIFICATION_EN - return ans, total_tokens - if finish_reason == "stop": - finish_completion = True - yield ans - break - yield ans - continue + except Exception as e: + logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}") + history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)}) + except Exception as e: + e = self._exceptions(e, attempt) + if e: + yield total_tokens + return - except openai.APIError as e: - yield ans + "\n**ERROR**: " + str(e) - - yield total_tokens + assert False, "Shouldn't be here." def chat_streamly(self, system, history, gen_conf): if system: @@ -636,49 +618,21 @@ class QWenChat(Base): return "".join(result_list[:-1]), result_list[-1] def _chat(self, history, gen_conf): - tk_count = 0 if self.is_reasoning_model(self.model_name) or self.model_name in ["qwen-vl-plus", "qwen-vl-plus-latest", "qwen-vl-max", "qwen-vl-max-latest"]: - try: - response = super()._chat(history, gen_conf) - return response - except Exception as e: - error_msg = str(e).lower() - if "invalid_parameter_error" in error_msg and "only support stream mode" in error_msg: - return self._simulate_one_shot_from_stream(history, gen_conf) + return super()._chat(history, gen_conf) + response = Generation.call(self.model_name, messages=history, result_format="message", **gen_conf) + ans = "" + tk_count = 0 + if response.status_code == HTTPStatus.OK: + ans += response.output.choices[0]["message"]["content"] + tk_count += self.total_token_count(response) + if response.output.choices[0].get("finish_reason", "") == "length": + if is_chinese([ans]): + ans += LENGTH_NOTIFICATION_CN else: - return "**ERROR**: " + str(e), tk_count - - try: - ans = "" - response = Generation.call(self.model_name, messages=history, result_format="message", **gen_conf) - if response.status_code == HTTPStatus.OK: - ans += response.output.choices[0]["message"]["content"] - tk_count += self.total_token_count(response) - if response.output.choices[0].get("finish_reason", "") == "length": - if is_chinese([ans]): - ans += LENGTH_NOTIFICATION_CN - else: - ans += LENGTH_NOTIFICATION_EN - return ans, tk_count - return "**ERROR**: " + response.message, tk_count - except Exception as e: - error_msg = str(e).lower() - if "invalid_parameter_error" in error_msg and "only support stream mode" in error_msg: - return self._simulate_one_shot_from_stream(history, gen_conf) - else: - return "**ERROR**: " + str(e), tk_count - - def _simulate_one_shot_from_stream(self, history, gen_conf): - """ - Handles models that require streaming output but need one-shot response. - """ - g = self._chat_streamly("", history, gen_conf, incremental_output=True) - result_list = list(g) - error_msg_list = [item for item in result_list if str(item).find("**ERROR**") >= 0] - if len(error_msg_list) > 0: - return "**ERROR**: " + "".join(error_msg_list), 0 - else: - return "".join(result_list[:-1]), result_list[-1] + ans += LENGTH_NOTIFICATION_EN + return ans, tk_count + return "**ERROR**: " + response.message, tk_count def _wrap_toolcall_message(self, old_message, message): if not old_message: @@ -971,10 +925,10 @@ class LocalAIChat(Base): class LocalLLM(Base): + def __init__(self, key, model_name, base_url=None, **kwargs): super().__init__(key, model_name, base_url=base_url, **kwargs) from jina import Client - self.client = Client(port=12345, protocol="grpc", asyncio=True) def _prepare_prompt(self, system, history, gen_conf): @@ -1031,7 +985,13 @@ class VolcEngineChat(Base): class MiniMaxChat(Base): - def __init__(self, key, model_name, base_url="https://api.minimax.chat/v1/text/chatcompletion_v2", **kwargs): + def __init__( + self, + key, + model_name, + base_url="https://api.minimax.chat/v1/text/chatcompletion_v2", + **kwargs + ): super().__init__(key, model_name, base_url=base_url, **kwargs) if not base_url: @@ -1263,7 +1223,6 @@ class GeminiChat(Base): def _chat(self, history, gen_conf): from google.generativeai.types import content_types - system = history[0]["content"] if history and history[0]["role"] == "system" else "" hist = [] for item in history: @@ -1921,4 +1880,4 @@ class GPUStackChat(Base): if not base_url: raise ValueError("Local llm url cannot be None") base_url = urljoin(base_url, "v1") - super().__init__(key, model_name, base_url, **kwargs) + super().__init__(key, model_name, base_url, **kwargs) \ No newline at end of file diff --git a/rag/prompts.py b/rag/prompts.py index 551ed99f9..389d6a66d 100644 --- a/rag/prompts.py +++ b/rag/prompts.py @@ -119,7 +119,7 @@ def kb_prompt(kbinfos, max_tokens): doc2chunks = defaultdict(lambda: {"chunks": [], "meta": []}) for i, ck in enumerate(kbinfos["chunks"][:chunks_num]): cnt = f"---\nID: {i}\n" + (f"URL: {ck['url']}\n" if "url" in ck else "") - cnt += ck["content_with_weight"] + cnt += re.sub(r"( style=\"[^\"]+\"||)", " ", ck["content_with_weight"], flags=re.DOTALL|re.IGNORECASE) doc2chunks[ck["docnm_kwd"]]["chunks"].append(cnt) doc2chunks[ck["docnm_kwd"]]["meta"] = docs.get(ck["doc_id"], {})