mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
add support for NVIDIA llm (#1645)
### What problem does this PR solve? add support for NVIDIA llm ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: Zhedong Cen <cenzhedong2@126.com>
This commit is contained in:
@ -164,3 +164,41 @@ class LocalAIRerank(Base):
|
||||
|
||||
def similarity(self, query: str, texts: list):
|
||||
raise NotImplementedError("The LocalAIRerank has not been implement")
|
||||
|
||||
|
||||
class NvidiaRerank(Base):
|
||||
def __init__(
|
||||
self, key, model_name, base_url="https://ai.api.nvidia.com/v1/retrieval/nvidia/"
|
||||
):
|
||||
if not base_url:
|
||||
base_url = "https://ai.api.nvidia.com/v1/retrieval/nvidia/"
|
||||
self.model_name = model_name
|
||||
|
||||
if self.model_name == "nvidia/nv-rerankqa-mistral-4b-v3":
|
||||
self.base_url = os.path.join(
|
||||
base_url, "nv-rerankqa-mistral-4b-v3", "reranking"
|
||||
)
|
||||
|
||||
if self.model_name == "nvidia/rerank-qa-mistral-4b":
|
||||
self.base_url = os.path.join(base_url, "reranking")
|
||||
self.model_name = "nv-rerank-qa-mistral-4b:1"
|
||||
|
||||
self.headers = {
|
||||
"accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {key}",
|
||||
}
|
||||
|
||||
def similarity(self, query: str, texts: list):
|
||||
token_count = num_tokens_from_string(query) + sum(
|
||||
[num_tokens_from_string(t) for t in texts]
|
||||
)
|
||||
data = {
|
||||
"model": self.model_name,
|
||||
"query": {"text": query},
|
||||
"passages": [{"text": text} for text in texts],
|
||||
"truncate": "END",
|
||||
"top_n": len(texts),
|
||||
}
|
||||
res = requests.post(self.base_url, headers=self.headers, json=data).json()
|
||||
return (np.array([d["logit"] for d in res["rankings"]]), token_count)
|
||||
|
||||
Reference in New Issue
Block a user