add support for NVIDIA llm (#1645)

### What problem does this PR solve?

add support for NVIDIA llm
### Type of change

- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: Zhedong Cen <cenzhedong2@126.com>
This commit is contained in:
黄腾
2024-07-23 10:43:09 +08:00
committed by GitHub
parent 95821f6fb6
commit b4a281eca1
8 changed files with 508 additions and 7 deletions

View File

@ -462,3 +462,41 @@ class GeminiEmbed(Base):
title="Embedding of single string")
token_count = num_tokens_from_string(text)
return np.array(result['embedding']),token_count
class NvidiaEmbed(Base):
def __init__(
self, key, model_name, base_url="https://integrate.api.nvidia.com/v1/embeddings"
):
if not base_url:
base_url = "https://integrate.api.nvidia.com/v1/embeddings"
self.api_key = key
self.base_url = base_url
self.headers = {
"accept": "application/json",
"Content-Type": "application/json",
"authorization": f"Bearer {self.api_key}",
}
self.model_name = model_name
if model_name == "nvidia/embed-qa-4":
self.base_url = "https://ai.api.nvidia.com/v1/retrieval/nvidia/embeddings"
self.model_name = "NV-Embed-QA"
if model_name == "snowflake/arctic-embed-l":
self.base_url = "https://ai.api.nvidia.com/v1/retrieval/snowflake/arctic-embed-l/embeddings"
def encode(self, texts: list, batch_size=None):
payload = {
"input": texts,
"input_type": "query",
"model": self.model_name,
"encoding_format": "float",
"truncate": "END",
}
res = requests.post(self.base_url, headers=self.headers, json=payload).json()
return (
np.array([d["embedding"] for d in res["data"]]),
res["usage"]["total_tokens"],
)
def encode_queries(self, text):
embds, cnt = self.encode([text])
return np.array(embds[0]), cnt