mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 12:32:30 +08:00
Fix: jina embedding issue (#11628)
### What problem does this PR solve? Fix: jina embedding issue #11614 Feat: Add jina embedding v4 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
@ -1194,6 +1194,12 @@
|
||||
"tags": "TEXT EMBEDDING",
|
||||
"max_tokens": 8196,
|
||||
"model_type": "embedding"
|
||||
},
|
||||
{
|
||||
"llm_name": "jina-embeddings-v4",
|
||||
"tags": "TEXT EMBEDDING",
|
||||
"max_tokens": 32768,
|
||||
"model_type": "embedding"
|
||||
}
|
||||
]
|
||||
},
|
||||
|
||||
@ -349,35 +349,6 @@ class YoudaoEmbed(Base):
|
||||
return np.array(embds[0]), num_tokens_from_string(text)
|
||||
|
||||
|
||||
class JinaEmbed(Base):
|
||||
_FACTORY_NAME = "Jina"
|
||||
|
||||
def __init__(self, key, model_name="jina-embeddings-v3", base_url="https://api.jina.ai/v1/embeddings"):
|
||||
self.base_url = "https://api.jina.ai/v1/embeddings"
|
||||
self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
|
||||
self.model_name = model_name
|
||||
|
||||
def encode(self, texts: list):
|
||||
texts = [truncate(t, 8196) for t in texts]
|
||||
batch_size = 16
|
||||
ress = []
|
||||
token_count = 0
|
||||
for i in range(0, len(texts), batch_size):
|
||||
data = {"model": self.model_name, "input": texts[i : i + batch_size], "encoding_type": "float"}
|
||||
response = requests.post(self.base_url, headers=self.headers, json=data)
|
||||
try:
|
||||
res = response.json()
|
||||
ress.extend([d["embedding"] for d in res["data"]])
|
||||
token_count += self.total_token_count(res)
|
||||
except Exception as _e:
|
||||
log_exception(_e, response)
|
||||
return np.array(ress), token_count
|
||||
|
||||
def encode_queries(self, text):
|
||||
embds, cnt = self.encode([text])
|
||||
return np.array(embds[0]), cnt
|
||||
|
||||
|
||||
class JinaMultiVecEmbed(Base):
|
||||
_FACTORY_NAME = "Jina"
|
||||
|
||||
@ -403,11 +374,28 @@ class JinaMultiVecEmbed(Base):
|
||||
img_b64s = base64.b64encode(text).decode('utf8')
|
||||
input.append({"image": img_b64s}) # base64 encoded image
|
||||
for i in range(0, len(texts), batch_size):
|
||||
data = {"model": self.model_name, "task": task, "truncate": True, "return_multivector": True, "input": input[i : i + batch_size]}
|
||||
data = {"model": self.model_name, "input": input[i : i + batch_size]}
|
||||
if "v4" in self.model_name:
|
||||
data["return_multivector"] = True
|
||||
|
||||
if "v3" in self.model_name or "v4" in self.model_name:
|
||||
data['task'] = task
|
||||
data['truncate'] = True
|
||||
|
||||
response = requests.post(self.base_url, headers=self.headers, json=data)
|
||||
try:
|
||||
res = response.json()
|
||||
ress.extend([d["embeddings"] for d in res["data"]])
|
||||
for d in res['data']:
|
||||
if data.get("return_multivector", False): # v4
|
||||
token_embs = np.asarray(d['embeddings'], dtype=np.float32)
|
||||
chunk_emb = token_embs.mean(axis=0)
|
||||
|
||||
else:
|
||||
# v2/v3
|
||||
chunk_emb = np.asarray(d['embedding'], dtype=np.float32)
|
||||
|
||||
ress.append(chunk_emb)
|
||||
|
||||
token_count += self.total_token_count(res)
|
||||
except Exception as _e:
|
||||
log_exception(_e, response)
|
||||
|
||||
Reference in New Issue
Block a user