diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index 91507d147..d028e75e3 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -1188,8 +1188,36 @@ class GoogleChat(Base): del gen_conf[k] return gen_conf + def _get_thinking_config(self, gen_conf): + """Extract and create ThinkingConfig from gen_conf. + + Default behavior for Vertex AI Generative Models: thinking_budget=0 (disabled) + unless explicitly specified by the user. This does not apply to Claude models. + + Users can override by setting thinking_budget in gen_conf/llm_setting: + - 0: Disabled (default) + - 1-24576: Manual budget + - -1: Auto (model decides) + """ + # Claude models don't support ThinkingConfig + if "claude" in self.model_name: + gen_conf.pop("thinking_budget", None) + return None + + # For Vertex AI Generative Models, default to thinking disabled + thinking_budget = gen_conf.pop("thinking_budget", 0) + + if thinking_budget is not None: + try: + import vertexai.generative_models as glm # type: ignore + return glm.ThinkingConfig(thinking_budget=thinking_budget) + except Exception: + pass + return None + def _chat(self, history, gen_conf={}, **kwargs): system = history[0]["content"] if history and history[0]["role"] == "system" else "" + thinking_config = self._get_thinking_config(gen_conf) gen_conf = self._clean_conf(gen_conf) if "claude" in self.model_name: response = self.client.messages.create( @@ -1223,7 +1251,10 @@ class GoogleChat(Base): } ] - response = self.client.generate_content(hist, generation_config=gen_conf) + if thinking_config: + response = self.client.generate_content(hist, generation_config=gen_conf, thinking_config=thinking_config) + else: + response = self.client.generate_content(hist, generation_config=gen_conf) ans = response.text return ans, response.usage_metadata.total_token_count @@ -1255,6 +1286,7 @@ class GoogleChat(Base): response = None total_tokens = 0 self.client._system_instruction = system + thinking_config = self._get_thinking_config(gen_conf) if "max_tokens" in gen_conf: gen_conf["max_output_tokens"] = gen_conf["max_tokens"] del gen_conf["max_tokens"] @@ -1272,7 +1304,10 @@ class GoogleChat(Base): ] ans = "" try: - response = self.client.generate_content(history, generation_config=gen_conf, stream=True) + if thinking_config: + response = self.client.generate_content(history, generation_config=gen_conf, thinking_config=thinking_config, stream=True) + else: + response = self.client.generate_content(history, generation_config=gen_conf, stream=True) for resp in response: ans = resp.text total_tokens += num_tokens_from_string(ans)