From 2a0f835ffea54e4bc899dfa03a4d211b278bea2f Mon Sep 17 00:00:00 2001 From: Stephen Hu <812791840@qq.com> Date: Mon, 15 Dec 2025 11:33:57 +0800 Subject: [PATCH] Refactor: Improve the logic to calculate embedding total token count (#11943) ### What problem does this PR solve? Improve the logic to calculate embedding total token count ### Type of change - [x] Refactoring --- rag/llm/embedding_model.py | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index 1890b68d0..58cdd8e76 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -28,7 +28,7 @@ from openai import OpenAI from zhipuai import ZhipuAI from common.log_utils import log_exception -from common.token_utils import num_tokens_from_string, truncate +from common.token_utils import num_tokens_from_string, truncate, total_token_count_from_response from common import settings import logging import base64 @@ -118,7 +118,7 @@ class OpenAIEmbed(Base): res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name, encoding_format="float", extra_body={"drop_params": True}) try: ress.extend([d.embedding for d in res.data]) - total_tokens += self.total_token_count(res) + total_tokens += total_token_count_from_response(res) except Exception as _e: log_exception(_e, res) raise Exception(f"Error: {res}") @@ -127,7 +127,7 @@ class OpenAIEmbed(Base): def encode_queries(self, text): res = self.client.embeddings.create(input=[truncate(text, 8191)], model=self.model_name, encoding_format="float",extra_body={"drop_params": True}) try: - return np.array(res.data[0].embedding), self.total_token_count(res) + return np.array(res.data[0].embedding), self.total_token_count_from_response(res) except Exception as _e: log_exception(_e, res) raise Exception(f"Error: {res}") @@ -216,7 +216,7 @@ class QWenEmbed(Base): for e in resp["output"]["embeddings"]: embds[e["text_index"]] = e["embedding"] res.extend(embds) - token_count += self.total_token_count(resp) + token_count += self.total_token_count_from_response(resp) except Exception as _e: log_exception(_e, resp) raise @@ -225,7 +225,7 @@ class QWenEmbed(Base): def encode_queries(self, text): resp = dashscope.TextEmbedding.call(model=self.model_name, input=text[:2048], api_key=self.key, text_type="query") try: - return np.array(resp["output"]["embeddings"][0]["embedding"]), self.total_token_count(resp) + return np.array(resp["output"]["embeddings"][0]["embedding"]), self.total_token_count_from_response(resp) except Exception as _e: log_exception(_e, resp) raise Exception(f"Error: {resp}") @@ -253,7 +253,7 @@ class ZhipuEmbed(Base): res = self.client.embeddings.create(input=txt, model=self.model_name) try: arr.append(res.data[0].embedding) - tks_num += self.total_token_count(res) + tks_num += self.total_token_count_from_response(res) except Exception as _e: log_exception(_e, res) raise Exception(f"Error: {res}") @@ -262,7 +262,7 @@ class ZhipuEmbed(Base): def encode_queries(self, text): res = self.client.embeddings.create(input=text, model=self.model_name) try: - return np.array(res.data[0].embedding), self.total_token_count(res) + return np.array(res.data[0].embedding), self.total_token_count_from_response(res) except Exception as _e: log_exception(_e, res) raise Exception(f"Error: {res}") @@ -323,7 +323,7 @@ class XinferenceEmbed(Base): try: res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name) ress.extend([d.embedding for d in res.data]) - total_tokens += self.total_token_count(res) + total_tokens += self.total_token_count_from_response(res) except Exception as _e: log_exception(_e, res) raise Exception(f"Error: {res}") @@ -333,7 +333,7 @@ class XinferenceEmbed(Base): res = None try: res = self.client.embeddings.create(input=[text], model=self.model_name) - return np.array(res.data[0].embedding), self.total_token_count(res) + return np.array(res.data[0].embedding), self.total_token_count_from_response(res) except Exception as _e: log_exception(_e, res) raise Exception(f"Error: {res}") @@ -409,7 +409,7 @@ class JinaMultiVecEmbed(Base): ress.append(chunk_emb) - token_count += self.total_token_count(res) + token_count += self.total_token_count_from_response(res) except Exception as _e: log_exception(_e, response) raise Exception(f"Error: {response}") @@ -443,7 +443,7 @@ class MistralEmbed(Base): try: res = self.client.embeddings(input=texts[i : i + batch_size], model=self.model_name) ress.extend([d.embedding for d in res.data]) - token_count += self.total_token_count(res) + token_count += self.total_token_count_from_response(res) break except Exception as _e: if retry_max == 1: @@ -460,7 +460,7 @@ class MistralEmbed(Base): while retry_max > 0: try: res = self.client.embeddings(input=[truncate(text, 8196)], model=self.model_name) - return np.array(res.data[0].embedding), self.total_token_count(res) + return np.array(res.data[0].embedding), self.total_token_count_from_response(res) except Exception as _e: if retry_max == 1: log_exception(_e) @@ -595,7 +595,7 @@ class NvidiaEmbed(Base): try: res = response.json() ress.extend([d["embedding"] for d in res["data"]]) - token_count += self.total_token_count(res) + token_count += self.total_token_count_from_response(res) except Exception as _e: log_exception(_e, response) raise Exception(f"Error: {response}") @@ -732,7 +732,7 @@ class SILICONFLOWEmbed(Base): try: res = response.json() ress.extend([d["embedding"] for d in res["data"]]) - token_count += self.total_token_count(res) + token_count += self.total_token_count_from_response(res) except Exception as _e: log_exception(_e, response) raise Exception(f"Error: {response}") @@ -748,7 +748,7 @@ class SILICONFLOWEmbed(Base): response = requests.post(self.base_url, json=payload, headers=self.headers) try: res = response.json() - return np.array(res["data"][0]["embedding"]), self.total_token_count(res) + return np.array(res["data"][0]["embedding"]), self.total_token_count_from_response(res) except Exception as _e: log_exception(_e, response) raise Exception(f"Error: {response}") @@ -794,7 +794,7 @@ class BaiduYiyanEmbed(Base): try: return ( np.array([r["embedding"] for r in res["data"]]), - self.total_token_count(res), + self.total_token_count_from_response(res), ) except Exception as _e: log_exception(_e, res) @@ -805,7 +805,7 @@ class BaiduYiyanEmbed(Base): try: return ( np.array([r["embedding"] for r in res["data"]]), - self.total_token_count(res), + self.total_token_count_from_response(res), ) except Exception as _e: log_exception(_e, res)