Refactor: use the same implement for total token count from res (#10197)

### What problem does this PR solve?
use the same implement for total token count from res

### Type of change

- [x] Refactoring
This commit is contained in:
Stephen Hu
2025-09-22 17:17:06 +08:00
committed by GitHub
parent ca9f30e1a1
commit 94dbd4aac9
4 changed files with 20 additions and 33 deletions

View File

@ -36,7 +36,7 @@ from zhipuai import ZhipuAI
from rag.llm import FACTORY_DEFAULT_BASE_URL, LITELLM_PROVIDER_PREFIX, SupportedLiteLLMProvider from rag.llm import FACTORY_DEFAULT_BASE_URL, LITELLM_PROVIDER_PREFIX, SupportedLiteLLMProvider
from rag.nlp import is_chinese, is_english from rag.nlp import is_chinese, is_english
from rag.utils import num_tokens_from_string from rag.utils import num_tokens_from_string, total_token_count_from_response
# Error message constants # Error message constants
@ -445,15 +445,7 @@ class Base(ABC):
yield total_tokens yield total_tokens
def total_token_count(self, resp): def total_token_count(self, resp):
try: return total_token_count_from_response(resp)
return resp.usage.total_tokens
except Exception:
pass
try:
return resp["usage"]["total_tokens"]
except Exception:
pass
return 0
def _calculate_dynamic_ctx(self, history): def _calculate_dynamic_ctx(self, history):
"""Calculate dynamic context window size""" """Calculate dynamic context window size"""

View File

@ -33,7 +33,7 @@ from zhipuai import ZhipuAI
from api import settings from api import settings
from api.utils.file_utils import get_home_cache_dir from api.utils.file_utils import get_home_cache_dir
from api.utils.log_utils import log_exception from api.utils.log_utils import log_exception
from rag.utils import num_tokens_from_string, truncate from rag.utils import num_tokens_from_string, truncate, total_token_count_from_response
class Base(ABC): class Base(ABC):
@ -52,15 +52,7 @@ class Base(ABC):
raise NotImplementedError("Please implement encode method!") raise NotImplementedError("Please implement encode method!")
def total_token_count(self, resp): def total_token_count(self, resp):
try: return total_token_count_from_response(resp)
return resp.usage.total_tokens
except Exception:
pass
try:
return resp["usage"]["total_tokens"]
except Exception:
pass
return 0
class DefaultEmbedding(Base): class DefaultEmbedding(Base):

View File

@ -30,7 +30,7 @@ from yarl import URL
from api import settings from api import settings
from api.utils.file_utils import get_home_cache_dir from api.utils.file_utils import get_home_cache_dir
from api.utils.log_utils import log_exception from api.utils.log_utils import log_exception
from rag.utils import num_tokens_from_string, truncate from rag.utils import num_tokens_from_string, truncate, total_token_count_from_response
class Base(ABC): class Base(ABC):
def __init__(self, key, model_name, **kwargs): def __init__(self, key, model_name, **kwargs):
@ -44,18 +44,7 @@ class Base(ABC):
raise NotImplementedError("Please implement encode method!") raise NotImplementedError("Please implement encode method!")
def total_token_count(self, resp): def total_token_count(self, resp):
if hasattr(resp, "usage") and hasattr(resp.usage, "total_tokens"): return total_token_count_from_response(resp)
try:
return resp.usage.total_tokens
except Exception:
pass
if 'usage' in resp and 'total_tokens' in resp['usage']:
try:
return resp["usage"]["total_tokens"]
except Exception:
pass
return 0
class DefaultRerank(Base): class DefaultRerank(Base):

View File

@ -88,6 +88,20 @@ def num_tokens_from_string(string: str) -> int:
except Exception: except Exception:
return 0 return 0
def total_token_count_from_response(resp):
if hasattr(resp, "usage") and hasattr(resp.usage, "total_tokens"):
try:
return resp.usage.total_tokens
except Exception:
pass
if 'usage' in resp and 'total_tokens' in resp['usage']:
try:
return resp["usage"]["total_tokens"]
except Exception:
pass
return 0
def truncate(string: str, max_len: int) -> str: def truncate(string: str, max_len: int) -> str:
"""Returns truncated text if the length of text exceed max_len.""" """Returns truncated text if the length of text exceed max_len."""