mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
add rerank model (#969)
### What problem does this PR solve? feat: add rerank models to the project #724 #162 ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -13,8 +13,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
from huggingface_hub import snapshot_download
|
||||
from zhipuai import ZhipuAI
|
||||
import os
|
||||
@ -26,21 +28,9 @@ from FlagEmbedding import FlagModel
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from api.utils.file_utils import get_project_base_directory, get_home_cache_dir
|
||||
from api.utils.file_utils import get_home_cache_dir
|
||||
from rag.utils import num_tokens_from_string, truncate
|
||||
|
||||
try:
|
||||
flag_model = FlagModel(os.path.join(get_home_cache_dir(), "bge-large-zh-v1.5"),
|
||||
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
|
||||
use_fp16=torch.cuda.is_available())
|
||||
except Exception as e:
|
||||
model_dir = snapshot_download(repo_id="BAAI/bge-large-zh-v1.5",
|
||||
local_dir=os.path.join(get_home_cache_dir(), "bge-large-zh-v1.5"),
|
||||
local_dir_use_symlinks=False)
|
||||
flag_model = FlagModel(model_dir,
|
||||
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
|
||||
use_fp16=torch.cuda.is_available())
|
||||
|
||||
|
||||
class Base(ABC):
|
||||
def __init__(self, key, model_name):
|
||||
@ -54,7 +44,9 @@ class Base(ABC):
|
||||
|
||||
|
||||
class DefaultEmbedding(Base):
|
||||
def __init__(self, *args, **kwargs):
|
||||
_model = None
|
||||
|
||||
def __init__(self, key, model_name, **kwargs):
|
||||
"""
|
||||
If you have trouble downloading HuggingFace models, -_^ this might help!!
|
||||
|
||||
@ -66,7 +58,18 @@ class DefaultEmbedding(Base):
|
||||
^_-
|
||||
|
||||
"""
|
||||
self.model = flag_model
|
||||
if not DefaultEmbedding._model:
|
||||
try:
|
||||
self._model = FlagModel(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)),
|
||||
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
|
||||
use_fp16=torch.cuda.is_available())
|
||||
except Exception as e:
|
||||
model_dir = snapshot_download(repo_id="BAAI/bge-large-zh-v1.5",
|
||||
local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)),
|
||||
local_dir_use_symlinks=False)
|
||||
self._model = FlagModel(model_dir,
|
||||
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
|
||||
use_fp16=torch.cuda.is_available())
|
||||
|
||||
def encode(self, texts: list, batch_size=32):
|
||||
texts = [truncate(t, 2048) for t in texts]
|
||||
@ -75,12 +78,12 @@ class DefaultEmbedding(Base):
|
||||
token_count += num_tokens_from_string(t)
|
||||
res = []
|
||||
for i in range(0, len(texts), batch_size):
|
||||
res.extend(self.model.encode(texts[i:i + batch_size]).tolist())
|
||||
res.extend(self._model.encode(texts[i:i + batch_size]).tolist())
|
||||
return np.array(res), token_count
|
||||
|
||||
def encode_queries(self, text: str):
|
||||
token_count = num_tokens_from_string(text)
|
||||
return self.model.encode_queries([text]).tolist()[0], token_count
|
||||
return self._model.encode_queries([text]).tolist()[0], token_count
|
||||
|
||||
|
||||
class OpenAIEmbed(Base):
|
||||
@ -189,16 +192,19 @@ class OllamaEmbed(Base):
|
||||
|
||||
|
||||
class FastEmbed(Base):
|
||||
_model = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
key: Optional[str] = None,
|
||||
model_name: str = "BAAI/bge-small-en-v1.5",
|
||||
cache_dir: Optional[str] = None,
|
||||
threads: Optional[int] = None,
|
||||
**kwargs,
|
||||
self,
|
||||
key: Optional[str] = None,
|
||||
model_name: str = "BAAI/bge-small-en-v1.5",
|
||||
cache_dir: Optional[str] = None,
|
||||
threads: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
from fastembed import TextEmbedding
|
||||
self._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
|
||||
if not FastEmbed._model:
|
||||
self._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
|
||||
|
||||
def encode(self, texts: list, batch_size=32):
|
||||
# Using the internal tokenizer to encode the texts and get the total
|
||||
@ -265,3 +271,29 @@ class YoudaoEmbed(Base):
|
||||
def encode_queries(self, text):
|
||||
embds = YoudaoEmbed._client.encode([text])
|
||||
return np.array(embds[0]), num_tokens_from_string(text)
|
||||
|
||||
|
||||
class JinaEmbed(Base):
|
||||
def __init__(self, key, model_name="jina-embeddings-v2-base-zh",
|
||||
base_url="https://api.jina.ai/v1/embeddings"):
|
||||
|
||||
self.base_url = "https://api.jina.ai/v1/embeddings"
|
||||
self.headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {key}"
|
||||
}
|
||||
self.model_name = model_name
|
||||
|
||||
def encode(self, texts: list, batch_size=None):
|
||||
texts = [truncate(t, 8196) for t in texts]
|
||||
data = {
|
||||
"model": self.model_name,
|
||||
"input": texts,
|
||||
'encoding_type': 'float'
|
||||
}
|
||||
res = requests.post(self.base_url, headers=self.headers, json=data)
|
||||
return np.array([d["embedding"] for d in res["data"]]), res["usage"]["total_tokens"]
|
||||
|
||||
def encode_queries(self, text):
|
||||
embds, cnt = self.encode([text])
|
||||
return np.array(embds[0]), cnt
|
||||
Reference in New Issue
Block a user