add support for NVIDIA llm (#1645)

### What problem does this PR solve?

add support for NVIDIA llm
### Type of change

- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: Zhedong Cen <cenzhedong2@126.com>
This commit is contained in:
黄腾
2024-07-23 10:43:09 +08:00
committed by GitHub
parent 95821f6fb6
commit b4a281eca1
8 changed files with 508 additions and 7 deletions

View File

@ -581,7 +581,6 @@ class MiniMaxChat(Base):
response = requests.request(
"POST", url=self.base_url, headers=headers, data=payload
)
print(response, flush=True)
response = response.json()
ans = response["choices"][0]["message"]["content"].strip()
if response["choices"][0]["finish_reason"] == "length":
@ -902,4 +901,79 @@ class StepFunChat(Base):
def __init__(self, key, model_name, base_url="https://api.stepfun.com/v1/chat/completions"):
if not base_url:
base_url = "https://api.stepfun.com/v1/chat/completions"
super().__init__(key, model_name, base_url)
super().__init__(key, model_name, base_url)
class NvidiaChat(Base):
def __init__(
self,
key,
model_name,
base_url="https://integrate.api.nvidia.com/v1/chat/completions",
):
if not base_url:
base_url = "https://integrate.api.nvidia.com/v1/chat/completions"
self.base_url = base_url
self.model_name = model_name
self.api_key = key
self.headers = {
"accept": "application/json",
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
def chat(self, system, history, gen_conf):
if system:
history.insert(0, {"role": "system", "content": system})
for k in list(gen_conf.keys()):
if k not in ["temperature", "top_p", "max_tokens"]:
del gen_conf[k]
payload = {"model": self.model_name, "messages": history, **gen_conf}
try:
response = requests.post(
url=self.base_url, headers=self.headers, json=payload
)
response = response.json()
ans = response["choices"][0]["message"]["content"].strip()
return ans, response["usage"]["total_tokens"]
except Exception as e:
return "**ERROR**: " + str(e), 0
def chat_streamly(self, system, history, gen_conf):
if system:
history.insert(0, {"role": "system", "content": system})
for k in list(gen_conf.keys()):
if k not in ["temperature", "top_p", "max_tokens"]:
del gen_conf[k]
ans = ""
total_tokens = 0
payload = {
"model": self.model_name,
"messages": history,
"stream": True,
**gen_conf,
}
try:
response = requests.post(
url=self.base_url,
headers=self.headers,
json=payload,
)
for resp in response.text.split("\n\n"):
if "choices" not in resp:
continue
resp = json.loads(resp[6:])
if "content" in resp["choices"][0]["delta"]:
text = resp["choices"][0]["delta"]["content"]
else:
continue
ans += text
if "usage" in resp:
total_tokens = resp["usage"]["total_tokens"]
yield ans
except Exception as e:
yield ans + "\n**ERROR**: " + str(e)
yield total_tokens