diff --git a/rag/llm/cv_model.py b/rag/llm/cv_model.py index 0a1559319..7e763641a 100644 --- a/rag/llm/cv_model.py +++ b/rag/llm/cv_model.py @@ -26,7 +26,7 @@ from openai.lib.azure import AzureOpenAI from zhipuai import ZhipuAI from rag.nlp import is_english from rag.prompts.generator import vision_llm_describe_prompt -from rag.utils import num_tokens_from_string +from rag.utils import num_tokens_from_string, total_token_count_from_response class Base(ABC): @@ -125,7 +125,7 @@ class Base(ABC): b64 = base64.b64encode(data).decode("utf-8") return f"data:{mime};base64,{b64}" with BytesIO() as buffered: - fmt = "JPEG" + fmt = "jpeg" try: image.save(buffered, format="JPEG") except Exception: @@ -133,10 +133,10 @@ class Base(ABC): buffered.seek(0) buffered.truncate() image.save(buffered, format="PNG") - fmt = "PNG" + fmt = "png" data = buffered.getvalue() b64 = base64.b64encode(data).decode("utf-8") - mime = f"image/{fmt.lower()}" + mime = f"image/{fmt}" return f"data:{mime};base64,{b64}" def prompt(self, b64): @@ -178,7 +178,7 @@ class GptV4(Base): model=self.model_name, messages=self.prompt(b64), ) - return res.choices[0].message.content.strip(), res.usage.total_tokens + return res.choices[0].message.content.strip(), total_token_count_from_response(res) def describe_with_prompt(self, image, prompt=None): b64 = self.image2base64(image) @@ -186,7 +186,7 @@ class GptV4(Base): model=self.model_name, messages=self.vision_llm_prompt(b64, prompt), ) - return res.choices[0].message.content.strip(), res.usage.total_tokens + return res.choices[0].message.content.strip(),total_token_count_from_response(res) class AzureGptV4(GptV4): @@ -522,11 +522,10 @@ class GeminiCV(Base): ) b64 = self.image2base64(image) with BytesIO(base64.b64decode(b64)) as bio: - img = open(bio) - input = [prompt, img] - res = self.model.generate_content(input) - img.close() - return res.text, res.usage_metadata.total_token_count + with open(bio) as img: + input = [prompt, img] + res = self.model.generate_content(input) + return res.text, total_token_count_from_response(res) def describe_with_prompt(self, image, prompt=None): from PIL.Image import open @@ -534,11 +533,10 @@ class GeminiCV(Base): b64 = self.image2base64(image) vision_prompt = prompt if prompt else vision_llm_describe_prompt() with BytesIO(base64.b64decode(b64)) as bio: - img = open(bio) - input = [vision_prompt, img] - res = self.model.generate_content(input) - img.close() - return res.text, res.usage_metadata.total_token_count + with open(bio) as img: + input = [vision_prompt, img] + res = self.model.generate_content(input) + return res.text, total_token_count_from_response(res) def chat(self, system, history, gen_conf, images=[]): generation_config = dict(temperature=gen_conf.get("temperature", 0.3), top_p=gen_conf.get("top_p", 0.7)) @@ -547,7 +545,7 @@ class GeminiCV(Base): self._form_history(system, history, images), generation_config=generation_config) ans = response.text - return ans, response.usage_metadata.total_token_count + return ans, total_token_count_from_response(ans) except Exception as e: return "**ERROR**: " + str(e), 0 @@ -570,10 +568,7 @@ class GeminiCV(Base): except Exception as e: yield ans + "\n**ERROR**: " + str(e) - if response and hasattr(response, "usage_metadata") and hasattr(response.usage_metadata, "total_token_count"): - yield response.usage_metadata.total_token_count - else: - yield 0 + yield total_token_count_from_response(response) class NvidiaCV(Base): diff --git a/rag/utils/__init__.py b/rag/utils/__init__.py index 22445da92..798b5bf60 100644 --- a/rag/utils/__init__.py +++ b/rag/utils/__init__.py @@ -95,6 +95,12 @@ def total_token_count_from_response(resp): except Exception: pass + if hasattr(resp, "usage_metadata") and hasattr(resp.usage_metadata, "total_tokens"): + try: + return resp.usage_metadata.total_tokens + except Exception: + pass + if 'usage' in resp and 'total_tokens' in resp['usage']: try: return resp["usage"]["total_tokens"]