mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Refactor: Remove Useless split for BedrockEmbed (#9067)
### What problem does this PR solve? Remove Useless split for BedrockEmbed ### Type of change - [x] Refactoring
This commit is contained in:
@ -291,7 +291,7 @@ class OllamaEmbed(Base):
|
|||||||
arr = []
|
arr = []
|
||||||
tks_num = 0
|
tks_num = 0
|
||||||
for txt in texts:
|
for txt in texts:
|
||||||
# remove special tokens if they exist
|
# remove special tokens if they exist base on regex in one request
|
||||||
for token in OllamaEmbed._special_tokens:
|
for token in OllamaEmbed._special_tokens:
|
||||||
txt = txt.replace(token, "")
|
txt = txt.replace(token, "")
|
||||||
res = self.client.embeddings(prompt=txt, model=self.model_name, options={"use_mmap": True}, keep_alive=self.keep_alive)
|
res = self.client.embeddings(prompt=txt, model=self.model_name, options={"use_mmap": True}, keep_alive=self.keep_alive)
|
||||||
@ -487,6 +487,8 @@ class BedrockEmbed(Base):
|
|||||||
self.bedrock_sk = json.loads(key).get("bedrock_sk", "")
|
self.bedrock_sk = json.loads(key).get("bedrock_sk", "")
|
||||||
self.bedrock_region = json.loads(key).get("bedrock_region", "")
|
self.bedrock_region = json.loads(key).get("bedrock_region", "")
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
|
self.is_amazon = self.model_name.split(".")[0] == "amazon"
|
||||||
|
self.is_cohere = self.model_name.split(".")[0] == "cohere"
|
||||||
|
|
||||||
if self.bedrock_ak == "" or self.bedrock_sk == "" or self.bedrock_region == "":
|
if self.bedrock_ak == "" or self.bedrock_sk == "" or self.bedrock_region == "":
|
||||||
# Try to create a client using the default credentials (AWS_PROFILE, AWS_DEFAULT_REGION, etc.)
|
# Try to create a client using the default credentials (AWS_PROFILE, AWS_DEFAULT_REGION, etc.)
|
||||||
@ -499,9 +501,9 @@ class BedrockEmbed(Base):
|
|||||||
embeddings = []
|
embeddings = []
|
||||||
token_count = 0
|
token_count = 0
|
||||||
for text in texts:
|
for text in texts:
|
||||||
if self.model_name.split(".")[0] == "amazon":
|
if self.is_amazon:
|
||||||
body = {"inputText": text}
|
body = {"inputText": text}
|
||||||
elif self.model_name.split(".")[0] == "cohere":
|
elif self.is_cohere:
|
||||||
body = {"texts": [text], "input_type": "search_document"}
|
body = {"texts": [text], "input_type": "search_document"}
|
||||||
|
|
||||||
response = self.client.invoke_model(modelId=self.model_name, body=json.dumps(body))
|
response = self.client.invoke_model(modelId=self.model_name, body=json.dumps(body))
|
||||||
@ -517,9 +519,9 @@ class BedrockEmbed(Base):
|
|||||||
def encode_queries(self, text):
|
def encode_queries(self, text):
|
||||||
embeddings = []
|
embeddings = []
|
||||||
token_count = num_tokens_from_string(text)
|
token_count = num_tokens_from_string(text)
|
||||||
if self.model_name.split(".")[0] == "amazon":
|
if self.is_amazon:
|
||||||
body = {"inputText": truncate(text, 8196)}
|
body = {"inputText": truncate(text, 8196)}
|
||||||
elif self.model_name.split(".")[0] == "cohere":
|
elif self.is_cohere:
|
||||||
body = {"texts": [truncate(text, 8196)], "input_type": "search_query"}
|
body = {"texts": [truncate(text, 8196)], "input_type": "search_query"}
|
||||||
|
|
||||||
response = self.client.invoke_model(modelId=self.model_name, body=json.dumps(body))
|
response = self.client.invoke_model(modelId=self.model_name, body=json.dumps(body))
|
||||||
|
|||||||
Reference in New Issue
Block a user