Refactor for total_tokens. (#4652)

### What problem does this PR solve?

#4567
### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
Kevin Hu
2025-01-26 13:54:26 +08:00
committed by GitHub
parent c24137bd11
commit 4776fa5e4e
3 changed files with 79 additions and 52 deletions

View File

@ -44,11 +44,23 @@ class Base(ABC):
def encode_queries(self, text: str):
raise NotImplementedError("Please implement encode method!")
def total_token_count(self, resp):
try:
return resp.usage.total_tokens
except Exception:
pass
try:
return resp["usage"]["total_tokens"]
except Exception:
pass
return 0
class DefaultEmbedding(Base):
_model = None
_model_name = ""
_model_lock = threading.Lock()
def __init__(self, key, model_name, **kwargs):
"""
If you have trouble downloading HuggingFace models, -_^ this might help!!
@ -115,13 +127,13 @@ class OpenAIEmbed(Base):
res = self.client.embeddings.create(input=texts[i:i + batch_size],
model=self.model_name)
ress.extend([d.embedding for d in res.data])
total_tokens += res.usage.total_tokens
total_tokens += self.total_token_count(res)
return np.array(ress), total_tokens
def encode_queries(self, text):
res = self.client.embeddings.create(input=[truncate(text, 8191)],
model=self.model_name)
return np.array(res.data[0].embedding), res.usage.total_tokens
return np.array(res.data[0].embedding), self.total_token_count(res)
class LocalAIEmbed(Base):
@ -188,7 +200,7 @@ class QWenEmbed(Base):
for e in resp["output"]["embeddings"]:
embds[e["text_index"]] = e["embedding"]
res.extend(embds)
token_count += resp["usage"]["total_tokens"]
token_count += self.total_token_count(resp)
return np.array(res), token_count
except Exception as e:
raise Exception("Account abnormal. Please ensure it's on good standing to use QWen's "+self.model_name)
@ -203,7 +215,7 @@ class QWenEmbed(Base):
text_type="query"
)
return np.array(resp["output"]["embeddings"][0]
["embedding"]), resp["usage"]["total_tokens"]
["embedding"]), self.total_token_count(resp)
except Exception:
raise Exception("Account abnormal. Please ensure it's on good standing to use QWen's "+self.model_name)
return np.array([]), 0
@ -229,13 +241,13 @@ class ZhipuEmbed(Base):
res = self.client.embeddings.create(input=txt,
model=self.model_name)
arr.append(res.data[0].embedding)
tks_num += res.usage.total_tokens
tks_num += self.total_token_count(res)
return np.array(arr), tks_num
def encode_queries(self, text):
res = self.client.embeddings.create(input=text,
model=self.model_name)
return np.array(res.data[0].embedding), res.usage.total_tokens
return np.array(res.data[0].embedding), self.total_token_count(res)
class OllamaEmbed(Base):
@ -318,13 +330,13 @@ class XinferenceEmbed(Base):
for i in range(0, len(texts), batch_size):
res = self.client.embeddings.create(input=texts[i:i + batch_size], model=self.model_name)
ress.extend([d.embedding for d in res.data])
total_tokens += res.usage.total_tokens
total_tokens += self.total_token_count(res)
return np.array(ress), total_tokens
def encode_queries(self, text):
res = self.client.embeddings.create(input=[text],
model=self.model_name)
return np.array(res.data[0].embedding), res.usage.total_tokens
return np.array(res.data[0].embedding), self.total_token_count(res)
class YoudaoEmbed(Base):
@ -383,7 +395,7 @@ class JinaEmbed(Base):
}
res = requests.post(self.base_url, headers=self.headers, json=data).json()
ress.extend([d["embedding"] for d in res["data"]])
token_count += res["usage"]["total_tokens"]
token_count += self.total_token_count(res)
return np.array(ress), token_count
def encode_queries(self, text):
@ -447,13 +459,13 @@ class MistralEmbed(Base):
res = self.client.embeddings(input=texts[i:i + batch_size],
model=self.model_name)
ress.extend([d.embedding for d in res.data])
token_count += res.usage.total_tokens
token_count += self.total_token_count(res)
return np.array(ress), token_count
def encode_queries(self, text):
res = self.client.embeddings(input=[truncate(text, 8196)],
model=self.model_name)
return np.array(res.data[0].embedding), res.usage.total_tokens
return np.array(res.data[0].embedding), self.total_token_count(res)
class BedrockEmbed(Base):
@ -565,7 +577,7 @@ class NvidiaEmbed(Base):
}
res = requests.post(self.base_url, headers=self.headers, json=payload).json()
ress.extend([d["embedding"] for d in res["data"]])
token_count += res["usage"]["total_tokens"]
token_count += self.total_token_count(res)
return np.array(ress), token_count
def encode_queries(self, text):
@ -677,7 +689,7 @@ class SILICONFLOWEmbed(Base):
if "data" not in res or not isinstance(res["data"], list) or len(res["data"]) != len(texts_batch):
raise ValueError(f"SILICONFLOWEmbed.encode got invalid response from {self.base_url}")
ress.extend([d["embedding"] for d in res["data"]])
token_count += res["usage"]["total_tokens"]
token_count += self.total_token_count(res)
return np.array(ress), token_count
def encode_queries(self, text):
@ -689,7 +701,7 @@ class SILICONFLOWEmbed(Base):
res = requests.post(self.base_url, json=payload, headers=self.headers).json()
if "data" not in res or not isinstance(res["data"], list) or len(res["data"])!= 1:
raise ValueError(f"SILICONFLOWEmbed.encode_queries got invalid response from {self.base_url}")
return np.array(res["data"][0]["embedding"]), res["usage"]["total_tokens"]
return np.array(res["data"][0]["embedding"]), self.total_token_count(res)
class ReplicateEmbed(Base):
@ -727,14 +739,14 @@ class BaiduYiyanEmbed(Base):
res = self.client.do(model=self.model_name, texts=texts).body
return (
np.array([r["embedding"] for r in res["data"]]),
res["usage"]["total_tokens"],
self.total_token_count(res),
)
def encode_queries(self, text):
res = self.client.do(model=self.model_name, texts=[text]).body
return (
np.array([r["embedding"] for r in res["data"]]),
res["usage"]["total_tokens"],
self.total_token_count(res),
)