mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Don't select vector on infinity (#11151)
### What problem does this PR solve? Don't select vector on infinity ### Type of change - [x] Performance Improvement
This commit is contained in:
@ -31,6 +31,7 @@ from common.log_utils import log_exception
|
||||
from common.token_utils import num_tokens_from_string, truncate
|
||||
from common import settings
|
||||
import logging
|
||||
import base64
|
||||
|
||||
|
||||
class Base(ABC):
|
||||
@ -377,6 +378,46 @@ class JinaEmbed(Base):
|
||||
return np.array(embds[0]), cnt
|
||||
|
||||
|
||||
class JinaMultiVecEmbed(Base):
|
||||
_FACTORY_NAME = "Jina"
|
||||
|
||||
def __init__(self, key, model_name="jina-embeddings-v4", base_url="https://api.jina.ai/v1/embeddings"):
|
||||
self.base_url = "https://api.jina.ai/v1/embeddings"
|
||||
self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
|
||||
self.model_name = model_name
|
||||
|
||||
def encode(self, texts: list[str|bytes], task="retrieval.passage"):
|
||||
batch_size = 16
|
||||
ress = []
|
||||
token_count = 0
|
||||
input = []
|
||||
for text in texts:
|
||||
if isinstance(text, str):
|
||||
input.append({"text": text})
|
||||
elif isinstance(text, bytes):
|
||||
img_b64s = None
|
||||
try:
|
||||
base64.b64decode(text, validate=True)
|
||||
img_b64s = text.decode('utf8')
|
||||
except Exception:
|
||||
img_b64s = base64.b64encode(text).decode('utf8')
|
||||
input.append({"image": img_b64s}) # base64 encoded image
|
||||
for i in range(0, len(texts), batch_size):
|
||||
data = {"model": self.model_name, "task": task, "truncate": True, "return_multivector": True, "input": input[i : i + batch_size]}
|
||||
response = requests.post(self.base_url, headers=self.headers, json=data)
|
||||
try:
|
||||
res = response.json()
|
||||
ress.extend([d["embeddings"] for d in res["data"]])
|
||||
token_count += self.total_token_count(res)
|
||||
except Exception as _e:
|
||||
log_exception(_e, response)
|
||||
return np.array(ress), token_count
|
||||
|
||||
def encode_queries(self, text):
|
||||
embds, cnt = self.encode([text], task="retrieval.query")
|
||||
return np.array(embds[0]), cnt
|
||||
|
||||
|
||||
class MistralEmbed(Base):
|
||||
_FACTORY_NAME = "Mistral"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user