mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-02-05 18:15:06 +08:00
Feat: support doubao-embedding-vision model (#12983)
### What problem does this PR solve? Add support `doubao-embedding-vision` model. `doubao-embedding-large-text` is deprecated. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -379,7 +379,7 @@ class JinaMultiVecEmbed(Base):
|
||||
data = {"model": self.model_name, "input": input[i : i + batch_size]}
|
||||
if "v4" in self.model_name:
|
||||
data["return_multivector"] = True
|
||||
|
||||
|
||||
if "v3" in self.model_name or "v4" in self.model_name:
|
||||
data['task'] = task
|
||||
data['truncate'] = True
|
||||
@ -391,7 +391,7 @@ class JinaMultiVecEmbed(Base):
|
||||
if data.get("return_multivector", False): # v4
|
||||
token_embs = np.asarray(d['embeddings'], dtype=np.float32)
|
||||
chunk_emb = token_embs.mean(axis=0)
|
||||
|
||||
|
||||
else:
|
||||
# v2/v3
|
||||
chunk_emb = np.asarray(d['embedding'], dtype=np.float32)
|
||||
@ -481,7 +481,7 @@ class BedrockEmbed(Base):
|
||||
self.model_name = model_name
|
||||
self.is_amazon = self.model_name.split(".")[0] == "amazon"
|
||||
self.is_cohere = self.model_name.split(".")[0] == "cohere"
|
||||
|
||||
|
||||
if mode == "access_key_secret":
|
||||
self.bedrock_ak = key.get("bedrock_ak")
|
||||
self.bedrock_sk = key.get("bedrock_sk")
|
||||
@ -885,15 +885,70 @@ class HuggingFaceEmbed(Base):
|
||||
raise Exception(f"Error: {response.status_code} - {response.text}")
|
||||
|
||||
|
||||
class VolcEngineEmbed(OpenAIEmbed):
|
||||
class VolcEngineEmbed(Base):
|
||||
_FACTORY_NAME = "VolcEngine"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://ark.cn-beijing.volces.com/api/v3"):
|
||||
if not base_url:
|
||||
base_url = "https://ark.cn-beijing.volces.com/api/v3"
|
||||
ark_api_key = json.loads(key).get("ark_api_key", "")
|
||||
model_name = json.loads(key).get("ep_id", "") + json.loads(key).get("endpoint_id", "")
|
||||
super().__init__(ark_api_key, model_name, base_url)
|
||||
self.base_url = base_url
|
||||
|
||||
cfg = json.loads(key)
|
||||
self.ark_api_key = cfg.get("ark_api_key", "")
|
||||
self.model_name = model_name
|
||||
|
||||
@staticmethod
|
||||
def _extract_embedding(result: dict) -> list[float]:
|
||||
if not isinstance(result, dict):
|
||||
raise TypeError(f"Unexpected response type: {type(result)}")
|
||||
|
||||
data = result.get("data")
|
||||
if data is None:
|
||||
raise KeyError("Missing 'data' in response")
|
||||
|
||||
if isinstance(data, list):
|
||||
if not data:
|
||||
raise ValueError("Empty 'data' in response")
|
||||
item = data[0]
|
||||
elif isinstance(data, dict):
|
||||
item = data
|
||||
else:
|
||||
raise TypeError(f"Unexpected 'data' type: {type(data)}")
|
||||
|
||||
if not isinstance(item, dict):
|
||||
raise TypeError("Unexpected item shape in 'data'")
|
||||
if "embedding" not in item:
|
||||
raise KeyError("Missing 'embedding' in response item")
|
||||
return item["embedding"]
|
||||
|
||||
def _encode_texts(self, texts: list[str]):
|
||||
from common.http_client import sync_request
|
||||
|
||||
url = f"{self.base_url}/embeddings/multimodal"
|
||||
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.ark_api_key}"}
|
||||
|
||||
ress: list[list[float]] = []
|
||||
total_tokens = 0
|
||||
for text in texts:
|
||||
request_body = {"model": self.model_name, "input": [{"type": "text", "text": text}]}
|
||||
response = sync_request(method="POST", url=url, headers=headers, json=request_body, timeout=60)
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Error: {response.status_code} - {response.text}")
|
||||
result = response.json()
|
||||
try:
|
||||
ress.append(self._extract_embedding(result))
|
||||
total_tokens += total_token_count_from_response(result)
|
||||
except Exception as _e:
|
||||
log_exception(_e)
|
||||
|
||||
return np.array(ress), total_tokens
|
||||
|
||||
def encode(self, texts: list):
|
||||
return self._encode_texts(texts)
|
||||
|
||||
def encode_queries(self, text: str):
|
||||
embeddings, tokens = self._encode_texts([text])
|
||||
return embeddings[0], tokens
|
||||
|
||||
|
||||
class GPUStackEmbed(OpenAIEmbed):
|
||||
|
||||
Reference in New Issue
Block a user