Refactor: use the same implement for total token count from res (#10197)

### What problem does this PR solve?
use the same implement for total token count from res

### Type of change

- [x] Refactoring
This commit is contained in:
Stephen Hu
2025-09-22 17:17:06 +08:00
committed by GitHub
parent ca9f30e1a1
commit 94dbd4aac9
4 changed files with 20 additions and 33 deletions

View File

@ -36,7 +36,7 @@ from zhipuai import ZhipuAI
from rag.llm import FACTORY_DEFAULT_BASE_URL, LITELLM_PROVIDER_PREFIX, SupportedLiteLLMProvider
from rag.nlp import is_chinese, is_english
from rag.utils import num_tokens_from_string
from rag.utils import num_tokens_from_string, total_token_count_from_response
# Error message constants
@ -445,15 +445,7 @@ class Base(ABC):
yield total_tokens
def total_token_count(self, resp):
try:
return resp.usage.total_tokens
except Exception:
pass
try:
return resp["usage"]["total_tokens"]
except Exception:
pass
return 0
return total_token_count_from_response(resp)
def _calculate_dynamic_ctx(self, history):
"""Calculate dynamic context window size"""

View File

@ -33,7 +33,7 @@ from zhipuai import ZhipuAI
from api import settings
from api.utils.file_utils import get_home_cache_dir
from api.utils.log_utils import log_exception
from rag.utils import num_tokens_from_string, truncate
from rag.utils import num_tokens_from_string, truncate, total_token_count_from_response
class Base(ABC):
@ -52,15 +52,7 @@ class Base(ABC):
raise NotImplementedError("Please implement encode method!")
def total_token_count(self, resp):
try:
return resp.usage.total_tokens
except Exception:
pass
try:
return resp["usage"]["total_tokens"]
except Exception:
pass
return 0
return total_token_count_from_response(resp)
class DefaultEmbedding(Base):

View File

@ -30,7 +30,7 @@ from yarl import URL
from api import settings
from api.utils.file_utils import get_home_cache_dir
from api.utils.log_utils import log_exception
from rag.utils import num_tokens_from_string, truncate
from rag.utils import num_tokens_from_string, truncate, total_token_count_from_response
class Base(ABC):
def __init__(self, key, model_name, **kwargs):
@ -44,18 +44,7 @@ class Base(ABC):
raise NotImplementedError("Please implement encode method!")
def total_token_count(self, resp):
if hasattr(resp, "usage") and hasattr(resp.usage, "total_tokens"):
try:
return resp.usage.total_tokens
except Exception:
pass
if 'usage' in resp and 'total_tokens' in resp['usage']:
try:
return resp["usage"]["total_tokens"]
except Exception:
pass
return 0
return total_token_count_from_response(resp)
class DefaultRerank(Base):

View File

@ -88,6 +88,20 @@ def num_tokens_from_string(string: str) -> int:
except Exception:
return 0
def total_token_count_from_response(resp):
if hasattr(resp, "usage") and hasattr(resp.usage, "total_tokens"):
try:
return resp.usage.total_tokens
except Exception:
pass
if 'usage' in resp and 'total_tokens' in resp['usage']:
try:
return resp["usage"]["total_tokens"]
except Exception:
pass
return 0
def truncate(string: str, max_len: int) -> str:
"""Returns truncated text if the length of text exceed max_len."""