add support for NVIDIA llm (#1645)

### What problem does this PR solve?

add support for NVIDIA llm
### Type of change

- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: Zhedong Cen <cenzhedong2@126.com>
This commit is contained in:
黄腾
2024-07-23 10:43:09 +08:00
committed by GitHub
parent 95821f6fb6
commit b4a281eca1
8 changed files with 508 additions and 7 deletions

View File

@ -137,7 +137,6 @@ class Base(ABC):
]
class GptV4(Base):
def __init__(self, key, model_name="gpt-4-vision-preview", lang="Chinese", base_url="https://api.openai.com/v1"):
if not base_url: base_url="https://api.openai.com/v1"
@ -619,3 +618,65 @@ class LocalCV(Base):
def describe(self, image, max_tokens=1024):
return "", 0
class NvidiaCV(Base):
def __init__(
self,
key,
model_name,
lang="Chinese",
base_url="https://ai.api.nvidia.com/v1/vlm",
):
if not base_url:
base_url = ("https://ai.api.nvidia.com/v1/vlm",)
self.lang = lang
factory, llm_name = model_name.split("/")
if factory != "liuhaotian":
self.base_url = os.path.join(base_url, factory, llm_name)
else:
self.base_url = os.path.join(
base_url, "community", llm_name.replace("-v1.6", "16")
)
self.key = key
def describe(self, image, max_tokens=1024):
b64 = self.image2base64(image)
response = requests.post(
url=self.base_url,
headers={
"accept": "application/json",
"content-type": "application/json",
"Authorization": f"Bearer {self.key}",
},
json={
"messages": self.prompt(b64),
"max_tokens": max_tokens,
},
)
response = response.json()
return (
response["choices"][0]["message"]["content"].strip(),
response["usage"]["total_tokens"],
)
def prompt(self, b64):
return [
{
"role": "user",
"content": (
"请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。"
if self.lang.lower() == "chinese"
else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out."
)
+ f' <img src="data:image/jpeg;base64,{b64}"/>',
}
]
def chat_prompt(self, text, b64):
return [
{
"role": "user",
"content": text + f' <img src="data:image/jpeg;base64,{b64}"/>',
}
]