mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-02-02 00:25:06 +08:00
Refa: automatic LLMs registration (#8651)
### What problem does this PR solve? Support automatic LLMs registration. ### Type of change - [x] Refactoring
This commit is contained in:
@ -13,24 +13,24 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
from abc import ABC
|
||||
from collections.abc import Iterable
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import requests
|
||||
import httpx
|
||||
from huggingface_hub import snapshot_download
|
||||
import os
|
||||
from abc import ABC
|
||||
import numpy as np
|
||||
import requests
|
||||
from huggingface_hub import snapshot_download
|
||||
from yarl import URL
|
||||
|
||||
from api import settings
|
||||
from api.utils.file_utils import get_home_cache_dir
|
||||
from api.utils.log_utils import log_exception
|
||||
from rag.utils import num_tokens_from_string, truncate
|
||||
import json
|
||||
|
||||
|
||||
def sigmoid(x):
|
||||
@ -57,6 +57,7 @@ class Base(ABC):
|
||||
|
||||
|
||||
class DefaultRerank(Base):
|
||||
_FACTORY_NAME = "BAAI"
|
||||
_model = None
|
||||
_model_lock = threading.Lock()
|
||||
|
||||
@ -75,17 +76,13 @@ class DefaultRerank(Base):
|
||||
if not settings.LIGHTEN and not DefaultRerank._model:
|
||||
import torch
|
||||
from FlagEmbedding import FlagReranker
|
||||
|
||||
with DefaultRerank._model_lock:
|
||||
if not DefaultRerank._model:
|
||||
try:
|
||||
DefaultRerank._model = FlagReranker(
|
||||
os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
|
||||
use_fp16=torch.cuda.is_available())
|
||||
DefaultRerank._model = FlagReranker(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), use_fp16=torch.cuda.is_available())
|
||||
except Exception:
|
||||
model_dir = snapshot_download(repo_id=model_name,
|
||||
local_dir=os.path.join(get_home_cache_dir(),
|
||||
re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
|
||||
local_dir_use_symlinks=False)
|
||||
model_dir = snapshot_download(repo_id=model_name, local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), local_dir_use_symlinks=False)
|
||||
DefaultRerank._model = FlagReranker(model_dir, use_fp16=torch.cuda.is_available())
|
||||
self._model = DefaultRerank._model
|
||||
self._dynamic_batch_size = 8
|
||||
@ -94,6 +91,7 @@ class DefaultRerank(Base):
|
||||
def torch_empty_cache(self):
|
||||
try:
|
||||
import torch
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
except Exception as e:
|
||||
print(f"Error emptying cache: {e}")
|
||||
@ -112,7 +110,7 @@ class DefaultRerank(Base):
|
||||
while retry_count < max_retries:
|
||||
try:
|
||||
# call subclass implemented batch processing calculation
|
||||
batch_scores = self._compute_batch_scores(pairs[i:i + current_batch])
|
||||
batch_scores = self._compute_batch_scores(pairs[i : i + current_batch])
|
||||
res.extend(batch_scores)
|
||||
i += current_batch
|
||||
self._dynamic_batch_size = min(self._dynamic_batch_size * 2, 8)
|
||||
@ -152,23 +150,16 @@ class DefaultRerank(Base):
|
||||
|
||||
|
||||
class JinaRerank(Base):
|
||||
def __init__(self, key, model_name="jina-reranker-v2-base-multilingual",
|
||||
base_url="https://api.jina.ai/v1/rerank"):
|
||||
_FACTORY_NAME = "Jina"
|
||||
|
||||
def __init__(self, key, model_name="jina-reranker-v2-base-multilingual", base_url="https://api.jina.ai/v1/rerank"):
|
||||
self.base_url = "https://api.jina.ai/v1/rerank"
|
||||
self.headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {key}"
|
||||
}
|
||||
self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
|
||||
self.model_name = model_name
|
||||
|
||||
def similarity(self, query: str, texts: list):
|
||||
texts = [truncate(t, 8196) for t in texts]
|
||||
data = {
|
||||
"model": self.model_name,
|
||||
"query": query,
|
||||
"documents": texts,
|
||||
"top_n": len(texts)
|
||||
}
|
||||
data = {"model": self.model_name, "query": query, "documents": texts, "top_n": len(texts)}
|
||||
res = requests.post(self.base_url, headers=self.headers, json=data).json()
|
||||
rank = np.zeros(len(texts), dtype=float)
|
||||
try:
|
||||
@ -180,22 +171,20 @@ class JinaRerank(Base):
|
||||
|
||||
|
||||
class YoudaoRerank(DefaultRerank):
|
||||
_FACTORY_NAME = "Youdao"
|
||||
_model = None
|
||||
_model_lock = threading.Lock()
|
||||
|
||||
def __init__(self, key=None, model_name="maidalun1020/bce-reranker-base_v1", **kwargs):
|
||||
if not settings.LIGHTEN and not YoudaoRerank._model:
|
||||
from BCEmbedding import RerankerModel
|
||||
|
||||
with YoudaoRerank._model_lock:
|
||||
if not YoudaoRerank._model:
|
||||
try:
|
||||
YoudaoRerank._model = RerankerModel(model_name_or_path=os.path.join(
|
||||
get_home_cache_dir(),
|
||||
re.sub(r"^[a-zA-Z0-9]+/", "", model_name)))
|
||||
YoudaoRerank._model = RerankerModel(model_name_or_path=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)))
|
||||
except Exception:
|
||||
YoudaoRerank._model = RerankerModel(
|
||||
model_name_or_path=model_name.replace(
|
||||
"maidalun1020", "InfiniFlow"))
|
||||
YoudaoRerank._model = RerankerModel(model_name_or_path=model_name.replace("maidalun1020", "InfiniFlow"))
|
||||
|
||||
self._model = YoudaoRerank._model
|
||||
self._dynamic_batch_size = 8
|
||||
@ -212,6 +201,8 @@ class YoudaoRerank(DefaultRerank):
|
||||
|
||||
|
||||
class XInferenceRerank(Base):
|
||||
_FACTORY_NAME = "Xinference"
|
||||
|
||||
def __init__(self, key="x", model_name="", base_url=""):
|
||||
if base_url.find("/v1") == -1:
|
||||
base_url = urljoin(base_url, "/v1/rerank")
|
||||
@ -219,10 +210,7 @@ class XInferenceRerank(Base):
|
||||
base_url = urljoin(base_url, "/v1/rerank")
|
||||
self.model_name = model_name
|
||||
self.base_url = base_url
|
||||
self.headers = {
|
||||
"Content-Type": "application/json",
|
||||
"accept": "application/json"
|
||||
}
|
||||
self.headers = {"Content-Type": "application/json", "accept": "application/json"}
|
||||
if key and key != "x":
|
||||
self.headers["Authorization"] = f"Bearer {key}"
|
||||
|
||||
@ -233,13 +221,7 @@ class XInferenceRerank(Base):
|
||||
token_count = 0
|
||||
for _, t in pairs:
|
||||
token_count += num_tokens_from_string(t)
|
||||
data = {
|
||||
"model": self.model_name,
|
||||
"query": query,
|
||||
"return_documents": "true",
|
||||
"return_len": "true",
|
||||
"documents": texts
|
||||
}
|
||||
data = {"model": self.model_name, "query": query, "return_documents": "true", "return_len": "true", "documents": texts}
|
||||
res = requests.post(self.base_url, headers=self.headers, json=data).json()
|
||||
rank = np.zeros(len(texts), dtype=float)
|
||||
try:
|
||||
@ -251,15 +233,14 @@ class XInferenceRerank(Base):
|
||||
|
||||
|
||||
class LocalAIRerank(Base):
|
||||
_FACTORY_NAME = "LocalAI"
|
||||
|
||||
def __init__(self, key, model_name, base_url):
|
||||
if base_url.find("/rerank") == -1:
|
||||
self.base_url = urljoin(base_url, "/rerank")
|
||||
else:
|
||||
self.base_url = base_url
|
||||
self.headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {key}"
|
||||
}
|
||||
self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
|
||||
self.model_name = model_name.split("___")[0]
|
||||
|
||||
def similarity(self, query: str, texts: list):
|
||||
@ -296,16 +277,15 @@ class LocalAIRerank(Base):
|
||||
|
||||
|
||||
class NvidiaRerank(Base):
|
||||
def __init__(
|
||||
self, key, model_name, base_url="https://ai.api.nvidia.com/v1/retrieval/nvidia/"
|
||||
):
|
||||
_FACTORY_NAME = "NVIDIA"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://ai.api.nvidia.com/v1/retrieval/nvidia/"):
|
||||
if not base_url:
|
||||
base_url = "https://ai.api.nvidia.com/v1/retrieval/nvidia/"
|
||||
self.model_name = model_name
|
||||
|
||||
if self.model_name == "nvidia/nv-rerankqa-mistral-4b-v3":
|
||||
self.base_url = urljoin(base_url, "nv-rerankqa-mistral-4b-v3/reranking"
|
||||
)
|
||||
self.base_url = urljoin(base_url, "nv-rerankqa-mistral-4b-v3/reranking")
|
||||
|
||||
if self.model_name == "nvidia/rerank-qa-mistral-4b":
|
||||
self.base_url = urljoin(base_url, "reranking")
|
||||
@ -318,9 +298,7 @@ class NvidiaRerank(Base):
|
||||
}
|
||||
|
||||
def similarity(self, query: str, texts: list):
|
||||
token_count = num_tokens_from_string(query) + sum(
|
||||
[num_tokens_from_string(t) for t in texts]
|
||||
)
|
||||
token_count = num_tokens_from_string(query) + sum([num_tokens_from_string(t) for t in texts])
|
||||
data = {
|
||||
"model": self.model_name,
|
||||
"query": {"text": query},
|
||||
@ -339,6 +317,8 @@ class NvidiaRerank(Base):
|
||||
|
||||
|
||||
class LmStudioRerank(Base):
|
||||
_FACTORY_NAME = "LM-Studio"
|
||||
|
||||
def __init__(self, key, model_name, base_url):
|
||||
pass
|
||||
|
||||
@ -347,15 +327,14 @@ class LmStudioRerank(Base):
|
||||
|
||||
|
||||
class OpenAI_APIRerank(Base):
|
||||
_FACTORY_NAME = "OpenAI-API-Compatible"
|
||||
|
||||
def __init__(self, key, model_name, base_url):
|
||||
if base_url.find("/rerank") == -1:
|
||||
self.base_url = urljoin(base_url, "/rerank")
|
||||
else:
|
||||
self.base_url = base_url
|
||||
self.headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {key}"
|
||||
}
|
||||
self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
|
||||
self.model_name = model_name.split("___")[0]
|
||||
|
||||
def similarity(self, query: str, texts: list):
|
||||
@ -392,6 +371,8 @@ class OpenAI_APIRerank(Base):
|
||||
|
||||
|
||||
class CoHereRerank(Base):
|
||||
_FACTORY_NAME = ["Cohere", "VLLM"]
|
||||
|
||||
def __init__(self, key, model_name, base_url=None):
|
||||
from cohere import Client
|
||||
|
||||
@ -399,9 +380,7 @@ class CoHereRerank(Base):
|
||||
self.model_name = model_name.split("___")[0]
|
||||
|
||||
def similarity(self, query: str, texts: list):
|
||||
token_count = num_tokens_from_string(query) + sum(
|
||||
[num_tokens_from_string(t) for t in texts]
|
||||
)
|
||||
token_count = num_tokens_from_string(query) + sum([num_tokens_from_string(t) for t in texts])
|
||||
res = self.client.rerank(
|
||||
model=self.model_name,
|
||||
query=query,
|
||||
@ -419,6 +398,8 @@ class CoHereRerank(Base):
|
||||
|
||||
|
||||
class TogetherAIRerank(Base):
|
||||
_FACTORY_NAME = "TogetherAI"
|
||||
|
||||
def __init__(self, key, model_name, base_url):
|
||||
pass
|
||||
|
||||
@ -427,9 +408,9 @@ class TogetherAIRerank(Base):
|
||||
|
||||
|
||||
class SILICONFLOWRerank(Base):
|
||||
def __init__(
|
||||
self, key, model_name, base_url="https://api.siliconflow.cn/v1/rerank"
|
||||
):
|
||||
_FACTORY_NAME = "SILICONFLOW"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://api.siliconflow.cn/v1/rerank"):
|
||||
if not base_url:
|
||||
base_url = "https://api.siliconflow.cn/v1/rerank"
|
||||
self.model_name = model_name
|
||||
@ -450,9 +431,7 @@ class SILICONFLOWRerank(Base):
|
||||
"max_chunks_per_doc": 1024,
|
||||
"overlap_tokens": 80,
|
||||
}
|
||||
response = requests.post(
|
||||
self.base_url, json=payload, headers=self.headers
|
||||
).json()
|
||||
response = requests.post(self.base_url, json=payload, headers=self.headers).json()
|
||||
rank = np.zeros(len(texts), dtype=float)
|
||||
try:
|
||||
for d in response["results"]:
|
||||
@ -466,6 +445,8 @@ class SILICONFLOWRerank(Base):
|
||||
|
||||
|
||||
class BaiduYiyanRerank(Base):
|
||||
_FACTORY_NAME = "BaiduYiyan"
|
||||
|
||||
def __init__(self, key, model_name, base_url=None):
|
||||
from qianfan.resources import Reranker
|
||||
|
||||
@ -492,6 +473,8 @@ class BaiduYiyanRerank(Base):
|
||||
|
||||
|
||||
class VoyageRerank(Base):
|
||||
_FACTORY_NAME = "Voyage AI"
|
||||
|
||||
def __init__(self, key, model_name, base_url=None):
|
||||
import voyageai
|
||||
|
||||
@ -502,9 +485,7 @@ class VoyageRerank(Base):
|
||||
rank = np.zeros(len(texts), dtype=float)
|
||||
if not texts:
|
||||
return rank, 0
|
||||
res = self.client.rerank(
|
||||
query=query, documents=texts, model=self.model_name, top_k=len(texts)
|
||||
)
|
||||
res = self.client.rerank(query=query, documents=texts, model=self.model_name, top_k=len(texts))
|
||||
try:
|
||||
for r in res.results:
|
||||
rank[r.index] = r.relevance_score
|
||||
@ -514,22 +495,20 @@ class VoyageRerank(Base):
|
||||
|
||||
|
||||
class QWenRerank(Base):
|
||||
def __init__(self, key, model_name='gte-rerank', base_url=None, **kwargs):
|
||||
_FACTORY_NAME = "Tongyi-Qianwen"
|
||||
|
||||
def __init__(self, key, model_name="gte-rerank", base_url=None, **kwargs):
|
||||
import dashscope
|
||||
|
||||
self.api_key = key
|
||||
self.model_name = dashscope.TextReRank.Models.gte_rerank if model_name is None else model_name
|
||||
|
||||
def similarity(self, query: str, texts: list):
|
||||
import dashscope
|
||||
from http import HTTPStatus
|
||||
resp = dashscope.TextReRank.call(
|
||||
api_key=self.api_key,
|
||||
model=self.model_name,
|
||||
query=query,
|
||||
documents=texts,
|
||||
top_n=len(texts),
|
||||
return_documents=False
|
||||
)
|
||||
|
||||
import dashscope
|
||||
|
||||
resp = dashscope.TextReRank.call(api_key=self.api_key, model=self.model_name, query=query, documents=texts, top_n=len(texts), return_documents=False)
|
||||
rank = np.zeros(len(texts), dtype=float)
|
||||
if resp.status_code == HTTPStatus.OK:
|
||||
try:
|
||||
@ -543,6 +522,8 @@ class QWenRerank(Base):
|
||||
|
||||
|
||||
class HuggingfaceRerank(DefaultRerank):
|
||||
_FACTORY_NAME = "HuggingFace"
|
||||
|
||||
@staticmethod
|
||||
def post(query: str, texts: list, url="127.0.0.1"):
|
||||
exc = None
|
||||
@ -550,9 +531,9 @@ class HuggingfaceRerank(DefaultRerank):
|
||||
batch_size = 8
|
||||
for i in range(0, len(texts), batch_size):
|
||||
try:
|
||||
res = requests.post(f"http://{url}/rerank", headers={"Content-Type": "application/json"},
|
||||
json={"query": query, "texts": texts[i: i + batch_size],
|
||||
"raw_scores": False, "truncate": True})
|
||||
res = requests.post(
|
||||
f"http://{url}/rerank", headers={"Content-Type": "application/json"}, json={"query": query, "texts": texts[i : i + batch_size], "raw_scores": False, "truncate": True}
|
||||
)
|
||||
|
||||
for o in res.json():
|
||||
scores[o["index"] + i] = o["score"]
|
||||
@ -577,9 +558,9 @@ class HuggingfaceRerank(DefaultRerank):
|
||||
|
||||
|
||||
class GPUStackRerank(Base):
|
||||
def __init__(
|
||||
self, key, model_name, base_url
|
||||
):
|
||||
_FACTORY_NAME = "GPUStack"
|
||||
|
||||
def __init__(self, key, model_name, base_url):
|
||||
if not base_url:
|
||||
raise ValueError("url cannot be None")
|
||||
|
||||
@ -600,9 +581,7 @@ class GPUStackRerank(Base):
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
self.base_url, json=payload, headers=self.headers
|
||||
)
|
||||
response = requests.post(self.base_url, json=payload, headers=self.headers)
|
||||
response.raise_for_status()
|
||||
response_json = response.json()
|
||||
|
||||
@ -623,11 +602,12 @@ class GPUStackRerank(Base):
|
||||
)
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise ValueError(
|
||||
f"Error calling GPUStackRerank model {self.model_name}: {e.response.status_code} - {e.response.text}")
|
||||
raise ValueError(f"Error calling GPUStackRerank model {self.model_name}: {e.response.status_code} - {e.response.text}")
|
||||
|
||||
|
||||
class NovitaRerank(JinaRerank):
|
||||
_FACTORY_NAME = "NovitaAI"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://api.novita.ai/v3/openai/rerank"):
|
||||
if not base_url:
|
||||
base_url = "https://api.novita.ai/v3/openai/rerank"
|
||||
@ -635,7 +615,9 @@ class NovitaRerank(JinaRerank):
|
||||
|
||||
|
||||
class GiteeRerank(JinaRerank):
|
||||
_FACTORY_NAME = "GiteeAI"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://ai.gitee.com/v1/rerank"):
|
||||
if not base_url:
|
||||
base_url = "https://ai.gitee.com/v1/rerank"
|
||||
super().__init__(key, model_name, base_url)
|
||||
super().__init__(key, model_name, base_url)
|
||||
|
||||
Reference in New Issue
Block a user