diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index e460de6e2..7e8494b96 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -291,7 +291,7 @@ class OllamaEmbed(Base): arr = [] tks_num = 0 for txt in texts: - # remove special tokens if they exist + # remove special tokens if they exist base on regex in one request for token in OllamaEmbed._special_tokens: txt = txt.replace(token, "") res = self.client.embeddings(prompt=txt, model=self.model_name, options={"use_mmap": True}, keep_alive=self.keep_alive) @@ -487,6 +487,8 @@ class BedrockEmbed(Base): self.bedrock_sk = json.loads(key).get("bedrock_sk", "") self.bedrock_region = json.loads(key).get("bedrock_region", "") self.model_name = model_name + self.is_amazon = self.model_name.split(".")[0] == "amazon" + self.is_cohere = self.model_name.split(".")[0] == "cohere" if self.bedrock_ak == "" or self.bedrock_sk == "" or self.bedrock_region == "": # Try to create a client using the default credentials (AWS_PROFILE, AWS_DEFAULT_REGION, etc.) @@ -499,9 +501,9 @@ class BedrockEmbed(Base): embeddings = [] token_count = 0 for text in texts: - if self.model_name.split(".")[0] == "amazon": + if self.is_amazon: body = {"inputText": text} - elif self.model_name.split(".")[0] == "cohere": + elif self.is_cohere: body = {"texts": [text], "input_type": "search_document"} response = self.client.invoke_model(modelId=self.model_name, body=json.dumps(body)) @@ -517,9 +519,9 @@ class BedrockEmbed(Base): def encode_queries(self, text): embeddings = [] token_count = num_tokens_from_string(text) - if self.model_name.split(".")[0] == "amazon": + if self.is_amazon: body = {"inputText": truncate(text, 8196)} - elif self.model_name.split(".")[0] == "cohere": + elif self.is_cohere: body = {"texts": [truncate(text, 8196)], "input_type": "search_query"} response = self.client.invoke_model(modelId=self.model_name, body=json.dumps(body))