diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index 698264fbf..49c535d81 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -1216,11 +1216,11 @@ class LmStudioChat(Base): class OpenAI_APIChat(Base): _FACTORY_NAME = ["VLLM", "OpenAI-API-Compatible"] - def __init__(self, key, model_name, base_url): + def __init__(self, key, model_name, base_url, **kwargs): if not base_url: raise ValueError("url cannot be None") model_name = model_name.split("___")[0] - super().__init__(key, model_name, base_url) + super().__init__(key, model_name, base_url, **kwargs) class PPIOChat(Base): diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index 2a270e82c..272e0ff0b 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -37,7 +37,12 @@ from rag.utils import num_tokens_from_string, truncate class Base(ABC): - def __init__(self, key, model_name): + def __init__(self, key, model_name, **kwargs): + """ + Constructor for abstract base class. + Parameters are accepted for interface consistency but are not stored. + Subclasses should implement their own initialization as needed. + """ pass def encode(self, texts: list): @@ -864,7 +869,7 @@ class VoyageEmbed(Base): class HuggingFaceEmbed(Base): _FACTORY_NAME = "HuggingFace" - def __init__(self, key, model_name, base_url=None): + def __init__(self, key, model_name, base_url=None, **kwargs): if not model_name: raise ValueError("Model name cannot be None") self.key = key @@ -946,4 +951,4 @@ class Ai302Embed(Base): def __init__(self, key, model_name, base_url="https://api.302.ai/v1/embeddings"): if not base_url: base_url = "https://api.302.ai/v1/embeddings" - super().__init__(key, model_name, base_url) + super().__init__(key, model_name, base_url) \ No newline at end of file diff --git a/rag/llm/rerank_model.py b/rag/llm/rerank_model.py index a1fa466d8..3111a30d6 100644 --- a/rag/llm/rerank_model.py +++ b/rag/llm/rerank_model.py @@ -33,7 +33,11 @@ from api.utils.log_utils import log_exception from rag.utils import num_tokens_from_string, truncate class Base(ABC): - def __init__(self, key, model_name): + def __init__(self, key, model_name, **kwargs): + """ + Abstract base class constructor. + Parameters are not stored; initialization is left to subclasses. + """ pass def similarity(self, query: str, texts: list): @@ -315,7 +319,7 @@ class NvidiaRerank(Base): class LmStudioRerank(Base): _FACTORY_NAME = "LM-Studio" - def __init__(self, key, model_name, base_url): + def __init__(self, key, model_name, base_url, **kwargs): pass def similarity(self, query: str, texts: list): @@ -396,7 +400,7 @@ class CoHereRerank(Base): class TogetherAIRerank(Base): _FACTORY_NAME = "TogetherAI" - def __init__(self, key, model_name, base_url): + def __init__(self, key, model_name, base_url, **kwargs): pass def similarity(self, query: str, texts: list): diff --git a/rag/llm/sequence2txt_model.py b/rag/llm/sequence2txt_model.py index 3a26b8835..27b83425f 100644 --- a/rag/llm/sequence2txt_model.py +++ b/rag/llm/sequence2txt_model.py @@ -28,7 +28,11 @@ from rag.utils import num_tokens_from_string class Base(ABC): - def __init__(self, key, model_name): + def __init__(self, key, model_name, **kwargs): + """ + Abstract base class constructor. + Parameters are not stored; initialization is left to subclasses. + """ pass def transcription(self, audio, **kwargs): diff --git a/rag/llm/tts_model.py b/rag/llm/tts_model.py index 2e944d6e9..9520cbbbf 100644 --- a/rag/llm/tts_model.py +++ b/rag/llm/tts_model.py @@ -63,7 +63,11 @@ class ServeTTSRequest(BaseModel): class Base(ABC): - def __init__(self, key, model_name, base_url): + def __init__(self, key, model_name, base_url, **kwargs): + """ + Abstract base class constructor. + Parameters are not stored; subclasses should handle their own initialization. + """ pass def tts(self, audio):