mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Feat: Gemini supports video parsing (#10671)
### What problem does this PR solve? Gemini supports video parsing.   Close: #10617 ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -23,44 +23,62 @@ from PIL import Image
|
||||
from api.db import LLMType
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from deepdoc.vision import OCR
|
||||
from rag.nlp import tokenize
|
||||
from rag.nlp import rag_tokenizer, tokenize
|
||||
from rag.utils import clean_markdown_block
|
||||
from rag.nlp import rag_tokenizer
|
||||
|
||||
|
||||
ocr = OCR()
|
||||
|
||||
# Gemini supported MIME types
|
||||
VIDEO_EXTS = [".mp4", ".mov", ".avi", ".flv", ".mpeg", ".mpg", ".webm", ".wmv", ".3gp", ".3gpp"]
|
||||
|
||||
|
||||
def chunk(filename, binary, tenant_id, lang, callback=None, **kwargs):
|
||||
img = Image.open(io.BytesIO(binary)).convert('RGB')
|
||||
doc = {
|
||||
"docnm_kwd": filename,
|
||||
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename)),
|
||||
"image": img,
|
||||
"doc_type_kwd": "image"
|
||||
}
|
||||
bxs = ocr(np.array(img))
|
||||
txt = "\n".join([t[0] for _, t in bxs if t[0]])
|
||||
eng = lang.lower() == "english"
|
||||
callback(0.4, "Finish OCR: (%s ...)" % txt[:12])
|
||||
if (eng and len(txt.split()) > 32) or len(txt) > 32:
|
||||
tokenize(doc, txt, eng)
|
||||
callback(0.8, "OCR results is too long to use CV LLM.")
|
||||
return [doc]
|
||||
|
||||
try:
|
||||
callback(0.4, "Use CV LLM to describe the picture.")
|
||||
cv_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, lang=lang)
|
||||
img_binary = io.BytesIO()
|
||||
img.save(img_binary, format='JPEG')
|
||||
img_binary.seek(0)
|
||||
ans = cv_mdl.describe(img_binary.read())
|
||||
callback(0.8, "CV LLM respond: %s ..." % ans[:32])
|
||||
txt += "\n" + ans
|
||||
tokenize(doc, txt, eng)
|
||||
return [doc]
|
||||
except Exception as e:
|
||||
callback(prog=-1, msg=str(e))
|
||||
if any(filename.lower().endswith(ext) for ext in VIDEO_EXTS):
|
||||
try:
|
||||
doc.update({"doc_type_kwd": "video"})
|
||||
cv_mdl = LLMBundle(tenant_id, llm_type=LLMType.IMAGE2TEXT, lang=lang)
|
||||
ans = cv_mdl.chat(system="", history=[], gen_conf={}, video_bytes=binary, filename=filename)
|
||||
callback(0.8, "CV LLM respond: %s ..." % ans[:32])
|
||||
ans += "\n" + ans
|
||||
tokenize(doc, ans, eng)
|
||||
return [doc]
|
||||
except Exception as e:
|
||||
callback(prog=-1, msg=str(e))
|
||||
else:
|
||||
img = Image.open(io.BytesIO(binary)).convert("RGB")
|
||||
doc.update(
|
||||
{
|
||||
"image": img,
|
||||
"doc_type_kwd": "image",
|
||||
}
|
||||
)
|
||||
bxs = ocr(np.array(img))
|
||||
txt = "\n".join([t[0] for _, t in bxs if t[0]])
|
||||
callback(0.4, "Finish OCR: (%s ...)" % txt[:12])
|
||||
if (eng and len(txt.split()) > 32) or len(txt) > 32:
|
||||
tokenize(doc, txt, eng)
|
||||
callback(0.8, "OCR results is too long to use CV LLM.")
|
||||
return [doc]
|
||||
|
||||
try:
|
||||
callback(0.4, "Use CV LLM to describe the picture.")
|
||||
cv_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, lang=lang)
|
||||
img_binary = io.BytesIO()
|
||||
img.save(img_binary, format="JPEG")
|
||||
img_binary.seek(0)
|
||||
ans = cv_mdl.describe(img_binary.read())
|
||||
callback(0.8, "CV LLM respond: %s ..." % ans[:32])
|
||||
txt += "\n" + ans
|
||||
tokenize(doc, txt, eng)
|
||||
return [doc]
|
||||
except Exception as e:
|
||||
callback(prog=-1, msg=str(e))
|
||||
|
||||
return []
|
||||
|
||||
@ -79,7 +97,7 @@ def vision_llm_chunk(binary, vision_model, prompt=None, callback=None):
|
||||
|
||||
try:
|
||||
with io.BytesIO() as img_binary:
|
||||
img.save(img_binary, format='JPEG')
|
||||
img.save(img_binary, format="JPEG")
|
||||
img_binary.seek(0)
|
||||
ans = clean_markdown_block(vision_model.describe_with_prompt(img_binary.read(), prompt))
|
||||
txt += "\n" + ans
|
||||
|
||||
@ -16,6 +16,7 @@
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import logging
|
||||
from abc import ABC
|
||||
from copy import deepcopy
|
||||
from io import BytesIO
|
||||
@ -529,6 +530,7 @@ class GeminiCV(Base):
|
||||
|
||||
client.configure(api_key=key)
|
||||
_client = client.get_default_generative_client()
|
||||
self.api_key=key
|
||||
self.model_name = model_name
|
||||
self.model = GenerativeModel(model_name=self.model_name)
|
||||
self.model._client = _client
|
||||
@ -571,7 +573,15 @@ class GeminiCV(Base):
|
||||
res = self.model.generate_content(input)
|
||||
return res.text, total_token_count_from_response(res)
|
||||
|
||||
def chat(self, system, history, gen_conf, images=[]):
|
||||
|
||||
def chat(self, system, history, gen_conf, images=[], video_bytes=None, filename=""):
|
||||
if video_bytes:
|
||||
try:
|
||||
summary, summary_num_tokens = self._process_video(video_bytes, filename)
|
||||
return summary, summary_num_tokens
|
||||
except Exception as e:
|
||||
return "**ERROR**: " + str(e), 0
|
||||
|
||||
generation_config = dict(temperature=gen_conf.get("temperature", 0.3), top_p=gen_conf.get("top_p", 0.7))
|
||||
try:
|
||||
response = self.model.generate_content(
|
||||
@ -603,6 +613,48 @@ class GeminiCV(Base):
|
||||
|
||||
yield total_token_count_from_response(response)
|
||||
|
||||
def _process_video(self, video_bytes, filename):
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
video_size_mb = len(video_bytes) / (1024 * 1024)
|
||||
client = genai.Client(api_key=self.api_key)
|
||||
|
||||
tmp_path = None
|
||||
try:
|
||||
if video_size_mb <= 20:
|
||||
response = client.models.generate_content(
|
||||
model="models/gemini-2.5-flash",
|
||||
contents=types.Content(parts=[
|
||||
types.Part(inline_data=types.Blob(data=video_bytes, mime_type="video/mp4")),
|
||||
types.Part(text="Please summarize the video in proper sentences.")
|
||||
])
|
||||
)
|
||||
else:
|
||||
logging.info(f"Video size {video_size_mb:.2f}MB exceeds 20MB. Using Files API...")
|
||||
video_suffix = Path(filename).suffix or ".mp4"
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=video_suffix) as tmp:
|
||||
tmp.write(video_bytes)
|
||||
tmp_path = Path(tmp.name)
|
||||
uploaded_file = client.files.upload(file=tmp_path)
|
||||
|
||||
response = client.models.generate_content(
|
||||
model="gemini-2.5-flash",
|
||||
contents=[uploaded_file, "Please summarize this video in proper sentences."]
|
||||
)
|
||||
|
||||
summary = response.text or ""
|
||||
logging.info(f"Video summarized: {summary[:32]}...")
|
||||
return summary, num_tokens_from_string(summary)
|
||||
except Exception as e:
|
||||
logging.error(f"Video processing failed: {e}")
|
||||
raise
|
||||
finally:
|
||||
if tmp_path and tmp_path.exists():
|
||||
tmp_path.unlink()
|
||||
|
||||
|
||||
class NvidiaCV(Base):
|
||||
_FACTORY_NAME = "NVIDIA"
|
||||
|
||||
Reference in New Issue
Block a user