mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Refactor: use the same implement for total token count from res (#10197)
### What problem does this PR solve? use the same implement for total token count from res ### Type of change - [x] Refactoring
This commit is contained in:
@ -36,7 +36,7 @@ from zhipuai import ZhipuAI
|
|||||||
|
|
||||||
from rag.llm import FACTORY_DEFAULT_BASE_URL, LITELLM_PROVIDER_PREFIX, SupportedLiteLLMProvider
|
from rag.llm import FACTORY_DEFAULT_BASE_URL, LITELLM_PROVIDER_PREFIX, SupportedLiteLLMProvider
|
||||||
from rag.nlp import is_chinese, is_english
|
from rag.nlp import is_chinese, is_english
|
||||||
from rag.utils import num_tokens_from_string
|
from rag.utils import num_tokens_from_string, total_token_count_from_response
|
||||||
|
|
||||||
|
|
||||||
# Error message constants
|
# Error message constants
|
||||||
@ -445,15 +445,7 @@ class Base(ABC):
|
|||||||
yield total_tokens
|
yield total_tokens
|
||||||
|
|
||||||
def total_token_count(self, resp):
|
def total_token_count(self, resp):
|
||||||
try:
|
return total_token_count_from_response(resp)
|
||||||
return resp.usage.total_tokens
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
try:
|
|
||||||
return resp["usage"]["total_tokens"]
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
return 0
|
|
||||||
|
|
||||||
def _calculate_dynamic_ctx(self, history):
|
def _calculate_dynamic_ctx(self, history):
|
||||||
"""Calculate dynamic context window size"""
|
"""Calculate dynamic context window size"""
|
||||||
|
|||||||
@ -33,7 +33,7 @@ from zhipuai import ZhipuAI
|
|||||||
from api import settings
|
from api import settings
|
||||||
from api.utils.file_utils import get_home_cache_dir
|
from api.utils.file_utils import get_home_cache_dir
|
||||||
from api.utils.log_utils import log_exception
|
from api.utils.log_utils import log_exception
|
||||||
from rag.utils import num_tokens_from_string, truncate
|
from rag.utils import num_tokens_from_string, truncate, total_token_count_from_response
|
||||||
|
|
||||||
|
|
||||||
class Base(ABC):
|
class Base(ABC):
|
||||||
@ -52,15 +52,7 @@ class Base(ABC):
|
|||||||
raise NotImplementedError("Please implement encode method!")
|
raise NotImplementedError("Please implement encode method!")
|
||||||
|
|
||||||
def total_token_count(self, resp):
|
def total_token_count(self, resp):
|
||||||
try:
|
return total_token_count_from_response(resp)
|
||||||
return resp.usage.total_tokens
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
try:
|
|
||||||
return resp["usage"]["total_tokens"]
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
return 0
|
|
||||||
|
|
||||||
|
|
||||||
class DefaultEmbedding(Base):
|
class DefaultEmbedding(Base):
|
||||||
|
|||||||
@ -30,7 +30,7 @@ from yarl import URL
|
|||||||
from api import settings
|
from api import settings
|
||||||
from api.utils.file_utils import get_home_cache_dir
|
from api.utils.file_utils import get_home_cache_dir
|
||||||
from api.utils.log_utils import log_exception
|
from api.utils.log_utils import log_exception
|
||||||
from rag.utils import num_tokens_from_string, truncate
|
from rag.utils import num_tokens_from_string, truncate, total_token_count_from_response
|
||||||
|
|
||||||
class Base(ABC):
|
class Base(ABC):
|
||||||
def __init__(self, key, model_name, **kwargs):
|
def __init__(self, key, model_name, **kwargs):
|
||||||
@ -44,18 +44,7 @@ class Base(ABC):
|
|||||||
raise NotImplementedError("Please implement encode method!")
|
raise NotImplementedError("Please implement encode method!")
|
||||||
|
|
||||||
def total_token_count(self, resp):
|
def total_token_count(self, resp):
|
||||||
if hasattr(resp, "usage") and hasattr(resp.usage, "total_tokens"):
|
return total_token_count_from_response(resp)
|
||||||
try:
|
|
||||||
return resp.usage.total_tokens
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if 'usage' in resp and 'total_tokens' in resp['usage']:
|
|
||||||
try:
|
|
||||||
return resp["usage"]["total_tokens"]
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
return 0
|
|
||||||
|
|
||||||
|
|
||||||
class DefaultRerank(Base):
|
class DefaultRerank(Base):
|
||||||
|
|||||||
@ -88,6 +88,20 @@ def num_tokens_from_string(string: str) -> int:
|
|||||||
except Exception:
|
except Exception:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
def total_token_count_from_response(resp):
|
||||||
|
if hasattr(resp, "usage") and hasattr(resp.usage, "total_tokens"):
|
||||||
|
try:
|
||||||
|
return resp.usage.total_tokens
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if 'usage' in resp and 'total_tokens' in resp['usage']:
|
||||||
|
try:
|
||||||
|
return resp["usage"]["total_tokens"]
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
def truncate(string: str, max_len: int) -> str:
|
def truncate(string: str, max_len: int) -> str:
|
||||||
"""Returns truncated text if the length of text exceed max_len."""
|
"""Returns truncated text if the length of text exceed max_len."""
|
||||||
|
|||||||
Reference in New Issue
Block a user