add rerank model (#969)

### What problem does this PR solve?

feat: add rerank models to the project #724 #162

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
KevinHuSh
2024-05-29 16:50:02 +08:00
committed by GitHub
parent e1f0644deb
commit 614defec21
17 changed files with 437 additions and 64 deletions

View File

@ -16,18 +16,19 @@
from .embedding_model import *
from .chat_model import *
from .cv_model import *
from .rerank_model import *
EmbeddingModel = {
"Ollama": OllamaEmbed,
"OpenAI": OpenAIEmbed,
"Xinference": XinferenceEmbed,
"Tongyi-Qianwen": DefaultEmbedding, #QWenEmbed,
"Tongyi-Qianwen": DefaultEmbedding,#QWenEmbed,
"ZHIPU-AI": ZhipuEmbed,
"FastEmbed": FastEmbed,
"Youdao": YoudaoEmbed,
"DeepSeek": DefaultEmbedding,
"BaiChuan": BaiChuanEmbed
"BaiChuan": BaiChuanEmbed,
"BAAI": DefaultEmbedding
}
@ -52,3 +53,9 @@ ChatModel = {
"BaiChuan": BaiChuanChat
}
RerankModel = {
"BAAI": DefaultRerank,
"Jina": JinaRerank,
"Youdao": YoudaoRerank,
}

View File

@ -13,8 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import re
from typing import Optional
import requests
from huggingface_hub import snapshot_download
from zhipuai import ZhipuAI
import os
@ -26,21 +28,9 @@ from FlagEmbedding import FlagModel
import torch
import numpy as np
from api.utils.file_utils import get_project_base_directory, get_home_cache_dir
from api.utils.file_utils import get_home_cache_dir
from rag.utils import num_tokens_from_string, truncate
try:
flag_model = FlagModel(os.path.join(get_home_cache_dir(), "bge-large-zh-v1.5"),
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
use_fp16=torch.cuda.is_available())
except Exception as e:
model_dir = snapshot_download(repo_id="BAAI/bge-large-zh-v1.5",
local_dir=os.path.join(get_home_cache_dir(), "bge-large-zh-v1.5"),
local_dir_use_symlinks=False)
flag_model = FlagModel(model_dir,
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
use_fp16=torch.cuda.is_available())
class Base(ABC):
def __init__(self, key, model_name):
@ -54,7 +44,9 @@ class Base(ABC):
class DefaultEmbedding(Base):
def __init__(self, *args, **kwargs):
_model = None
def __init__(self, key, model_name, **kwargs):
"""
If you have trouble downloading HuggingFace models, -_^ this might help!!
@ -66,7 +58,18 @@ class DefaultEmbedding(Base):
^_-
"""
self.model = flag_model
if not DefaultEmbedding._model:
try:
self._model = FlagModel(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)),
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
use_fp16=torch.cuda.is_available())
except Exception as e:
model_dir = snapshot_download(repo_id="BAAI/bge-large-zh-v1.5",
local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)),
local_dir_use_symlinks=False)
self._model = FlagModel(model_dir,
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
use_fp16=torch.cuda.is_available())
def encode(self, texts: list, batch_size=32):
texts = [truncate(t, 2048) for t in texts]
@ -75,12 +78,12 @@ class DefaultEmbedding(Base):
token_count += num_tokens_from_string(t)
res = []
for i in range(0, len(texts), batch_size):
res.extend(self.model.encode(texts[i:i + batch_size]).tolist())
res.extend(self._model.encode(texts[i:i + batch_size]).tolist())
return np.array(res), token_count
def encode_queries(self, text: str):
token_count = num_tokens_from_string(text)
return self.model.encode_queries([text]).tolist()[0], token_count
return self._model.encode_queries([text]).tolist()[0], token_count
class OpenAIEmbed(Base):
@ -189,16 +192,19 @@ class OllamaEmbed(Base):
class FastEmbed(Base):
_model = None
def __init__(
self,
key: Optional[str] = None,
model_name: str = "BAAI/bge-small-en-v1.5",
cache_dir: Optional[str] = None,
threads: Optional[int] = None,
**kwargs,
self,
key: Optional[str] = None,
model_name: str = "BAAI/bge-small-en-v1.5",
cache_dir: Optional[str] = None,
threads: Optional[int] = None,
**kwargs,
):
from fastembed import TextEmbedding
self._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
if not FastEmbed._model:
self._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
def encode(self, texts: list, batch_size=32):
# Using the internal tokenizer to encode the texts and get the total
@ -265,3 +271,29 @@ class YoudaoEmbed(Base):
def encode_queries(self, text):
embds = YoudaoEmbed._client.encode([text])
return np.array(embds[0]), num_tokens_from_string(text)
class JinaEmbed(Base):
def __init__(self, key, model_name="jina-embeddings-v2-base-zh",
base_url="https://api.jina.ai/v1/embeddings"):
self.base_url = "https://api.jina.ai/v1/embeddings"
self.headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {key}"
}
self.model_name = model_name
def encode(self, texts: list, batch_size=None):
texts = [truncate(t, 8196) for t in texts]
data = {
"model": self.model_name,
"input": texts,
'encoding_type': 'float'
}
res = requests.post(self.base_url, headers=self.headers, json=data)
return np.array([d["embedding"] for d in res["data"]]), res["usage"]["total_tokens"]
def encode_queries(self, text):
embds, cnt = self.encode([text])
return np.array(embds[0]), cnt

