Feat: Gemini supports video parsing (#10671)

### What problem does this PR solve?

Gemini supports video parsing.


![img_v3_02r8_adbd5adc-d665-4756-9a00-3ae0f12224fg](https://github.com/user-attachments/assets/30d8d296-c336-4b55-9823-803979e705ca)


![img_v3_02r8_ab60c046-1727-4029-ad2e-66097fd3ccbg](https://github.com/user-attachments/assets/441b1487-a970-427e-98b6-6e1e002f2bad)

Close: #10617

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Yongteng Lei
2025-10-20 16:49:47 +08:00
committed by GitHub
parent de46b0d46e
commit 5b2e5dd334
4 changed files with 108 additions and 38 deletions

View File

@ -205,7 +205,7 @@ class LLMBundle(LLM4Tenant):
return txt return txt
return txt[last_think_end + len("</think>") :] return txt[last_think_end + len("</think>") :]
@staticmethod @staticmethod
def _clean_param(chat_partial, **kwargs): def _clean_param(chat_partial, **kwargs):
func = chat_partial.func func = chat_partial.func
@ -222,15 +222,15 @@ class LLMBundle(LLM4Tenant):
if not support_var_args: if not support_var_args:
use_kwargs = {k: v for k, v in kwargs.items() if k in keyword_args} use_kwargs = {k: v for k, v in kwargs.items() if k in keyword_args}
return use_kwargs return use_kwargs
def chat(self, system: str, history: list, gen_conf: dict = {}, **kwargs) -> str: def chat(self, system: str, history: list, gen_conf: dict = {}, **kwargs) -> str:
if self.langfuse: if self.langfuse:
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat", model=self.llm_name, input={"system": system, "history": history}) generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat", model=self.llm_name, input={"system": system, "history": history})
chat_partial = partial(self.mdl.chat, system, history, gen_conf) chat_partial = partial(self.mdl.chat, system, history, gen_conf, **kwargs)
if self.is_tools and self.mdl.is_tools: if self.is_tools and self.mdl.is_tools:
chat_partial = partial(self.mdl.chat_with_tools, system, history, gen_conf) chat_partial = partial(self.mdl.chat_with_tools, system, history, gen_conf, **kwargs)
use_kwargs = self._clean_param(chat_partial, **kwargs) use_kwargs = self._clean_param(chat_partial, **kwargs)
txt, used_tokens = chat_partial(**use_kwargs) txt, used_tokens = chat_partial(**use_kwargs)
txt = self._remove_reasoning_content(txt) txt = self._remove_reasoning_content(txt)

View File

@ -1345,35 +1345,35 @@
"llm_name": "gemini-2.5-flash", "llm_name": "gemini-2.5-flash",
"tags": "LLM,CHAT,1024K,IMAGE2TEXT", "tags": "LLM,CHAT,1024K,IMAGE2TEXT",
"max_tokens": 1048576, "max_tokens": 1048576,
"model_type": "chat", "model_type": "image2text",
"is_tools": true "is_tools": true
}, },
{ {
"llm_name": "gemini-2.5-pro", "llm_name": "gemini-2.5-pro",
"tags": "LLM,CHAT,IMAGE2TEXT,1024K", "tags": "LLM,CHAT,IMAGE2TEXT,1024K",
"max_tokens": 1048576, "max_tokens": 1048576,
"model_type": "chat", "model_type": "image2text",
"is_tools": true "is_tools": true
}, },
{ {
"llm_name": "gemini-2.5-flash-lite", "llm_name": "gemini-2.5-flash-lite",
"tags": "LLM,CHAT,1024K,IMAGE2TEXT", "tags": "LLM,CHAT,1024K,IMAGE2TEXT",
"max_tokens": 1048576, "max_tokens": 1048576,
"model_type": "chat", "model_type": "image2text",
"is_tools": true "is_tools": true
}, },
{ {
"llm_name": "gemini-2.0-flash", "llm_name": "gemini-2.0-flash",
"tags": "LLM,CHAT,1024K", "tags": "LLM,CHAT,1024K",
"max_tokens": 1048576, "max_tokens": 1048576,
"model_type": "chat", "model_type": "image2text",
"is_tools": true "is_tools": true
}, },
{ {
"llm_name": "gemini-2.0-flash-lite", "llm_name": "gemini-2.0-flash-lite",
"tags": "LLM,CHAT,1024K", "tags": "LLM,CHAT,1024K",
"max_tokens": 1048576, "max_tokens": 1048576,
"model_type": "chat", "model_type": "image2text",
"is_tools": true "is_tools": true
}, },
{ {

View File

@ -23,44 +23,62 @@ from PIL import Image
from api.db import LLMType from api.db import LLMType
from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
from deepdoc.vision import OCR from deepdoc.vision import OCR
from rag.nlp import tokenize from rag.nlp import rag_tokenizer, tokenize
from rag.utils import clean_markdown_block from rag.utils import clean_markdown_block
from rag.nlp import rag_tokenizer
ocr = OCR() ocr = OCR()
# Gemini supported MIME types
VIDEO_EXTS = [".mp4", ".mov", ".avi", ".flv", ".mpeg", ".mpg", ".webm", ".wmv", ".3gp", ".3gpp"]
def chunk(filename, binary, tenant_id, lang, callback=None, **kwargs): def chunk(filename, binary, tenant_id, lang, callback=None, **kwargs):
img = Image.open(io.BytesIO(binary)).convert('RGB')
doc = { doc = {
"docnm_kwd": filename, "docnm_kwd": filename,
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename)), "title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename)),
"image": img,
"doc_type_kwd": "image"
} }
bxs = ocr(np.array(img))
txt = "\n".join([t[0] for _, t in bxs if t[0]])
eng = lang.lower() == "english" eng = lang.lower() == "english"
callback(0.4, "Finish OCR: (%s ...)" % txt[:12])
if (eng and len(txt.split()) > 32) or len(txt) > 32:
tokenize(doc, txt, eng)
callback(0.8, "OCR results is too long to use CV LLM.")
return [doc]
try: if any(filename.lower().endswith(ext) for ext in VIDEO_EXTS):
callback(0.4, "Use CV LLM to describe the picture.") try:
cv_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, lang=lang) doc.update({"doc_type_kwd": "video"})
img_binary = io.BytesIO() cv_mdl = LLMBundle(tenant_id, llm_type=LLMType.IMAGE2TEXT, lang=lang)
img.save(img_binary, format='JPEG') ans = cv_mdl.chat(system="", history=[], gen_conf={}, video_bytes=binary, filename=filename)
img_binary.seek(0) callback(0.8, "CV LLM respond: %s ..." % ans[:32])
ans = cv_mdl.describe(img_binary.read()) ans += "\n" + ans
callback(0.8, "CV LLM respond: %s ..." % ans[:32]) tokenize(doc, ans, eng)
txt += "\n" + ans return [doc]
tokenize(doc, txt, eng) except Exception as e:
return [doc] callback(prog=-1, msg=str(e))
except Exception as e: else:
callback(prog=-1, msg=str(e)) img = Image.open(io.BytesIO(binary)).convert("RGB")
doc.update(
{
"image": img,
"doc_type_kwd": "image",
}
)
bxs = ocr(np.array(img))
txt = "\n".join([t[0] for _, t in bxs if t[0]])
callback(0.4, "Finish OCR: (%s ...)" % txt[:12])
if (eng and len(txt.split()) > 32) or len(txt) > 32:
tokenize(doc, txt, eng)
callback(0.8, "OCR results is too long to use CV LLM.")
return [doc]
try:
callback(0.4, "Use CV LLM to describe the picture.")
cv_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, lang=lang)
img_binary = io.BytesIO()
img.save(img_binary, format="JPEG")
img_binary.seek(0)
ans = cv_mdl.describe(img_binary.read())
callback(0.8, "CV LLM respond: %s ..." % ans[:32])
txt += "\n" + ans
tokenize(doc, txt, eng)
return [doc]
except Exception as e:
callback(prog=-1, msg=str(e))
return [] return []
@ -79,7 +97,7 @@ def vision_llm_chunk(binary, vision_model, prompt=None, callback=None):
try: try:
with io.BytesIO() as img_binary: with io.BytesIO() as img_binary:
img.save(img_binary, format='JPEG') img.save(img_binary, format="JPEG")
img_binary.seek(0) img_binary.seek(0)
ans = clean_markdown_block(vision_model.describe_with_prompt(img_binary.read(), prompt)) ans = clean_markdown_block(vision_model.describe_with_prompt(img_binary.read(), prompt))
txt += "\n" + ans txt += "\n" + ans

