mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Move token related functions to common (#10942)
### What problem does this PR solve? As title ### Type of change - [x] Refactoring Signed-off-by: Jin Hai <haijin.chn@gmail.com>
This commit is contained in:
@ -41,7 +41,7 @@ from rag.app.tag import label_question
|
||||
from rag.nlp.search import index_name
|
||||
from rag.prompts.generator import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, message_fit_in, \
|
||||
gen_meta_filter, PROMPT_JINJA_ENV, ASK_SUMMARY
|
||||
from rag.utils import num_tokens_from_string
|
||||
from common.token_utils import num_tokens_from_string
|
||||
from rag.utils.tavily_conn import Tavily
|
||||
from common.string_utils import remove_redundant_spaces
|
||||
|
||||
|
||||
@ -16,7 +16,7 @@
|
||||
import inspect
|
||||
import logging
|
||||
import re
|
||||
from rag.utils import num_tokens_from_string
|
||||
from common.token_utils import num_tokens_from_string
|
||||
from functools import partial
|
||||
from typing import Generator
|
||||
from api.db.db_models import LLM
|
||||
|
||||
76
common/token_utils.py
Normal file
76
common/token_utils.py
Normal file
@ -0,0 +1,76 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
|
||||
import os
|
||||
import tiktoken
|
||||
|
||||
from common.file_utils import get_project_base_directory
|
||||
|
||||
tiktoken_cache_dir = get_project_base_directory()
|
||||
os.environ["TIKTOKEN_CACHE_DIR"] = tiktoken_cache_dir
|
||||
# encoder = tiktoken.encoding_for_model("gpt-3.5-turbo")
|
||||
encoder = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
|
||||
def num_tokens_from_string(string: str) -> int:
|
||||
"""Returns the number of tokens in a text string."""
|
||||
try:
|
||||
code_list = encoder.encode(string)
|
||||
return len(code_list)
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
def total_token_count_from_response(resp):
|
||||
if resp is None:
|
||||
return 0
|
||||
|
||||
if hasattr(resp, "usage") and hasattr(resp.usage, "total_tokens"):
|
||||
try:
|
||||
return resp.usage.total_tokens
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if hasattr(resp, "usage_metadata") and hasattr(resp.usage_metadata, "total_tokens"):
|
||||
try:
|
||||
return resp.usage_metadata.total_tokens
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if 'usage' in resp and 'total_tokens' in resp['usage']:
|
||||
try:
|
||||
return resp["usage"]["total_tokens"]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if 'usage' in resp and 'input_tokens' in resp['usage'] and 'output_tokens' in resp['usage']:
|
||||
try:
|
||||
return resp["usage"]["input_tokens"] + resp["usage"]["output_tokens"]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if 'meta' in resp and 'tokens' in resp['meta'] and 'input_tokens' in resp['meta']['tokens'] and 'output_tokens' in resp['meta']['tokens']:
|
||||
try:
|
||||
return resp["meta"]["tokens"]["input_tokens"] + resp["meta"]["tokens"]["output_tokens"]
|
||||
except Exception:
|
||||
pass
|
||||
return 0
|
||||
|
||||
|
||||
def truncate(string: str, max_len: int) -> str:
|
||||
"""Returns truncated text if the length of text exceed max_len."""
|
||||
return encoder.decode(encoder.encode(string)[:max_len])
|
||||
|
||||
@ -17,7 +17,7 @@
|
||||
import re
|
||||
|
||||
from deepdoc.parser.utils import get_text
|
||||
from rag.nlp import num_tokens_from_string
|
||||
from common.token_utils import num_tokens_from_string
|
||||
|
||||
|
||||
class RAGFlowTxtParser:
|
||||
|
||||
@ -21,7 +21,7 @@ from graphrag.general.extractor import Extractor
|
||||
from graphrag.general.leiden import add_community_info2graph
|
||||
from rag.llm.chat_model import Base as CompletionLLM
|
||||
from graphrag.utils import perform_variable_replacements, dict_has_keys_with_types, chat_limiter
|
||||
from rag.utils import num_tokens_from_string
|
||||
from common.token_utils import num_tokens_from_string
|
||||
import trio
|
||||
|
||||
|
||||
|
||||
@ -38,7 +38,7 @@ from graphrag.utils import (
|
||||
)
|
||||
from rag.llm.chat_model import Base as CompletionLLM
|
||||
from rag.prompts.generator import message_fit_in
|
||||
from rag.utils import truncate
|
||||
from common.token_utils import truncate
|
||||
|
||||
GRAPH_FIELD_SEP = "<SEP>"
|
||||
DEFAULT_ENTITY_TYPES = ["organization", "person", "geo", "event", "category"]
|
||||
|
||||
@ -16,7 +16,7 @@ from graphrag.general.graph_prompt import GRAPH_EXTRACTION_PROMPT, CONTINUE_PROM
|
||||
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, chat_limiter, split_string_by_multi_markers
|
||||
from rag.llm.chat_model import Base as CompletionLLM
|
||||
import networkx as nx
|
||||
from rag.utils import num_tokens_from_string
|
||||
from common.token_utils import num_tokens_from_string
|
||||
|
||||
DEFAULT_TUPLE_DELIMITER = "<|>"
|
||||
DEFAULT_RECORD_DELIMITER = "##"
|
||||
|
||||
@ -165,7 +165,7 @@ async def run_graphrag_for_kb(
|
||||
return {"ok_docs": [], "failed_docs": [], "total_docs": 0, "total_chunks": 0, "seconds": 0.0}
|
||||
|
||||
def load_doc_chunks(doc_id: str) -> list[str]:
|
||||
from rag.utils import num_tokens_from_string
|
||||
from common.token_utils import num_tokens_from_string
|
||||
|
||||
chunks = []
|
||||
current_chunk = ""
|
||||
|
||||
@ -27,7 +27,7 @@ from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, chat_l
|
||||
from rag.llm.chat_model import Base as CompletionLLM
|
||||
import markdown_to_json
|
||||
from functools import reduce
|
||||
from rag.utils import num_tokens_from_string
|
||||
from common.token_utils import num_tokens_from_string
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@ -17,7 +17,7 @@ from graphrag.general.extractor import ENTITY_EXTRACTION_MAX_GLEANINGS, Extracto
|
||||
from graphrag.light.graph_prompt import PROMPTS
|
||||
from graphrag.utils import chat_limiter, pack_user_ass_to_openai_messages, split_string_by_multi_markers
|
||||
from rag.llm.chat_model import Base as CompletionLLM
|
||||
from rag.utils import num_tokens_from_string
|
||||
from common.token_utils import num_tokens_from_string
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@ -24,7 +24,7 @@ import trio
|
||||
from common.misc_utils import get_uuid
|
||||
from graphrag.query_analyze_prompt import PROMPTS
|
||||
from graphrag.utils import get_entity_type2samples, get_llm_cache, set_llm_cache, get_relation
|
||||
from rag.utils import num_tokens_from_string
|
||||
from common.token_utils import num_tokens_from_string
|
||||
from rag.utils.doc_store_conn import OrderByExpr
|
||||
|
||||
from rag.nlp.search import Dealer, index_name
|
||||
|
||||
@ -21,7 +21,7 @@ import re
|
||||
from api.db import ParserType
|
||||
from io import BytesIO
|
||||
from rag.nlp import rag_tokenizer, tokenize, tokenize_table, bullets_category, title_frequency, tokenize_chunks, docx_question_level
|
||||
from rag.utils import num_tokens_from_string
|
||||
from common.token_utils import num_tokens_from_string
|
||||
from deepdoc.parser import PdfParser, PlainParser, DocxParser
|
||||
from deepdoc.parser.figure_parser import vision_figure_parser_pdf_wrapper,vision_figure_parser_docx_wrapper
|
||||
from docx import Document
|
||||
|
||||
@ -29,7 +29,7 @@ from rag.flow.tokenizer.schema import TokenizerFromUpstream
|
||||
from rag.nlp import rag_tokenizer
|
||||
from rag.settings import EMBEDDING_BATCH_SIZE
|
||||
from rag.svr.task_executor import embed_limiter
|
||||
from rag.utils import truncate
|
||||
from common.token_utils import truncate
|
||||
|
||||
|
||||
class TokenizerParam(ProcessParamBase):
|
||||
|
||||
@ -36,7 +36,7 @@ from zhipuai import ZhipuAI
|
||||
|
||||
from rag.llm import FACTORY_DEFAULT_BASE_URL, LITELLM_PROVIDER_PREFIX, SupportedLiteLLMProvider
|
||||
from rag.nlp import is_chinese, is_english
|
||||
from rag.utils import num_tokens_from_string, total_token_count_from_response
|
||||
from common.token_utils import num_tokens_from_string, total_token_count_from_response
|
||||
|
||||
|
||||
# Error message constants
|
||||
|
||||
@ -30,7 +30,7 @@ from openai.lib.azure import AzureOpenAI
|
||||
from zhipuai import ZhipuAI
|
||||
from rag.nlp import is_english
|
||||
from rag.prompts.generator import vision_llm_describe_prompt
|
||||
from rag.utils import num_tokens_from_string, total_token_count_from_response
|
||||
from common.token_utils import num_tokens_from_string, total_token_count_from_response
|
||||
|
||||
|
||||
class Base(ABC):
|
||||
|
||||
@ -28,7 +28,7 @@ from openai import OpenAI
|
||||
from zhipuai import ZhipuAI
|
||||
|
||||
from api.utils.log_utils import log_exception
|
||||
from rag.utils import num_tokens_from_string, truncate
|
||||
from common.token_utils import num_tokens_from_string, truncate
|
||||
from api import settings
|
||||
import logging
|
||||
|
||||
|
||||
@ -23,7 +23,7 @@ import requests
|
||||
from yarl import URL
|
||||
|
||||
from api.utils.log_utils import log_exception
|
||||
from rag.utils import num_tokens_from_string, truncate, total_token_count_from_response
|
||||
from common.token_utils import num_tokens_from_string, truncate, total_token_count_from_response
|
||||
|
||||
class Base(ABC):
|
||||
def __init__(self, key, model_name, **kwargs):
|
||||
|
||||
@ -24,7 +24,7 @@ import requests
|
||||
from openai import OpenAI
|
||||
from openai.lib.azure import AzureOpenAI
|
||||
|
||||
from rag.utils import num_tokens_from_string
|
||||
from common.token_utils import num_tokens_from_string
|
||||
|
||||
|
||||
class Base(ABC):
|
||||
|
||||
@ -36,7 +36,7 @@ import requests
|
||||
import websocket
|
||||
from pydantic import BaseModel, conint
|
||||
|
||||
from rag.utils import num_tokens_from_string
|
||||
from common.token_utils import num_tokens_from_string
|
||||
|
||||
|
||||
class ServeReferenceAudio(BaseModel):
|
||||
|
||||
@ -18,7 +18,7 @@ import logging
|
||||
import random
|
||||
from collections import Counter
|
||||
|
||||
from rag.utils import num_tokens_from_string
|
||||
from common.token_utils import num_tokens_from_string
|
||||
from . import rag_tokenizer
|
||||
import re
|
||||
import copy
|
||||
|
||||
@ -26,7 +26,7 @@ from common.misc_utils import hash_str2int
|
||||
from rag.nlp import rag_tokenizer
|
||||
from rag.prompts.template import load_prompt
|
||||
from rag.settings import TAG_FLD
|
||||
from rag.utils import encoder, num_tokens_from_string
|
||||
from common.token_utils import encoder, num_tokens_from_string
|
||||
|
||||
|
||||
STOP_TOKEN="<|STOP|>"
|
||||
|
||||
@ -28,7 +28,7 @@ from graphrag.utils import (
|
||||
set_llm_cache,
|
||||
chat_limiter,
|
||||
)
|
||||
from rag.utils import truncate
|
||||
from common.token_utils import truncate
|
||||
|
||||
|
||||
class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
||||
|
||||
@ -65,7 +65,7 @@ from rag.app import laws, paper, presentation, manual, qa, table, book, resume,
|
||||
from rag.nlp import search, rag_tokenizer, add_positions
|
||||
from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
|
||||
from rag.settings import DOC_MAXIMUM_SIZE, DOC_BULK_SIZE, EMBEDDING_BATCH_SIZE, SVR_CONSUMER_GROUP_NAME, get_svr_queue_name, get_svr_queue_names, print_rag_settings, TAG_FLD, PAGERANK_FLD
|
||||
from rag.utils import num_tokens_from_string, truncate
|
||||
from common.token_utils import num_tokens_from_string, truncate
|
||||
from rag.utils.redis_conn import REDIS_CONN, RedisDistributedLock
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
from graphrag.utils import chat_limiter
|
||||
|
||||
@ -14,59 +14,3 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import os
|
||||
import tiktoken
|
||||
|
||||
from common.file_utils import get_project_base_directory
|
||||
|
||||
tiktoken_cache_dir = get_project_base_directory()
|
||||
os.environ["TIKTOKEN_CACHE_DIR"] = tiktoken_cache_dir
|
||||
# encoder = tiktoken.encoding_for_model("gpt-3.5-turbo")
|
||||
encoder = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
|
||||
def num_tokens_from_string(string: str) -> int:
|
||||
"""Returns the number of tokens in a text string."""
|
||||
try:
|
||||
return len(encoder.encode(string))
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
def total_token_count_from_response(resp):
|
||||
if hasattr(resp, "usage") and hasattr(resp.usage, "total_tokens"):
|
||||
try:
|
||||
return resp.usage.total_tokens
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if hasattr(resp, "usage_metadata") and hasattr(resp.usage_metadata, "total_tokens"):
|
||||
try:
|
||||
return resp.usage_metadata.total_tokens
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if 'usage' in resp and 'total_tokens' in resp['usage']:
|
||||
try:
|
||||
return resp["usage"]["total_tokens"]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if 'usage' in resp and 'input_tokens' in resp['usage'] and 'output_tokens' in resp['usage']:
|
||||
try:
|
||||
return resp["usage"]["input_tokens"] + resp["usage"]["output_tokens"]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if 'meta' in resp and 'tokens' in resp['meta'] and 'input_tokens' in resp['meta']['tokens'] and 'output_tokens' in resp['meta']['tokens']:
|
||||
try:
|
||||
return resp["meta"]["tokens"]["input_tokens"] + resp["meta"]["tokens"]["output_tokens"]
|
||||
except Exception:
|
||||
pass
|
||||
return 0
|
||||
|
||||
|
||||
def truncate(string: str, max_len: int) -> str:
|
||||
"""Returns truncated text if the length of text exceed max_len."""
|
||||
return encoder.decode(encoder.encode(string)[:max_len])
|
||||
|
||||
|
||||
|
||||
431
test/unit_test/common/test_token_utils.py
Normal file
431
test/unit_test/common/test_token_utils.py
Normal file
@ -0,0 +1,431 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from common.token_utils import num_tokens_from_string, total_token_count_from_response, truncate, encoder
|
||||
import pytest
|
||||
|
||||
class TestNumTokensFromString:
|
||||
"""Test cases for num_tokens_from_string function"""
|
||||
|
||||
def test_empty_string(self):
|
||||
"""Test that empty string returns zero tokens"""
|
||||
result = num_tokens_from_string("")
|
||||
assert result == 0
|
||||
|
||||
def test_single_word(self):
|
||||
"""Test token count for a single word"""
|
||||
# "hello" should be 1 token with cl100k_base encoding
|
||||
result = num_tokens_from_string("hello")
|
||||
assert result == 1
|
||||
|
||||
def test_multiple_words(self):
|
||||
"""Test token count for multiple words"""
|
||||
# "hello world" typically becomes 2 tokens
|
||||
result = num_tokens_from_string("hello world")
|
||||
assert result == 2
|
||||
|
||||
def test_special_characters(self):
|
||||
"""Test token count with special characters"""
|
||||
result = num_tokens_from_string("hello, world!")
|
||||
# Special characters may be separate tokens
|
||||
assert result == 4
|
||||
|
||||
def test_hanzi_characters(self):
|
||||
"""Test token count with special characters"""
|
||||
result = num_tokens_from_string("世界")
|
||||
# Special characters may be separate tokens
|
||||
assert result > 0
|
||||
|
||||
def test_unicode_characters(self):
|
||||
"""Test token count with unicode characters"""
|
||||
result = num_tokens_from_string("Hello 世界 🌍")
|
||||
# Unicode characters typically require multiple tokens
|
||||
assert result > 0
|
||||
|
||||
def test_long_text(self):
|
||||
"""Test token count for longer text"""
|
||||
long_text = "This is a longer piece of text that should contain multiple sentences. " \
|
||||
"It will help verify that the token counting works correctly for substantial input."
|
||||
result = num_tokens_from_string(long_text)
|
||||
assert result > 10
|
||||
|
||||
def test_whitespace_only(self):
|
||||
"""Test token count for whitespace-only strings"""
|
||||
result = num_tokens_from_string(" \n\t ")
|
||||
# Whitespace may or may not be tokens depending on the encoding
|
||||
assert result >= 0
|
||||
|
||||
def test_numbers(self):
|
||||
"""Test token count with numerical values"""
|
||||
result = num_tokens_from_string("12345 678.90")
|
||||
assert result > 0
|
||||
|
||||
def test_mixed_content(self):
|
||||
"""Test token count with mixed content types"""
|
||||
mixed_text = "Hello! 123 Main St. Price: $19.99 🎉"
|
||||
result = num_tokens_from_string(mixed_text)
|
||||
assert result > 0
|
||||
|
||||
def test_encoding_error_handling(self):
|
||||
"""Test that function handles encoding errors gracefully"""
|
||||
# This test verifies the exception handling in the function.
|
||||
# The function should return 0 when encoding fails
|
||||
# Note: We can't easily simulate encoding errors without mocking
|
||||
pass
|
||||
|
||||
|
||||
# Additional parameterized tests for efficiency
|
||||
@pytest.mark.parametrize("input_string,expected_min_tokens", [
|
||||
("a", 1), # Single character
|
||||
("test", 1), # Single word
|
||||
("hello world", 2), # Two words
|
||||
("This is a sentence.", 4), # Short sentence
|
||||
# ("A" * 100, 100), # Repeated characters
|
||||
])
|
||||
def test_token_count_ranges(input_string, expected_min_tokens):
|
||||
"""Parameterized test for various input strings"""
|
||||
result = num_tokens_from_string(input_string)
|
||||
assert result >= expected_min_tokens
|
||||
|
||||
|
||||
def test_consistency():
|
||||
"""Test that the same input produces consistent results"""
|
||||
test_string = "Consistent token counting"
|
||||
first_result = num_tokens_from_string(test_string)
|
||||
second_result = num_tokens_from_string(test_string)
|
||||
|
||||
assert first_result == second_result
|
||||
assert first_result > 0
|
||||
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
class TestTotalTokenCountFromResponse:
|
||||
"""Test cases for total_token_count_from_response function"""
|
||||
|
||||
def test_dict_with_usage_total_tokens(self):
|
||||
"""Test dictionary response with usage['total_tokens']"""
|
||||
resp_dict = {
|
||||
'usage': {
|
||||
'total_tokens': 175
|
||||
}
|
||||
}
|
||||
|
||||
result = total_token_count_from_response(resp_dict)
|
||||
assert result == 175
|
||||
|
||||
def test_dict_with_usage_input_output_tokens(self):
|
||||
"""Test dictionary response with input_tokens and output_tokens in usage"""
|
||||
resp_dict = {
|
||||
'usage': {
|
||||
'input_tokens': 100,
|
||||
'output_tokens': 50
|
||||
}
|
||||
}
|
||||
|
||||
result = total_token_count_from_response(resp_dict)
|
||||
assert result == 150
|
||||
|
||||
def test_dict_with_meta_tokens_input_output(self):
|
||||
"""Test dictionary response with meta.tokens.input_tokens and output_tokens"""
|
||||
resp_dict = {
|
||||
'meta': {
|
||||
'tokens': {
|
||||
'input_tokens': 80,
|
||||
'output_tokens': 40
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result = total_token_count_from_response(resp_dict)
|
||||
assert result == 120
|
||||
|
||||
def test_priority_order_usage_total_tokens_first(self):
|
||||
"""Test that resp.usage.total_tokens takes priority over other formats"""
|
||||
# Create a response that matches multiple conditions
|
||||
mock_usage = Mock()
|
||||
mock_usage.total_tokens = 300
|
||||
|
||||
mock_usage_metadata = Mock()
|
||||
mock_usage_metadata.total_tokens = 400
|
||||
|
||||
mock_resp = Mock()
|
||||
mock_resp.usage = mock_usage
|
||||
mock_resp.usage_metadata = mock_usage_metadata
|
||||
|
||||
result = total_token_count_from_response(mock_resp)
|
||||
assert result == 300 # Should use the first matching condition
|
||||
|
||||
def test_priority_order_usage_metadata_second(self):
|
||||
"""Test that resp.usage_metadata.total_tokens is second in priority"""
|
||||
# Create a response without resp.usage but with resp.usage_metadata
|
||||
mock_usage_metadata = Mock()
|
||||
mock_usage_metadata.total_tokens = 250
|
||||
|
||||
mock_resp = Mock()
|
||||
delattr(mock_resp, 'usage') # Ensure no usage attribute
|
||||
mock_resp.usage_metadata = mock_usage_metadata
|
||||
|
||||
result = total_token_count_from_response(mock_resp)
|
||||
assert result == 250
|
||||
|
||||
def test_priority_order_dict_usage_total_tokens_third(self):
|
||||
"""Test that dict['usage']['total_tokens'] is third in priority"""
|
||||
resp_dict = {
|
||||
'usage': {
|
||||
'total_tokens': 180,
|
||||
'input_tokens': 100,
|
||||
'output_tokens': 80
|
||||
},
|
||||
'meta': {
|
||||
'tokens': {
|
||||
'input_tokens': 200,
|
||||
'output_tokens': 100
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result = total_token_count_from_response(resp_dict)
|
||||
assert result == 180 # Should use total_tokens from usage
|
||||
|
||||
def test_priority_order_dict_usage_input_output_fourth(self):
|
||||
"""Test that dict['usage']['input_tokens'] + output_tokens is fourth in priority"""
|
||||
resp_dict = {
|
||||
'usage': {
|
||||
'input_tokens': 120,
|
||||
'output_tokens': 60
|
||||
},
|
||||
'meta': {
|
||||
'tokens': {
|
||||
'input_tokens': 200,
|
||||
'output_tokens': 100
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result = total_token_count_from_response(resp_dict)
|
||||
assert result == 180 # Should sum input_tokens + output_tokens from usage
|
||||
|
||||
def test_priority_order_meta_tokens_last(self):
|
||||
"""Test that meta.tokens is the last option in priority"""
|
||||
resp_dict = {
|
||||
'meta': {
|
||||
'tokens': {
|
||||
'input_tokens': 90,
|
||||
'output_tokens': 30
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result = total_token_count_from_response(resp_dict)
|
||||
assert result == 120
|
||||
|
||||
def test_no_token_info_returns_zero(self):
|
||||
"""Test that function returns 0 when no token information is found"""
|
||||
empty_resp = {}
|
||||
result = total_token_count_from_response(empty_resp)
|
||||
assert result == 0
|
||||
|
||||
def test_partial_dict_usage_missing_output_tokens(self):
|
||||
"""Test dictionary with usage but missing output_tokens"""
|
||||
resp_dict = {
|
||||
'usage': {
|
||||
'input_tokens': 100
|
||||
# Missing output_tokens
|
||||
}
|
||||
}
|
||||
|
||||
result = total_token_count_from_response(resp_dict)
|
||||
assert result == 0 # Should not match the condition and return 0
|
||||
|
||||
def test_partial_meta_tokens_missing_input_tokens(self):
|
||||
"""Test dictionary with meta.tokens but missing input_tokens"""
|
||||
resp_dict = {
|
||||
'meta': {
|
||||
'tokens': {
|
||||
'output_tokens': 50
|
||||
# Missing input_tokens
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result = total_token_count_from_response(resp_dict)
|
||||
assert result == 0 # Should not match the condition and return 0
|
||||
|
||||
def test_none_response(self):
|
||||
"""Test that function handles None response gracefully"""
|
||||
result = total_token_count_from_response(None)
|
||||
assert result == 0
|
||||
|
||||
def test_invalid_response_type(self):
|
||||
"""Test that function handles invalid response types gracefully"""
|
||||
result = total_token_count_from_response("invalid response")
|
||||
assert result == 0
|
||||
|
||||
# result = total_token_count_from_response(123)
|
||||
# assert result == 0
|
||||
|
||||
|
||||
# Parameterized tests for different response formats
|
||||
@pytest.mark.parametrize("response_data,expected_tokens", [
|
||||
# Object with usage.total_tokens
|
||||
({"usage": Mock(total_tokens=150)}, 150),
|
||||
# Dict with usage.total_tokens
|
||||
({"usage": {"total_tokens": 175}}, 175),
|
||||
# Dict with usage.input_tokens + output_tokens
|
||||
({"usage": {"input_tokens": 100, "output_tokens": 50}}, 150),
|
||||
# Dict with meta.tokens.input_tokens + output_tokens
|
||||
({"meta": {"tokens": {"input_tokens": 80, "output_tokens": 40}}}, 120),
|
||||
# Empty dict
|
||||
({}, 0),
|
||||
])
|
||||
def test_various_response_formats(response_data, expected_tokens):
|
||||
"""Test various response formats using parameterized tests"""
|
||||
if isinstance(response_data, dict) and not any(isinstance(v, Mock) for v in response_data.values()):
|
||||
# Regular dictionary
|
||||
resp = response_data
|
||||
else:
|
||||
# Mock object
|
||||
resp = Mock()
|
||||
for key, value in response_data.items():
|
||||
setattr(resp, key, value)
|
||||
|
||||
result = total_token_count_from_response(resp)
|
||||
assert result == expected_tokens
|
||||
|
||||
|
||||
class TestTruncate:
|
||||
"""Test cases for truncate function"""
|
||||
|
||||
def test_empty_string(self):
|
||||
"""Test truncation of empty string"""
|
||||
result = truncate("", 5)
|
||||
assert result == ""
|
||||
assert isinstance(result, str)
|
||||
|
||||
def test_string_shorter_than_max_len(self):
|
||||
"""Test string that is shorter than max_len"""
|
||||
original_string = "hello"
|
||||
result = truncate(original_string, 10)
|
||||
assert result == original_string
|
||||
assert len(encoder.encode(result)) <= 10
|
||||
|
||||
def test_string_equal_to_max_len(self):
|
||||
"""Test string that exactly equals max_len in tokens"""
|
||||
# Create a string that encodes to exactly 5 tokens
|
||||
test_string = "hello world test"
|
||||
encoded = encoder.encode(test_string)
|
||||
exact_length = len(encoded)
|
||||
|
||||
result = truncate(test_string, exact_length)
|
||||
assert result == test_string
|
||||
assert len(encoder.encode(result)) == exact_length
|
||||
|
||||
def test_string_longer_than_max_len(self):
|
||||
"""Test string that is longer than max_len"""
|
||||
long_string = "This is a longer string that will be truncated"
|
||||
max_len = 5
|
||||
|
||||
result = truncate(long_string, max_len)
|
||||
assert len(encoder.encode(result)) == max_len
|
||||
assert result != long_string
|
||||
|
||||
def test_truncation_preserves_beginning(self):
|
||||
"""Test that truncation preserves the beginning of the string"""
|
||||
test_string = "The quick brown fox jumps over the lazy dog"
|
||||
max_len = 3
|
||||
|
||||
result = truncate(test_string, max_len)
|
||||
encoded_result = encoder.encode(result)
|
||||
|
||||
# The truncated result should match the beginning of the original encoding
|
||||
original_encoded = encoder.encode(test_string)
|
||||
assert encoded_result == original_encoded[:max_len]
|
||||
|
||||
def test_unicode_characters(self):
|
||||
"""Test truncation with unicode characters"""
|
||||
unicode_string = "Hello 世界 🌍 测试"
|
||||
max_len = 4
|
||||
|
||||
result = truncate(unicode_string, max_len)
|
||||
assert len(encoder.encode(result)) == max_len
|
||||
# Should be a valid string
|
||||
assert isinstance(result, str)
|
||||
|
||||
def test_special_characters(self):
|
||||
"""Test truncation with special characters"""
|
||||
special_string = "Hello, world! @#$%^&*()"
|
||||
max_len = 3
|
||||
|
||||
result = truncate(special_string, max_len)
|
||||
assert len(encoder.encode(result)) == max_len
|
||||
|
||||
def test_whitespace_string(self):
|
||||
"""Test truncation of whitespace-only string"""
|
||||
whitespace_string = " \n\t "
|
||||
max_len = 2
|
||||
|
||||
result = truncate(whitespace_string, max_len)
|
||||
assert len(encoder.encode(result)) <= max_len
|
||||
assert isinstance(result, str)
|
||||
|
||||
def test_max_len_zero(self):
|
||||
"""Test truncation with max_len = 0"""
|
||||
test_string = "hello world"
|
||||
result = truncate(test_string, 0)
|
||||
assert result == ""
|
||||
assert len(encoder.encode(result)) == 0
|
||||
|
||||
def test_max_len_one(self):
|
||||
"""Test truncation with max_len = 1"""
|
||||
test_string = "hello world"
|
||||
result = truncate(test_string, 1)
|
||||
assert len(encoder.encode(result)) == 1
|
||||
|
||||
def test_preserves_decoding_encoding_consistency(self):
|
||||
"""Test that truncation preserves encoding-decoding consistency"""
|
||||
test_string = "This is a test string for encoding consistency"
|
||||
max_len = 6
|
||||
|
||||
result = truncate(test_string, max_len)
|
||||
# Re-encoding the result should give the same token count
|
||||
re_encoded = encoder.encode(result)
|
||||
assert len(re_encoded) == max_len
|
||||
|
||||
def test_multibyte_characters_truncation(self):
|
||||
"""Test truncation with multibyte characters that span multiple tokens"""
|
||||
# Some unicode characters may require multiple tokens
|
||||
multibyte_string = "🚀🌟🎉✨🔥💫"
|
||||
max_len = 3
|
||||
|
||||
result = truncate(multibyte_string, max_len)
|
||||
assert len(encoder.encode(result)) == max_len
|
||||
|
||||
def test_mixed_english_chinese_text(self):
|
||||
"""Test truncation with mixed English and Chinese text"""
|
||||
mixed_string = "Hello 世界, this is a test 测试"
|
||||
max_len = 5
|
||||
|
||||
result = truncate(mixed_string, max_len)
|
||||
assert len(encoder.encode(result)) == max_len
|
||||
|
||||
def test_numbers_and_symbols(self):
|
||||
"""Test truncation with numbers and symbols"""
|
||||
number_string = "12345 678.90 $100.00 @username #tag"
|
||||
max_len = 4
|
||||
|
||||
result = truncate(number_string, max_len)
|
||||
assert len(encoder.encode(result)) == max_len
|
||||
Reference in New Issue
Block a user