diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index 68aaf0d00..c966be2f7 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -641,6 +641,10 @@ class ZhipuChat(Base): def _clean_conf(self, gen_conf): if "max_tokens" in gen_conf: del gen_conf["max_tokens"] + gen_conf = self._clean_conf_plealty(gen_conf) + return gen_conf + + def _clean_conf_plealty(self, gen_conf): if "presence_penalty" in gen_conf: del gen_conf["presence_penalty"] if "frequency_penalty" in gen_conf: @@ -648,22 +652,14 @@ class ZhipuChat(Base): return gen_conf def chat_with_tools(self, system: str, history: list, gen_conf: dict): - if "presence_penalty" in gen_conf: - del gen_conf["presence_penalty"] - if "frequency_penalty" in gen_conf: - del gen_conf["frequency_penalty"] + gen_conf = self._clean_conf_plealty(gen_conf) return super().chat_with_tools(system, history, gen_conf) def chat_streamly(self, system, history, gen_conf={}, **kwargs): if system and history and history[0].get("role") != "system": history.insert(0, {"role": "system", "content": system}) - if "max_tokens" in gen_conf: - del gen_conf["max_tokens"] - if "presence_penalty" in gen_conf: - del gen_conf["presence_penalty"] - if "frequency_penalty" in gen_conf: - del gen_conf["frequency_penalty"] + gen_conf = self._clean_conf(gen_conf) ans = "" tk_count = 0 try: @@ -689,11 +685,7 @@ class ZhipuChat(Base): yield tk_count def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict): - if "presence_penalty" in gen_conf: - del gen_conf["presence_penalty"] - if "frequency_penalty" in gen_conf: - del gen_conf["frequency_penalty"] - + gen_conf = self._clean_conf_plealty(gen_conf) return super().chat_streamly_with_tools(system, history, gen_conf)