add support for NVIDIA llm (#1645)

### What problem does this PR solve?

add support for NVIDIA llm
### Type of change

- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: Zhedong Cen <cenzhedong2@126.com>
This commit is contained in:
黄腾
2024-07-23 10:43:09 +08:00
committed by GitHub
parent 95821f6fb6
commit b4a281eca1
8 changed files with 508 additions and 7 deletions

View File

@ -164,3 +164,41 @@ class LocalAIRerank(Base):
def similarity(self, query: str, texts: list):
raise NotImplementedError("The LocalAIRerank has not been implement")
class NvidiaRerank(Base):
def __init__(
self, key, model_name, base_url="https://ai.api.nvidia.com/v1/retrieval/nvidia/"
):
if not base_url:
base_url = "https://ai.api.nvidia.com/v1/retrieval/nvidia/"
self.model_name = model_name
if self.model_name == "nvidia/nv-rerankqa-mistral-4b-v3":
self.base_url = os.path.join(
base_url, "nv-rerankqa-mistral-4b-v3", "reranking"
)
if self.model_name == "nvidia/rerank-qa-mistral-4b":
self.base_url = os.path.join(base_url, "reranking")
self.model_name = "nv-rerank-qa-mistral-4b:1"
self.headers = {
"accept": "application/json",
"Content-Type": "application/json",
"Authorization": f"Bearer {key}",
}
def similarity(self, query: str, texts: list):
token_count = num_tokens_from_string(query) + sum(
[num_tokens_from_string(t) for t in texts]
)
data = {
"model": self.model_name,
"query": {"text": query},
"passages": [{"text": text} for text in texts],
"truncate": "END",
"top_n": len(texts),
}
res = requests.post(self.base_url, headers=self.headers, json=data).json()
return (np.array([d["logit"] for d in res["rankings"]]), token_count)