mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Add bce-embedding and fastembed (#383)
### What problem does this PR solve? Issue link:#326 ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -24,8 +24,8 @@ EmbeddingModel = {
|
||||
"Xinference": XinferenceEmbed,
|
||||
"Tongyi-Qianwen": HuEmbedding, #QWenEmbed,
|
||||
"ZHIPU-AI": ZhipuEmbed,
|
||||
"Moonshot": HuEmbedding,
|
||||
"FastEmbed": FastEmbed
|
||||
"FastEmbed": FastEmbed,
|
||||
"QAnything": QAnythingEmbed
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -20,7 +20,6 @@ from abc import ABC
|
||||
from ollama import Client
|
||||
import dashscope
|
||||
from openai import OpenAI
|
||||
from fastembed import TextEmbedding
|
||||
from FlagEmbedding import FlagModel
|
||||
import torch
|
||||
import numpy as np
|
||||
@ -28,16 +27,17 @@ import numpy as np
|
||||
from api.utils.file_utils import get_project_base_directory
|
||||
from rag.utils import num_tokens_from_string
|
||||
|
||||
|
||||
try:
|
||||
flag_model = FlagModel(os.path.join(
|
||||
get_project_base_directory(),
|
||||
"rag/res/bge-large-zh-v1.5"),
|
||||
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
|
||||
use_fp16=torch.cuda.is_available())
|
||||
get_project_base_directory(),
|
||||
"rag/res/bge-large-zh-v1.5"),
|
||||
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
|
||||
use_fp16=torch.cuda.is_available())
|
||||
except Exception as e:
|
||||
flag_model = FlagModel("BAAI/bge-large-zh-v1.5",
|
||||
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
|
||||
use_fp16=torch.cuda.is_available())
|
||||
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
|
||||
use_fp16=torch.cuda.is_available())
|
||||
|
||||
|
||||
class Base(ABC):
|
||||
@ -82,8 +82,10 @@ class HuEmbedding(Base):
|
||||
|
||||
|
||||
class OpenAIEmbed(Base):
|
||||
def __init__(self, key, model_name="text-embedding-ada-002", base_url="https://api.openai.com/v1"):
|
||||
if not base_url: base_url="https://api.openai.com/v1"
|
||||
def __init__(self, key, model_name="text-embedding-ada-002",
|
||||
base_url="https://api.openai.com/v1"):
|
||||
if not base_url:
|
||||
base_url = "https://api.openai.com/v1"
|
||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||
self.model_name = model_name
|
||||
|
||||
@ -142,7 +144,7 @@ class ZhipuEmbed(Base):
|
||||
tks_num = 0
|
||||
for txt in texts:
|
||||
res = self.client.embeddings.create(input=txt,
|
||||
model=self.model_name)
|
||||
model=self.model_name)
|
||||
arr.append(res.data[0].embedding)
|
||||
tks_num += res.usage.total_tokens
|
||||
return np.array(arr), tks_num
|
||||
@ -163,14 +165,14 @@ class OllamaEmbed(Base):
|
||||
tks_num = 0
|
||||
for txt in texts:
|
||||
res = self.client.embeddings(prompt=txt,
|
||||
model=self.model_name)
|
||||
model=self.model_name)
|
||||
arr.append(res["embedding"])
|
||||
tks_num += 128
|
||||
return np.array(arr), tks_num
|
||||
|
||||
def encode_queries(self, text):
|
||||
res = self.client.embeddings(prompt=text,
|
||||
model=self.model_name)
|
||||
model=self.model_name)
|
||||
return np.array(res["embedding"]), 128
|
||||
|
||||
|
||||
@ -183,10 +185,12 @@ class FastEmbed(Base):
|
||||
threads: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
from fastembed import TextEmbedding
|
||||
self._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
|
||||
|
||||
def encode(self, texts: list, batch_size=32):
|
||||
# Using the internal tokenizer to encode the texts and get the total number of tokens
|
||||
# Using the internal tokenizer to encode the texts and get the total
|
||||
# number of tokens
|
||||
encodings = self._model.model.tokenizer.encode_batch(texts)
|
||||
total_tokens = sum(len(e) for e in encodings)
|
||||
|
||||
@ -195,7 +199,8 @@ class FastEmbed(Base):
|
||||
return np.array(embeddings), total_tokens
|
||||
|
||||
def encode_queries(self, text: str):
|
||||
# Using the internal tokenizer to encode the texts and get the total number of tokens
|
||||
# Using the internal tokenizer to encode the texts and get the total
|
||||
# number of tokens
|
||||
encoding = self._model.model.tokenizer.encode(text)
|
||||
embedding = next(self._model.query_embed(text)).tolist()
|
||||
|
||||
@ -218,3 +223,33 @@ class XinferenceEmbed(Base):
|
||||
model=self.model_name)
|
||||
return np.array(res.data[0].embedding), res.usage.total_tokens
|
||||
|
||||
|
||||
class QAnythingEmbed(Base):
|
||||
_client = None
|
||||
|
||||
def __init__(self, key=None, model_name="maidalun1020/bce-embedding-base_v1", **kwargs):
|
||||
from BCEmbedding import EmbeddingModel as qanthing
|
||||
if not QAnythingEmbed._client:
|
||||
try:
|
||||
print("LOADING BCE...")
|
||||
QAnythingEmbed._client = qanthing(model_name_or_path=os.path.join(
|
||||
get_project_base_directory(),
|
||||
"rag/res/bce-embedding-base_v1"))
|
||||
except Exception as e:
|
||||
QAnythingEmbed._client = qanthing(
|
||||
model_name_or_path=model_name.replace(
|
||||
"maidalun1020", "InfiniFlow"))
|
||||
|
||||
def encode(self, texts: list, batch_size=10):
|
||||
res = []
|
||||
token_count = 0
|
||||
for t in texts:
|
||||
token_count += num_tokens_from_string(t)
|
||||
for i in range(0, len(texts), batch_size):
|
||||
embds = QAnythingEmbed._client.encode(texts[i:i + batch_size])
|
||||
res.extend(embds)
|
||||
return np.array(res), token_count
|
||||
|
||||
def encode_queries(self, text):
|
||||
embds = QAnythingEmbed._client.encode([text])
|
||||
return np.array(embds[0]), num_tokens_from_string(text)
|
||||
|
||||
Reference in New Issue
Block a user