diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index 3c6ea4d09..55faa6b32 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -1217,13 +1217,12 @@ class GeminiChat(Base): def _clean_conf(self, gen_conf): for k in list(gen_conf.keys()): - if k not in ["temperature", "top_p"]: + if k not in ["temperature", "top_p", "max_tokens"]: del gen_conf[k] return gen_conf def _chat(self, history, gen_conf): from google.generativeai.types import content_types - system = history[0]["content"] if history and history[0]["role"] == "system" else "" hist = [] for item in history: @@ -1246,12 +1245,9 @@ class GeminiChat(Base): def chat_streamly(self, system, history, gen_conf): from google.generativeai.types import content_types - + gen_conf = self._clean_conf(gen_conf) if system: self.model._system_instruction = content_types.to_content(system) - for k in list(gen_conf.keys()): - if k not in ["temperature", "top_p", "max_tokens"]: - del gen_conf[k] for item in history: if "role" in item and item["role"] == "assistant": item["role"] = "model"