mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Refactor: improve cv model logics (#10414)
1. improve how to get total token count Improve how to get total token count ### Type of change - [x] Refactoring
This commit is contained in:
@ -26,7 +26,7 @@ from openai.lib.azure import AzureOpenAI
|
||||
from zhipuai import ZhipuAI
|
||||
from rag.nlp import is_english
|
||||
from rag.prompts.generator import vision_llm_describe_prompt
|
||||
from rag.utils import num_tokens_from_string
|
||||
from rag.utils import num_tokens_from_string, total_token_count_from_response
|
||||
|
||||
|
||||
class Base(ABC):
|
||||
@ -125,7 +125,7 @@ class Base(ABC):
|
||||
b64 = base64.b64encode(data).decode("utf-8")
|
||||
return f"data:{mime};base64,{b64}"
|
||||
with BytesIO() as buffered:
|
||||
fmt = "JPEG"
|
||||
fmt = "jpeg"
|
||||
try:
|
||||
image.save(buffered, format="JPEG")
|
||||
except Exception:
|
||||
@ -133,10 +133,10 @@ class Base(ABC):
|
||||
buffered.seek(0)
|
||||
buffered.truncate()
|
||||
image.save(buffered, format="PNG")
|
||||
fmt = "PNG"
|
||||
fmt = "png"
|
||||
data = buffered.getvalue()
|
||||
b64 = base64.b64encode(data).decode("utf-8")
|
||||
mime = f"image/{fmt.lower()}"
|
||||
mime = f"image/{fmt}"
|
||||
return f"data:{mime};base64,{b64}"
|
||||
|
||||
def prompt(self, b64):
|
||||
@ -178,7 +178,7 @@ class GptV4(Base):
|
||||
model=self.model_name,
|
||||
messages=self.prompt(b64),
|
||||
)
|
||||
return res.choices[0].message.content.strip(), res.usage.total_tokens
|
||||
return res.choices[0].message.content.strip(), total_token_count_from_response(res)
|
||||
|
||||
def describe_with_prompt(self, image, prompt=None):
|
||||
b64 = self.image2base64(image)
|
||||
@ -186,7 +186,7 @@ class GptV4(Base):
|
||||
model=self.model_name,
|
||||
messages=self.vision_llm_prompt(b64, prompt),
|
||||
)
|
||||
return res.choices[0].message.content.strip(), res.usage.total_tokens
|
||||
return res.choices[0].message.content.strip(),total_token_count_from_response(res)
|
||||
|
||||
|
||||
class AzureGptV4(GptV4):
|
||||
@ -522,11 +522,10 @@ class GeminiCV(Base):
|
||||
)
|
||||
b64 = self.image2base64(image)
|
||||
with BytesIO(base64.b64decode(b64)) as bio:
|
||||
img = open(bio)
|
||||
input = [prompt, img]
|
||||
res = self.model.generate_content(input)
|
||||
img.close()
|
||||
return res.text, res.usage_metadata.total_token_count
|
||||
with open(bio) as img:
|
||||
input = [prompt, img]
|
||||
res = self.model.generate_content(input)
|
||||
return res.text, total_token_count_from_response(res)
|
||||
|
||||
def describe_with_prompt(self, image, prompt=None):
|
||||
from PIL.Image import open
|
||||
@ -534,11 +533,10 @@ class GeminiCV(Base):
|
||||
b64 = self.image2base64(image)
|
||||
vision_prompt = prompt if prompt else vision_llm_describe_prompt()
|
||||
with BytesIO(base64.b64decode(b64)) as bio:
|
||||
img = open(bio)
|
||||
input = [vision_prompt, img]
|
||||
res = self.model.generate_content(input)
|
||||
img.close()
|
||||
return res.text, res.usage_metadata.total_token_count
|
||||
with open(bio) as img:
|
||||
input = [vision_prompt, img]
|
||||
res = self.model.generate_content(input)
|
||||
return res.text, total_token_count_from_response(res)
|
||||
|
||||
def chat(self, system, history, gen_conf, images=[]):
|
||||
generation_config = dict(temperature=gen_conf.get("temperature", 0.3), top_p=gen_conf.get("top_p", 0.7))
|
||||
@ -547,7 +545,7 @@ class GeminiCV(Base):
|
||||
self._form_history(system, history, images),
|
||||
generation_config=generation_config)
|
||||
ans = response.text
|
||||
return ans, response.usage_metadata.total_token_count
|
||||
return ans, total_token_count_from_response(ans)
|
||||
except Exception as e:
|
||||
return "**ERROR**: " + str(e), 0
|
||||
|
||||
@ -570,10 +568,7 @@ class GeminiCV(Base):
|
||||
except Exception as e:
|
||||
yield ans + "\n**ERROR**: " + str(e)
|
||||
|
||||
if response and hasattr(response, "usage_metadata") and hasattr(response.usage_metadata, "total_token_count"):
|
||||
yield response.usage_metadata.total_token_count
|
||||
else:
|
||||
yield 0
|
||||
yield total_token_count_from_response(response)
|
||||
|
||||
|
||||
class NvidiaCV(Base):
|
||||
|
||||
@ -95,6 +95,12 @@ def total_token_count_from_response(resp):
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if hasattr(resp, "usage_metadata") and hasattr(resp.usage_metadata, "total_tokens"):
|
||||
try:
|
||||
return resp.usage_metadata.total_tokens
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if 'usage' in resp and 'total_tokens' in resp['usage']:
|
||||
try:
|
||||
return resp["usage"]["total_tokens"]
|
||||
|
||||
Reference in New Issue
Block a user