From 94dbd4aac90d9888d784315f2477308d1625127c Mon Sep 17 00:00:00 2001 From: Stephen Hu Date: Mon, 22 Sep 2025 17:17:06 +0800 Subject: [PATCH] Refactor: use the same implement for total token count from res (#10197) ### What problem does this PR solve? use the same implement for total token count from res ### Type of change - [x] Refactoring --- rag/llm/chat_model.py | 12 ++---------- rag/llm/embedding_model.py | 12 ++---------- rag/llm/rerank_model.py | 15 ++------------- rag/utils/__init__.py | 14 ++++++++++++++ 4 files changed, 20 insertions(+), 33 deletions(-) diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index b43277fc0..a3fb357f3 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -36,7 +36,7 @@ from zhipuai import ZhipuAI from rag.llm import FACTORY_DEFAULT_BASE_URL, LITELLM_PROVIDER_PREFIX, SupportedLiteLLMProvider from rag.nlp import is_chinese, is_english -from rag.utils import num_tokens_from_string +from rag.utils import num_tokens_from_string, total_token_count_from_response # Error message constants @@ -445,15 +445,7 @@ class Base(ABC): yield total_tokens def total_token_count(self, resp): - try: - return resp.usage.total_tokens - except Exception: - pass - try: - return resp["usage"]["total_tokens"] - except Exception: - pass - return 0 + return total_token_count_from_response(resp) def _calculate_dynamic_ctx(self, history): """Calculate dynamic context window size""" diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index 2e78b3dbc..4fc17eb31 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -33,7 +33,7 @@ from zhipuai import ZhipuAI from api import settings from api.utils.file_utils import get_home_cache_dir from api.utils.log_utils import log_exception -from rag.utils import num_tokens_from_string, truncate +from rag.utils import num_tokens_from_string, truncate, total_token_count_from_response class Base(ABC): @@ -52,15 +52,7 @@ class Base(ABC): raise NotImplementedError("Please implement encode method!") def total_token_count(self, resp): - try: - return resp.usage.total_tokens - except Exception: - pass - try: - return resp["usage"]["total_tokens"] - except Exception: - pass - return 0 + return total_token_count_from_response(resp) class DefaultEmbedding(Base): diff --git a/rag/llm/rerank_model.py b/rag/llm/rerank_model.py index f4ca7fb01..7256b047b 100644 --- a/rag/llm/rerank_model.py +++ b/rag/llm/rerank_model.py @@ -30,7 +30,7 @@ from yarl import URL from api import settings from api.utils.file_utils import get_home_cache_dir from api.utils.log_utils import log_exception -from rag.utils import num_tokens_from_string, truncate +from rag.utils import num_tokens_from_string, truncate, total_token_count_from_response class Base(ABC): def __init__(self, key, model_name, **kwargs): @@ -44,18 +44,7 @@ class Base(ABC): raise NotImplementedError("Please implement encode method!") def total_token_count(self, resp): - if hasattr(resp, "usage") and hasattr(resp.usage, "total_tokens"): - try: - return resp.usage.total_tokens - except Exception: - pass - - if 'usage' in resp and 'total_tokens' in resp['usage']: - try: - return resp["usage"]["total_tokens"] - except Exception: - pass - return 0 + return total_token_count_from_response(resp) class DefaultRerank(Base): diff --git a/rag/utils/__init__.py b/rag/utils/__init__.py index 8468bf4c3..22445da92 100644 --- a/rag/utils/__init__.py +++ b/rag/utils/__init__.py @@ -88,6 +88,20 @@ def num_tokens_from_string(string: str) -> int: except Exception: return 0 +def total_token_count_from_response(resp): + if hasattr(resp, "usage") and hasattr(resp.usage, "total_tokens"): + try: + return resp.usage.total_tokens + except Exception: + pass + + if 'usage' in resp and 'total_tokens' in resp['usage']: + try: + return resp["usage"]["total_tokens"] + except Exception: + pass + return 0 + def truncate(string: str, max_len: int) -> str: """Returns truncated text if the length of text exceed max_len."""