mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Refactor: Remove Useless split for BedrockEmbed (#9067)
### What problem does this PR solve? Remove Useless split for BedrockEmbed ### Type of change - [x] Refactoring
This commit is contained in:
@ -291,7 +291,7 @@ class OllamaEmbed(Base):
|
||||
arr = []
|
||||
tks_num = 0
|
||||
for txt in texts:
|
||||
# remove special tokens if they exist
|
||||
# remove special tokens if they exist base on regex in one request
|
||||
for token in OllamaEmbed._special_tokens:
|
||||
txt = txt.replace(token, "")
|
||||
res = self.client.embeddings(prompt=txt, model=self.model_name, options={"use_mmap": True}, keep_alive=self.keep_alive)
|
||||
@ -487,6 +487,8 @@ class BedrockEmbed(Base):
|
||||
self.bedrock_sk = json.loads(key).get("bedrock_sk", "")
|
||||
self.bedrock_region = json.loads(key).get("bedrock_region", "")
|
||||
self.model_name = model_name
|
||||
self.is_amazon = self.model_name.split(".")[0] == "amazon"
|
||||
self.is_cohere = self.model_name.split(".")[0] == "cohere"
|
||||
|
||||
if self.bedrock_ak == "" or self.bedrock_sk == "" or self.bedrock_region == "":
|
||||
# Try to create a client using the default credentials (AWS_PROFILE, AWS_DEFAULT_REGION, etc.)
|
||||
@ -499,9 +501,9 @@ class BedrockEmbed(Base):
|
||||
embeddings = []
|
||||
token_count = 0
|
||||
for text in texts:
|
||||
if self.model_name.split(".")[0] == "amazon":
|
||||
if self.is_amazon:
|
||||
body = {"inputText": text}
|
||||
elif self.model_name.split(".")[0] == "cohere":
|
||||
elif self.is_cohere:
|
||||
body = {"texts": [text], "input_type": "search_document"}
|
||||
|
||||
response = self.client.invoke_model(modelId=self.model_name, body=json.dumps(body))
|
||||
@ -517,9 +519,9 @@ class BedrockEmbed(Base):
|
||||
def encode_queries(self, text):
|
||||
embeddings = []
|
||||
token_count = num_tokens_from_string(text)
|
||||
if self.model_name.split(".")[0] == "amazon":
|
||||
if self.is_amazon:
|
||||
body = {"inputText": truncate(text, 8196)}
|
||||
elif self.model_name.split(".")[0] == "cohere":
|
||||
elif self.is_cohere:
|
||||
body = {"texts": [truncate(text, 8196)], "input_type": "search_query"}
|
||||
|
||||
response = self.client.invoke_model(modelId=self.model_name, body=json.dumps(body))
|
||||
|
||||
Reference in New Issue
Block a user