113
rag/llm/rerank_model.py Normal file
View File

@ -0,0 +1,113 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import re
import requests
import torch
from FlagEmbedding import FlagReranker
from huggingface_hub import snapshot_download
import os
from abc import ABC
import numpy as np
from api.utils.file_utils import get_home_cache_dir
from rag.utils import num_tokens_from_string, truncate
class Base(ABC):
def __init__(self, key, model_name):
pass
def similarity(self, query: str, texts: list):
raise NotImplementedError("Please implement encode method!")
class DefaultRerank(Base):
_model = None
def __init__(self, key, model_name, **kwargs):
"""
If you have trouble downloading HuggingFace models, -_^ this might help!!
For Linux:
export HF_ENDPOINT=https://hf-mirror.com
For Windows:
Good luck
^_-
"""
if not DefaultRerank._model:
try:
self._model = FlagReranker(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)),
use_fp16=torch.cuda.is_available())
except Exception as e:
self._model = snapshot_download(repo_id=model_name,
local_dir=os.path.join(get_home_cache_dir(),
re.sub(r"^[a-zA-Z]+/", "", model_name)),
local_dir_use_symlinks=False)
self._model = FlagReranker(os.path.join(get_home_cache_dir(), model_name),
use_fp16=torch.cuda.is_available())
def similarity(self, query: str, texts: list):
pairs = [(query,truncate(t, 2048)) for t in texts]
token_count = 0
for _, t in pairs:
token_count += num_tokens_from_string(t)
batch_size = 32
res = []
for i in range(0, len(pairs), batch_size):
scores = self._model.compute_score(pairs[i:i + batch_size], max_length=2048)
res.extend(scores)
return np.array(res), token_count
class JinaRerank(Base):
def __init__(self, key, model_name="jina-reranker-v1-base-en",
base_url="https://api.jina.ai/v1/rerank"):
self.base_url = "https://api.jina.ai/v1/rerank"
self.headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {key}"
}
self.model_name = model_name
def similarity(self, query: str, texts: list):
texts = [truncate(t, 8196) for t in texts]
data = {
"model": self.model_name,
"query": query,
"documents": texts,
"top_n": len(texts)
}
res = requests.post(self.base_url, headers=self.headers, json=data)
return np.array([d["relevance_score"] for d in res["results"]]), res["usage"]["total_tokens"]
class YoudaoRerank(DefaultRerank):
_model = None
def __init__(self, key=None, model_name="maidalun1020/bce-reranker-base_v1", **kwargs):
from BCEmbedding import RerankerModel
if not YoudaoRerank._model:
try:
print("LOADING BCE...")
YoudaoRerank._model = RerankerModel(model_name_or_path=os.path.join(
get_home_cache_dir(),
re.sub(r"^[a-zA-Z]+/", "", model_name)))
except Exception as e:
YoudaoRerank._model = RerankerModel(
model_name_or_path=model_name.replace(
"maidalun1020", "InfiniFlow"))