Fix: Patch LiteLLM (#9416)

### What problem does this PR solve?

Patch LiteLLM refactor. #9408

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
Yongteng Lei
2025-08-12 15:54:30 +08:00
committed by GitHub
parent 79e2edc835
commit a0c2da1219

View File

@ -1455,7 +1455,7 @@ class LiteLLMBase(ABC):
if self.model_name.lower().find("qwen3") >= 0: if self.model_name.lower().find("qwen3") >= 0:
kwargs["extra_body"] = {"enable_thinking": False} kwargs["extra_body"] = {"enable_thinking": False}
completion_args = self._construct_completion_args(history=history, **gen_conf) completion_args = self._construct_completion_args(history=history, stream=False, tools=False, **gen_conf)
response = litellm.completion( response = litellm.completion(
**completion_args, **completion_args,
drop_params=True, drop_params=True,
@ -1475,7 +1475,7 @@ class LiteLLMBase(ABC):
logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4)) logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4))
reasoning_start = False reasoning_start = False
completion_args = self._construct_completion_args(history=history, **gen_conf) completion_args = self._construct_completion_args(history=history, stream=True, tools=False, **gen_conf)
stop = kwargs.get("stop") stop = kwargs.get("stop")
if stop: if stop:
completion_args["stop"] = stop completion_args["stop"] = stop
@ -1571,17 +1571,27 @@ class LiteLLMBase(ABC):
self.toolcall_session = toolcall_session self.toolcall_session = toolcall_session
self.tools = tools self.tools = tools
def _construct_completion_args(self, history, **kwargs): def _construct_completion_args(self, history, stream: bool, tools: bool, **kwargs):
completion_args = { completion_args = {
"model": self.model_name, "model": self.model_name,
"messages": history, "messages": history,
"stream": False,
"tools": self.tools,
"tool_choice": "auto",
"api_key": self.api_key, "api_key": self.api_key,
**kwargs, **kwargs,
} }
if self.provider in SupportedLiteLLMProvider: if stream:
completion_args.update(
{
"stream": stream,
}
)
if tools and self.tools:
completion_args.update(
{
"tools": self.tools,
"tool_choice": "auto",
}
)
if self.provider in FACTORY_DEFAULT_BASE_URL:
completion_args.update({"api_base": self.base_url}) completion_args.update({"api_base": self.base_url})
elif self.provider == SupportedLiteLLMProvider.Bedrock: elif self.provider == SupportedLiteLLMProvider.Bedrock:
completion_args.pop("api_key", None) completion_args.pop("api_key", None)
@ -1611,7 +1621,7 @@ class LiteLLMBase(ABC):
for _ in range(self.max_rounds + 1): for _ in range(self.max_rounds + 1):
logging.info(f"{self.tools=}") logging.info(f"{self.tools=}")
completion_args = self._construct_completion_args(history=history, **gen_conf) completion_args = self._construct_completion_args(history=history, stream=False, tools=True, **gen_conf)
response = litellm.completion( response = litellm.completion(
**completion_args, **completion_args,
drop_params=True, drop_params=True,
@ -1708,7 +1718,7 @@ class LiteLLMBase(ABC):
reasoning_start = False reasoning_start = False
logging.info(f"{tools=}") logging.info(f"{tools=}")
completion_args = self._construct_completion_args(history=history, **gen_conf) completion_args = self._construct_completion_args(history=history, stream=True, tools=True, **gen_conf)
response = litellm.completion( response = litellm.completion(
**completion_args, **completion_args,
drop_params=True, drop_params=True,
@ -1786,7 +1796,7 @@ class LiteLLMBase(ABC):
logging.warning(f"Exceed max rounds: {self.max_rounds}") logging.warning(f"Exceed max rounds: {self.max_rounds}")
history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"}) history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
completion_args = self._construct_completion_args(history=history, **gen_conf) completion_args = self._construct_completion_args(history=history, stream=True, tools=True, **gen_conf)
response = litellm.completion( response = litellm.completion(
**completion_args, **completion_args,
drop_params=True, drop_params=True,