diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index 4f37cf482..698264fbf 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -1075,6 +1075,9 @@ class GeminiChat(Base): for k in list(gen_conf.keys()): if k not in ["temperature", "top_p", "max_tokens"]: del gen_conf[k] + # if max_tokens exists, rename it to max_output_tokens to match Gemini's API + if k == "max_tokens": + gen_conf["max_output_tokens"] = gen_conf.pop("max_tokens") return gen_conf def _chat(self, history, gen_conf={}, **kwargs):