mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Feat: add Gemini 3 Pro preview (#11361)
### What problem does this PR solve? Add Gemini 3 Pro preview. Change `GenerativeModel` to `genai`. ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -1429,6 +1429,13 @@
|
|||||||
"status": "1",
|
"status": "1",
|
||||||
"rank": "980",
|
"rank": "980",
|
||||||
"llm": [
|
"llm": [
|
||||||
|
{
|
||||||
|
"llm_name": "gemini-3-pro-preview",
|
||||||
|
"tags": "LLM,CHAT,1M,IMAGE2TEXT",
|
||||||
|
"max_tokens": 1048576,
|
||||||
|
"model_type": "image2text",
|
||||||
|
"is_tools": true
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"llm_name": "gemini-2.5-flash",
|
"llm_name": "gemini-2.5-flash",
|
||||||
"tags": "LLM,CHAT,1024K,IMAGE2TEXT",
|
"tags": "LLM,CHAT,1024K,IMAGE2TEXT",
|
||||||
@ -5474,4 +5481,4 @@
|
|||||||
]
|
]
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|||||||
@ -14,24 +14,27 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
import re
|
|
||||||
import base64
|
import base64
|
||||||
import json
|
import json
|
||||||
import os
|
|
||||||
import tempfile
|
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import tempfile
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from urllib.parse import urljoin
|
from urllib.parse import urljoin
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
from openai.lib.azure import AzureOpenAI
|
from openai.lib.azure import AzureOpenAI
|
||||||
from zhipuai import ZhipuAI
|
from zhipuai import ZhipuAI
|
||||||
|
|
||||||
|
from common.token_utils import num_tokens_from_string, total_token_count_from_response
|
||||||
from rag.nlp import is_english
|
from rag.nlp import is_english
|
||||||
from rag.prompts.generator import vision_llm_describe_prompt
|
from rag.prompts.generator import vision_llm_describe_prompt
|
||||||
from common.token_utils import num_tokens_from_string, total_token_count_from_response
|
|
||||||
|
|
||||||
class Base(ABC):
|
class Base(ABC):
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
@ -70,12 +73,7 @@ class Base(ABC):
|
|||||||
|
|
||||||
pmpt = [{"type": "text", "text": text}]
|
pmpt = [{"type": "text", "text": text}]
|
||||||
for img in images:
|
for img in images:
|
||||||
pmpt.append({
|
pmpt.append({"type": "image_url", "image_url": {"url": img if isinstance(img, str) and img.startswith("data:") else f"data:image/png;base64,{img}"}})
|
||||||
"type": "image_url",
|
|
||||||
"image_url": {
|
|
||||||
"url": img if isinstance(img, str) and img.startswith("data:") else f"data:image/png;base64,{img}"
|
|
||||||
}
|
|
||||||
})
|
|
||||||
return pmpt
|
return pmpt
|
||||||
|
|
||||||
def chat(self, system, history, gen_conf, images=None, **kwargs):
|
def chat(self, system, history, gen_conf, images=None, **kwargs):
|
||||||
@ -128,7 +126,7 @@ class Base(ABC):
|
|||||||
try:
|
try:
|
||||||
image.save(buffered, format="JPEG")
|
image.save(buffered, format="JPEG")
|
||||||
except Exception:
|
except Exception:
|
||||||
# reset buffer before saving PNG
|
# reset buffer before saving PNG
|
||||||
buffered.seek(0)
|
buffered.seek(0)
|
||||||
buffered.truncate()
|
buffered.truncate()
|
||||||
image.save(buffered, format="PNG")
|
image.save(buffered, format="PNG")
|
||||||
@ -158,7 +156,7 @@ class Base(ABC):
|
|||||||
try:
|
try:
|
||||||
image.save(buffered, format="JPEG")
|
image.save(buffered, format="JPEG")
|
||||||
except Exception:
|
except Exception:
|
||||||
# reset buffer before saving PNG
|
# reset buffer before saving PNG
|
||||||
buffered.seek(0)
|
buffered.seek(0)
|
||||||
buffered.truncate()
|
buffered.truncate()
|
||||||
image.save(buffered, format="PNG")
|
image.save(buffered, format="PNG")
|
||||||
@ -176,18 +174,13 @@ class Base(ABC):
|
|||||||
"请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。"
|
"请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。"
|
||||||
if self.lang.lower() == "chinese"
|
if self.lang.lower() == "chinese"
|
||||||
else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
|
else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
|
||||||
b64
|
b64,
|
||||||
)
|
),
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
def vision_llm_prompt(self, b64, prompt=None):
|
def vision_llm_prompt(self, b64, prompt=None):
|
||||||
return [
|
return [{"role": "user", "content": self._image_prompt(prompt if prompt else vision_llm_describe_prompt(), b64)}]
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": self._image_prompt(prompt if prompt else vision_llm_describe_prompt(), b64)
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class GptV4(Base):
|
class GptV4(Base):
|
||||||
@ -208,7 +201,7 @@ class GptV4(Base):
|
|||||||
model=self.model_name,
|
model=self.model_name,
|
||||||
messages=self.prompt(b64),
|
messages=self.prompt(b64),
|
||||||
extra_body=self.extra_body,
|
extra_body=self.extra_body,
|
||||||
unused = None,
|
unused=None,
|
||||||
)
|
)
|
||||||
return res.choices[0].message.content.strip(), total_token_count_from_response(res)
|
return res.choices[0].message.content.strip(), total_token_count_from_response(res)
|
||||||
|
|
||||||
@ -219,7 +212,7 @@ class GptV4(Base):
|
|||||||
messages=self.vision_llm_prompt(b64, prompt),
|
messages=self.vision_llm_prompt(b64, prompt),
|
||||||
extra_body=self.extra_body,
|
extra_body=self.extra_body,
|
||||||
)
|
)
|
||||||
return res.choices[0].message.content.strip(),total_token_count_from_response(res)
|
return res.choices[0].message.content.strip(), total_token_count_from_response(res)
|
||||||
|
|
||||||
|
|
||||||
class AzureGptV4(GptV4):
|
class AzureGptV4(GptV4):
|
||||||
@ -324,14 +317,12 @@ class Zhipu4V(GptV4):
|
|||||||
self.lang = lang
|
self.lang = lang
|
||||||
Base.__init__(self, **kwargs)
|
Base.__init__(self, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def _clean_conf(self, gen_conf):
|
def _clean_conf(self, gen_conf):
|
||||||
if "max_tokens" in gen_conf:
|
if "max_tokens" in gen_conf:
|
||||||
del gen_conf["max_tokens"]
|
del gen_conf["max_tokens"]
|
||||||
gen_conf = self._clean_conf_plealty(gen_conf)
|
gen_conf = self._clean_conf_plealty(gen_conf)
|
||||||
return gen_conf
|
return gen_conf
|
||||||
|
|
||||||
|
|
||||||
def _clean_conf_plealty(self, gen_conf):
|
def _clean_conf_plealty(self, gen_conf):
|
||||||
if "presence_penalty" in gen_conf:
|
if "presence_penalty" in gen_conf:
|
||||||
del gen_conf["presence_penalty"]
|
del gen_conf["presence_penalty"]
|
||||||
@ -339,24 +330,17 @@ class Zhipu4V(GptV4):
|
|||||||
del gen_conf["frequency_penalty"]
|
del gen_conf["frequency_penalty"]
|
||||||
return gen_conf
|
return gen_conf
|
||||||
|
|
||||||
|
|
||||||
def _request(self, msg, stream, gen_conf={}):
|
def _request(self, msg, stream, gen_conf={}):
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
self.base_url,
|
self.base_url,
|
||||||
json={
|
json={"model": self.model_name, "messages": msg, "stream": stream, **gen_conf},
|
||||||
"model": self.model_name,
|
headers={
|
||||||
"messages": msg,
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
"stream": stream,
|
"Content-Type": "application/json",
|
||||||
**gen_conf
|
|
||||||
},
|
},
|
||||||
headers= {
|
|
||||||
"Authorization": f"Bearer {self.api_key}",
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
|
|
||||||
def chat(self, system, history, gen_conf, images=None, stream=False, **kwargs):
|
def chat(self, system, history, gen_conf, images=None, stream=False, **kwargs):
|
||||||
if system and history and history[0].get("role") != "system":
|
if system and history and history[0].get("role") != "system":
|
||||||
history.insert(0, {"role": "system", "content": system})
|
history.insert(0, {"role": "system", "content": system})
|
||||||
@ -369,10 +353,9 @@ class Zhipu4V(GptV4):
|
|||||||
|
|
||||||
cleaned = re.sub(r"<\|(begin_of_box|end_of_box)\|>", "", content).strip()
|
cleaned = re.sub(r"<\|(begin_of_box|end_of_box)\|>", "", content).strip()
|
||||||
return cleaned, total_token_count_from_response(response)
|
return cleaned, total_token_count_from_response(response)
|
||||||
|
|
||||||
|
|
||||||
def chat_streamly(self, system, history, gen_conf, images=None, **kwargs):
|
def chat_streamly(self, system, history, gen_conf, images=None, **kwargs):
|
||||||
from rag.llm.chat_model import LENGTH_NOTIFICATION_CN, LENGTH_NOTIFICATION_EN
|
from rag.llm.chat_model import LENGTH_NOTIFICATION_CN, LENGTH_NOTIFICATION_EN
|
||||||
from rag.nlp import is_chinese
|
from rag.nlp import is_chinese
|
||||||
|
|
||||||
if system and history and history[0].get("role") != "system":
|
if system and history and history[0].get("role") != "system":
|
||||||
@ -402,44 +385,24 @@ class Zhipu4V(GptV4):
|
|||||||
|
|
||||||
yield tk_count
|
yield tk_count
|
||||||
|
|
||||||
|
|
||||||
def describe(self, image):
|
def describe(self, image):
|
||||||
return self.describe_with_prompt(image)
|
return self.describe_with_prompt(image)
|
||||||
|
|
||||||
|
|
||||||
def describe_with_prompt(self, image, prompt=None):
|
def describe_with_prompt(self, image, prompt=None):
|
||||||
b64 = self.image2base64(image)
|
b64 = self.image2base64(image)
|
||||||
if prompt is None:
|
if prompt is None:
|
||||||
prompt = "Describe this image."
|
prompt = "Describe this image."
|
||||||
|
|
||||||
# Chat messages
|
# Chat messages
|
||||||
messages = [
|
messages = [{"role": "user", "content": [{"type": "image_url", "image_url": {"url": b64}}, {"type": "text", "text": prompt}]}]
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "image_url",
|
|
||||||
"image_url": { "url": b64 }
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": prompt
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
resp = self.client.chat.completions.create(
|
resp = self.client.chat.completions.create(model=self.model_name, messages=messages, stream=False)
|
||||||
model=self.model_name,
|
|
||||||
messages=messages,
|
|
||||||
stream=False
|
|
||||||
)
|
|
||||||
|
|
||||||
content = resp.choices[0].message.content.strip()
|
content = resp.choices[0].message.content.strip()
|
||||||
cleaned = re.sub(r"<\|(begin_of_box|end_of_box)\|>", "", content).strip()
|
cleaned = re.sub(r"<\|(begin_of_box|end_of_box)\|>", "", content).strip()
|
||||||
|
|
||||||
return cleaned, num_tokens_from_string(cleaned)
|
return cleaned, num_tokens_from_string(cleaned)
|
||||||
|
|
||||||
|
|
||||||
class StepFunCV(GptV4):
|
class StepFunCV(GptV4):
|
||||||
_FACTORY_NAME = "StepFun"
|
_FACTORY_NAME = "StepFun"
|
||||||
@ -452,6 +415,7 @@ class StepFunCV(GptV4):
|
|||||||
self.lang = lang
|
self.lang = lang
|
||||||
Base.__init__(self, **kwargs)
|
Base.__init__(self, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class VolcEngineCV(GptV4):
|
class VolcEngineCV(GptV4):
|
||||||
_FACTORY_NAME = "VolcEngine"
|
_FACTORY_NAME = "VolcEngine"
|
||||||
|
|
||||||
@ -464,6 +428,7 @@ class VolcEngineCV(GptV4):
|
|||||||
self.lang = lang
|
self.lang = lang
|
||||||
Base.__init__(self, **kwargs)
|
Base.__init__(self, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class LmStudioCV(GptV4):
|
class LmStudioCV(GptV4):
|
||||||
_FACTORY_NAME = "LM-Studio"
|
_FACTORY_NAME = "LM-Studio"
|
||||||
|
|
||||||
@ -502,13 +467,7 @@ class TogetherAICV(GptV4):
|
|||||||
class YiCV(GptV4):
|
class YiCV(GptV4):
|
||||||
_FACTORY_NAME = "01.AI"
|
_FACTORY_NAME = "01.AI"
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, key, model_name, lang="Chinese", base_url="https://api.lingyiwanwu.com/v1", **kwargs):
|
||||||
self,
|
|
||||||
key,
|
|
||||||
model_name,
|
|
||||||
lang="Chinese",
|
|
||||||
base_url="https://api.lingyiwanwu.com/v1", **kwargs
|
|
||||||
):
|
|
||||||
if not base_url:
|
if not base_url:
|
||||||
base_url = "https://api.lingyiwanwu.com/v1"
|
base_url = "https://api.lingyiwanwu.com/v1"
|
||||||
super().__init__(key, model_name, lang, base_url, **kwargs)
|
super().__init__(key, model_name, lang, base_url, **kwargs)
|
||||||
@ -517,13 +476,7 @@ class YiCV(GptV4):
|
|||||||
class SILICONFLOWCV(GptV4):
|
class SILICONFLOWCV(GptV4):
|
||||||
_FACTORY_NAME = "SILICONFLOW"
|
_FACTORY_NAME = "SILICONFLOW"
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, key, model_name, lang="Chinese", base_url="https://api.siliconflow.cn/v1", **kwargs):
|
||||||
self,
|
|
||||||
key,
|
|
||||||
model_name,
|
|
||||||
lang="Chinese",
|
|
||||||
base_url="https://api.siliconflow.cn/v1", **kwargs
|
|
||||||
):
|
|
||||||
if not base_url:
|
if not base_url:
|
||||||
base_url = "https://api.siliconflow.cn/v1"
|
base_url = "https://api.siliconflow.cn/v1"
|
||||||
super().__init__(key, model_name, lang, base_url, **kwargs)
|
super().__init__(key, model_name, lang, base_url, **kwargs)
|
||||||
@ -532,13 +485,7 @@ class SILICONFLOWCV(GptV4):
|
|||||||
class OpenRouterCV(GptV4):
|
class OpenRouterCV(GptV4):
|
||||||
_FACTORY_NAME = "OpenRouter"
|
_FACTORY_NAME = "OpenRouter"
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, key, model_name, lang="Chinese", base_url="https://openrouter.ai/api/v1", **kwargs):
|
||||||
self,
|
|
||||||
key,
|
|
||||||
model_name,
|
|
||||||
lang="Chinese",
|
|
||||||
base_url="https://openrouter.ai/api/v1", **kwargs
|
|
||||||
):
|
|
||||||
if not base_url:
|
if not base_url:
|
||||||
base_url = "https://openrouter.ai/api/v1"
|
base_url = "https://openrouter.ai/api/v1"
|
||||||
api_key = json.loads(key).get("api_key", "")
|
api_key = json.loads(key).get("api_key", "")
|
||||||
@ -549,6 +496,7 @@ class OpenRouterCV(GptV4):
|
|||||||
provider_order = json.loads(key).get("provider_order", "")
|
provider_order = json.loads(key).get("provider_order", "")
|
||||||
self.extra_body = {}
|
self.extra_body = {}
|
||||||
if provider_order:
|
if provider_order:
|
||||||
|
|
||||||
def _to_order_list(x):
|
def _to_order_list(x):
|
||||||
if x is None:
|
if x is None:
|
||||||
return []
|
return []
|
||||||
@ -557,6 +505,7 @@ class OpenRouterCV(GptV4):
|
|||||||
if isinstance(x, (list, tuple)):
|
if isinstance(x, (list, tuple)):
|
||||||
return [str(s).strip() for s in x if str(s).strip()]
|
return [str(s).strip() for s in x if str(s).strip()]
|
||||||
return []
|
return []
|
||||||
|
|
||||||
provider_cfg = {}
|
provider_cfg = {}
|
||||||
provider_order = _to_order_list(provider_order)
|
provider_order = _to_order_list(provider_order)
|
||||||
provider_cfg["order"] = provider_order
|
provider_cfg["order"] = provider_order
|
||||||
@ -616,18 +565,18 @@ class OllamaCV(Base):
|
|||||||
|
|
||||||
def __init__(self, key, model_name, lang="Chinese", **kwargs):
|
def __init__(self, key, model_name, lang="Chinese", **kwargs):
|
||||||
from ollama import Client
|
from ollama import Client
|
||||||
|
|
||||||
self.client = Client(host=kwargs["base_url"])
|
self.client = Client(host=kwargs["base_url"])
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.lang = lang
|
self.lang = lang
|
||||||
self.keep_alive = kwargs.get("ollama_keep_alive", int(os.environ.get("OLLAMA_KEEP_ALIVE", -1)))
|
self.keep_alive = kwargs.get("ollama_keep_alive", int(os.environ.get("OLLAMA_KEEP_ALIVE", -1)))
|
||||||
Base.__init__(self, **kwargs)
|
Base.__init__(self, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def _clean_img(self, img):
|
def _clean_img(self, img):
|
||||||
if not isinstance(img, str):
|
if not isinstance(img, str):
|
||||||
return img
|
return img
|
||||||
|
|
||||||
#remove the header like "data/*;base64,"
|
# remove the header like "data/*;base64,"
|
||||||
if img.startswith("data:") and ";base64," in img:
|
if img.startswith("data:") and ";base64," in img:
|
||||||
img = img.split(";base64,")[1]
|
img = img.split(";base64,")[1]
|
||||||
return img
|
return img
|
||||||
@ -687,12 +636,7 @@ class OllamaCV(Base):
|
|||||||
|
|
||||||
def chat(self, system, history, gen_conf, images=None, **kwargs):
|
def chat(self, system, history, gen_conf, images=None, **kwargs):
|
||||||
try:
|
try:
|
||||||
response = self.client.chat(
|
response = self.client.chat(model=self.model_name, messages=self._form_history(system, history, images), options=self._clean_conf(gen_conf), keep_alive=self.keep_alive)
|
||||||
model=self.model_name,
|
|
||||||
messages=self._form_history(system, history, images),
|
|
||||||
options=self._clean_conf(gen_conf),
|
|
||||||
keep_alive=self.keep_alive
|
|
||||||
)
|
|
||||||
|
|
||||||
ans = response["message"]["content"].strip()
|
ans = response["message"]["content"].strip()
|
||||||
return ans, response["eval_count"] + response.get("prompt_eval_count", 0)
|
return ans, response["eval_count"] + response.get("prompt_eval_count", 0)
|
||||||
@ -702,13 +646,7 @@ class OllamaCV(Base):
|
|||||||
def chat_streamly(self, system, history, gen_conf, images=None, **kwargs):
|
def chat_streamly(self, system, history, gen_conf, images=None, **kwargs):
|
||||||
ans = ""
|
ans = ""
|
||||||
try:
|
try:
|
||||||
response = self.client.chat(
|
response = self.client.chat(model=self.model_name, messages=self._form_history(system, history, images), stream=True, options=self._clean_conf(gen_conf), keep_alive=self.keep_alive)
|
||||||
model=self.model_name,
|
|
||||||
messages=self._form_history(system, history, images),
|
|
||||||
stream=True,
|
|
||||||
options=self._clean_conf(gen_conf),
|
|
||||||
keep_alive=self.keep_alive
|
|
||||||
)
|
|
||||||
for resp in response:
|
for resp in response:
|
||||||
if resp["done"]:
|
if resp["done"]:
|
||||||
yield resp.get("prompt_eval_count", 0) + resp.get("eval_count", 0)
|
yield resp.get("prompt_eval_count", 0) + resp.get("eval_count", 0)
|
||||||
@ -723,29 +661,80 @@ class GeminiCV(Base):
|
|||||||
_FACTORY_NAME = "Gemini"
|
_FACTORY_NAME = "Gemini"
|
||||||
|
|
||||||
def __init__(self, key, model_name="gemini-1.0-pro-vision-latest", lang="Chinese", **kwargs):
|
def __init__(self, key, model_name="gemini-1.0-pro-vision-latest", lang="Chinese", **kwargs):
|
||||||
from google.generativeai import GenerativeModel, client
|
from google import genai
|
||||||
|
|
||||||
client.configure(api_key=key)
|
self.api_key = key
|
||||||
_client = client.get_default_generative_client()
|
|
||||||
self.api_key=key
|
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.model = GenerativeModel(model_name=self.model_name)
|
self.client = genai.Client(api_key=key)
|
||||||
self.model._client = _client
|
|
||||||
self.lang = lang
|
self.lang = lang
|
||||||
Base.__init__(self, **kwargs)
|
Base.__init__(self, **kwargs)
|
||||||
|
logging.info(f"[GeminiCV] Initialized with model={self.model_name} lang={self.lang}")
|
||||||
|
|
||||||
|
def _image_to_part(self, image):
|
||||||
|
from google.genai import types
|
||||||
|
|
||||||
|
if isinstance(image, str) and image.startswith("data:") and ";base64," in image:
|
||||||
|
header, b64data = image.split(",", 1)
|
||||||
|
mime = header.split(":", 1)[1].split(";", 1)[0]
|
||||||
|
data = base64.b64decode(b64data)
|
||||||
|
else:
|
||||||
|
data_url = self.image2base64(image)
|
||||||
|
header, b64data = data_url.split(",", 1)
|
||||||
|
mime = header.split(":", 1)[1].split(";", 1)[0]
|
||||||
|
data = base64.b64decode(b64data)
|
||||||
|
|
||||||
|
return types.Part(
|
||||||
|
inline_data=types.Blob(
|
||||||
|
mime_type=mime,
|
||||||
|
data=data,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def _form_history(self, system, history, images=None):
|
def _form_history(self, system, history, images=None):
|
||||||
hist = []
|
from google.genai import types
|
||||||
if system:
|
|
||||||
hist.append({"role": "user", "parts": [system, history[0]["content"]]})
|
contents = []
|
||||||
|
images = images or []
|
||||||
|
system_len = len(system) if isinstance(system, str) else 0
|
||||||
|
history_len = len(history) if history else 0
|
||||||
|
images_len = len(images)
|
||||||
|
logging.info(f"[GeminiCV] _form_history called: system_len={system_len} history_len={history_len} images_len={images_len}")
|
||||||
|
|
||||||
|
image_parts = []
|
||||||
for img in images:
|
for img in images:
|
||||||
hist[0]["parts"].append(("data:image/jpeg;base64," + img) if img[:4]!="data" else img)
|
try:
|
||||||
for h in history[1:]:
|
image_parts.append(self._image_to_part(img))
|
||||||
hist.append({"role": "user" if h["role"]=="user" else "model", "parts": [h["content"]]})
|
except Exception:
|
||||||
return hist
|
continue
|
||||||
|
|
||||||
|
remaining_history = history or []
|
||||||
|
if system or remaining_history:
|
||||||
|
parts = []
|
||||||
|
if system:
|
||||||
|
parts.append(types.Part(text=system))
|
||||||
|
if remaining_history:
|
||||||
|
first = remaining_history[0]
|
||||||
|
parts.append(types.Part(text=first.get("content", "")))
|
||||||
|
remaining_history = remaining_history[1:]
|
||||||
|
parts.extend(image_parts)
|
||||||
|
contents.append(types.Content(role="user", parts=parts))
|
||||||
|
elif image_parts:
|
||||||
|
contents.append(types.Content(role="user", parts=image_parts))
|
||||||
|
|
||||||
|
role_map = {"user": "user", "assistant": "model", "system": "user"}
|
||||||
|
for h in remaining_history:
|
||||||
|
role = role_map.get(h.get("role"), "user")
|
||||||
|
contents.append(
|
||||||
|
types.Content(
|
||||||
|
role=role,
|
||||||
|
parts=[types.Part(text=h.get("content", ""))],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return contents
|
||||||
|
|
||||||
def describe(self, image):
|
def describe(self, image):
|
||||||
from PIL.Image import open
|
from google.genai import types
|
||||||
|
|
||||||
prompt = (
|
prompt = (
|
||||||
"请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。"
|
"请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。"
|
||||||
@ -753,74 +742,104 @@ class GeminiCV(Base):
|
|||||||
else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out."
|
else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out."
|
||||||
)
|
)
|
||||||
|
|
||||||
if image is bytes:
|
contents = [
|
||||||
with BytesIO(image) as bio:
|
types.Content(
|
||||||
with open(bio) as img:
|
role="user",
|
||||||
input = [prompt, img]
|
parts=[
|
||||||
res = self.model.generate_content(input)
|
types.Part(text=prompt),
|
||||||
return res.text, total_token_count_from_response(res)
|
self._image_to_part(image),
|
||||||
else:
|
],
|
||||||
b64 = self.image2base64_rawvalue(image)
|
)
|
||||||
with BytesIO(base64.b64decode(b64)) as bio:
|
]
|
||||||
with open(bio) as img:
|
|
||||||
input = [prompt, img]
|
res = self.client.models.generate_content(
|
||||||
res = self.model.generate_content(input)
|
model=self.model_name,
|
||||||
return res.text, total_token_count_from_response(res)
|
contents=contents,
|
||||||
|
)
|
||||||
|
return res.text, total_token_count_from_response(res)
|
||||||
|
|
||||||
def describe_with_prompt(self, image, prompt=None):
|
def describe_with_prompt(self, image, prompt=None):
|
||||||
from PIL.Image import open
|
from google.genai import types
|
||||||
|
|
||||||
vision_prompt = prompt if prompt else vision_llm_describe_prompt()
|
vision_prompt = prompt if prompt else vision_llm_describe_prompt()
|
||||||
|
|
||||||
if image is bytes:
|
contents = [
|
||||||
with BytesIO(image) as bio:
|
types.Content(
|
||||||
with open(bio) as img:
|
role="user",
|
||||||
input = [vision_prompt, img]
|
parts=[
|
||||||
res = self.model.generate_content(input)
|
types.Part(text=vision_prompt),
|
||||||
return res.text, total_token_count_from_response(res)
|
self._image_to_part(image),
|
||||||
else:
|
],
|
||||||
b64 = self.image2base64_rawvalue(image)
|
)
|
||||||
with BytesIO(base64.b64decode(b64)) as bio:
|
]
|
||||||
with open(bio) as img:
|
|
||||||
input = [vision_prompt, img]
|
|
||||||
res = self.model.generate_content(input)
|
|
||||||
return res.text, total_token_count_from_response(res)
|
|
||||||
|
|
||||||
|
res = self.client.models.generate_content(
|
||||||
|
model=self.model_name,
|
||||||
|
contents=contents,
|
||||||
|
)
|
||||||
|
return res.text, total_token_count_from_response(res)
|
||||||
|
|
||||||
def chat(self, system, history, gen_conf, images=None, video_bytes=None, filename="", **kwargs):
|
def chat(self, system, history, gen_conf, images=None, video_bytes=None, filename="", **kwargs):
|
||||||
if video_bytes:
|
if video_bytes:
|
||||||
try:
|
try:
|
||||||
|
size = len(video_bytes) if video_bytes else 0
|
||||||
|
logging.info(f"[GeminiCV] chat called with video: filename={filename} size={size}")
|
||||||
summary, summary_num_tokens = self._process_video(video_bytes, filename)
|
summary, summary_num_tokens = self._process_video(video_bytes, filename)
|
||||||
return summary, summary_num_tokens
|
return summary, summary_num_tokens
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
logging.info(f"[GeminiCV] chat video error: {e}")
|
||||||
return "**ERROR**: " + str(e), 0
|
return "**ERROR**: " + str(e), 0
|
||||||
|
|
||||||
generation_config = dict(temperature=gen_conf.get("temperature", 0.3), top_p=gen_conf.get("top_p", 0.7))
|
from google.genai import types
|
||||||
|
|
||||||
|
history_len = len(history) if history else 0
|
||||||
|
images_len = len(images) if images else 0
|
||||||
|
logging.info(f"[GeminiCV] chat called: history_len={history_len} images_len={images_len} gen_conf={gen_conf}")
|
||||||
|
|
||||||
|
generation_config = types.GenerateContentConfig(
|
||||||
|
temperature=gen_conf.get("temperature", 0.3),
|
||||||
|
top_p=gen_conf.get("top_p", 0.7),
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
response = self.model.generate_content(
|
response = self.client.models.generate_content(
|
||||||
self._form_history(system, history, images),
|
model=self.model_name,
|
||||||
generation_config=generation_config)
|
contents=self._form_history(system, history, images),
|
||||||
|
config=generation_config,
|
||||||
|
)
|
||||||
ans = response.text
|
ans = response.text
|
||||||
return ans, total_token_count_from_response(ans)
|
logging.info("[GeminiCV] chat completed")
|
||||||
|
return ans, total_token_count_from_response(response)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
logging.warning(f"[GeminiCV] chat error: {e}")
|
||||||
return "**ERROR**: " + str(e), 0
|
return "**ERROR**: " + str(e), 0
|
||||||
|
|
||||||
def chat_streamly(self, system, history, gen_conf, images=None, **kwargs):
|
def chat_streamly(self, system, history, gen_conf, images=None, **kwargs):
|
||||||
ans = ""
|
ans = ""
|
||||||
response = None
|
response = None
|
||||||
try:
|
try:
|
||||||
generation_config = dict(temperature=gen_conf.get("temperature", 0.3), top_p=gen_conf.get("top_p", 0.7))
|
from google.genai import types
|
||||||
response = self.model.generate_content(
|
|
||||||
self._form_history(system, history, images),
|
generation_config = types.GenerateContentConfig(
|
||||||
generation_config=generation_config,
|
temperature=gen_conf.get("temperature", 0.3),
|
||||||
stream=True,
|
top_p=gen_conf.get("top_p", 0.7),
|
||||||
|
)
|
||||||
|
history_len = len(history) if history else 0
|
||||||
|
images_len = len(images) if images else 0
|
||||||
|
logging.info(f"[GeminiCV] chat_streamly called: history_len={history_len} images_len={images_len} gen_conf={gen_conf}")
|
||||||
|
|
||||||
|
response_stream = self.client.models.generate_content_stream(
|
||||||
|
model=self.model_name,
|
||||||
|
contents=self._form_history(system, history, images),
|
||||||
|
config=generation_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
for resp in response:
|
for chunk in response_stream:
|
||||||
if not resp.text:
|
if chunk.text:
|
||||||
continue
|
ans += chunk.text
|
||||||
ans = resp.text
|
yield chunk.text
|
||||||
yield ans
|
logging.info("[GeminiCV] chat_streamly completed")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
logging.warning(f"[GeminiCV] chat_streamly error: {e}")
|
||||||
yield ans + "\n**ERROR**: " + str(e)
|
yield ans + "\n**ERROR**: " + str(e)
|
||||||
|
|
||||||
yield total_token_count_from_response(response)
|
yield total_token_count_from_response(response)
|
||||||
@ -830,17 +849,15 @@ class GeminiCV(Base):
|
|||||||
from google.genai import types
|
from google.genai import types
|
||||||
|
|
||||||
video_size_mb = len(video_bytes) / (1024 * 1024)
|
video_size_mb = len(video_bytes) / (1024 * 1024)
|
||||||
client = genai.Client(api_key=self.api_key)
|
client = self.client if hasattr(self, "client") else genai.Client(api_key=self.api_key)
|
||||||
|
logging.info(f"[GeminiCV] _process_video called: filename={filename} size_mb={video_size_mb:.2f}")
|
||||||
|
|
||||||
tmp_path = None
|
tmp_path = None
|
||||||
try:
|
try:
|
||||||
if video_size_mb <= 20:
|
if video_size_mb <= 20:
|
||||||
response = client.models.generate_content(
|
response = client.models.generate_content(
|
||||||
model="models/gemini-2.5-flash",
|
model="models/gemini-2.5-flash",
|
||||||
contents=types.Content(parts=[
|
contents=types.Content(parts=[types.Part(inline_data=types.Blob(data=video_bytes, mime_type="video/mp4")), types.Part(text="Please summarize the video in proper sentences.")]),
|
||||||
types.Part(inline_data=types.Blob(data=video_bytes, mime_type="video/mp4")),
|
|
||||||
types.Part(text="Please summarize the video in proper sentences.")
|
|
||||||
])
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logging.info(f"Video size {video_size_mb:.2f}MB exceeds 20MB. Using Files API...")
|
logging.info(f"Video size {video_size_mb:.2f}MB exceeds 20MB. Using Files API...")
|
||||||
@ -850,16 +867,13 @@ class GeminiCV(Base):
|
|||||||
tmp_path = Path(tmp.name)
|
tmp_path = Path(tmp.name)
|
||||||
uploaded_file = client.files.upload(file=tmp_path)
|
uploaded_file = client.files.upload(file=tmp_path)
|
||||||
|
|
||||||
response = client.models.generate_content(
|
response = client.models.generate_content(model="gemini-2.5-flash", contents=[uploaded_file, "Please summarize this video in proper sentences."])
|
||||||
model="gemini-2.5-flash",
|
|
||||||
contents=[uploaded_file, "Please summarize this video in proper sentences."]
|
|
||||||
)
|
|
||||||
|
|
||||||
summary = response.text or ""
|
summary = response.text or ""
|
||||||
logging.info(f"Video summarized: {summary[:32]}...")
|
logging.info(f"[GeminiCV] Video summarized: {summary[:32]}...")
|
||||||
return summary, num_tokens_from_string(summary)
|
return summary, num_tokens_from_string(summary)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Video processing failed: {e}")
|
logging.warning(f"[GeminiCV] Video processing failed: {e}")
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
if tmp_path and tmp_path.exists():
|
if tmp_path and tmp_path.exists():
|
||||||
@ -869,13 +883,7 @@ class GeminiCV(Base):
|
|||||||
class NvidiaCV(Base):
|
class NvidiaCV(Base):
|
||||||
_FACTORY_NAME = "NVIDIA"
|
_FACTORY_NAME = "NVIDIA"
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, key, model_name, lang="Chinese", base_url="https://ai.api.nvidia.com/v1/vlm", **kwargs):
|
||||||
self,
|
|
||||||
key,
|
|
||||||
model_name,
|
|
||||||
lang="Chinese",
|
|
||||||
base_url="https://ai.api.nvidia.com/v1/vlm", **kwargs
|
|
||||||
):
|
|
||||||
if not base_url:
|
if not base_url:
|
||||||
base_url = ("https://ai.api.nvidia.com/v1/vlm",)
|
base_url = ("https://ai.api.nvidia.com/v1/vlm",)
|
||||||
self.lang = lang
|
self.lang = lang
|
||||||
@ -920,9 +928,7 @@ class NvidiaCV(Base):
|
|||||||
"content-type": "application/json",
|
"content-type": "application/json",
|
||||||
"Authorization": f"Bearer {self.key}",
|
"Authorization": f"Bearer {self.key}",
|
||||||
},
|
},
|
||||||
json={
|
json={"messages": msg, **gen_conf},
|
||||||
"messages": msg, **gen_conf
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
@ -930,18 +936,12 @@ class NvidiaCV(Base):
|
|||||||
b64 = self.image2base64(image)
|
b64 = self.image2base64(image)
|
||||||
vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
|
vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
|
||||||
response = self._request(vision_prompt)
|
response = self._request(vision_prompt)
|
||||||
return (
|
return (response["choices"][0]["message"]["content"].strip(), total_token_count_from_response(response))
|
||||||
response["choices"][0]["message"]["content"].strip(),
|
|
||||||
total_token_count_from_response(response)
|
|
||||||
)
|
|
||||||
|
|
||||||
def chat(self, system, history, gen_conf, images=None, **kwargs):
|
def chat(self, system, history, gen_conf, images=None, **kwargs):
|
||||||
try:
|
try:
|
||||||
response = self._request(self._form_history(system, history, images), gen_conf)
|
response = self._request(self._form_history(system, history, images), gen_conf)
|
||||||
return (
|
return (response["choices"][0]["message"]["content"].strip(), total_token_count_from_response(response))
|
||||||
response["choices"][0]["message"]["content"].strip(),
|
|
||||||
total_token_count_from_response(response)
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return "**ERROR**: " + str(e), 0
|
return "**ERROR**: " + str(e), 0
|
||||||
|
|
||||||
@ -950,7 +950,7 @@ class NvidiaCV(Base):
|
|||||||
try:
|
try:
|
||||||
response = self._request(self._form_history(system, history, images), gen_conf)
|
response = self._request(self._form_history(system, history, images), gen_conf)
|
||||||
cnt = response["choices"][0]["message"]["content"]
|
cnt = response["choices"][0]["message"]["content"]
|
||||||
total_tokens += total_token_count_from_response(response)
|
total_tokens += total_token_count_from_response(response)
|
||||||
for resp in cnt:
|
for resp in cnt:
|
||||||
yield resp
|
yield resp
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -978,14 +978,15 @@ class AnthropicCV(Base):
|
|||||||
return text
|
return text
|
||||||
pmpt = [{"type": "text", "text": text}]
|
pmpt = [{"type": "text", "text": text}]
|
||||||
for img in images:
|
for img in images:
|
||||||
pmpt.append({
|
pmpt.append(
|
||||||
"type": "image",
|
{
|
||||||
"source": {
|
"type": "image",
|
||||||
"type": "base64",
|
"source": {
|
||||||
"media_type": (img.split(":")[1].split(";")[0] if isinstance(img, str) and img[:4] == "data" else "image/png"),
|
"type": "base64",
|
||||||
"data": (img.split(",")[1] if isinstance(img, str) and img[:4] == "data" else img)
|
"media_type": (img.split(":")[1].split(";")[0] if isinstance(img, str) and img[:4] == "data" else "image/png"),
|
||||||
},
|
"data": (img.split(",")[1] if isinstance(img, str) and img[:4] == "data" else img),
|
||||||
}
|
},
|
||||||
|
}
|
||||||
)
|
)
|
||||||
return pmpt
|
return pmpt
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user