mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-19 03:56:42 +08:00
Refactor: Improve the logic to calculate embedding total token count (#11943)
### What problem does this PR solve? Improve the logic to calculate embedding total token count ### Type of change - [x] Refactoring
This commit is contained in:
@ -28,7 +28,7 @@ from openai import OpenAI
|
|||||||
from zhipuai import ZhipuAI
|
from zhipuai import ZhipuAI
|
||||||
|
|
||||||
from common.log_utils import log_exception
|
from common.log_utils import log_exception
|
||||||
from common.token_utils import num_tokens_from_string, truncate
|
from common.token_utils import num_tokens_from_string, truncate, total_token_count_from_response
|
||||||
from common import settings
|
from common import settings
|
||||||
import logging
|
import logging
|
||||||
import base64
|
import base64
|
||||||
@ -118,7 +118,7 @@ class OpenAIEmbed(Base):
|
|||||||
res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name, encoding_format="float", extra_body={"drop_params": True})
|
res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name, encoding_format="float", extra_body={"drop_params": True})
|
||||||
try:
|
try:
|
||||||
ress.extend([d.embedding for d in res.data])
|
ress.extend([d.embedding for d in res.data])
|
||||||
total_tokens += self.total_token_count(res)
|
total_tokens += total_token_count_from_response(res)
|
||||||
except Exception as _e:
|
except Exception as _e:
|
||||||
log_exception(_e, res)
|
log_exception(_e, res)
|
||||||
raise Exception(f"Error: {res}")
|
raise Exception(f"Error: {res}")
|
||||||
@ -127,7 +127,7 @@ class OpenAIEmbed(Base):
|
|||||||
def encode_queries(self, text):
|
def encode_queries(self, text):
|
||||||
res = self.client.embeddings.create(input=[truncate(text, 8191)], model=self.model_name, encoding_format="float",extra_body={"drop_params": True})
|
res = self.client.embeddings.create(input=[truncate(text, 8191)], model=self.model_name, encoding_format="float",extra_body={"drop_params": True})
|
||||||
try:
|
try:
|
||||||
return np.array(res.data[0].embedding), self.total_token_count(res)
|
return np.array(res.data[0].embedding), self.total_token_count_from_response(res)
|
||||||
except Exception as _e:
|
except Exception as _e:
|
||||||
log_exception(_e, res)
|
log_exception(_e, res)
|
||||||
raise Exception(f"Error: {res}")
|
raise Exception(f"Error: {res}")
|
||||||
@ -216,7 +216,7 @@ class QWenEmbed(Base):
|
|||||||
for e in resp["output"]["embeddings"]:
|
for e in resp["output"]["embeddings"]:
|
||||||
embds[e["text_index"]] = e["embedding"]
|
embds[e["text_index"]] = e["embedding"]
|
||||||
res.extend(embds)
|
res.extend(embds)
|
||||||
token_count += self.total_token_count(resp)
|
token_count += self.total_token_count_from_response(resp)
|
||||||
except Exception as _e:
|
except Exception as _e:
|
||||||
log_exception(_e, resp)
|
log_exception(_e, resp)
|
||||||
raise
|
raise
|
||||||
@ -225,7 +225,7 @@ class QWenEmbed(Base):
|
|||||||
def encode_queries(self, text):
|
def encode_queries(self, text):
|
||||||
resp = dashscope.TextEmbedding.call(model=self.model_name, input=text[:2048], api_key=self.key, text_type="query")
|
resp = dashscope.TextEmbedding.call(model=self.model_name, input=text[:2048], api_key=self.key, text_type="query")
|
||||||
try:
|
try:
|
||||||
return np.array(resp["output"]["embeddings"][0]["embedding"]), self.total_token_count(resp)
|
return np.array(resp["output"]["embeddings"][0]["embedding"]), self.total_token_count_from_response(resp)
|
||||||
except Exception as _e:
|
except Exception as _e:
|
||||||
log_exception(_e, resp)
|
log_exception(_e, resp)
|
||||||
raise Exception(f"Error: {resp}")
|
raise Exception(f"Error: {resp}")
|
||||||
@ -253,7 +253,7 @@ class ZhipuEmbed(Base):
|
|||||||
res = self.client.embeddings.create(input=txt, model=self.model_name)
|
res = self.client.embeddings.create(input=txt, model=self.model_name)
|
||||||
try:
|
try:
|
||||||
arr.append(res.data[0].embedding)
|
arr.append(res.data[0].embedding)
|
||||||
tks_num += self.total_token_count(res)
|
tks_num += self.total_token_count_from_response(res)
|
||||||
except Exception as _e:
|
except Exception as _e:
|
||||||
log_exception(_e, res)
|
log_exception(_e, res)
|
||||||
raise Exception(f"Error: {res}")
|
raise Exception(f"Error: {res}")
|
||||||
@ -262,7 +262,7 @@ class ZhipuEmbed(Base):
|
|||||||
def encode_queries(self, text):
|
def encode_queries(self, text):
|
||||||
res = self.client.embeddings.create(input=text, model=self.model_name)
|
res = self.client.embeddings.create(input=text, model=self.model_name)
|
||||||
try:
|
try:
|
||||||
return np.array(res.data[0].embedding), self.total_token_count(res)
|
return np.array(res.data[0].embedding), self.total_token_count_from_response(res)
|
||||||
except Exception as _e:
|
except Exception as _e:
|
||||||
log_exception(_e, res)
|
log_exception(_e, res)
|
||||||
raise Exception(f"Error: {res}")
|
raise Exception(f"Error: {res}")
|
||||||
@ -323,7 +323,7 @@ class XinferenceEmbed(Base):
|
|||||||
try:
|
try:
|
||||||
res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name)
|
res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name)
|
||||||
ress.extend([d.embedding for d in res.data])
|
ress.extend([d.embedding for d in res.data])
|
||||||
total_tokens += self.total_token_count(res)
|
total_tokens += self.total_token_count_from_response(res)
|
||||||
except Exception as _e:
|
except Exception as _e:
|
||||||
log_exception(_e, res)
|
log_exception(_e, res)
|
||||||
raise Exception(f"Error: {res}")
|
raise Exception(f"Error: {res}")
|
||||||
@ -333,7 +333,7 @@ class XinferenceEmbed(Base):
|
|||||||
res = None
|
res = None
|
||||||
try:
|
try:
|
||||||
res = self.client.embeddings.create(input=[text], model=self.model_name)
|
res = self.client.embeddings.create(input=[text], model=self.model_name)
|
||||||
return np.array(res.data[0].embedding), self.total_token_count(res)
|
return np.array(res.data[0].embedding), self.total_token_count_from_response(res)
|
||||||
except Exception as _e:
|
except Exception as _e:
|
||||||
log_exception(_e, res)
|
log_exception(_e, res)
|
||||||
raise Exception(f"Error: {res}")
|
raise Exception(f"Error: {res}")
|
||||||
@ -409,7 +409,7 @@ class JinaMultiVecEmbed(Base):
|
|||||||
|
|
||||||
ress.append(chunk_emb)
|
ress.append(chunk_emb)
|
||||||
|
|
||||||
token_count += self.total_token_count(res)
|
token_count += self.total_token_count_from_response(res)
|
||||||
except Exception as _e:
|
except Exception as _e:
|
||||||
log_exception(_e, response)
|
log_exception(_e, response)
|
||||||
raise Exception(f"Error: {response}")
|
raise Exception(f"Error: {response}")
|
||||||
@ -443,7 +443,7 @@ class MistralEmbed(Base):
|
|||||||
try:
|
try:
|
||||||
res = self.client.embeddings(input=texts[i : i + batch_size], model=self.model_name)
|
res = self.client.embeddings(input=texts[i : i + batch_size], model=self.model_name)
|
||||||
ress.extend([d.embedding for d in res.data])
|
ress.extend([d.embedding for d in res.data])
|
||||||
token_count += self.total_token_count(res)
|
token_count += self.total_token_count_from_response(res)
|
||||||
break
|
break
|
||||||
except Exception as _e:
|
except Exception as _e:
|
||||||
if retry_max == 1:
|
if retry_max == 1:
|
||||||
@ -460,7 +460,7 @@ class MistralEmbed(Base):
|
|||||||
while retry_max > 0:
|
while retry_max > 0:
|
||||||
try:
|
try:
|
||||||
res = self.client.embeddings(input=[truncate(text, 8196)], model=self.model_name)
|
res = self.client.embeddings(input=[truncate(text, 8196)], model=self.model_name)
|
||||||
return np.array(res.data[0].embedding), self.total_token_count(res)
|
return np.array(res.data[0].embedding), self.total_token_count_from_response(res)
|
||||||
except Exception as _e:
|
except Exception as _e:
|
||||||
if retry_max == 1:
|
if retry_max == 1:
|
||||||
log_exception(_e)
|
log_exception(_e)
|
||||||
@ -595,7 +595,7 @@ class NvidiaEmbed(Base):
|
|||||||
try:
|
try:
|
||||||
res = response.json()
|
res = response.json()
|
||||||
ress.extend([d["embedding"] for d in res["data"]])
|
ress.extend([d["embedding"] for d in res["data"]])
|
||||||
token_count += self.total_token_count(res)
|
token_count += self.total_token_count_from_response(res)
|
||||||
except Exception as _e:
|
except Exception as _e:
|
||||||
log_exception(_e, response)
|
log_exception(_e, response)
|
||||||
raise Exception(f"Error: {response}")
|
raise Exception(f"Error: {response}")
|
||||||
@ -732,7 +732,7 @@ class SILICONFLOWEmbed(Base):
|
|||||||
try:
|
try:
|
||||||
res = response.json()
|
res = response.json()
|
||||||
ress.extend([d["embedding"] for d in res["data"]])
|
ress.extend([d["embedding"] for d in res["data"]])
|
||||||
token_count += self.total_token_count(res)
|
token_count += self.total_token_count_from_response(res)
|
||||||
except Exception as _e:
|
except Exception as _e:
|
||||||
log_exception(_e, response)
|
log_exception(_e, response)
|
||||||
raise Exception(f"Error: {response}")
|
raise Exception(f"Error: {response}")
|
||||||
@ -748,7 +748,7 @@ class SILICONFLOWEmbed(Base):
|
|||||||
response = requests.post(self.base_url, json=payload, headers=self.headers)
|
response = requests.post(self.base_url, json=payload, headers=self.headers)
|
||||||
try:
|
try:
|
||||||
res = response.json()
|
res = response.json()
|
||||||
return np.array(res["data"][0]["embedding"]), self.total_token_count(res)
|
return np.array(res["data"][0]["embedding"]), self.total_token_count_from_response(res)
|
||||||
except Exception as _e:
|
except Exception as _e:
|
||||||
log_exception(_e, response)
|
log_exception(_e, response)
|
||||||
raise Exception(f"Error: {response}")
|
raise Exception(f"Error: {response}")
|
||||||
@ -794,7 +794,7 @@ class BaiduYiyanEmbed(Base):
|
|||||||
try:
|
try:
|
||||||
return (
|
return (
|
||||||
np.array([r["embedding"] for r in res["data"]]),
|
np.array([r["embedding"] for r in res["data"]]),
|
||||||
self.total_token_count(res),
|
self.total_token_count_from_response(res),
|
||||||
)
|
)
|
||||||
except Exception as _e:
|
except Exception as _e:
|
||||||
log_exception(_e, res)
|
log_exception(_e, res)
|
||||||
@ -805,7 +805,7 @@ class BaiduYiyanEmbed(Base):
|
|||||||
try:
|
try:
|
||||||
return (
|
return (
|
||||||
np.array([r["embedding"] for r in res["data"]]),
|
np.array([r["embedding"] for r in res["data"]]),
|
||||||
self.total_token_count(res),
|
self.total_token_count_from_response(res),
|
||||||
)
|
)
|
||||||
except Exception as _e:
|
except Exception as _e:
|
||||||
log_exception(_e, res)
|
log_exception(_e, res)
|
||||||
|
|||||||
Reference in New Issue
Block a user