Refactor: Remove Useless split for BedrockEmbed (#9067)

### What problem does this PR solve?

Remove Useless split for BedrockEmbed

### Type of change

- [x] Refactoring
This commit is contained in:
Stephen Hu
2025-07-28 10:16:38 +08:00
committed by GitHub
parent 0fccd1fef3
commit 86b4da0844

View File

@ -291,7 +291,7 @@ class OllamaEmbed(Base):
arr = [] arr = []
tks_num = 0 tks_num = 0
for txt in texts: for txt in texts:
# remove special tokens if they exist # remove special tokens if they exist base on regex in one request
for token in OllamaEmbed._special_tokens: for token in OllamaEmbed._special_tokens:
txt = txt.replace(token, "") txt = txt.replace(token, "")
res = self.client.embeddings(prompt=txt, model=self.model_name, options={"use_mmap": True}, keep_alive=self.keep_alive) res = self.client.embeddings(prompt=txt, model=self.model_name, options={"use_mmap": True}, keep_alive=self.keep_alive)
@ -487,6 +487,8 @@ class BedrockEmbed(Base):
self.bedrock_sk = json.loads(key).get("bedrock_sk", "") self.bedrock_sk = json.loads(key).get("bedrock_sk", "")
self.bedrock_region = json.loads(key).get("bedrock_region", "") self.bedrock_region = json.loads(key).get("bedrock_region", "")
self.model_name = model_name self.model_name = model_name
self.is_amazon = self.model_name.split(".")[0] == "amazon"
self.is_cohere = self.model_name.split(".")[0] == "cohere"
if self.bedrock_ak == "" or self.bedrock_sk == "" or self.bedrock_region == "": if self.bedrock_ak == "" or self.bedrock_sk == "" or self.bedrock_region == "":
# Try to create a client using the default credentials (AWS_PROFILE, AWS_DEFAULT_REGION, etc.) # Try to create a client using the default credentials (AWS_PROFILE, AWS_DEFAULT_REGION, etc.)
@ -499,9 +501,9 @@ class BedrockEmbed(Base):
embeddings = [] embeddings = []
token_count = 0 token_count = 0
for text in texts: for text in texts:
if self.model_name.split(".")[0] == "amazon": if self.is_amazon:
body = {"inputText": text} body = {"inputText": text}
elif self.model_name.split(".")[0] == "cohere": elif self.is_cohere:
body = {"texts": [text], "input_type": "search_document"} body = {"texts": [text], "input_type": "search_document"}
response = self.client.invoke_model(modelId=self.model_name, body=json.dumps(body)) response = self.client.invoke_model(modelId=self.model_name, body=json.dumps(body))
@ -517,9 +519,9 @@ class BedrockEmbed(Base):
def encode_queries(self, text): def encode_queries(self, text):
embeddings = [] embeddings = []
token_count = num_tokens_from_string(text) token_count = num_tokens_from_string(text)
if self.model_name.split(".")[0] == "amazon": if self.is_amazon:
body = {"inputText": truncate(text, 8196)} body = {"inputText": truncate(text, 8196)}
elif self.model_name.split(".")[0] == "cohere": elif self.is_cohere:
body = {"texts": [truncate(text, 8196)], "input_type": "search_query"} body = {"texts": [truncate(text, 8196)], "input_type": "search_query"}
response = self.client.invoke_model(modelId=self.model_name, body=json.dumps(body)) response = self.client.invoke_model(modelId=self.model_name, body=json.dumps(body))