View File

@ -16,6 +16,7 @@
import base64 import base64
import json import json
import os import os
import logging
from abc import ABC from abc import ABC
from copy import deepcopy from copy import deepcopy
from io import BytesIO from io import BytesIO
@ -529,6 +530,7 @@ class GeminiCV(Base):
client.configure(api_key=key) client.configure(api_key=key)
_client = client.get_default_generative_client() _client = client.get_default_generative_client()
self.api_key=key
self.model_name = model_name self.model_name = model_name
self.model = GenerativeModel(model_name=self.model_name) self.model = GenerativeModel(model_name=self.model_name)
self.model._client = _client self.model._client = _client
@ -571,7 +573,15 @@ class GeminiCV(Base):
res = self.model.generate_content(input) res = self.model.generate_content(input)
return res.text, total_token_count_from_response(res) return res.text, total_token_count_from_response(res)
def chat(self, system, history, gen_conf, images=[]):
def chat(self, system, history, gen_conf, images=[], video_bytes=None, filename=""):
if video_bytes:
try:
summary, summary_num_tokens = self._process_video(video_bytes, filename)
return summary, summary_num_tokens
except Exception as e:
return "**ERROR**: " + str(e), 0
generation_config = dict(temperature=gen_conf.get("temperature", 0.3), top_p=gen_conf.get("top_p", 0.7)) generation_config = dict(temperature=gen_conf.get("temperature", 0.3), top_p=gen_conf.get("top_p", 0.7))
try: try:
response = self.model.generate_content( response = self.model.generate_content(
@ -603,6 +613,48 @@ class GeminiCV(Base):
yield total_token_count_from_response(response) yield total_token_count_from_response(response)
def _process_video(self, video_bytes, filename):
from google import genai
from google.genai import types
import tempfile
from pathlib import Path
video_size_mb = len(video_bytes) / (1024 * 1024)
client = genai.Client(api_key=self.api_key)
tmp_path = None
try:
if video_size_mb <= 20:
response = client.models.generate_content(
model="models/gemini-2.5-flash",
contents=types.Content(parts=[
types.Part(inline_data=types.Blob(data=video_bytes, mime_type="video/mp4")),
types.Part(text="Please summarize the video in proper sentences.")
])
)
else:
logging.info(f"Video size {video_size_mb:.2f}MB exceeds 20MB. Using Files API...")
video_suffix = Path(filename).suffix or ".mp4"
with tempfile.NamedTemporaryFile(delete=False, suffix=video_suffix) as tmp:
tmp.write(video_bytes)
tmp_path = Path(tmp.name)
uploaded_file = client.files.upload(file=tmp_path)
response = client.models.generate_content(
model="gemini-2.5-flash",
contents=[uploaded_file, "Please summarize this video in proper sentences."]
)
summary = response.text or ""
logging.info(f"Video summarized: {summary[:32]}...")
return summary, num_tokens_from_string(summary)
except Exception as e:
logging.error(f"Video processing failed: {e}")
raise
finally:
if tmp_path and tmp_path.exists():
tmp_path.unlink()
class NvidiaCV(Base): class NvidiaCV(Base):
_FACTORY_NAME = "NVIDIA" _FACTORY_NAME = "NVIDIA"