mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Refactor: improve cv model logics (#10414)
1. improve how to get total token count Improve how to get total token count ### Type of change - [x] Refactoring
This commit is contained in:
@ -26,7 +26,7 @@ from openai.lib.azure import AzureOpenAI
|
|||||||
from zhipuai import ZhipuAI
|
from zhipuai import ZhipuAI
|
||||||
from rag.nlp import is_english
|
from rag.nlp import is_english
|
||||||
from rag.prompts.generator import vision_llm_describe_prompt
|
from rag.prompts.generator import vision_llm_describe_prompt
|
||||||
from rag.utils import num_tokens_from_string
|
from rag.utils import num_tokens_from_string, total_token_count_from_response
|
||||||
|
|
||||||
|
|
||||||
class Base(ABC):
|
class Base(ABC):
|
||||||
@ -125,7 +125,7 @@ class Base(ABC):
|
|||||||
b64 = base64.b64encode(data).decode("utf-8")
|
b64 = base64.b64encode(data).decode("utf-8")
|
||||||
return f"data:{mime};base64,{b64}"
|
return f"data:{mime};base64,{b64}"
|
||||||
with BytesIO() as buffered:
|
with BytesIO() as buffered:
|
||||||
fmt = "JPEG"
|
fmt = "jpeg"
|
||||||
try:
|
try:
|
||||||
image.save(buffered, format="JPEG")
|
image.save(buffered, format="JPEG")
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -133,10 +133,10 @@ class Base(ABC):
|
|||||||
buffered.seek(0)
|
buffered.seek(0)
|
||||||
buffered.truncate()
|
buffered.truncate()
|
||||||
image.save(buffered, format="PNG")
|
image.save(buffered, format="PNG")
|
||||||
fmt = "PNG"
|
fmt = "png"
|
||||||
data = buffered.getvalue()
|
data = buffered.getvalue()
|
||||||
b64 = base64.b64encode(data).decode("utf-8")
|
b64 = base64.b64encode(data).decode("utf-8")
|
||||||
mime = f"image/{fmt.lower()}"
|
mime = f"image/{fmt}"
|
||||||
return f"data:{mime};base64,{b64}"
|
return f"data:{mime};base64,{b64}"
|
||||||
|
|
||||||
def prompt(self, b64):
|
def prompt(self, b64):
|
||||||
@ -178,7 +178,7 @@ class GptV4(Base):
|
|||||||
model=self.model_name,
|
model=self.model_name,
|
||||||
messages=self.prompt(b64),
|
messages=self.prompt(b64),
|
||||||
)
|
)
|
||||||
return res.choices[0].message.content.strip(), res.usage.total_tokens
|
return res.choices[0].message.content.strip(), total_token_count_from_response(res)
|
||||||
|
|
||||||
def describe_with_prompt(self, image, prompt=None):
|
def describe_with_prompt(self, image, prompt=None):
|
||||||
b64 = self.image2base64(image)
|
b64 = self.image2base64(image)
|
||||||
@ -186,7 +186,7 @@ class GptV4(Base):
|
|||||||
model=self.model_name,
|
model=self.model_name,
|
||||||
messages=self.vision_llm_prompt(b64, prompt),
|
messages=self.vision_llm_prompt(b64, prompt),
|
||||||
)
|
)
|
||||||
return res.choices[0].message.content.strip(), res.usage.total_tokens
|
return res.choices[0].message.content.strip(),total_token_count_from_response(res)
|
||||||
|
|
||||||
|
|
||||||
class AzureGptV4(GptV4):
|
class AzureGptV4(GptV4):
|
||||||
@ -522,11 +522,10 @@ class GeminiCV(Base):
|
|||||||
)
|
)
|
||||||
b64 = self.image2base64(image)
|
b64 = self.image2base64(image)
|
||||||
with BytesIO(base64.b64decode(b64)) as bio:
|
with BytesIO(base64.b64decode(b64)) as bio:
|
||||||
img = open(bio)
|
with open(bio) as img:
|
||||||
input = [prompt, img]
|
input = [prompt, img]
|
||||||
res = self.model.generate_content(input)
|
res = self.model.generate_content(input)
|
||||||
img.close()
|
return res.text, total_token_count_from_response(res)
|
||||||
return res.text, res.usage_metadata.total_token_count
|
|
||||||
|
|
||||||
def describe_with_prompt(self, image, prompt=None):
|
def describe_with_prompt(self, image, prompt=None):
|
||||||
from PIL.Image import open
|
from PIL.Image import open
|
||||||
@ -534,11 +533,10 @@ class GeminiCV(Base):
|
|||||||
b64 = self.image2base64(image)
|
b64 = self.image2base64(image)
|
||||||
vision_prompt = prompt if prompt else vision_llm_describe_prompt()
|
vision_prompt = prompt if prompt else vision_llm_describe_prompt()
|
||||||
with BytesIO(base64.b64decode(b64)) as bio:
|
with BytesIO(base64.b64decode(b64)) as bio:
|
||||||
img = open(bio)
|
with open(bio) as img:
|
||||||
input = [vision_prompt, img]
|
input = [vision_prompt, img]
|
||||||
res = self.model.generate_content(input)
|
res = self.model.generate_content(input)
|
||||||
img.close()
|
return res.text, total_token_count_from_response(res)
|
||||||
return res.text, res.usage_metadata.total_token_count
|
|
||||||
|
|
||||||
def chat(self, system, history, gen_conf, images=[]):
|
def chat(self, system, history, gen_conf, images=[]):
|
||||||
generation_config = dict(temperature=gen_conf.get("temperature", 0.3), top_p=gen_conf.get("top_p", 0.7))
|
generation_config = dict(temperature=gen_conf.get("temperature", 0.3), top_p=gen_conf.get("top_p", 0.7))
|
||||||
@ -547,7 +545,7 @@ class GeminiCV(Base):
|
|||||||
self._form_history(system, history, images),
|
self._form_history(system, history, images),
|
||||||
generation_config=generation_config)
|
generation_config=generation_config)
|
||||||
ans = response.text
|
ans = response.text
|
||||||
return ans, response.usage_metadata.total_token_count
|
return ans, total_token_count_from_response(ans)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return "**ERROR**: " + str(e), 0
|
return "**ERROR**: " + str(e), 0
|
||||||
|
|
||||||
@ -570,10 +568,7 @@ class GeminiCV(Base):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
yield ans + "\n**ERROR**: " + str(e)
|
yield ans + "\n**ERROR**: " + str(e)
|
||||||
|
|
||||||
if response and hasattr(response, "usage_metadata") and hasattr(response.usage_metadata, "total_token_count"):
|
yield total_token_count_from_response(response)
|
||||||
yield response.usage_metadata.total_token_count
|
|
||||||
else:
|
|
||||||
yield 0
|
|
||||||
|
|
||||||
|
|
||||||
class NvidiaCV(Base):
|
class NvidiaCV(Base):
|
||||||
|
|||||||
@ -95,6 +95,12 @@ def total_token_count_from_response(resp):
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
if hasattr(resp, "usage_metadata") and hasattr(resp.usage_metadata, "total_tokens"):
|
||||||
|
try:
|
||||||
|
return resp.usage_metadata.total_tokens
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
if 'usage' in resp and 'total_tokens' in resp['usage']:
|
if 'usage' in resp and 'total_tokens' in resp['usage']:
|
||||||
try:
|
try:
|
||||||
return resp["usage"]["total_tokens"]
|
return resp["usage"]["total_tokens"]
|
||||||
|
|||||||
Reference in New Issue
Block a user