diff --git a/rag/llm/cv_model.py b/rag/llm/cv_model.py index bec92ba6c..03b913e97 100644 --- a/rag/llm/cv_model.py +++ b/rag/llm/cv_model.py @@ -13,13 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. # + import base64 import json import os +import tempfile import logging from abc import ABC from copy import deepcopy from io import BytesIO +from pathlib import Path from urllib.parse import urljoin import requests from openai import OpenAI @@ -171,6 +174,7 @@ class GptV4(Base): def __init__(self, key, model_name="gpt-4-vision-preview", lang="Chinese", base_url="https://api.openai.com/v1", **kwargs): if not base_url: base_url = "https://api.openai.com/v1" + self.api_key = key self.client = OpenAI(api_key=key, base_url=base_url) self.model_name = model_name self.lang = lang @@ -224,6 +228,61 @@ class QWenCV(GptV4): base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1" super().__init__(key, model_name, lang=lang, base_url=base_url, **kwargs) + def chat(self, system, history, gen_conf, images=[], video_bytes=None, filename=""): + if video_bytes: + try: + summary, summary_num_tokens = self._process_video(video_bytes, filename) + return summary, summary_num_tokens + except Exception as e: + return "**ERROR**: " + str(e), 0 + + return "**ERROR**: Method chat not supported yet.", 0 + + def _process_video(self, video_bytes, filename): + from dashscope import MultiModalConversation + + video_suffix = Path(filename).suffix or ".mp4" + with tempfile.NamedTemporaryFile(delete=False, suffix=video_suffix) as tmp: + tmp.write(video_bytes) + tmp_path = tmp.name + + video_path = f"file://{tmp_path}" + messages = [ + { + "role": "user", + "content": [ + { + "video": video_path, + "fps": 2, + }, + { + "text": "Please summarize this video in proper sentences.", + }, + ], + } + ] + + def call_api(): + response = MultiModalConversation.call( + api_key=self.api_key, + model=self.model_name, + messages=messages, + ) + summary = response["output"]["choices"][0]["message"].content[0]["text"] + return summary, num_tokens_from_string(summary) + + try: + return call_api() + except Exception as e1: + import dashscope + + dashscope.base_http_api_url = "https://dashscope-intl.aliyuncs.com/api/v1" + try: + return call_api() + except Exception as e2: + raise RuntimeError(f"Both default and intl endpoint failed.\nFirst error: {e1}\nSecond error: {e2}") + + class HunyuanCV(GptV4): _FACTORY_NAME = "Tencent Hunyuan" @@ -616,8 +675,6 @@ class GeminiCV(Base): def _process_video(self, video_bytes, filename): from google import genai from google.genai import types - import tempfile - from pathlib import Path video_size_mb = len(video_bytes) / (1024 * 1024) client = genai.Client(api_key=self.api_key)