From a0c2da121980dd6738b9af15b9e82c656aebe36e Mon Sep 17 00:00:00 2001 From: Yongteng Lei Date: Tue, 12 Aug 2025 15:54:30 +0800 Subject: [PATCH] Fix: Patch LiteLLM (#9416) ### What problem does this PR solve? Patch LiteLLM refactor. #9408 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- rag/llm/chat_model.py | 30 ++++++++++++++++++++---------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index c96afa12f..1f14cbf72 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -1455,7 +1455,7 @@ class LiteLLMBase(ABC): if self.model_name.lower().find("qwen3") >= 0: kwargs["extra_body"] = {"enable_thinking": False} - completion_args = self._construct_completion_args(history=history, **gen_conf) + completion_args = self._construct_completion_args(history=history, stream=False, tools=False, **gen_conf) response = litellm.completion( **completion_args, drop_params=True, @@ -1475,7 +1475,7 @@ class LiteLLMBase(ABC): logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4)) reasoning_start = False - completion_args = self._construct_completion_args(history=history, **gen_conf) + completion_args = self._construct_completion_args(history=history, stream=True, tools=False, **gen_conf) stop = kwargs.get("stop") if stop: completion_args["stop"] = stop @@ -1571,17 +1571,27 @@ class LiteLLMBase(ABC): self.toolcall_session = toolcall_session self.tools = tools - def _construct_completion_args(self, history, **kwargs): + def _construct_completion_args(self, history, stream: bool, tools: bool, **kwargs): completion_args = { "model": self.model_name, "messages": history, - "stream": False, - "tools": self.tools, - "tool_choice": "auto", "api_key": self.api_key, **kwargs, } - if self.provider in SupportedLiteLLMProvider: + if stream: + completion_args.update( + { + "stream": stream, + } + ) + if tools and self.tools: + completion_args.update( + { + "tools": self.tools, + "tool_choice": "auto", + } + ) + if self.provider in FACTORY_DEFAULT_BASE_URL: completion_args.update({"api_base": self.base_url}) elif self.provider == SupportedLiteLLMProvider.Bedrock: completion_args.pop("api_key", None) @@ -1611,7 +1621,7 @@ class LiteLLMBase(ABC): for _ in range(self.max_rounds + 1): logging.info(f"{self.tools=}") - completion_args = self._construct_completion_args(history=history, **gen_conf) + completion_args = self._construct_completion_args(history=history, stream=False, tools=True, **gen_conf) response = litellm.completion( **completion_args, drop_params=True, @@ -1708,7 +1718,7 @@ class LiteLLMBase(ABC): reasoning_start = False logging.info(f"{tools=}") - completion_args = self._construct_completion_args(history=history, **gen_conf) + completion_args = self._construct_completion_args(history=history, stream=True, tools=True, **gen_conf) response = litellm.completion( **completion_args, drop_params=True, @@ -1786,7 +1796,7 @@ class LiteLLMBase(ABC): logging.warning(f"Exceed max rounds: {self.max_rounds}") history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"}) - completion_args = self._construct_completion_args(history=history, **gen_conf) + completion_args = self._construct_completion_args(history=history, stream=True, tools=True, **gen_conf) response = litellm.completion( **completion_args, drop_params=True,