diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index c96afa12f..1f14cbf72 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -1455,7 +1455,7 @@ class LiteLLMBase(ABC): if self.model_name.lower().find("qwen3") >= 0: kwargs["extra_body"] = {"enable_thinking": False} - completion_args = self._construct_completion_args(history=history, **gen_conf) + completion_args = self._construct_completion_args(history=history, stream=False, tools=False, **gen_conf) response = litellm.completion( **completion_args, drop_params=True, @@ -1475,7 +1475,7 @@ class LiteLLMBase(ABC): logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4)) reasoning_start = False - completion_args = self._construct_completion_args(history=history, **gen_conf) + completion_args = self._construct_completion_args(history=history, stream=True, tools=False, **gen_conf) stop = kwargs.get("stop") if stop: completion_args["stop"] = stop @@ -1571,17 +1571,27 @@ class LiteLLMBase(ABC): self.toolcall_session = toolcall_session self.tools = tools - def _construct_completion_args(self, history, **kwargs): + def _construct_completion_args(self, history, stream: bool, tools: bool, **kwargs): completion_args = { "model": self.model_name, "messages": history, - "stream": False, - "tools": self.tools, - "tool_choice": "auto", "api_key": self.api_key, **kwargs, } - if self.provider in SupportedLiteLLMProvider: + if stream: + completion_args.update( + { + "stream": stream, + } + ) + if tools and self.tools: + completion_args.update( + { + "tools": self.tools, + "tool_choice": "auto", + } + ) + if self.provider in FACTORY_DEFAULT_BASE_URL: completion_args.update({"api_base": self.base_url}) elif self.provider == SupportedLiteLLMProvider.Bedrock: completion_args.pop("api_key", None) @@ -1611,7 +1621,7 @@ class LiteLLMBase(ABC): for _ in range(self.max_rounds + 1): logging.info(f"{self.tools=}") - completion_args = self._construct_completion_args(history=history, **gen_conf) + completion_args = self._construct_completion_args(history=history, stream=False, tools=True, **gen_conf) response = litellm.completion( **completion_args, drop_params=True, @@ -1708,7 +1718,7 @@ class LiteLLMBase(ABC): reasoning_start = False logging.info(f"{tools=}") - completion_args = self._construct_completion_args(history=history, **gen_conf) + completion_args = self._construct_completion_args(history=history, stream=True, tools=True, **gen_conf) response = litellm.completion( **completion_args, drop_params=True, @@ -1786,7 +1796,7 @@ class LiteLLMBase(ABC): logging.warning(f"Exceed max rounds: {self.max_rounds}") history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"}) - completion_args = self._construct_completion_args(history=history, **gen_conf) + completion_args = self._construct_completion_args(history=history, stream=True, tools=True, **gen_conf) response = litellm.completion( **completion_args, drop_params=True,