mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-02-05 18:15:06 +08:00
Feat: support doubao-embedding-vision model (#12983)
### What problem does this PR solve? Add support `doubao-embedding-vision` model. `doubao-embedding-large-text` is deprecated. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -379,7 +379,7 @@ class JinaMultiVecEmbed(Base):
|
|||||||
data = {"model": self.model_name, "input": input[i : i + batch_size]}
|
data = {"model": self.model_name, "input": input[i : i + batch_size]}
|
||||||
if "v4" in self.model_name:
|
if "v4" in self.model_name:
|
||||||
data["return_multivector"] = True
|
data["return_multivector"] = True
|
||||||
|
|
||||||
if "v3" in self.model_name or "v4" in self.model_name:
|
if "v3" in self.model_name or "v4" in self.model_name:
|
||||||
data['task'] = task
|
data['task'] = task
|
||||||
data['truncate'] = True
|
data['truncate'] = True
|
||||||
@ -391,7 +391,7 @@ class JinaMultiVecEmbed(Base):
|
|||||||
if data.get("return_multivector", False): # v4
|
if data.get("return_multivector", False): # v4
|
||||||
token_embs = np.asarray(d['embeddings'], dtype=np.float32)
|
token_embs = np.asarray(d['embeddings'], dtype=np.float32)
|
||||||
chunk_emb = token_embs.mean(axis=0)
|
chunk_emb = token_embs.mean(axis=0)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# v2/v3
|
# v2/v3
|
||||||
chunk_emb = np.asarray(d['embedding'], dtype=np.float32)
|
chunk_emb = np.asarray(d['embedding'], dtype=np.float32)
|
||||||
@ -481,7 +481,7 @@ class BedrockEmbed(Base):
|
|||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.is_amazon = self.model_name.split(".")[0] == "amazon"
|
self.is_amazon = self.model_name.split(".")[0] == "amazon"
|
||||||
self.is_cohere = self.model_name.split(".")[0] == "cohere"
|
self.is_cohere = self.model_name.split(".")[0] == "cohere"
|
||||||
|
|
||||||
if mode == "access_key_secret":
|
if mode == "access_key_secret":
|
||||||
self.bedrock_ak = key.get("bedrock_ak")
|
self.bedrock_ak = key.get("bedrock_ak")
|
||||||
self.bedrock_sk = key.get("bedrock_sk")
|
self.bedrock_sk = key.get("bedrock_sk")
|
||||||
@ -885,15 +885,70 @@ class HuggingFaceEmbed(Base):
|
|||||||
raise Exception(f"Error: {response.status_code} - {response.text}")
|
raise Exception(f"Error: {response.status_code} - {response.text}")
|
||||||
|
|
||||||
|
|
||||||
class VolcEngineEmbed(OpenAIEmbed):
|
class VolcEngineEmbed(Base):
|
||||||
_FACTORY_NAME = "VolcEngine"
|
_FACTORY_NAME = "VolcEngine"
|
||||||
|
|
||||||
def __init__(self, key, model_name, base_url="https://ark.cn-beijing.volces.com/api/v3"):
|
def __init__(self, key, model_name, base_url="https://ark.cn-beijing.volces.com/api/v3"):
|
||||||
if not base_url:
|
if not base_url:
|
||||||
base_url = "https://ark.cn-beijing.volces.com/api/v3"
|
base_url = "https://ark.cn-beijing.volces.com/api/v3"
|
||||||
ark_api_key = json.loads(key).get("ark_api_key", "")
|
self.base_url = base_url
|
||||||
model_name = json.loads(key).get("ep_id", "") + json.loads(key).get("endpoint_id", "")
|
|
||||||
super().__init__(ark_api_key, model_name, base_url)
|
cfg = json.loads(key)
|
||||||
|
self.ark_api_key = cfg.get("ark_api_key", "")
|
||||||
|
self.model_name = model_name
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_embedding(result: dict) -> list[float]:
|
||||||
|
if not isinstance(result, dict):
|
||||||
|
raise TypeError(f"Unexpected response type: {type(result)}")
|
||||||
|
|
||||||
|
data = result.get("data")
|
||||||
|
if data is None:
|
||||||
|
raise KeyError("Missing 'data' in response")
|
||||||
|
|
||||||
|
if isinstance(data, list):
|
||||||
|
if not data:
|
||||||
|
raise ValueError("Empty 'data' in response")
|
||||||
|
item = data[0]
|
||||||
|
elif isinstance(data, dict):
|
||||||
|
item = data
|
||||||
|
else:
|
||||||
|
raise TypeError(f"Unexpected 'data' type: {type(data)}")
|
||||||
|
|
||||||
|
if not isinstance(item, dict):
|
||||||
|
raise TypeError("Unexpected item shape in 'data'")
|
||||||
|
if "embedding" not in item:
|
||||||
|
raise KeyError("Missing 'embedding' in response item")
|
||||||
|
return item["embedding"]
|
||||||
|
|
||||||
|
def _encode_texts(self, texts: list[str]):
|
||||||
|
from common.http_client import sync_request
|
||||||
|
|
||||||
|
url = f"{self.base_url}/embeddings/multimodal"
|
||||||
|
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.ark_api_key}"}
|
||||||
|
|
||||||
|
ress: list[list[float]] = []
|
||||||
|
total_tokens = 0
|
||||||
|
for text in texts:
|
||||||
|
request_body = {"model": self.model_name, "input": [{"type": "text", "text": text}]}
|
||||||
|
response = sync_request(method="POST", url=url, headers=headers, json=request_body, timeout=60)
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise Exception(f"Error: {response.status_code} - {response.text}")
|
||||||
|
result = response.json()
|
||||||
|
try:
|
||||||
|
ress.append(self._extract_embedding(result))
|
||||||
|
total_tokens += total_token_count_from_response(result)
|
||||||
|
except Exception as _e:
|
||||||
|
log_exception(_e)
|
||||||
|
|
||||||
|
return np.array(ress), total_tokens
|
||||||
|
|
||||||
|
def encode(self, texts: list):
|
||||||
|
return self._encode_texts(texts)
|
||||||
|
|
||||||
|
def encode_queries(self, text: str):
|
||||||
|
embeddings, tokens = self._encode_texts([text])
|
||||||
|
return embeddings[0], tokens
|
||||||
|
|
||||||
|
|
||||||
class GPUStackEmbed(OpenAIEmbed):
|
class GPUStackEmbed(OpenAIEmbed):
|
||||||
|
|||||||
Reference in New Issue
Block a user