From d94386e00a9cb3a9e5d3f66e3947ed0da2a9105b Mon Sep 17 00:00:00 2001 From: Zhichang Yu Date: Fri, 29 Nov 2024 14:52:27 +0800 Subject: [PATCH] Pass top_p to ollama (#3744) ### What problem does this PR solve? Pass top_p to ollama. Close #1769 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- rag/llm/chat_model.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index 9dea59a72..90786c58f 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -356,7 +356,7 @@ class OllamaChat(Base): options = {} if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"] if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"] - if "top_p" in gen_conf: options["top_k"] = gen_conf["top_p"] + if "top_p" in gen_conf: options["top_p"] = gen_conf["top_p"] if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"] if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"] response = self.client.chat( @@ -376,7 +376,7 @@ class OllamaChat(Base): options = {} if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"] if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"] - if "top_p" in gen_conf: options["top_k"] = gen_conf["top_p"] + if "top_p" in gen_conf: options["top_p"] = gen_conf["top_p"] if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"] if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"] ans = "" @@ -430,7 +430,7 @@ class LocalLLM(Base): try: self._connection.send(pickle.dumps((name, args, kwargs))) return pickle.loads(self._connection.recv()) - except Exception as e: + except Exception: self.__conn() raise Exception("RPC connection lost!") @@ -442,7 +442,7 @@ class LocalLLM(Base): self.client = Client(port=12345, protocol="grpc", asyncio=True) def _prepare_prompt(self, system, history, gen_conf): - from rag.svr.jina_server import Prompt, Generation + from rag.svr.jina_server import Prompt if system: history.insert(0, {"role": "system", "content": system}) if "max_tokens" in gen_conf: @@ -450,7 +450,7 @@ class LocalLLM(Base): return Prompt(message=history, gen_conf=gen_conf) def _stream_response(self, endpoint, prompt): - from rag.svr.jina_server import Prompt, Generation + from rag.svr.jina_server import Generation answer = "" try: res = self.client.stream_doc(