From 244d8a47b9c60909b6f991d09963e2024fba26df Mon Sep 17 00:00:00 2001 From: Liu An Date: Mon, 23 Jun 2025 15:59:25 +0800 Subject: [PATCH] Fix: AzureChat model code (#8426) ### What problem does this PR solve? - Simplify AzureChat constructor by passing base_url directly - Clean up spacing and formatting in chat_model.py - Remove redundant parentheses and improve code consistency - #8423 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- rag/llm/chat_model.py | 30 ++++++++++++------------------ 1 file changed, 12 insertions(+), 18 deletions(-) diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index fbef34781..b1235ba62 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -157,9 +157,9 @@ class Base(ABC): tk_count = 0 hist = deepcopy(history) # Implement exponential backoff retry strategy - for attempt in range(self.max_retries+1): + for attempt in range(self.max_retries + 1): history = hist - for _ in range(self.max_rounds*2): + for _ in range(self.max_rounds * 2): try: response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=self.tools, **gen_conf) tk_count += self.total_token_count(response) @@ -185,7 +185,6 @@ class Base(ABC): except Exception as e: history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)}) - except Exception as e: e = self._exceptions(e, attempt) if e: @@ -198,7 +197,7 @@ class Base(ABC): gen_conf = self._clean_conf(gen_conf) # Implement exponential backoff retry strategy - for attempt in range(self.max_retries+1): + for attempt in range(self.max_retries + 1): try: return self._chat(history, gen_conf) except Exception as e: @@ -232,9 +231,9 @@ class Base(ABC): total_tokens = 0 hist = deepcopy(history) # Implement exponential backoff retry strategy - for attempt in range(self.max_retries+1): + for attempt in range(self.max_retries + 1): history = hist - for _ in range(self.max_rounds*2): + for _ in range(self.max_rounds * 2): reasoning_start = False try: response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, **gen_conf) @@ -453,11 +452,11 @@ class DeepSeekChat(Base): class AzureChat(Base): - def __init__(self, key, model_name, **kwargs): + def __init__(self, key, model_name, base_url, **kwargs): api_key = json.loads(key).get("api_key", "") api_version = json.loads(key).get("api_version", "2024-02-01") - super().__init__(key, model_name, kwargs["base_url"], **kwargs) - self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version) + super().__init__(key, model_name, base_url, **kwargs) + self.client = AzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version) self.model_name = model_name @@ -925,10 +924,10 @@ class LocalAIChat(Base): class LocalLLM(Base): - def __init__(self, key, model_name, base_url=None, **kwargs): super().__init__(key, model_name, base_url=base_url, **kwargs) from jina import Client + self.client = Client(port=12345, protocol="grpc", asyncio=True) def _prepare_prompt(self, system, history, gen_conf): @@ -985,13 +984,7 @@ class VolcEngineChat(Base): class MiniMaxChat(Base): - def __init__( - self, - key, - model_name, - base_url="https://api.minimax.chat/v1/text/chatcompletion_v2", - **kwargs - ): + def __init__(self, key, model_name, base_url="https://api.minimax.chat/v1/text/chatcompletion_v2", **kwargs): super().__init__(key, model_name, base_url=base_url, **kwargs) if not base_url: @@ -1223,6 +1216,7 @@ class GeminiChat(Base): def _chat(self, history, gen_conf): from google.generativeai.types import content_types + system = history[0]["content"] if history and history[0]["role"] == "system" else "" hist = [] for item in history: @@ -1880,4 +1874,4 @@ class GPUStackChat(Base): if not base_url: raise ValueError("Local llm url cannot be None") base_url = urljoin(base_url, "v1") - super().__init__(key, model_name, base_url, **kwargs) \ No newline at end of file + super().__init__(key, model_name, base_url, **kwargs)