From 35539092d04ff1373038a64e852c24404eb1da95 Mon Sep 17 00:00:00 2001 From: so95 Date: Thu, 7 Aug 2025 08:45:37 +0700 Subject: [PATCH] Add **kwargs to model base class constructors (#9252) Updated constructors for base and derived classes in chat, embedding, rerank, sequence2txt, and tts models to accept **kwargs. This change improves extensibility and allows passing additional parameters without breaking existing interfaces. - [x] Bug Fix (non-breaking change which fixes an issue) --------- Co-authored-by: IT: Sop.Son Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- rag/llm/chat_model.py | 4 ++-- rag/llm/embedding_model.py | 11 ++++++++--- rag/llm/rerank_model.py | 10 +++++++--- rag/llm/sequence2txt_model.py | 6 +++++- rag/llm/tts_model.py | 6 +++++- 5 files changed, 27 insertions(+), 10 deletions(-) diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index 698264fbf..49c535d81 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -1216,11 +1216,11 @@ class LmStudioChat(Base): class OpenAI_APIChat(Base): _FACTORY_NAME = ["VLLM", "OpenAI-API-Compatible"] - def __init__(self, key, model_name, base_url): + def __init__(self, key, model_name, base_url, **kwargs): if not base_url: raise ValueError("url cannot be None") model_name = model_name.split("___")[0] - super().__init__(key, model_name, base_url) + super().__init__(key, model_name, base_url, **kwargs) class PPIOChat(Base): diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index 2a270e82c..272e0ff0b 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -37,7 +37,12 @@ from rag.utils import num_tokens_from_string, truncate class Base(ABC): - def __init__(self, key, model_name): + def __init__(self, key, model_name, **kwargs): + """ + Constructor for abstract base class. + Parameters are accepted for interface consistency but are not stored. + Subclasses should implement their own initialization as needed. + """ pass def encode(self, texts: list): @@ -864,7 +869,7 @@ class VoyageEmbed(Base): class HuggingFaceEmbed(Base): _FACTORY_NAME = "HuggingFace" - def __init__(self, key, model_name, base_url=None): + def __init__(self, key, model_name, base_url=None, **kwargs): if not model_name: raise ValueError("Model name cannot be None") self.key = key @@ -946,4 +951,4 @@ class Ai302Embed(Base): def __init__(self, key, model_name, base_url="https://api.302.ai/v1/embeddings"): if not base_url: base_url = "https://api.302.ai/v1/embeddings" - super().__init__(key, model_name, base_url) + super().__init__(key, model_name, base_url) \ No newline at end of file diff --git a/rag/llm/rerank_model.py b/rag/llm/rerank_model.py index a1fa466d8..3111a30d6 100644 --- a/rag/llm/rerank_model.py +++ b/rag/llm/rerank_model.py @@ -33,7 +33,11 @@ from api.utils.log_utils import log_exception from rag.utils import num_tokens_from_string, truncate class Base(ABC): - def __init__(self, key, model_name): + def __init__(self, key, model_name, **kwargs): + """ + Abstract base class constructor. + Parameters are not stored; initialization is left to subclasses. + """ pass def similarity(self, query: str, texts: list): @@ -315,7 +319,7 @@ class NvidiaRerank(Base): class LmStudioRerank(Base): _FACTORY_NAME = "LM-Studio" - def __init__(self, key, model_name, base_url): + def __init__(self, key, model_name, base_url, **kwargs): pass def similarity(self, query: str, texts: list): @@ -396,7 +400,7 @@ class CoHereRerank(Base): class TogetherAIRerank(Base): _FACTORY_NAME = "TogetherAI" - def __init__(self, key, model_name, base_url): + def __init__(self, key, model_name, base_url, **kwargs): pass def similarity(self, query: str, texts: list): diff --git a/rag/llm/sequence2txt_model.py b/rag/llm/sequence2txt_model.py index 3a26b8835..27b83425f 100644 --- a/rag/llm/sequence2txt_model.py +++ b/rag/llm/sequence2txt_model.py @@ -28,7 +28,11 @@ from rag.utils import num_tokens_from_string class Base(ABC): - def __init__(self, key, model_name): + def __init__(self, key, model_name, **kwargs): + """ + Abstract base class constructor. + Parameters are not stored; initialization is left to subclasses. + """ pass def transcription(self, audio, **kwargs): diff --git a/rag/llm/tts_model.py b/rag/llm/tts_model.py index 2e944d6e9..9520cbbbf 100644 --- a/rag/llm/tts_model.py +++ b/rag/llm/tts_model.py @@ -63,7 +63,11 @@ class ServeTTSRequest(BaseModel): class Base(ABC): - def __init__(self, key, model_name, base_url): + def __init__(self, key, model_name, base_url, **kwargs): + """ + Abstract base class constructor. + Parameters are not stored; subclasses should handle their own initialization. + """ pass def tts(self, audio):