Refactor: improve cv model logics (#10414)

1. improve how to get total token count

Improve how to get total token count

### Type of change
- [x] Refactoring
This commit is contained in:
Stephen Hu
2025-10-09 09:47:36 +08:00
committed by GitHub
parent dba9158f9a
commit 4585edc20e
2 changed files with 22 additions and 21 deletions

View File

@ -26,7 +26,7 @@ from openai.lib.azure import AzureOpenAI
from zhipuai import ZhipuAI
from rag.nlp import is_english
from rag.prompts.generator import vision_llm_describe_prompt
from rag.utils import num_tokens_from_string
from rag.utils import num_tokens_from_string, total_token_count_from_response
class Base(ABC):
@ -125,7 +125,7 @@ class Base(ABC):
b64 = base64.b64encode(data).decode("utf-8")
return f"data:{mime};base64,{b64}"
with BytesIO() as buffered:
fmt = "JPEG"
fmt = "jpeg"
try:
image.save(buffered, format="JPEG")
except Exception:
@ -133,10 +133,10 @@ class Base(ABC):
buffered.seek(0)
buffered.truncate()
image.save(buffered, format="PNG")
fmt = "PNG"
fmt = "png"
data = buffered.getvalue()
b64 = base64.b64encode(data).decode("utf-8")
mime = f"image/{fmt.lower()}"
mime = f"image/{fmt}"
return f"data:{mime};base64,{b64}"
def prompt(self, b64):
@ -178,7 +178,7 @@ class GptV4(Base):
model=self.model_name,
messages=self.prompt(b64),
)
return res.choices[0].message.content.strip(), res.usage.total_tokens
return res.choices[0].message.content.strip(), total_token_count_from_response(res)
def describe_with_prompt(self, image, prompt=None):
b64 = self.image2base64(image)
@ -186,7 +186,7 @@ class GptV4(Base):
model=self.model_name,
messages=self.vision_llm_prompt(b64, prompt),
)
return res.choices[0].message.content.strip(), res.usage.total_tokens
return res.choices[0].message.content.strip(),total_token_count_from_response(res)
class AzureGptV4(GptV4):
@ -522,11 +522,10 @@ class GeminiCV(Base):
)
b64 = self.image2base64(image)
with BytesIO(base64.b64decode(b64)) as bio:
img = open(bio)
with open(bio) as img:
input = [prompt, img]
res = self.model.generate_content(input)
img.close()
return res.text, res.usage_metadata.total_token_count
return res.text, total_token_count_from_response(res)
def describe_with_prompt(self, image, prompt=None):
from PIL.Image import open
@ -534,11 +533,10 @@ class GeminiCV(Base):
b64 = self.image2base64(image)
vision_prompt = prompt if prompt else vision_llm_describe_prompt()
with BytesIO(base64.b64decode(b64)) as bio:
img = open(bio)
with open(bio) as img:
input = [vision_prompt, img]
res = self.model.generate_content(input)
img.close()
return res.text, res.usage_metadata.total_token_count
return res.text, total_token_count_from_response(res)
def chat(self, system, history, gen_conf, images=[]):
generation_config = dict(temperature=gen_conf.get("temperature", 0.3), top_p=gen_conf.get("top_p", 0.7))
@ -547,7 +545,7 @@ class GeminiCV(Base):
self._form_history(system, history, images),
generation_config=generation_config)
ans = response.text
return ans, response.usage_metadata.total_token_count
return ans, total_token_count_from_response(ans)
except Exception as e:
return "**ERROR**: " + str(e), 0
@ -570,10 +568,7 @@ class GeminiCV(Base):
except Exception as e:
yield ans + "\n**ERROR**: " + str(e)
if response and hasattr(response, "usage_metadata") and hasattr(response.usage_metadata, "total_token_count"):
yield response.usage_metadata.total_token_count
else:
yield 0
yield total_token_count_from_response(response)
class NvidiaCV(Base):

View File

@ -95,6 +95,12 @@ def total_token_count_from_response(resp):
except Exception:
pass
if hasattr(resp, "usage_metadata") and hasattr(resp.usage_metadata, "total_tokens"):
try:
return resp.usage_metadata.total_tokens
except Exception:
pass
if 'usage' in resp and 'total_tokens' in resp['usage']:
try:
return resp["usage"]["total_tokens"]