diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index 79304d27a..68aaf0d00 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -870,6 +870,7 @@ class MistralChat(Base): return gen_conf def _chat(self, history, gen_conf={}, **kwargs): + gen_conf = self._clean_conf(gen_conf) response = self.client.chat(model=self.model_name, messages=history, **gen_conf) ans = response.choices[0].message.content if response.choices[0].finish_reason == "length": @@ -882,9 +883,7 @@ class MistralChat(Base): def chat_streamly(self, system, history, gen_conf={}, **kwargs): if system and history and history[0].get("role") != "system": history.insert(0, {"role": "system", "content": system}) - for k in list(gen_conf.keys()): - if k not in ["temperature", "top_p", "max_tokens"]: - del gen_conf[k] + gen_conf = self._clean_conf(gen_conf) ans = "" total_tokens = 0 try: