mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Refa: automatic LLMs registration (#8651)
### What problem does this PR solve? Support automatic LLMs registration. ### Type of change - [x] Refactoring
This commit is contained in:
@ -57,7 +57,7 @@ class Base(ABC):
|
||||
model=self.model_name,
|
||||
messages=history,
|
||||
temperature=gen_conf.get("temperature", 0.3),
|
||||
top_p=gen_conf.get("top_p", 0.7)
|
||||
top_p=gen_conf.get("top_p", 0.7),
|
||||
)
|
||||
return response.choices[0].message.content.strip(), response.usage.total_tokens
|
||||
except Exception as e:
|
||||
@ -79,7 +79,7 @@ class Base(ABC):
|
||||
messages=history,
|
||||
temperature=gen_conf.get("temperature", 0.3),
|
||||
top_p=gen_conf.get("top_p", 0.7),
|
||||
stream=True
|
||||
stream=True,
|
||||
)
|
||||
for resp in response:
|
||||
if not resp.choices[0].delta.content:
|
||||
@ -87,8 +87,7 @@ class Base(ABC):
|
||||
delta = resp.choices[0].delta.content
|
||||
ans += delta
|
||||
if resp.choices[0].finish_reason == "length":
|
||||
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
|
||||
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
||||
ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
||||
tk_count = resp.usage.total_tokens
|
||||
if resp.choices[0].finish_reason == "stop":
|
||||
tk_count = resp.usage.total_tokens
|
||||
@ -117,13 +116,12 @@ class Base(ABC):
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{b64}"
|
||||
},
|
||||
"image_url": {"url": f"data:image/jpeg;base64,{b64}"},
|
||||
},
|
||||
{
|
||||
"text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else
|
||||
"Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
|
||||
"text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。"
|
||||
if self.lang.lower() == "chinese"
|
||||
else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
|
||||
},
|
||||
],
|
||||
}
|
||||
@ -136,9 +134,7 @@ class Base(ABC):
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{b64}"
|
||||
},
|
||||
"image_url": {"url": f"data:image/jpeg;base64,{b64}"},
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
@ -156,14 +152,13 @@ class Base(ABC):
|
||||
"url": f"data:image/jpeg;base64,{b64}",
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": text
|
||||
},
|
||||
{"type": "text", "text": text},
|
||||
]
|
||||
|
||||
|
||||
class GptV4(Base):
|
||||
_FACTORY_NAME = "OpenAI"
|
||||
|
||||
def __init__(self, key, model_name="gpt-4-vision-preview", lang="Chinese", base_url="https://api.openai.com/v1"):
|
||||
if not base_url:
|
||||
base_url = "https://api.openai.com/v1"
|
||||
@ -181,7 +176,7 @@ class GptV4(Base):
|
||||
|
||||
res = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=prompt
|
||||
messages=prompt,
|
||||
)
|
||||
return res.choices[0].message.content.strip(), res.usage.total_tokens
|
||||
|
||||
@ -197,9 +192,11 @@ class GptV4(Base):
|
||||
|
||||
|
||||
class AzureGptV4(Base):
|
||||
_FACTORY_NAME = "Azure-OpenAI"
|
||||
|
||||
def __init__(self, key, model_name, lang="Chinese", **kwargs):
|
||||
api_key = json.loads(key).get('api_key', '')
|
||||
api_version = json.loads(key).get('api_version', '2024-02-01')
|
||||
api_key = json.loads(key).get("api_key", "")
|
||||
api_version = json.loads(key).get("api_version", "2024-02-01")
|
||||
self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version)
|
||||
self.model_name = model_name
|
||||
self.lang = lang
|
||||
@ -212,10 +209,7 @@ class AzureGptV4(Base):
|
||||
if "text" in c:
|
||||
c["type"] = "text"
|
||||
|
||||
res = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=prompt
|
||||
)
|
||||
res = self.client.chat.completions.create(model=self.model_name, messages=prompt)
|
||||
return res.choices[0].message.content.strip(), res.usage.total_tokens
|
||||
|
||||
def describe_with_prompt(self, image, prompt=None):
|
||||
@ -230,8 +224,11 @@ class AzureGptV4(Base):
|
||||
|
||||
|
||||
class QWenCV(Base):
|
||||
_FACTORY_NAME = "Tongyi-Qianwen"
|
||||
|
||||
def __init__(self, key, model_name="qwen-vl-chat-v1", lang="Chinese", **kwargs):
|
||||
import dashscope
|
||||
|
||||
dashscope.api_key = key
|
||||
self.model_name = model_name
|
||||
self.lang = lang
|
||||
@ -247,12 +244,11 @@ class QWenCV(Base):
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"image": f"file://{path}"},
|
||||
{
|
||||
"image": f"file://{path}"
|
||||
},
|
||||
{
|
||||
"text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else
|
||||
"Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
|
||||
"text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。"
|
||||
if self.lang.lower() == "chinese"
|
||||
else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
|
||||
},
|
||||
],
|
||||
}
|
||||
@ -270,11 +266,9 @@ class QWenCV(Base):
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"image": f"file://{path}"},
|
||||
{
|
||||
"image": f"file://{path}"
|
||||
},
|
||||
{
|
||||
"text": prompt if prompt else vision_llm_describe_prompt(),
|
||||
"text": prompt if prompt else vision_llm_describe_prompt(),
|
||||
},
|
||||
],
|
||||
}
|
||||
@ -290,9 +284,10 @@ class QWenCV(Base):
|
||||
from http import HTTPStatus
|
||||
|
||||
from dashscope import MultiModalConversation
|
||||
|
||||
response = MultiModalConversation.call(model=self.model_name, messages=self.prompt(image))
|
||||
if response.status_code == HTTPStatus.OK:
|
||||
return response.output.choices[0]['message']['content'][0]["text"], response.usage.output_tokens
|
||||
return response.output.choices[0]["message"]["content"][0]["text"], response.usage.output_tokens
|
||||
return response.message, 0
|
||||
|
||||
def describe_with_prompt(self, image, prompt=None):
|
||||
@ -303,33 +298,36 @@ class QWenCV(Base):
|
||||
vision_prompt = self.vision_llm_prompt(image, prompt) if prompt else self.vision_llm_prompt(image)
|
||||
response = MultiModalConversation.call(model=self.model_name, messages=vision_prompt)
|
||||
if response.status_code == HTTPStatus.OK:
|
||||
return response.output.choices[0]['message']['content'][0]["text"], response.usage.output_tokens
|
||||
return response.output.choices[0]["message"]["content"][0]["text"], response.usage.output_tokens
|
||||
return response.message, 0
|
||||
|
||||
def chat(self, system, history, gen_conf, image=""):
|
||||
from http import HTTPStatus
|
||||
|
||||
from dashscope import MultiModalConversation
|
||||
|
||||
if system:
|
||||
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
|
||||
|
||||
for his in history:
|
||||
if his["role"] == "user":
|
||||
his["content"] = self.chat_prompt(his["content"], image)
|
||||
response = MultiModalConversation.call(model=self.model_name, messages=history,
|
||||
temperature=gen_conf.get("temperature", 0.3),
|
||||
top_p=gen_conf.get("top_p", 0.7))
|
||||
response = MultiModalConversation.call(
|
||||
model=self.model_name,
|
||||
messages=history,
|
||||
temperature=gen_conf.get("temperature", 0.3),
|
||||
top_p=gen_conf.get("top_p", 0.7),
|
||||
)
|
||||
|
||||
ans = ""
|
||||
tk_count = 0
|
||||
if response.status_code == HTTPStatus.OK:
|
||||
ans = response.output.choices[0]['message']['content']
|
||||
ans = response.output.choices[0]["message"]["content"]
|
||||
if isinstance(ans, list):
|
||||
ans = ans[0]["text"] if ans else ""
|
||||
tk_count += response.usage.total_tokens
|
||||
if response.output.choices[0].get("finish_reason", "") == "length":
|
||||
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
|
||||
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
||||
ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
||||
return ans, tk_count
|
||||
|
||||
return "**ERROR**: " + response.message, tk_count
|
||||
@ -338,6 +336,7 @@ class QWenCV(Base):
|
||||
from http import HTTPStatus
|
||||
|
||||
from dashscope import MultiModalConversation
|
||||
|
||||
if system:
|
||||
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
|
||||
|
||||
@ -348,24 +347,25 @@ class QWenCV(Base):
|
||||
ans = ""
|
||||
tk_count = 0
|
||||
try:
|
||||
response = MultiModalConversation.call(model=self.model_name, messages=history,
|
||||
temperature=gen_conf.get("temperature", 0.3),
|
||||
top_p=gen_conf.get("top_p", 0.7),
|
||||
stream=True)
|
||||
response = MultiModalConversation.call(
|
||||
model=self.model_name,
|
||||
messages=history,
|
||||
temperature=gen_conf.get("temperature", 0.3),
|
||||
top_p=gen_conf.get("top_p", 0.7),
|
||||
stream=True,
|
||||
)
|
||||
for resp in response:
|
||||
if resp.status_code == HTTPStatus.OK:
|
||||
cnt = resp.output.choices[0]['message']['content']
|
||||
cnt = resp.output.choices[0]["message"]["content"]
|
||||
if isinstance(cnt, list):
|
||||
cnt = cnt[0]["text"] if ans else ""
|
||||
ans += cnt
|
||||
tk_count = resp.usage.total_tokens
|
||||
if resp.output.choices[0].get("finish_reason", "") == "length":
|
||||
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
|
||||
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
||||
ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
||||
yield ans
|
||||
else:
|
||||
yield ans + "\n**ERROR**: " + resp.message if str(resp.message).find(
|
||||
"Access") < 0 else "Out of credit. Please set the API key in **settings > Model providers.**"
|
||||
yield ans + "\n**ERROR**: " + resp.message if str(resp.message).find("Access") < 0 else "Out of credit. Please set the API key in **settings > Model providers.**"
|
||||
except Exception as e:
|
||||
yield ans + "\n**ERROR**: " + str(e)
|
||||
|
||||
@ -373,6 +373,8 @@ class QWenCV(Base):
|
||||
|
||||
|
||||
class Zhipu4V(Base):
|
||||
_FACTORY_NAME = "ZHIPU-AI"
|
||||
|
||||
def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs):
|
||||
self.client = ZhipuAI(api_key=key)
|
||||
self.model_name = model_name
|
||||
@ -394,10 +396,7 @@ class Zhipu4V(Base):
|
||||
b64 = self.image2base64(image)
|
||||
vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
|
||||
|
||||
res = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=vision_prompt
|
||||
)
|
||||
res = self.client.chat.completions.create(model=self.model_name, messages=vision_prompt)
|
||||
return res.choices[0].message.content.strip(), res.usage.total_tokens
|
||||
|
||||
def chat(self, system, history, gen_conf, image=""):
|
||||
@ -412,7 +411,7 @@ class Zhipu4V(Base):
|
||||
model=self.model_name,
|
||||
messages=history,
|
||||
temperature=gen_conf.get("temperature", 0.3),
|
||||
top_p=gen_conf.get("top_p", 0.7)
|
||||
top_p=gen_conf.get("top_p", 0.7),
|
||||
)
|
||||
return response.choices[0].message.content.strip(), response.usage.total_tokens
|
||||
except Exception as e:
|
||||
@ -434,7 +433,7 @@ class Zhipu4V(Base):
|
||||
messages=history,
|
||||
temperature=gen_conf.get("temperature", 0.3),
|
||||
top_p=gen_conf.get("top_p", 0.7),
|
||||
stream=True
|
||||
stream=True,
|
||||
)
|
||||
for resp in response:
|
||||
if not resp.choices[0].delta.content:
|
||||
@ -442,8 +441,7 @@ class Zhipu4V(Base):
|
||||
delta = resp.choices[0].delta.content
|
||||
ans += delta
|
||||
if resp.choices[0].finish_reason == "length":
|
||||
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
|
||||
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
||||
ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
||||
tk_count = resp.usage.total_tokens
|
||||
if resp.choices[0].finish_reason == "stop":
|
||||
tk_count = resp.usage.total_tokens
|
||||
@ -455,6 +453,8 @@ class Zhipu4V(Base):
|
||||
|
||||
|
||||
class OllamaCV(Base):
|
||||
_FACTORY_NAME = "Ollama"
|
||||
|
||||
def __init__(self, key, model_name, lang="Chinese", **kwargs):
|
||||
self.client = Client(host=kwargs["base_url"])
|
||||
self.model_name = model_name
|
||||
@ -466,7 +466,7 @@ class OllamaCV(Base):
|
||||
response = self.client.generate(
|
||||
model=self.model_name,
|
||||
prompt=prompt[0]["content"][1]["text"],
|
||||
images=[image]
|
||||
images=[image],
|
||||
)
|
||||
ans = response["response"].strip()
|
||||
return ans, 128
|
||||
@ -507,7 +507,7 @@ class OllamaCV(Base):
|
||||
model=self.model_name,
|
||||
messages=history,
|
||||
options=options,
|
||||
keep_alive=-1
|
||||
keep_alive=-1,
|
||||
)
|
||||
|
||||
ans = response["message"]["content"].strip()
|
||||
@ -538,7 +538,7 @@ class OllamaCV(Base):
|
||||
messages=history,
|
||||
stream=True,
|
||||
options=options,
|
||||
keep_alive=-1
|
||||
keep_alive=-1,
|
||||
)
|
||||
for resp in response:
|
||||
if resp["done"]:
|
||||
@ -551,6 +551,8 @@ class OllamaCV(Base):
|
||||
|
||||
|
||||
class LocalAICV(GptV4):
|
||||
_FACTORY_NAME = "LocalAI"
|
||||
|
||||
def __init__(self, key, model_name, base_url, lang="Chinese"):
|
||||
if not base_url:
|
||||
raise ValueError("Local cv model url cannot be None")
|
||||
@ -561,6 +563,8 @@ class LocalAICV(GptV4):
|
||||
|
||||
|
||||
class XinferenceCV(Base):
|
||||
_FACTORY_NAME = "Xinference"
|
||||
|
||||
def __init__(self, key, model_name="", lang="Chinese", base_url=""):
|
||||
base_url = urljoin(base_url, "v1")
|
||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||
@ -570,10 +574,7 @@ class XinferenceCV(Base):
|
||||
def describe(self, image):
|
||||
b64 = self.image2base64(image)
|
||||
|
||||
res = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=self.prompt(b64)
|
||||
)
|
||||
res = self.client.chat.completions.create(model=self.model_name, messages=self.prompt(b64))
|
||||
return res.choices[0].message.content.strip(), res.usage.total_tokens
|
||||
|
||||
def describe_with_prompt(self, image, prompt=None):
|
||||
@ -588,8 +589,11 @@ class XinferenceCV(Base):
|
||||
|
||||
|
||||
class GeminiCV(Base):
|
||||
_FACTORY_NAME = "Gemini"
|
||||
|
||||
def __init__(self, key, model_name="gemini-1.0-pro-vision-latest", lang="Chinese", **kwargs):
|
||||
from google.generativeai import GenerativeModel, client
|
||||
|
||||
client.configure(api_key=key)
|
||||
_client = client.get_default_generative_client()
|
||||
self.model_name = model_name
|
||||
@ -599,18 +603,21 @@ class GeminiCV(Base):
|
||||
|
||||
def describe(self, image):
|
||||
from PIL.Image import open
|
||||
prompt = "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else \
|
||||
"Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out."
|
||||
|
||||
prompt = (
|
||||
"请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。"
|
||||
if self.lang.lower() == "chinese"
|
||||
else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out."
|
||||
)
|
||||
b64 = self.image2base64(image)
|
||||
img = open(BytesIO(base64.b64decode(b64)))
|
||||
input = [prompt, img]
|
||||
res = self.model.generate_content(
|
||||
input
|
||||
)
|
||||
res = self.model.generate_content(input)
|
||||
return res.text, res.usage_metadata.total_token_count
|
||||
|
||||
def describe_with_prompt(self, image, prompt=None):
|
||||
from PIL.Image import open
|
||||
|
||||
b64 = self.image2base64(image)
|
||||
vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
|
||||
img = open(BytesIO(base64.b64decode(b64)))
|
||||
@ -622,6 +629,7 @@ class GeminiCV(Base):
|
||||
|
||||
def chat(self, system, history, gen_conf, image=""):
|
||||
from transformers import GenerationConfig
|
||||
|
||||
if system:
|
||||
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
|
||||
try:
|
||||
@ -635,9 +643,7 @@ class GeminiCV(Base):
|
||||
his.pop("content")
|
||||
history[-1]["parts"].append("data:image/jpeg;base64," + image)
|
||||
|
||||
response = self.model.generate_content(history, generation_config=GenerationConfig(
|
||||
temperature=gen_conf.get("temperature", 0.3),
|
||||
top_p=gen_conf.get("top_p", 0.7)))
|
||||
response = self.model.generate_content(history, generation_config=GenerationConfig(temperature=gen_conf.get("temperature", 0.3), top_p=gen_conf.get("top_p", 0.7)))
|
||||
|
||||
ans = response.text
|
||||
return ans, response.usage_metadata.total_token_count
|
||||
@ -646,6 +652,7 @@ class GeminiCV(Base):
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf, image=""):
|
||||
from transformers import GenerationConfig
|
||||
|
||||
if system:
|
||||
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
|
||||
|
||||
@ -661,9 +668,11 @@ class GeminiCV(Base):
|
||||
his.pop("content")
|
||||
history[-1]["parts"].append("data:image/jpeg;base64," + image)
|
||||
|
||||
response = self.model.generate_content(history, generation_config=GenerationConfig(
|
||||
temperature=gen_conf.get("temperature", 0.3),
|
||||
top_p=gen_conf.get("top_p", 0.7)), stream=True)
|
||||
response = self.model.generate_content(
|
||||
history,
|
||||
generation_config=GenerationConfig(temperature=gen_conf.get("temperature", 0.3), top_p=gen_conf.get("top_p", 0.7)),
|
||||
stream=True,
|
||||
)
|
||||
|
||||
for resp in response:
|
||||
if not resp.text:
|
||||
@ -677,6 +686,8 @@ class GeminiCV(Base):
|
||||
|
||||
|
||||
class OpenRouterCV(GptV4):
|
||||
_FACTORY_NAME = "OpenRouter"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
key,
|
||||
@ -692,6 +703,8 @@ class OpenRouterCV(GptV4):
|
||||
|
||||
|
||||
class LocalCV(Base):
|
||||
_FACTORY_NAME = "Moonshot"
|
||||
|
||||
def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs):
|
||||
pass
|
||||
|
||||
@ -700,6 +713,8 @@ class LocalCV(Base):
|
||||
|
||||
|
||||
class NvidiaCV(Base):
|
||||
_FACTORY_NAME = "NVIDIA"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
key,
|
||||
@ -726,9 +741,7 @@ class NvidiaCV(Base):
|
||||
"content-type": "application/json",
|
||||
"Authorization": f"Bearer {self.key}",
|
||||
},
|
||||
json={
|
||||
"messages": self.prompt(b64)
|
||||
},
|
||||
json={"messages": self.prompt(b64)},
|
||||
)
|
||||
response = response.json()
|
||||
return (
|
||||
@ -774,10 +787,7 @@ class NvidiaCV(Base):
|
||||
return [
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
prompt if prompt else vision_llm_describe_prompt()
|
||||
)
|
||||
+ f' <img src="data:image/jpeg;base64,{b64}"/>',
|
||||
"content": (prompt if prompt else vision_llm_describe_prompt()) + f' <img src="data:image/jpeg;base64,{b64}"/>',
|
||||
}
|
||||
]
|
||||
|
||||
@ -791,6 +801,8 @@ class NvidiaCV(Base):
|
||||
|
||||
|
||||
class StepFunCV(GptV4):
|
||||
_FACTORY_NAME = "StepFun"
|
||||
|
||||
def __init__(self, key, model_name="step-1v-8k", lang="Chinese", base_url="https://api.stepfun.com/v1"):
|
||||
if not base_url:
|
||||
base_url = "https://api.stepfun.com/v1"
|
||||
@ -800,6 +812,8 @@ class StepFunCV(GptV4):
|
||||
|
||||
|
||||
class LmStudioCV(GptV4):
|
||||
_FACTORY_NAME = "LM-Studio"
|
||||
|
||||
def __init__(self, key, model_name, lang="Chinese", base_url=""):
|
||||
if not base_url:
|
||||
raise ValueError("Local llm url cannot be None")
|
||||
@ -810,6 +824,8 @@ class LmStudioCV(GptV4):
|
||||
|
||||
|
||||
class OpenAI_APICV(GptV4):
|
||||
_FACTORY_NAME = ["VLLM", "OpenAI-API-Compatible"]
|
||||
|
||||
def __init__(self, key, model_name, lang="Chinese", base_url=""):
|
||||
if not base_url:
|
||||
raise ValueError("url cannot be None")
|
||||
@ -820,6 +836,8 @@ class OpenAI_APICV(GptV4):
|
||||
|
||||
|
||||
class TogetherAICV(GptV4):
|
||||
_FACTORY_NAME = "TogetherAI"
|
||||
|
||||
def __init__(self, key, model_name, lang="Chinese", base_url="https://api.together.xyz/v1"):
|
||||
if not base_url:
|
||||
base_url = "https://api.together.xyz/v1"
|
||||
@ -827,20 +845,38 @@ class TogetherAICV(GptV4):
|
||||
|
||||
|
||||
class YiCV(GptV4):
|
||||
def __init__(self, key, model_name, lang="Chinese", base_url="https://api.lingyiwanwu.com/v1",):
|
||||
_FACTORY_NAME = "01.AI"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
key,
|
||||
model_name,
|
||||
lang="Chinese",
|
||||
base_url="https://api.lingyiwanwu.com/v1",
|
||||
):
|
||||
if not base_url:
|
||||
base_url = "https://api.lingyiwanwu.com/v1"
|
||||
super().__init__(key, model_name, lang, base_url)
|
||||
|
||||
|
||||
class SILICONFLOWCV(GptV4):
|
||||
def __init__(self, key, model_name, lang="Chinese", base_url="https://api.siliconflow.cn/v1",):
|
||||
_FACTORY_NAME = "SILICONFLOW"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
key,
|
||||
model_name,
|
||||
lang="Chinese",
|
||||
base_url="https://api.siliconflow.cn/v1",
|
||||
):
|
||||
if not base_url:
|
||||
base_url = "https://api.siliconflow.cn/v1"
|
||||
super().__init__(key, model_name, lang, base_url)
|
||||
|
||||
|
||||
class HunyuanCV(Base):
|
||||
_FACTORY_NAME = "Tencent Hunyuan"
|
||||
|
||||
def __init__(self, key, model_name, lang="Chinese", base_url=None):
|
||||
from tencentcloud.common import credential
|
||||
from tencentcloud.hunyuan.v20230901 import hunyuan_client
|
||||
@ -895,14 +931,13 @@ class HunyuanCV(Base):
|
||||
"Contents": [
|
||||
{
|
||||
"Type": "image_url",
|
||||
"ImageUrl": {
|
||||
"Url": f"data:image/jpeg;base64,{b64}"
|
||||
},
|
||||
"ImageUrl": {"Url": f"data:image/jpeg;base64,{b64}"},
|
||||
},
|
||||
{
|
||||
"Type": "text",
|
||||
"Text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else
|
||||
"Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
|
||||
"Text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。"
|
||||
if self.lang.lower() == "chinese"
|
||||
else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
|
||||
},
|
||||
],
|
||||
}
|
||||
@ -910,6 +945,8 @@ class HunyuanCV(Base):
|
||||
|
||||
|
||||
class AnthropicCV(Base):
|
||||
_FACTORY_NAME = "Anthropic"
|
||||
|
||||
def __init__(self, key, model_name, base_url=None):
|
||||
import anthropic
|
||||
|
||||
@ -933,38 +970,29 @@ class AnthropicCV(Base):
|
||||
"data": b64,
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": prompt
|
||||
}
|
||||
{"type": "text", "text": prompt},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
def describe(self, image):
|
||||
b64 = self.image2base64(image)
|
||||
prompt = self.prompt(b64,
|
||||
"请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else
|
||||
"Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out."
|
||||
)
|
||||
|
||||
response = self.client.messages.create(
|
||||
model=self.model_name,
|
||||
max_tokens=self.max_tokens,
|
||||
messages=prompt
|
||||
prompt = self.prompt(
|
||||
b64,
|
||||
"请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。"
|
||||
if self.lang.lower() == "chinese"
|
||||
else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
|
||||
)
|
||||
return response["content"][0]["text"].strip(), response["usage"]["input_tokens"]+response["usage"]["output_tokens"]
|
||||
|
||||
response = self.client.messages.create(model=self.model_name, max_tokens=self.max_tokens, messages=prompt)
|
||||
return response["content"][0]["text"].strip(), response["usage"]["input_tokens"] + response["usage"]["output_tokens"]
|
||||
|
||||
def describe_with_prompt(self, image, prompt=None):
|
||||
b64 = self.image2base64(image)
|
||||
prompt = self.prompt(b64, prompt if prompt else vision_llm_describe_prompt())
|
||||
|
||||
response = self.client.messages.create(
|
||||
model=self.model_name,
|
||||
max_tokens=self.max_tokens,
|
||||
messages=prompt
|
||||
)
|
||||
return response["content"][0]["text"].strip(), response["usage"]["input_tokens"]+response["usage"]["output_tokens"]
|
||||
response = self.client.messages.create(model=self.model_name, max_tokens=self.max_tokens, messages=prompt)
|
||||
return response["content"][0]["text"].strip(), response["usage"]["input_tokens"] + response["usage"]["output_tokens"]
|
||||
|
||||
def chat(self, system, history, gen_conf):
|
||||
if "presence_penalty" in gen_conf:
|
||||
@ -984,11 +1012,7 @@ class AnthropicCV(Base):
|
||||
).to_dict()
|
||||
ans = response["content"][0]["text"]
|
||||
if response["stop_reason"] == "max_tokens":
|
||||
ans += (
|
||||
"...\nFor the content length reason, it stopped, continue?"
|
||||
if is_english([ans])
|
||||
else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
||||
)
|
||||
ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
||||
return (
|
||||
ans,
|
||||
response["usage"]["input_tokens"] + response["usage"]["output_tokens"],
|
||||
@ -1014,7 +1038,7 @@ class AnthropicCV(Base):
|
||||
**gen_conf,
|
||||
)
|
||||
for res in response:
|
||||
if res.type == 'content_block_delta':
|
||||
if res.type == "content_block_delta":
|
||||
if res.delta.type == "thinking_delta" and res.delta.thinking:
|
||||
if ans.find("<think>") < 0:
|
||||
ans += "<think>"
|
||||
@ -1030,7 +1054,10 @@ class AnthropicCV(Base):
|
||||
|
||||
yield total_tokens
|
||||
|
||||
|
||||
class GPUStackCV(GptV4):
|
||||
_FACTORY_NAME = "GPUStack"
|
||||
|
||||
def __init__(self, key, model_name, lang="Chinese", base_url=""):
|
||||
if not base_url:
|
||||
raise ValueError("Local llm url cannot be None")
|
||||
@ -1041,11 +1068,13 @@ class GPUStackCV(GptV4):
|
||||
|
||||
|
||||
class GoogleCV(Base):
|
||||
_FACTORY_NAME = "Google Cloud"
|
||||
|
||||
def __init__(self, key, model_name, lang="Chinese", base_url=None, **kwargs):
|
||||
import base64
|
||||
|
||||
from google.oauth2 import service_account
|
||||
|
||||
|
||||
key = json.loads(key)
|
||||
access_token = json.loads(base64.b64decode(key.get("google_service_account_key", "")))
|
||||
project_id = key.get("google_project_id", "")
|
||||
@ -1079,9 +1108,12 @@ class GoogleCV(Base):
|
||||
self.client = glm.GenerativeModel(model_name=self.model_name)
|
||||
|
||||
def describe(self, image):
|
||||
prompt = "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else \
|
||||
"Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out."
|
||||
|
||||
prompt = (
|
||||
"请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。"
|
||||
if self.lang.lower() == "chinese"
|
||||
else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out."
|
||||
)
|
||||
|
||||
if "claude" in self.model_name:
|
||||
b64 = self.image2base64(image)
|
||||
vision_prompt = [
|
||||
@ -1096,28 +1128,22 @@ class GoogleCV(Base):
|
||||
"data": b64,
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": prompt
|
||||
}
|
||||
{"type": "text", "text": prompt},
|
||||
],
|
||||
}
|
||||
]
|
||||
response = self.client.messages.create(
|
||||
model=self.model_name,
|
||||
max_tokens=8192,
|
||||
messages=vision_prompt
|
||||
messages=vision_prompt,
|
||||
)
|
||||
return response.content[0].text.strip(), response.usage.input_tokens + response.usage.output_tokens
|
||||
else:
|
||||
import vertexai.generative_models as glm
|
||||
|
||||
|
||||
b64 = self.image2base64(image)
|
||||
# Create proper image part for Gemini
|
||||
image_part = glm.Part.from_data(
|
||||
data=base64.b64decode(b64),
|
||||
mime_type="image/jpeg"
|
||||
)
|
||||
image_part = glm.Part.from_data(data=base64.b64decode(b64), mime_type="image/jpeg")
|
||||
input = [prompt, image_part]
|
||||
res = self.client.generate_content(input)
|
||||
return res.text, res.usage_metadata.total_token_count
|
||||
@ -1137,29 +1163,19 @@ class GoogleCV(Base):
|
||||
"data": b64,
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": prompt if prompt else vision_llm_describe_prompt()
|
||||
}
|
||||
{"type": "text", "text": prompt if prompt else vision_llm_describe_prompt()},
|
||||
],
|
||||
}
|
||||
]
|
||||
response = self.client.messages.create(
|
||||
model=self.model_name,
|
||||
max_tokens=8192,
|
||||
messages=vision_prompt
|
||||
)
|
||||
response = self.client.messages.create(model=self.model_name, max_tokens=8192, messages=vision_prompt)
|
||||
return response.content[0].text.strip(), response.usage.input_tokens + response.usage.output_tokens
|
||||
else:
|
||||
import vertexai.generative_models as glm
|
||||
|
||||
|
||||
b64 = self.image2base64(image)
|
||||
vision_prompt = prompt if prompt else vision_llm_describe_prompt()
|
||||
# Create proper image part for Gemini
|
||||
image_part = glm.Part.from_data(
|
||||
data=base64.b64decode(b64),
|
||||
mime_type="image/jpeg"
|
||||
)
|
||||
image_part = glm.Part.from_data(data=base64.b64decode(b64), mime_type="image/jpeg")
|
||||
input = [vision_prompt, image_part]
|
||||
res = self.client.generate_content(input)
|
||||
return res.text, res.usage_metadata.total_token_count
|
||||
@ -1180,25 +1196,17 @@ class GoogleCV(Base):
|
||||
"data": image,
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": his["content"]
|
||||
}
|
||||
{"type": "text", "text": his["content"]},
|
||||
]
|
||||
|
||||
response = self.client.messages.create(
|
||||
model=self.model_name,
|
||||
max_tokens=8192,
|
||||
messages=history,
|
||||
temperature=gen_conf.get("temperature", 0.3),
|
||||
top_p=gen_conf.get("top_p", 0.7)
|
||||
)
|
||||
response = self.client.messages.create(model=self.model_name, max_tokens=8192, messages=history, temperature=gen_conf.get("temperature", 0.3), top_p=gen_conf.get("top_p", 0.7))
|
||||
return response.content[0].text.strip(), response.usage.input_tokens + response.usage.output_tokens
|
||||
except Exception as e:
|
||||
return "**ERROR**: " + str(e), 0
|
||||
else:
|
||||
import vertexai.generative_models as glm
|
||||
from transformers import GenerationConfig
|
||||
|
||||
if system:
|
||||
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
|
||||
try:
|
||||
@ -1210,20 +1218,15 @@ class GoogleCV(Base):
|
||||
if his["role"] == "user":
|
||||
his["parts"] = [his["content"]]
|
||||
his.pop("content")
|
||||
|
||||
|
||||
# Create proper image part for Gemini
|
||||
img_bytes = base64.b64decode(image)
|
||||
image_part = glm.Part.from_data(
|
||||
data=img_bytes,
|
||||
mime_type="image/jpeg"
|
||||
)
|
||||
image_part = glm.Part.from_data(data=img_bytes, mime_type="image/jpeg")
|
||||
history[-1]["parts"].append(image_part)
|
||||
|
||||
response = self.client.generate_content(history, generation_config=GenerationConfig(
|
||||
temperature=gen_conf.get("temperature", 0.3),
|
||||
top_p=gen_conf.get("top_p", 0.7)))
|
||||
response = self.client.generate_content(history, generation_config=GenerationConfig(temperature=gen_conf.get("temperature", 0.3), top_p=gen_conf.get("top_p", 0.7)))
|
||||
|
||||
ans = response.text
|
||||
return ans, response.usage_metadata.total_token_count
|
||||
except Exception as e:
|
||||
return "**ERROR**: " + str(e), 0
|
||||
return "**ERROR**: " + str(e), 0
|
||||
|
||||
Reference in New Issue
Block a user