# # Copyright 2025 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import asyncio import json import logging import os import random import re import time from abc import ABC from copy import deepcopy from urllib.parse import urljoin import json_repair import litellm import openai import requests from openai import OpenAI from openai.lib.azure import AzureOpenAI from strenum import StrEnum from zhipuai import ZhipuAI from rag.llm import FACTORY_DEFAULT_BASE_URL, LITELLM_PROVIDER_PREFIX, SupportedLiteLLMProvider from rag.nlp import is_chinese, is_english from common.token_utils import num_tokens_from_string, total_token_count_from_response # Error message constants class LLMErrorCode(StrEnum): ERROR_RATE_LIMIT = "RATE_LIMIT_EXCEEDED" ERROR_AUTHENTICATION = "AUTH_ERROR" ERROR_INVALID_REQUEST = "INVALID_REQUEST" ERROR_SERVER = "SERVER_ERROR" ERROR_TIMEOUT = "TIMEOUT" ERROR_CONNECTION = "CONNECTION_ERROR" ERROR_MODEL = "MODEL_ERROR" ERROR_MAX_ROUNDS = "ERROR_MAX_ROUNDS" ERROR_CONTENT_FILTER = "CONTENT_FILTERED" ERROR_QUOTA = "QUOTA_EXCEEDED" ERROR_MAX_RETRIES = "MAX_RETRIES_EXCEEDED" ERROR_GENERIC = "GENERIC_ERROR" class ReActMode(StrEnum): FUNCTION_CALL = "function_call" REACT = "react" ERROR_PREFIX = "**ERROR**" LENGTH_NOTIFICATION_CN = "······\n由于大模型的上下文窗口大小限制,回答已经被大模型截断。" LENGTH_NOTIFICATION_EN = "...\nThe answer is truncated by your chosen LLM due to its limitation on context length." class Base(ABC): def __init__(self, key, model_name, base_url, **kwargs): timeout = int(os.environ.get("LM_TIMEOUT_SECONDS", 600)) self.client = OpenAI(api_key=key, base_url=base_url, timeout=timeout) self.model_name = model_name # Configure retry parameters self.max_retries = kwargs.get("max_retries", int(os.environ.get("LLM_MAX_RETRIES", 5))) self.base_delay = kwargs.get("retry_interval", float(os.environ.get("LLM_BASE_DELAY", 2.0))) self.max_rounds = kwargs.get("max_rounds", 5) self.is_tools = False self.tools = [] self.toolcall_sessions = {} def _get_delay(self): """Calculate retry delay time""" return self.base_delay * random.uniform(10, 150) def _classify_error(self, error): """Classify error based on error message content""" error_str = str(error).lower() keywords_mapping = [ (["quota", "capacity", "credit", "billing", "balance", "欠费"], LLMErrorCode.ERROR_QUOTA), (["rate limit", "429", "tpm limit", "too many requests", "requests per minute"], LLMErrorCode.ERROR_RATE_LIMIT), (["auth", "key", "apikey", "401", "forbidden", "permission"], LLMErrorCode.ERROR_AUTHENTICATION), (["invalid", "bad request", "400", "format", "malformed", "parameter"], LLMErrorCode.ERROR_INVALID_REQUEST), (["server", "503", "502", "504", "500", "unavailable"], LLMErrorCode.ERROR_SERVER), (["timeout", "timed out"], LLMErrorCode.ERROR_TIMEOUT), (["connect", "network", "unreachable", "dns"], LLMErrorCode.ERROR_CONNECTION), (["filter", "content", "policy", "blocked", "safety", "inappropriate"], LLMErrorCode.ERROR_CONTENT_FILTER), (["model", "not found", "does not exist", "not available"], LLMErrorCode.ERROR_MODEL), (["max rounds"], LLMErrorCode.ERROR_MODEL), ] for words, code in keywords_mapping: if re.search("({})".format("|".join(words)), error_str): return code return LLMErrorCode.ERROR_GENERIC def _clean_conf(self, gen_conf): if "max_tokens" in gen_conf: del gen_conf["max_tokens"] allowed_conf = { "temperature", "max_completion_tokens", "top_p", "stream", "stream_options", "stop", "n", "presence_penalty", "frequency_penalty", "functions", "function_call", "logit_bias", "user", "response_format", "seed", "tools", "tool_choice", "logprobs", "top_logprobs", "extra_headers" } gen_conf = {k: v for k, v in gen_conf.items() if k in allowed_conf} return gen_conf def _chat(self, history, gen_conf, **kwargs): logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2)) if self.model_name.lower().find("qwq") >= 0: logging.info(f"[INFO] {self.model_name} detected as reasoning model, using _chat_streamly") final_ans = "" tol_token = 0 for delta, tol in self._chat_streamly(history, gen_conf, with_reasoning=False, **kwargs): if delta.startswith("") or delta.endswith(""): continue final_ans += delta tol_token = tol if len(final_ans.strip()) == 0: final_ans = "**ERROR**: Empty response from reasoning model" return final_ans.strip(), tol_token if self.model_name.lower().find("qwen3") >= 0: kwargs["extra_body"] = {"enable_thinking": False} response = self.client.chat.completions.create(model=self.model_name, messages=history, **gen_conf, **kwargs) if not response.choices or not response.choices[0].message or not response.choices[0].message.content: return "", 0 ans = response.choices[0].message.content.strip() if response.choices[0].finish_reason == "length": ans = self._length_stop(ans) return ans, total_token_count_from_response(response) def _chat_streamly(self, history, gen_conf, **kwargs): logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4)) reasoning_start = False if kwargs.get("stop") or "stop" in gen_conf: response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf, stop=kwargs.get("stop")) else: response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf) for resp in response: if not resp.choices: continue if not resp.choices[0].delta.content: resp.choices[0].delta.content = "" if kwargs.get("with_reasoning", True) and hasattr(resp.choices[0].delta, "reasoning_content") and resp.choices[0].delta.reasoning_content: ans = "" if not reasoning_start: reasoning_start = True ans = "" ans += resp.choices[0].delta.reasoning_content + "" else: reasoning_start = False ans = resp.choices[0].delta.content tol = total_token_count_from_response(resp) if not tol: tol = num_tokens_from_string(resp.choices[0].delta.content) if resp.choices[0].finish_reason == "length": if is_chinese(ans): ans += LENGTH_NOTIFICATION_CN else: ans += LENGTH_NOTIFICATION_EN yield ans, tol def _length_stop(self, ans): if is_chinese([ans]): return ans + LENGTH_NOTIFICATION_CN return ans + LENGTH_NOTIFICATION_EN @property def _retryable_errors(self) -> set[str]: return { LLMErrorCode.ERROR_RATE_LIMIT, LLMErrorCode.ERROR_SERVER, } def _should_retry(self, error_code: str) -> bool: return error_code in self._retryable_errors def _exceptions(self, e, attempt) -> str | None: logging.exception("OpenAI chat_with_tools") # Classify the error error_code = self._classify_error(e) if attempt == self.max_retries: error_code = LLMErrorCode.ERROR_MAX_RETRIES if self._should_retry(error_code): delay = self._get_delay() logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})") time.sleep(delay) return None return f"{ERROR_PREFIX}: {error_code} - {str(e)}" def _verbose_tool_use(self, name, args, res): return "" + json.dumps({"name": name, "args": args, "result": res}, ensure_ascii=False, indent=2) + "" def _append_history(self, hist, tool_call, tool_res): hist.append( { "role": "assistant", "tool_calls": [ { "index": tool_call.index, "id": tool_call.id, "function": { "name": tool_call.function.name, "arguments": tool_call.function.arguments, }, "type": "function", }, ], } ) try: if isinstance(tool_res, dict): tool_res = json.dumps(tool_res, ensure_ascii=False) finally: hist.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_res)}) return hist def bind_tools(self, toolcall_session, tools): if not (toolcall_session and tools): return self.is_tools = True self.toolcall_session = toolcall_session self.tools = tools def chat_with_tools(self, system: str, history: list, gen_conf: dict = {}): gen_conf = self._clean_conf(gen_conf) if system and history and history[0].get("role") != "system": history.insert(0, {"role": "system", "content": system}) ans = "" tk_count = 0 hist = deepcopy(history) # Implement exponential backoff retry strategy for attempt in range(self.max_retries + 1): history = hist try: for _ in range(self.max_rounds + 1): logging.info(f"{self.tools=}") response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=self.tools, tool_choice="auto", **gen_conf) tk_count += total_token_count_from_response(response) if any([not response.choices, not response.choices[0].message]): raise Exception(f"500 response structure error. Response: {response}") if not hasattr(response.choices[0].message, "tool_calls") or not response.choices[0].message.tool_calls: if hasattr(response.choices[0].message, "reasoning_content") and response.choices[0].message.reasoning_content: ans += "" + response.choices[0].message.reasoning_content + "" ans += response.choices[0].message.content if response.choices[0].finish_reason == "length": ans = self._length_stop(ans) return ans, tk_count for tool_call in response.choices[0].message.tool_calls: logging.info(f"Response {tool_call=}") name = tool_call.function.name try: args = json_repair.loads(tool_call.function.arguments) tool_response = self.toolcall_session.tool_call(name, args) history = self._append_history(history, tool_call, tool_response) ans += self._verbose_tool_use(name, args, tool_response) except Exception as e: logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}") history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)}) ans += self._verbose_tool_use(name, {}, str(e)) logging.warning(f"Exceed max rounds: {self.max_rounds}") history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"}) response, token_count = self._chat(history, gen_conf) ans += response tk_count += token_count return ans, tk_count except Exception as e: e = self._exceptions(e, attempt) if e: return e, tk_count assert False, "Shouldn't be here." def chat(self, system, history, gen_conf={}, **kwargs): if system and history and history[0].get("role") != "system": history.insert(0, {"role": "system", "content": system}) gen_conf = self._clean_conf(gen_conf) # Implement exponential backoff retry strategy for attempt in range(self.max_retries + 1): try: return self._chat(history, gen_conf, **kwargs) except Exception as e: e = self._exceptions(e, attempt) if e: return e, 0 assert False, "Shouldn't be here." def _wrap_toolcall_message(self, stream): final_tool_calls = {} for chunk in stream: for tool_call in chunk.choices[0].delta.tool_calls or []: index = tool_call.index if index not in final_tool_calls: final_tool_calls[index] = tool_call final_tool_calls[index].function.arguments += tool_call.function.arguments return final_tool_calls def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict = {}): gen_conf = self._clean_conf(gen_conf) tools = self.tools if system and history and history[0].get("role") != "system": history.insert(0, {"role": "system", "content": system}) total_tokens = 0 hist = deepcopy(history) # Implement exponential backoff retry strategy for attempt in range(self.max_retries + 1): history = hist try: for _ in range(self.max_rounds + 1): reasoning_start = False logging.info(f"{tools=}") response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, tool_choice="auto", **gen_conf) final_tool_calls = {} answer = "" for resp in response: if resp.choices[0].delta.tool_calls: for tool_call in resp.choices[0].delta.tool_calls or []: index = tool_call.index if index not in final_tool_calls: if not tool_call.function.arguments: tool_call.function.arguments = "" final_tool_calls[index] = tool_call else: final_tool_calls[index].function.arguments += tool_call.function.arguments if tool_call.function.arguments else "" continue if any([not resp.choices, not resp.choices[0].delta, not hasattr(resp.choices[0].delta, "content")]): raise Exception("500 response structure error.") if not resp.choices[0].delta.content: resp.choices[0].delta.content = "" if hasattr(resp.choices[0].delta, "reasoning_content") and resp.choices[0].delta.reasoning_content: ans = "" if not reasoning_start: reasoning_start = True ans = "" ans += resp.choices[0].delta.reasoning_content + "" yield ans else: reasoning_start = False answer += resp.choices[0].delta.content yield resp.choices[0].delta.content tol = total_token_count_from_response(resp) if not tol: total_tokens += num_tokens_from_string(resp.choices[0].delta.content) else: total_tokens = tol finish_reason = resp.choices[0].finish_reason if hasattr(resp.choices[0], "finish_reason") else "" if finish_reason == "length": yield self._length_stop("") if answer: yield total_tokens return for tool_call in final_tool_calls.values(): name = tool_call.function.name try: args = json_repair.loads(tool_call.function.arguments) yield self._verbose_tool_use(name, args, "Begin to call...") tool_response = self.toolcall_session.tool_call(name, args) history = self._append_history(history, tool_call, tool_response) yield self._verbose_tool_use(name, args, tool_response) except Exception as e: logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}") history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)}) yield self._verbose_tool_use(name, {}, str(e)) logging.warning(f"Exceed max rounds: {self.max_rounds}") history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"}) response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf) for resp in response: if any([not resp.choices, not resp.choices[0].delta, not hasattr(resp.choices[0].delta, "content")]): raise Exception("500 response structure error.") if not resp.choices[0].delta.content: resp.choices[0].delta.content = "" continue tol = total_token_count_from_response(resp) if not tol: total_tokens += num_tokens_from_string(resp.choices[0].delta.content) else: total_tokens = tol answer += resp.choices[0].delta.content yield resp.choices[0].delta.content yield total_tokens return except Exception as e: e = self._exceptions(e, attempt) if e: yield e yield total_tokens return assert False, "Shouldn't be here." def chat_streamly(self, system, history, gen_conf: dict = {}, **kwargs): if system and history and history[0].get("role") != "system": history.insert(0, {"role": "system", "content": system}) gen_conf = self._clean_conf(gen_conf) ans = "" total_tokens = 0 try: for delta_ans, tol in self._chat_streamly(history, gen_conf, **kwargs): yield delta_ans total_tokens += tol except openai.APIError as e: yield ans + "\n**ERROR**: " + str(e) yield total_tokens def _calculate_dynamic_ctx(self, history): """Calculate dynamic context window size""" def count_tokens(text): """Calculate token count for text""" # Simple calculation: 1 token per ASCII character # 2 tokens for non-ASCII characters (Chinese, Japanese, Korean, etc.) total = 0 for char in text: if ord(char) < 128: # ASCII characters total += 1 else: # Non-ASCII characters (Chinese, Japanese, Korean, etc.) total += 2 return total # Calculate total tokens for all messages total_tokens = 0 for message in history: content = message.get("content", "") # Calculate content tokens content_tokens = count_tokens(content) # Add role marker token overhead role_tokens = 4 total_tokens += content_tokens + role_tokens # Apply 1.2x buffer ratio total_tokens_with_buffer = int(total_tokens * 1.2) if total_tokens_with_buffer <= 8192: ctx_size = 8192 else: ctx_multiplier = (total_tokens_with_buffer // 8192) + 1 ctx_size = ctx_multiplier * 8192 return ctx_size class GptTurbo(Base): _FACTORY_NAME = "OpenAI" def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1", **kwargs): if not base_url: base_url = "https://api.openai.com/v1" super().__init__(key, model_name, base_url, **kwargs) class XinferenceChat(Base): _FACTORY_NAME = "Xinference" def __init__(self, key=None, model_name="", base_url="", **kwargs): if not base_url: raise ValueError("Local llm url cannot be None") base_url = urljoin(base_url, "v1") super().__init__(key, model_name, base_url, **kwargs) class HuggingFaceChat(Base): _FACTORY_NAME = "HuggingFace" def __init__(self, key=None, model_name="", base_url="", **kwargs): if not base_url: raise ValueError("Local llm url cannot be None") base_url = urljoin(base_url, "v1") super().__init__(key, model_name.split("___")[0], base_url, **kwargs) class ModelScopeChat(Base): _FACTORY_NAME = "ModelScope" def __init__(self, key=None, model_name="", base_url="", **kwargs): if not base_url: raise ValueError("Local llm url cannot be None") base_url = urljoin(base_url, "v1") super().__init__(key, model_name.split("___")[0], base_url, **kwargs) class AzureChat(Base): _FACTORY_NAME = "Azure-OpenAI" def __init__(self, key, model_name, base_url, **kwargs): api_key = json.loads(key).get("api_key", "") api_version = json.loads(key).get("api_version", "2024-02-01") super().__init__(key, model_name, base_url, **kwargs) self.client = AzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version) self.model_name = model_name @property def _retryable_errors(self) -> set[str]: return { LLMErrorCode.ERROR_RATE_LIMIT, LLMErrorCode.ERROR_SERVER, LLMErrorCode.ERROR_QUOTA, } class BaiChuanChat(Base): _FACTORY_NAME = "BaiChuan" def __init__(self, key, model_name="Baichuan3-Turbo", base_url="https://api.baichuan-ai.com/v1", **kwargs): if not base_url: base_url = "https://api.baichuan-ai.com/v1" super().__init__(key, model_name, base_url, **kwargs) @staticmethod def _format_params(params): return { "temperature": params.get("temperature", 0.3), "top_p": params.get("top_p", 0.85), } def _clean_conf(self, gen_conf): return { "temperature": gen_conf.get("temperature", 0.3), "top_p": gen_conf.get("top_p", 0.85), } def _chat(self, history, gen_conf={}, **kwargs): response = self.client.chat.completions.create( model=self.model_name, messages=history, extra_body={"tools": [{"type": "web_search", "web_search": {"enable": True, "search_mode": "performance_first"}}]}, **gen_conf, ) ans = response.choices[0].message.content.strip() if response.choices[0].finish_reason == "length": if is_chinese([ans]): ans += LENGTH_NOTIFICATION_CN else: ans += LENGTH_NOTIFICATION_EN return ans, total_token_count_from_response(response) def chat_streamly(self, system, history, gen_conf={}, **kwargs): if system and history and history[0].get("role") != "system": history.insert(0, {"role": "system", "content": system}) if "max_tokens" in gen_conf: del gen_conf["max_tokens"] ans = "" total_tokens = 0 try: response = self.client.chat.completions.create( model=self.model_name, messages=history, extra_body={"tools": [{"type": "web_search", "web_search": {"enable": True, "search_mode": "performance_first"}}]}, stream=True, **self._format_params(gen_conf), ) for resp in response: if not resp.choices: continue if not resp.choices[0].delta.content: resp.choices[0].delta.content = "" ans = resp.choices[0].delta.content tol = total_token_count_from_response(resp) if not tol: total_tokens += num_tokens_from_string(resp.choices[0].delta.content) else: total_tokens = tol if resp.choices[0].finish_reason == "length": if is_chinese([ans]): ans += LENGTH_NOTIFICATION_CN else: ans += LENGTH_NOTIFICATION_EN yield ans except Exception as e: yield ans + "\n**ERROR**: " + str(e) yield total_tokens class ZhipuChat(Base): _FACTORY_NAME = "ZHIPU-AI" def __init__(self, key, model_name="glm-3-turbo", base_url=None, **kwargs): super().__init__(key, model_name, base_url=base_url, **kwargs) self.client = ZhipuAI(api_key=key) self.model_name = model_name def _clean_conf(self, gen_conf): if "max_tokens" in gen_conf: del gen_conf["max_tokens"] gen_conf = self._clean_conf_plealty(gen_conf) return gen_conf def _clean_conf_plealty(self, gen_conf): if "presence_penalty" in gen_conf: del gen_conf["presence_penalty"] if "frequency_penalty" in gen_conf: del gen_conf["frequency_penalty"] return gen_conf def chat_with_tools(self, system: str, history: list, gen_conf: dict): gen_conf = self._clean_conf_plealty(gen_conf) return super().chat_with_tools(system, history, gen_conf) def chat_streamly(self, system, history, gen_conf={}, **kwargs): if system and history and history[0].get("role") != "system": history.insert(0, {"role": "system", "content": system}) gen_conf = self._clean_conf(gen_conf) ans = "" tk_count = 0 try: logging.info(json.dumps(history, ensure_ascii=False, indent=2)) response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf) for resp in response: if not resp.choices[0].delta.content: continue delta = resp.choices[0].delta.content ans = delta if resp.choices[0].finish_reason == "length": if is_chinese(ans): ans += LENGTH_NOTIFICATION_CN else: ans += LENGTH_NOTIFICATION_EN tk_count = total_token_count_from_response(resp) if resp.choices[0].finish_reason == "stop": tk_count = total_token_count_from_response(resp) yield ans except Exception as e: yield ans + "\n**ERROR**: " + str(e) yield tk_count def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict): gen_conf = self._clean_conf_plealty(gen_conf) return super().chat_streamly_with_tools(system, history, gen_conf) class LocalAIChat(Base): _FACTORY_NAME = "LocalAI" def __init__(self, key, model_name, base_url=None, **kwargs): super().__init__(key, model_name, base_url=base_url, **kwargs) if not base_url: raise ValueError("Local llm url cannot be None") base_url = urljoin(base_url, "v1") self.client = OpenAI(api_key="empty", base_url=base_url) self.model_name = model_name.split("___")[0] class LocalLLM(Base): def __init__(self, key, model_name, base_url=None, **kwargs): super().__init__(key, model_name, base_url=base_url, **kwargs) from jina import Client self.client = Client(port=12345, protocol="grpc", asyncio=True) def _prepare_prompt(self, system, history, gen_conf): from rag.svr.jina_server import Prompt if system and history and history[0].get("role") != "system": history.insert(0, {"role": "system", "content": system}) return Prompt(message=history, gen_conf=gen_conf) def _stream_response(self, endpoint, prompt): from rag.svr.jina_server import Generation answer = "" try: res = self.client.stream_doc(on=endpoint, inputs=prompt, return_type=Generation) loop = asyncio.get_event_loop() try: while True: answer = loop.run_until_complete(res.__anext__()).text yield answer except StopAsyncIteration: pass except Exception as e: yield answer + "\n**ERROR**: " + str(e) yield num_tokens_from_string(answer) def chat(self, system, history, gen_conf={}, **kwargs): if "max_tokens" in gen_conf: del gen_conf["max_tokens"] prompt = self._prepare_prompt(system, history, gen_conf) chat_gen = self._stream_response("/chat", prompt) ans = next(chat_gen) total_tokens = next(chat_gen) return ans, total_tokens def chat_streamly(self, system, history, gen_conf={}, **kwargs): if "max_tokens" in gen_conf: del gen_conf["max_tokens"] prompt = self._prepare_prompt(system, history, gen_conf) return self._stream_response("/stream", prompt) class VolcEngineChat(Base): _FACTORY_NAME = "VolcEngine" def __init__(self, key, model_name, base_url="https://ark.cn-beijing.volces.com/api/v3", **kwargs): """ Since do not want to modify the original database fields, and the VolcEngine authentication method is quite special, Assemble ark_api_key, ep_id into api_key, store it as a dictionary type, and parse it for use model_name is for display only """ base_url = base_url if base_url else "https://ark.cn-beijing.volces.com/api/v3" ark_api_key = json.loads(key).get("ark_api_key", "") model_name = json.loads(key).get("ep_id", "") + json.loads(key).get("endpoint_id", "") super().__init__(ark_api_key, model_name, base_url, **kwargs) class MiniMaxChat(Base): _FACTORY_NAME = "MiniMax" def __init__(self, key, model_name, base_url="https://api.minimax.chat/v1/text/chatcompletion_v2", **kwargs): super().__init__(key, model_name, base_url=base_url, **kwargs) if not base_url: base_url = "https://api.minimax.chat/v1/text/chatcompletion_v2" self.base_url = base_url self.model_name = model_name self.api_key = key def _clean_conf(self, gen_conf): for k in list(gen_conf.keys()): if k not in ["temperature", "top_p", "max_tokens"]: del gen_conf[k] return gen_conf def _chat(self, history, gen_conf): headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", } payload = json.dumps({"model": self.model_name, "messages": history, **gen_conf}) response = requests.request("POST", url=self.base_url, headers=headers, data=payload) response = response.json() ans = response["choices"][0]["message"]["content"].strip() if response["choices"][0]["finish_reason"] == "length": if is_chinese(ans): ans += LENGTH_NOTIFICATION_CN else: ans += LENGTH_NOTIFICATION_EN return ans, total_token_count_from_response(response) def chat_streamly(self, system, history, gen_conf): if system and history and history[0].get("role") != "system": history.insert(0, {"role": "system", "content": system}) for k in list(gen_conf.keys()): if k not in ["temperature", "top_p", "max_tokens"]: del gen_conf[k] ans = "" total_tokens = 0 try: headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", } payload = json.dumps( { "model": self.model_name, "messages": history, "stream": True, **gen_conf, } ) response = requests.request( "POST", url=self.base_url, headers=headers, data=payload, ) for resp in response.text.split("\n\n")[:-1]: resp = json.loads(resp[6:]) text = "" if "choices" in resp and "delta" in resp["choices"][0]: text = resp["choices"][0]["delta"]["content"] ans = text tol = total_token_count_from_response(resp) if not tol: total_tokens += num_tokens_from_string(text) else: total_tokens = tol yield ans except Exception as e: yield ans + "\n**ERROR**: " + str(e) yield total_tokens class MistralChat(Base): _FACTORY_NAME = "Mistral" def __init__(self, key, model_name, base_url=None, **kwargs): super().__init__(key, model_name, base_url=base_url, **kwargs) from mistralai.client import MistralClient self.client = MistralClient(api_key=key) self.model_name = model_name def _clean_conf(self, gen_conf): for k in list(gen_conf.keys()): if k not in ["temperature", "top_p", "max_tokens"]: del gen_conf[k] return gen_conf def _chat(self, history, gen_conf={}, **kwargs): gen_conf = self._clean_conf(gen_conf) response = self.client.chat(model=self.model_name, messages=history, **gen_conf) ans = response.choices[0].message.content if response.choices[0].finish_reason == "length": if is_chinese(ans): ans += LENGTH_NOTIFICATION_CN else: ans += LENGTH_NOTIFICATION_EN return ans, total_token_count_from_response(response) def chat_streamly(self, system, history, gen_conf={}, **kwargs): if system and history and history[0].get("role") != "system": history.insert(0, {"role": "system", "content": system}) gen_conf = self._clean_conf(gen_conf) ans = "" total_tokens = 0 try: response = self.client.chat_stream(model=self.model_name, messages=history, **gen_conf, **kwargs) for resp in response: if not resp.choices or not resp.choices[0].delta.content: continue ans = resp.choices[0].delta.content total_tokens += 1 if resp.choices[0].finish_reason == "length": if is_chinese(ans): ans += LENGTH_NOTIFICATION_CN else: ans += LENGTH_NOTIFICATION_EN yield ans except openai.APIError as e: yield ans + "\n**ERROR**: " + str(e) yield total_tokens class LmStudioChat(Base): _FACTORY_NAME = "LM-Studio" def __init__(self, key, model_name, base_url, **kwargs): if not base_url: raise ValueError("Local llm url cannot be None") base_url = urljoin(base_url, "v1") super().__init__(key, model_name, base_url, **kwargs) self.client = OpenAI(api_key="lm-studio", base_url=base_url) self.model_name = model_name class OpenAI_APIChat(Base): _FACTORY_NAME = ["VLLM", "OpenAI-API-Compatible"] def __init__(self, key, model_name, base_url, **kwargs): if not base_url: raise ValueError("url cannot be None") model_name = model_name.split("___")[0] super().__init__(key, model_name, base_url, **kwargs) class LeptonAIChat(Base): _FACTORY_NAME = "LeptonAI" def __init__(self, key, model_name, base_url=None, **kwargs): if not base_url: base_url = urljoin("https://" + model_name + ".lepton.run", "api/v1") super().__init__(key, model_name, base_url, **kwargs) class ReplicateChat(Base): _FACTORY_NAME = "Replicate" def __init__(self, key, model_name, base_url=None, **kwargs): super().__init__(key, model_name, base_url=base_url, **kwargs) from replicate.client import Client self.model_name = model_name self.client = Client(api_token=key) def _chat(self, history, gen_conf={}, **kwargs): system = history[0]["content"] if history and history[0]["role"] == "system" else "" prompt = "\n".join([item["role"] + ":" + item["content"] for item in history[-5:] if item["role"] != "system"]) response = self.client.run( self.model_name, input={"system_prompt": system, "prompt": prompt, **gen_conf}, ) ans = "".join(response) return ans, num_tokens_from_string(ans) def chat_streamly(self, system, history, gen_conf={}, **kwargs): if "max_tokens" in gen_conf: del gen_conf["max_tokens"] prompt = "\n".join([item["role"] + ":" + item["content"] for item in history[-5:]]) ans = "" try: response = self.client.run( self.model_name, input={"system_prompt": system, "prompt": prompt, **gen_conf}, ) for resp in response: ans = resp yield ans except Exception as e: yield ans + "\n**ERROR**: " + str(e) yield num_tokens_from_string(ans) class HunyuanChat(Base): _FACTORY_NAME = "Tencent Hunyuan" def __init__(self, key, model_name, base_url=None, **kwargs): super().__init__(key, model_name, base_url=base_url, **kwargs) from tencentcloud.common import credential from tencentcloud.hunyuan.v20230901 import hunyuan_client key = json.loads(key) sid = key.get("hunyuan_sid", "") sk = key.get("hunyuan_sk", "") cred = credential.Credential(sid, sk) self.model_name = model_name self.client = hunyuan_client.HunyuanClient(cred, "") def _clean_conf(self, gen_conf): _gen_conf = {} if "temperature" in gen_conf: _gen_conf["Temperature"] = gen_conf["temperature"] if "top_p" in gen_conf: _gen_conf["TopP"] = gen_conf["top_p"] return _gen_conf def _chat(self, history, gen_conf={}, **kwargs): from tencentcloud.hunyuan.v20230901 import models hist = [{k.capitalize(): v for k, v in item.items()} for item in history] req = models.ChatCompletionsRequest() params = {"Model": self.model_name, "Messages": hist, **gen_conf} req.from_json_string(json.dumps(params)) response = self.client.ChatCompletions(req) ans = response.Choices[0].Message.Content return ans, response.Usage.TotalTokens def chat_streamly(self, system, history, gen_conf={}, **kwargs): from tencentcloud.common.exception.tencent_cloud_sdk_exception import ( TencentCloudSDKException, ) from tencentcloud.hunyuan.v20230901 import models _gen_conf = {} _history = [{k.capitalize(): v for k, v in item.items()} for item in history] if system and history and history[0].get("role") != "system": _history.insert(0, {"Role": "system", "Content": system}) if "max_tokens" in gen_conf: del gen_conf["max_tokens"] if "temperature" in gen_conf: _gen_conf["Temperature"] = gen_conf["temperature"] if "top_p" in gen_conf: _gen_conf["TopP"] = gen_conf["top_p"] req = models.ChatCompletionsRequest() params = { "Model": self.model_name, "Messages": _history, "Stream": True, **_gen_conf, } req.from_json_string(json.dumps(params)) ans = "" total_tokens = 0 try: response = self.client.ChatCompletions(req) for resp in response: resp = json.loads(resp["data"]) if not resp["Choices"] or not resp["Choices"][0]["Delta"]["Content"]: continue ans = resp["Choices"][0]["Delta"]["Content"] total_tokens += 1 yield ans except TencentCloudSDKException as e: yield ans + "\n**ERROR**: " + str(e) yield total_tokens class SparkChat(Base): _FACTORY_NAME = "XunFei Spark" def __init__(self, key, model_name, base_url="https://spark-api-open.xf-yun.com/v1", **kwargs): if not base_url: base_url = "https://spark-api-open.xf-yun.com/v1" model2version = { "Spark-Max": "generalv3.5", "Spark-Lite": "general", "Spark-Pro": "generalv3", "Spark-Pro-128K": "pro-128k", "Spark-4.0-Ultra": "4.0Ultra", } version2model = {v: k for k, v in model2version.items()} assert model_name in model2version or model_name in version2model, f"The given model name is not supported yet. Support: {list(model2version.keys())}" if model_name in model2version: model_version = model2version[model_name] else: model_version = model_name super().__init__(key, model_version, base_url, **kwargs) class BaiduYiyanChat(Base): _FACTORY_NAME = "BaiduYiyan" def __init__(self, key, model_name, base_url=None, **kwargs): super().__init__(key, model_name, base_url=base_url, **kwargs) import qianfan key = json.loads(key) ak = key.get("yiyan_ak", "") sk = key.get("yiyan_sk", "") self.client = qianfan.ChatCompletion(ak=ak, sk=sk) self.model_name = model_name.lower() def _clean_conf(self, gen_conf): gen_conf["penalty_score"] = ((gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty", 0)) / 2) + 1 if "max_tokens" in gen_conf: del gen_conf["max_tokens"] return gen_conf def _chat(self, history, gen_conf): system = history[0]["content"] if history and history[0]["role"] == "system" else "" response = self.client.do(model=self.model_name, messages=[h for h in history if h["role"] != "system"], system=system, **gen_conf).body ans = response["result"] return ans, total_token_count_from_response(response) def chat_streamly(self, system, history, gen_conf={}, **kwargs): gen_conf["penalty_score"] = ((gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty", 0)) / 2) + 1 if "max_tokens" in gen_conf: del gen_conf["max_tokens"] ans = "" total_tokens = 0 try: response = self.client.do(model=self.model_name, messages=history, system=system, stream=True, **gen_conf) for resp in response: resp = resp.body ans = resp["result"] total_tokens = total_token_count_from_response(resp) yield ans except Exception as e: return ans + "\n**ERROR**: " + str(e), 0 yield total_tokens class GoogleChat(Base): _FACTORY_NAME = "Google Cloud" def __init__(self, key, model_name, base_url=None, **kwargs): super().__init__(key, model_name, base_url=base_url, **kwargs) import base64 from google.oauth2 import service_account key = json.loads(key) access_token = json.loads(base64.b64decode(key.get("google_service_account_key", ""))) project_id = key.get("google_project_id", "") region = key.get("google_region", "") scopes = ["https://www.googleapis.com/auth/cloud-platform"] self.model_name = model_name if "claude" in self.model_name: from anthropic import AnthropicVertex from google.auth.transport.requests import Request if access_token: credits = service_account.Credentials.from_service_account_info(access_token, scopes=scopes) request = Request() credits.refresh(request) token = credits.token self.client = AnthropicVertex(region=region, project_id=project_id, access_token=token) else: self.client = AnthropicVertex(region=region, project_id=project_id) else: from google import genai if access_token: credits = service_account.Credentials.from_service_account_info(access_token, scopes=scopes) self.client = genai.Client(vertexai=True, project=project_id, location=region, credentials=credits) else: self.client = genai.Client(vertexai=True, project=project_id, location=region) def _clean_conf(self, gen_conf): if "claude" in self.model_name: if "max_tokens" in gen_conf: del gen_conf["max_tokens"] else: if "max_tokens" in gen_conf: gen_conf["max_output_tokens"] = gen_conf["max_tokens"] del gen_conf["max_tokens"] for k in list(gen_conf.keys()): if k not in ["temperature", "top_p", "max_output_tokens"]: del gen_conf[k] return gen_conf def _chat(self, history, gen_conf={}, **kwargs): system = history[0]["content"] if history and history[0]["role"] == "system" else "" if "claude" in self.model_name: gen_conf = self._clean_conf(gen_conf) response = self.client.messages.create( model=self.model_name, messages=[h for h in history if h["role"] != "system"], system=system, stream=False, **gen_conf, ).json() ans = response["content"][0]["text"] if response["stop_reason"] == "max_tokens": ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" return ( ans, response["usage"]["input_tokens"] + response["usage"]["output_tokens"], ) # Gemini models with google-genai SDK # Set default thinking_budget=0 if not specified if "thinking_budget" not in gen_conf: gen_conf["thinking_budget"] = 0 thinking_budget = gen_conf.pop("thinking_budget", 0) gen_conf = self._clean_conf(gen_conf) # Build GenerateContentConfig try: from google.genai.types import GenerateContentConfig, ThinkingConfig, Content, Part except ImportError as e: logging.error(f"[GoogleChat] Failed to import google-genai: {e}. Please install: pip install google-genai>=1.41.0") raise config_dict = {} if system: config_dict["system_instruction"] = system if "temperature" in gen_conf: config_dict["temperature"] = gen_conf["temperature"] if "top_p" in gen_conf: config_dict["top_p"] = gen_conf["top_p"] if "max_output_tokens" in gen_conf: config_dict["max_output_tokens"] = gen_conf["max_output_tokens"] # Add ThinkingConfig config_dict["thinking_config"] = ThinkingConfig(thinking_budget=thinking_budget) config = GenerateContentConfig(**config_dict) # Convert history to google-genai Content format contents = [] for item in history: if item["role"] == "system": continue # google-genai uses 'model' instead of 'assistant' role = "model" if item["role"] == "assistant" else item["role"] content = Content( role=role, parts=[Part(text=item["content"])] ) contents.append(content) response = self.client.models.generate_content( model=self.model_name, contents=contents, config=config ) ans = response.text # Get token count from response try: total_tokens = response.usage_metadata.total_token_count except Exception: total_tokens = 0 return ans, total_tokens def chat_streamly(self, system, history, gen_conf={}, **kwargs): if "claude" in self.model_name: if "max_tokens" in gen_conf: del gen_conf["max_tokens"] ans = "" total_tokens = 0 try: response = self.client.messages.create( model=self.model_name, messages=history, system=system, stream=True, **gen_conf, ) for res in response.iter_lines(): res = res.decode("utf-8") if "content_block_delta" in res and "data" in res: text = json.loads(res[6:])["delta"]["text"] ans = text total_tokens += num_tokens_from_string(text) except Exception as e: yield ans + "\n**ERROR**: " + str(e) yield total_tokens else: # Gemini models with google-genai SDK ans = "" total_tokens = 0 # Set default thinking_budget=0 if not specified if "thinking_budget" not in gen_conf: gen_conf["thinking_budget"] = 0 thinking_budget = gen_conf.pop("thinking_budget", 0) gen_conf = self._clean_conf(gen_conf) # Build GenerateContentConfig try: from google.genai.types import GenerateContentConfig, ThinkingConfig, Content, Part except ImportError as e: logging.error(f"[GoogleChat] Failed to import google-genai: {e}. Please install: pip install google-genai>=1.41.0") raise config_dict = {} if system: config_dict["system_instruction"] = system if "temperature" in gen_conf: config_dict["temperature"] = gen_conf["temperature"] if "top_p" in gen_conf: config_dict["top_p"] = gen_conf["top_p"] if "max_output_tokens" in gen_conf: config_dict["max_output_tokens"] = gen_conf["max_output_tokens"] # Add ThinkingConfig config_dict["thinking_config"] = ThinkingConfig(thinking_budget=thinking_budget) config = GenerateContentConfig(**config_dict) # Convert history to google-genai Content format contents = [] for item in history: # google-genai uses 'model' instead of 'assistant' role = "model" if item["role"] == "assistant" else item["role"] content = Content( role=role, parts=[Part(text=item["content"])] ) contents.append(content) try: for chunk in self.client.models.generate_content_stream( model=self.model_name, contents=contents, config=config ): text = chunk.text ans = text total_tokens += num_tokens_from_string(text) yield ans except Exception as e: yield ans + "\n**ERROR**: " + str(e) yield total_tokens class GPUStackChat(Base): _FACTORY_NAME = "GPUStack" def __init__(self, key=None, model_name="", base_url="", **kwargs): if not base_url: raise ValueError("Local llm url cannot be None") base_url = urljoin(base_url, "v1") super().__init__(key, model_name, base_url, **kwargs) class TokenPonyChat(Base): _FACTORY_NAME = "TokenPony" def __init__(self, key, model_name, base_url="https://ragflow.vip-api.tokenpony.cn/v1", **kwargs): if not base_url: base_url = "https://ragflow.vip-api.tokenpony.cn/v1" super().__init__(key, model_name, base_url, **kwargs) class DeerAPIChat(Base): _FACTORY_NAME = "DeerAPI" def __init__(self, key, model_name, base_url="https://api.deerapi.com/v1", **kwargs): if not base_url: base_url = "https://api.deerapi.com/v1" super().__init__(key, model_name, base_url, **kwargs) class LiteLLMBase(ABC): _FACTORY_NAME = [ "Tongyi-Qianwen", "Bedrock", "Moonshot", "xAI", "DeepInfra", "Groq", "Cohere", "Gemini", "DeepSeek", "NVIDIA", "TogetherAI", "Anthropic", "Ollama", "LongCat", "CometAPI", "SILICONFLOW", "OpenRouter", "StepFun", "PPIO", "PerfXCloud", "Upstage", "NovitaAI", "01.AI", "GiteeAI", "302.AI", "Jiekou.AI", ] def __init__(self, key, model_name, base_url=None, **kwargs): self.timeout = int(os.environ.get("LM_TIMEOUT_SECONDS", 600)) self.provider = kwargs.get("provider", "") self.prefix = LITELLM_PROVIDER_PREFIX.get(self.provider, "") self.model_name = f"{self.prefix}{model_name}" self.api_key = key self.base_url = (base_url or FACTORY_DEFAULT_BASE_URL.get(self.provider, "")).rstrip("/") # Configure retry parameters self.max_retries = kwargs.get("max_retries", int(os.environ.get("LLM_MAX_RETRIES", 5))) self.base_delay = kwargs.get("retry_interval", float(os.environ.get("LLM_BASE_DELAY", 2.0))) self.max_rounds = kwargs.get("max_rounds", 5) self.is_tools = False self.tools = [] self.toolcall_sessions = {} # Factory specific fields if self.provider == SupportedLiteLLMProvider.Bedrock: self.bedrock_ak = json.loads(key).get("bedrock_ak", "") self.bedrock_sk = json.loads(key).get("bedrock_sk", "") self.bedrock_region = json.loads(key).get("bedrock_region", "") elif self.provider == SupportedLiteLLMProvider.OpenRouter: self.api_key = json.loads(key).get("api_key", "") self.provider_order = json.loads(key).get("provider_order", "") def _get_delay(self): """Calculate retry delay time""" return self.base_delay * random.uniform(10, 150) def _classify_error(self, error): """Classify error based on error message content""" error_str = str(error).lower() keywords_mapping = [ (["quota", "capacity", "credit", "billing", "balance", "欠费"], LLMErrorCode.ERROR_QUOTA), (["rate limit", "429", "tpm limit", "too many requests", "requests per minute"], LLMErrorCode.ERROR_RATE_LIMIT), (["auth", "key", "apikey", "401", "forbidden", "permission"], LLMErrorCode.ERROR_AUTHENTICATION), (["invalid", "bad request", "400", "format", "malformed", "parameter"], LLMErrorCode.ERROR_INVALID_REQUEST), (["server", "503", "502", "504", "500", "unavailable"], LLMErrorCode.ERROR_SERVER), (["timeout", "timed out"], LLMErrorCode.ERROR_TIMEOUT), (["connect", "network", "unreachable", "dns"], LLMErrorCode.ERROR_CONNECTION), (["filter", "content", "policy", "blocked", "safety", "inappropriate"], LLMErrorCode.ERROR_CONTENT_FILTER), (["model", "not found", "does not exist", "not available"], LLMErrorCode.ERROR_MODEL), (["max rounds"], LLMErrorCode.ERROR_MODEL), ] for words, code in keywords_mapping: if re.search("({})".format("|".join(words)), error_str): return code return LLMErrorCode.ERROR_GENERIC def _clean_conf(self, gen_conf): if "max_tokens" in gen_conf: del gen_conf["max_tokens"] return gen_conf def _chat(self, history, gen_conf, **kwargs): logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2)) if self.model_name.lower().find("qwen3") >= 0: kwargs["extra_body"] = {"enable_thinking": False} completion_args = self._construct_completion_args(history=history, stream=False, tools=False, **gen_conf) response = litellm.completion( **completion_args, drop_params=True, timeout=self.timeout, ) # response = self.client.chat.completions.create(model=self.model_name, messages=history, **gen_conf, **kwargs) if any([not response.choices, not response.choices[0].message, not response.choices[0].message.content]): return "", 0 ans = response.choices[0].message.content.strip() if response.choices[0].finish_reason == "length": ans = self._length_stop(ans) return ans, total_token_count_from_response(response) def _chat_streamly(self, history, gen_conf, **kwargs): logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4)) reasoning_start = False completion_args = self._construct_completion_args(history=history, stream=True, tools=False, **gen_conf) stop = kwargs.get("stop") if stop: completion_args["stop"] = stop response = litellm.completion( **completion_args, drop_params=True, timeout=self.timeout, ) for resp in response: if not hasattr(resp, "choices") or not resp.choices: continue delta = resp.choices[0].delta if not hasattr(delta, "content") or delta.content is None: delta.content = "" if kwargs.get("with_reasoning", True) and hasattr(delta, "reasoning_content") and delta.reasoning_content: ans = "" if not reasoning_start: reasoning_start = True ans = "" ans += delta.reasoning_content + "" else: reasoning_start = False ans = delta.content tol = total_token_count_from_response(resp) if not tol: tol = num_tokens_from_string(delta.content) finish_reason = resp.choices[0].finish_reason if hasattr(resp.choices[0], "finish_reason") else "" if finish_reason == "length": if is_chinese(ans): ans += LENGTH_NOTIFICATION_CN else: ans += LENGTH_NOTIFICATION_EN yield ans, tol def _length_stop(self, ans): if is_chinese([ans]): return ans + LENGTH_NOTIFICATION_CN return ans + LENGTH_NOTIFICATION_EN @property def _retryable_errors(self) -> set[str]: return { LLMErrorCode.ERROR_RATE_LIMIT, LLMErrorCode.ERROR_SERVER, } def _should_retry(self, error_code: str) -> bool: return error_code in self._retryable_errors def _exceptions(self, e, attempt) -> str | None: logging.exception("OpenAI chat_with_tools") # Classify the error error_code = self._classify_error(e) if attempt == self.max_retries: error_code = LLMErrorCode.ERROR_MAX_RETRIES if self._should_retry(error_code): delay = self._get_delay() logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})") time.sleep(delay) return None return f"{ERROR_PREFIX}: {error_code} - {str(e)}" def _verbose_tool_use(self, name, args, res): return "" + json.dumps({"name": name, "args": args, "result": res}, ensure_ascii=False, indent=2) + "" def _append_history(self, hist, tool_call, tool_res): hist.append( { "role": "assistant", "tool_calls": [ { "index": tool_call.index, "id": tool_call.id, "function": { "name": tool_call.function.name, "arguments": tool_call.function.arguments, }, "type": "function", }, ], } ) try: if isinstance(tool_res, dict): tool_res = json.dumps(tool_res, ensure_ascii=False) finally: hist.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_res)}) return hist def bind_tools(self, toolcall_session, tools): if not (toolcall_session and tools): return self.is_tools = True self.toolcall_session = toolcall_session self.tools = tools def _construct_completion_args(self, history, stream: bool, tools: bool, **kwargs): completion_args = { "model": self.model_name, "messages": history, "api_key": self.api_key, "num_retries": self.max_retries, **kwargs, } if stream: completion_args.update( { "stream": stream, } ) if tools and self.tools: completion_args.update( { "tools": self.tools, "tool_choice": "auto", } ) if self.provider in FACTORY_DEFAULT_BASE_URL: completion_args.update({"api_base": self.base_url}) elif self.provider == SupportedLiteLLMProvider.Bedrock: completion_args.pop("api_key", None) completion_args.pop("api_base", None) completion_args.update( { "aws_access_key_id": self.bedrock_ak, "aws_secret_access_key": self.bedrock_sk, "aws_region_name": self.bedrock_region, } ) if self.provider == SupportedLiteLLMProvider.OpenRouter: if self.provider_order: def _to_order_list(x): if x is None: return [] if isinstance(x, str): return [s.strip() for s in x.split(",") if s.strip()] if isinstance(x, (list, tuple)): return [str(s).strip() for s in x if str(s).strip()] return [] extra_body = {} provider_cfg = {} provider_order = _to_order_list(self.provider_order) provider_cfg["order"] = provider_order provider_cfg["allow_fallbacks"] = False extra_body["provider"] = provider_cfg completion_args.update({"extra_body": extra_body}) return completion_args def chat_with_tools(self, system: str, history: list, gen_conf: dict = {}): gen_conf = self._clean_conf(gen_conf) if system and history and history[0].get("role") != "system": history.insert(0, {"role": "system", "content": system}) ans = "" tk_count = 0 hist = deepcopy(history) # Implement exponential backoff retry strategy for attempt in range(self.max_retries + 1): history = deepcopy(hist) # deepcopy is required here try: for _ in range(self.max_rounds + 1): logging.info(f"{self.tools=}") completion_args = self._construct_completion_args(history=history, stream=False, tools=True, **gen_conf) response = litellm.completion( **completion_args, drop_params=True, timeout=self.timeout, ) tk_count += total_token_count_from_response(response) if not hasattr(response, "choices") or not response.choices or not response.choices[0].message: raise Exception(f"500 response structure error. Response: {response}") message = response.choices[0].message if not hasattr(message, "tool_calls") or not message.tool_calls: if hasattr(message, "reasoning_content") and message.reasoning_content: ans += f"{message.reasoning_content}" ans += message.content or "" if response.choices[0].finish_reason == "length": ans = self._length_stop(ans) return ans, tk_count for tool_call in message.tool_calls: logging.info(f"Response {tool_call=}") name = tool_call.function.name try: args = json_repair.loads(tool_call.function.arguments) tool_response = self.toolcall_session.tool_call(name, args) history = self._append_history(history, tool_call, tool_response) ans += self._verbose_tool_use(name, args, tool_response) except Exception as e: logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}") history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)}) ans += self._verbose_tool_use(name, {}, str(e)) logging.warning(f"Exceed max rounds: {self.max_rounds}") history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"}) response, token_count = self._chat(history, gen_conf) ans += response tk_count += token_count return ans, tk_count except Exception as e: e = self._exceptions(e, attempt) if e: return e, tk_count assert False, "Shouldn't be here." def chat(self, system, history, gen_conf={}, **kwargs): if system and history and history[0].get("role") != "system": history.insert(0, {"role": "system", "content": system}) gen_conf = self._clean_conf(gen_conf) # Implement exponential backoff retry strategy for attempt in range(self.max_retries + 1): try: response = self._chat(history, gen_conf, **kwargs) return response except Exception as e: e = self._exceptions(e, attempt) if e: return e, 0 assert False, "Shouldn't be here." def _wrap_toolcall_message(self, stream): final_tool_calls = {} for chunk in stream: for tool_call in chunk.choices[0].delta.tool_calls or []: index = tool_call.index if index not in final_tool_calls: final_tool_calls[index] = tool_call final_tool_calls[index].function.arguments += tool_call.function.arguments return final_tool_calls def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict = {}): gen_conf = self._clean_conf(gen_conf) tools = self.tools if system and history and history[0].get("role") != "system": history.insert(0, {"role": "system", "content": system}) total_tokens = 0 hist = deepcopy(history) # Implement exponential backoff retry strategy for attempt in range(self.max_retries + 1): history = deepcopy(hist) # deepcopy is required here try: for _ in range(self.max_rounds + 1): reasoning_start = False logging.info(f"{tools=}") completion_args = self._construct_completion_args(history=history, stream=True, tools=True, **gen_conf) response = litellm.completion( **completion_args, drop_params=True, timeout=self.timeout, ) final_tool_calls = {} answer = "" for resp in response: if not hasattr(resp, "choices") or not resp.choices: continue delta = resp.choices[0].delta if hasattr(delta, "tool_calls") and delta.tool_calls: for tool_call in delta.tool_calls: index = tool_call.index if index not in final_tool_calls: if not tool_call.function.arguments: tool_call.function.arguments = "" final_tool_calls[index] = tool_call else: final_tool_calls[index].function.arguments += tool_call.function.arguments or "" continue if not hasattr(delta, "content") or delta.content is None: delta.content = "" if hasattr(delta, "reasoning_content") and delta.reasoning_content: ans = "" if not reasoning_start: reasoning_start = True ans = "" ans += delta.reasoning_content + "" yield ans else: reasoning_start = False answer += delta.content yield delta.content tol = total_token_count_from_response(resp) if not tol: total_tokens += num_tokens_from_string(delta.content) else: total_tokens += tol finish_reason = getattr(resp.choices[0], "finish_reason", "") if finish_reason == "length": yield self._length_stop("") if answer: yield total_tokens return for tool_call in final_tool_calls.values(): name = tool_call.function.name try: args = json_repair.loads(tool_call.function.arguments) yield self._verbose_tool_use(name, args, "Begin to call...") tool_response = self.toolcall_session.tool_call(name, args) history = self._append_history(history, tool_call, tool_response) yield self._verbose_tool_use(name, args, tool_response) except Exception as e: logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}") history.append( { "role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n{str(e)}", } ) yield self._verbose_tool_use(name, {}, str(e)) logging.warning(f"Exceed max rounds: {self.max_rounds}") history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"}) completion_args = self._construct_completion_args(history=history, stream=True, tools=True, **gen_conf) response = litellm.completion( **completion_args, drop_params=True, timeout=self.timeout, ) for resp in response: if not hasattr(resp, "choices") or not resp.choices: continue delta = resp.choices[0].delta if not hasattr(delta, "content") or delta.content is None: continue tol = total_token_count_from_response(resp) if not tol: total_tokens += num_tokens_from_string(delta.content) else: total_tokens += tol yield delta.content yield total_tokens return except Exception as e: e = self._exceptions(e, attempt) if e: yield e yield total_tokens return assert False, "Shouldn't be here." def chat_streamly(self, system, history, gen_conf: dict = {}, **kwargs): if system and history and history[0].get("role") != "system": history.insert(0, {"role": "system", "content": system}) gen_conf = self._clean_conf(gen_conf) ans = "" total_tokens = 0 try: for delta_ans, tol in self._chat_streamly(history, gen_conf, **kwargs): yield delta_ans total_tokens += tol except openai.APIError as e: yield ans + "\n**ERROR**: " + str(e) yield total_tokens def _calculate_dynamic_ctx(self, history): """Calculate dynamic context window size""" def count_tokens(text): """Calculate token count for text""" # Simple calculation: 1 token per ASCII character # 2 tokens for non-ASCII characters (Chinese, Japanese, Korean, etc.) total = 0 for char in text: if ord(char) < 128: # ASCII characters total += 1 else: # Non-ASCII characters (Chinese, Japanese, Korean, etc.) total += 2 return total # Calculate total tokens for all messages total_tokens = 0 for message in history: content = message.get("content", "") # Calculate content tokens content_tokens = count_tokens(content) # Add role marker token overhead role_tokens = 4 total_tokens += content_tokens + role_tokens # Apply 1.2x buffer ratio total_tokens_with_buffer = int(total_tokens * 1.2) if total_tokens_with_buffer <= 8192: ctx_size = 8192 else: ctx_multiplier = (total_tokens_with_buffer // 8192) + 1 ctx_size = ctx_multiplier * 8192 return ctx_size