From 360f5c1179369cac2f74234a3bcb5381f263967a Mon Sep 17 00:00:00 2001 From: Jin Hai Date: Mon, 3 Nov 2025 08:50:05 +0800 Subject: [PATCH] Move token related functions to common (#10942) ### What problem does this PR solve? As title ### Type of change - [x] Refactoring Signed-off-by: Jin Hai --- api/db/services/dialog_service.py | 2 +- api/db/services/llm_service.py | 2 +- common/token_utils.py | 76 +++ deepdoc/parser/txt_parser.py | 2 +- .../general/community_reports_extractor.py | 2 +- graphrag/general/extractor.py | 2 +- graphrag/general/graph_extractor.py | 2 +- graphrag/general/index.py | 2 +- graphrag/general/mind_map_extractor.py | 2 +- graphrag/light/graph_extractor.py | 2 +- graphrag/search.py | 2 +- rag/app/manual.py | 2 +- rag/flow/tokenizer/tokenizer.py | 2 +- rag/llm/chat_model.py | 2 +- rag/llm/cv_model.py | 2 +- rag/llm/embedding_model.py | 2 +- rag/llm/rerank_model.py | 2 +- rag/llm/sequence2txt_model.py | 2 +- rag/llm/tts_model.py | 2 +- rag/nlp/__init__.py | 2 +- rag/prompts/generator.py | 2 +- rag/raptor.py | 2 +- rag/svr/task_executor.py | 2 +- rag/utils/__init__.py | 56 --- test/unit_test/common/test_token_utils.py | 431 ++++++++++++++++++ 25 files changed, 529 insertions(+), 78 deletions(-) create mode 100644 common/token_utils.py create mode 100644 test/unit_test/common/test_token_utils.py diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index f3e6c49fd..93467e523 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -41,7 +41,7 @@ from rag.app.tag import label_question from rag.nlp.search import index_name from rag.prompts.generator import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, message_fit_in, \ gen_meta_filter, PROMPT_JINJA_ENV, ASK_SUMMARY -from rag.utils import num_tokens_from_string +from common.token_utils import num_tokens_from_string from rag.utils.tavily_conn import Tavily from common.string_utils import remove_redundant_spaces diff --git a/api/db/services/llm_service.py b/api/db/services/llm_service.py index 0befc6f6e..50d20ed8e 100644 --- a/api/db/services/llm_service.py +++ b/api/db/services/llm_service.py @@ -16,7 +16,7 @@ import inspect import logging import re -from rag.utils import num_tokens_from_string +from common.token_utils import num_tokens_from_string from functools import partial from typing import Generator from api.db.db_models import LLM diff --git a/common/token_utils.py b/common/token_utils.py new file mode 100644 index 000000000..29f10f7eb --- /dev/null +++ b/common/token_utils.py @@ -0,0 +1,76 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import os +import tiktoken + +from common.file_utils import get_project_base_directory + +tiktoken_cache_dir = get_project_base_directory() +os.environ["TIKTOKEN_CACHE_DIR"] = tiktoken_cache_dir +# encoder = tiktoken.encoding_for_model("gpt-3.5-turbo") +encoder = tiktoken.get_encoding("cl100k_base") + + +def num_tokens_from_string(string: str) -> int: + """Returns the number of tokens in a text string.""" + try: + code_list = encoder.encode(string) + return len(code_list) + except Exception: + return 0 + +def total_token_count_from_response(resp): + if resp is None: + return 0 + + if hasattr(resp, "usage") and hasattr(resp.usage, "total_tokens"): + try: + return resp.usage.total_tokens + except Exception: + pass + + if hasattr(resp, "usage_metadata") and hasattr(resp.usage_metadata, "total_tokens"): + try: + return resp.usage_metadata.total_tokens + except Exception: + pass + + if 'usage' in resp and 'total_tokens' in resp['usage']: + try: + return resp["usage"]["total_tokens"] + except Exception: + pass + + if 'usage' in resp and 'input_tokens' in resp['usage'] and 'output_tokens' in resp['usage']: + try: + return resp["usage"]["input_tokens"] + resp["usage"]["output_tokens"] + except Exception: + pass + + if 'meta' in resp and 'tokens' in resp['meta'] and 'input_tokens' in resp['meta']['tokens'] and 'output_tokens' in resp['meta']['tokens']: + try: + return resp["meta"]["tokens"]["input_tokens"] + resp["meta"]["tokens"]["output_tokens"] + except Exception: + pass + return 0 + + +def truncate(string: str, max_len: int) -> str: + """Returns truncated text if the length of text exceed max_len.""" + return encoder.decode(encoder.encode(string)[:max_len]) + diff --git a/deepdoc/parser/txt_parser.py b/deepdoc/parser/txt_parser.py index 1b22865e1..64e200cbc 100644 --- a/deepdoc/parser/txt_parser.py +++ b/deepdoc/parser/txt_parser.py @@ -17,7 +17,7 @@ import re from deepdoc.parser.utils import get_text -from rag.nlp import num_tokens_from_string +from common.token_utils import num_tokens_from_string class RAGFlowTxtParser: diff --git a/graphrag/general/community_reports_extractor.py b/graphrag/general/community_reports_extractor.py index 6f9fd65b9..fe611dedf 100644 --- a/graphrag/general/community_reports_extractor.py +++ b/graphrag/general/community_reports_extractor.py @@ -21,7 +21,7 @@ from graphrag.general.extractor import Extractor from graphrag.general.leiden import add_community_info2graph from rag.llm.chat_model import Base as CompletionLLM from graphrag.utils import perform_variable_replacements, dict_has_keys_with_types, chat_limiter -from rag.utils import num_tokens_from_string +from common.token_utils import num_tokens_from_string import trio diff --git a/graphrag/general/extractor.py b/graphrag/general/extractor.py index 28b591d55..a41ffd6a2 100644 --- a/graphrag/general/extractor.py +++ b/graphrag/general/extractor.py @@ -38,7 +38,7 @@ from graphrag.utils import ( ) from rag.llm.chat_model import Base as CompletionLLM from rag.prompts.generator import message_fit_in -from rag.utils import truncate +from common.token_utils import truncate GRAPH_FIELD_SEP = "" DEFAULT_ENTITY_TYPES = ["organization", "person", "geo", "event", "category"] diff --git a/graphrag/general/graph_extractor.py b/graphrag/general/graph_extractor.py index 346a5b95b..59ebeeddf 100644 --- a/graphrag/general/graph_extractor.py +++ b/graphrag/general/graph_extractor.py @@ -16,7 +16,7 @@ from graphrag.general.graph_prompt import GRAPH_EXTRACTION_PROMPT, CONTINUE_PROM from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, chat_limiter, split_string_by_multi_markers from rag.llm.chat_model import Base as CompletionLLM import networkx as nx -from rag.utils import num_tokens_from_string +from common.token_utils import num_tokens_from_string DEFAULT_TUPLE_DELIMITER = "<|>" DEFAULT_RECORD_DELIMITER = "##" diff --git a/graphrag/general/index.py b/graphrag/general/index.py index 650b511de..52b298e32 100644 --- a/graphrag/general/index.py +++ b/graphrag/general/index.py @@ -165,7 +165,7 @@ async def run_graphrag_for_kb( return {"ok_docs": [], "failed_docs": [], "total_docs": 0, "total_chunks": 0, "seconds": 0.0} def load_doc_chunks(doc_id: str) -> list[str]: - from rag.utils import num_tokens_from_string + from common.token_utils import num_tokens_from_string chunks = [] current_chunk = "" diff --git a/graphrag/general/mind_map_extractor.py b/graphrag/general/mind_map_extractor.py index d713eb59f..c85579d3d 100644 --- a/graphrag/general/mind_map_extractor.py +++ b/graphrag/general/mind_map_extractor.py @@ -27,7 +27,7 @@ from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, chat_l from rag.llm.chat_model import Base as CompletionLLM import markdown_to_json from functools import reduce -from rag.utils import num_tokens_from_string +from common.token_utils import num_tokens_from_string @dataclass diff --git a/graphrag/light/graph_extractor.py b/graphrag/light/graph_extractor.py index 474d47597..c2827c00f 100644 --- a/graphrag/light/graph_extractor.py +++ b/graphrag/light/graph_extractor.py @@ -17,7 +17,7 @@ from graphrag.general.extractor import ENTITY_EXTRACTION_MAX_GLEANINGS, Extracto from graphrag.light.graph_prompt import PROMPTS from graphrag.utils import chat_limiter, pack_user_ass_to_openai_messages, split_string_by_multi_markers from rag.llm.chat_model import Base as CompletionLLM -from rag.utils import num_tokens_from_string +from common.token_utils import num_tokens_from_string @dataclass diff --git a/graphrag/search.py b/graphrag/search.py index a415c7610..cdc7e785f 100644 --- a/graphrag/search.py +++ b/graphrag/search.py @@ -24,7 +24,7 @@ import trio from common.misc_utils import get_uuid from graphrag.query_analyze_prompt import PROMPTS from graphrag.utils import get_entity_type2samples, get_llm_cache, set_llm_cache, get_relation -from rag.utils import num_tokens_from_string +from common.token_utils import num_tokens_from_string from rag.utils.doc_store_conn import OrderByExpr from rag.nlp.search import Dealer, index_name diff --git a/rag/app/manual.py b/rag/app/manual.py index a433a10e2..e093ac490 100644 --- a/rag/app/manual.py +++ b/rag/app/manual.py @@ -21,7 +21,7 @@ import re from api.db import ParserType from io import BytesIO from rag.nlp import rag_tokenizer, tokenize, tokenize_table, bullets_category, title_frequency, tokenize_chunks, docx_question_level -from rag.utils import num_tokens_from_string +from common.token_utils import num_tokens_from_string from deepdoc.parser import PdfParser, PlainParser, DocxParser from deepdoc.parser.figure_parser import vision_figure_parser_pdf_wrapper,vision_figure_parser_docx_wrapper from docx import Document diff --git a/rag/flow/tokenizer/tokenizer.py b/rag/flow/tokenizer/tokenizer.py index 23e179496..6b7d6ad2a 100644 --- a/rag/flow/tokenizer/tokenizer.py +++ b/rag/flow/tokenizer/tokenizer.py @@ -29,7 +29,7 @@ from rag.flow.tokenizer.schema import TokenizerFromUpstream from rag.nlp import rag_tokenizer from rag.settings import EMBEDDING_BATCH_SIZE from rag.svr.task_executor import embed_limiter -from rag.utils import truncate +from common.token_utils import truncate class TokenizerParam(ProcessParamBase): diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index cd22e76ec..e938ef844 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -36,7 +36,7 @@ from zhipuai import ZhipuAI from rag.llm import FACTORY_DEFAULT_BASE_URL, LITELLM_PROVIDER_PREFIX, SupportedLiteLLMProvider from rag.nlp import is_chinese, is_english -from rag.utils import num_tokens_from_string, total_token_count_from_response +from common.token_utils import num_tokens_from_string, total_token_count_from_response # Error message constants diff --git a/rag/llm/cv_model.py b/rag/llm/cv_model.py index 2ef0cb54a..c0f90807d 100644 --- a/rag/llm/cv_model.py +++ b/rag/llm/cv_model.py @@ -30,7 +30,7 @@ from openai.lib.azure import AzureOpenAI from zhipuai import ZhipuAI from rag.nlp import is_english from rag.prompts.generator import vision_llm_describe_prompt -from rag.utils import num_tokens_from_string, total_token_count_from_response +from common.token_utils import num_tokens_from_string, total_token_count_from_response class Base(ABC): diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index 10eba69d3..82564a056 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -28,7 +28,7 @@ from openai import OpenAI from zhipuai import ZhipuAI from api.utils.log_utils import log_exception -from rag.utils import num_tokens_from_string, truncate +from common.token_utils import num_tokens_from_string, truncate from api import settings import logging diff --git a/rag/llm/rerank_model.py b/rag/llm/rerank_model.py index 7a4207d1e..15fae3d34 100644 --- a/rag/llm/rerank_model.py +++ b/rag/llm/rerank_model.py @@ -23,7 +23,7 @@ import requests from yarl import URL from api.utils.log_utils import log_exception -from rag.utils import num_tokens_from_string, truncate, total_token_count_from_response +from common.token_utils import num_tokens_from_string, truncate, total_token_count_from_response class Base(ABC): def __init__(self, key, model_name, **kwargs): diff --git a/rag/llm/sequence2txt_model.py b/rag/llm/sequence2txt_model.py index c66adada4..dcbc7c2f2 100644 --- a/rag/llm/sequence2txt_model.py +++ b/rag/llm/sequence2txt_model.py @@ -24,7 +24,7 @@ import requests from openai import OpenAI from openai.lib.azure import AzureOpenAI -from rag.utils import num_tokens_from_string +from common.token_utils import num_tokens_from_string class Base(ABC): diff --git a/rag/llm/tts_model.py b/rag/llm/tts_model.py index b073016ff..3cebff329 100644 --- a/rag/llm/tts_model.py +++ b/rag/llm/tts_model.py @@ -36,7 +36,7 @@ import requests import websocket from pydantic import BaseModel, conint -from rag.utils import num_tokens_from_string +from common.token_utils import num_tokens_from_string class ServeReferenceAudio(BaseModel): diff --git a/rag/nlp/__init__.py b/rag/nlp/__init__.py index 4fd6eb5bc..61a3b6f3a 100644 --- a/rag/nlp/__init__.py +++ b/rag/nlp/__init__.py @@ -18,7 +18,7 @@ import logging import random from collections import Counter -from rag.utils import num_tokens_from_string +from common.token_utils import num_tokens_from_string from . import rag_tokenizer import re import copy diff --git a/rag/prompts/generator.py b/rag/prompts/generator.py index f5680540b..78f6cbad9 100644 --- a/rag/prompts/generator.py +++ b/rag/prompts/generator.py @@ -26,7 +26,7 @@ from common.misc_utils import hash_str2int from rag.nlp import rag_tokenizer from rag.prompts.template import load_prompt from rag.settings import TAG_FLD -from rag.utils import encoder, num_tokens_from_string +from common.token_utils import encoder, num_tokens_from_string STOP_TOKEN="<|STOP|>" diff --git a/rag/raptor.py b/rag/raptor.py index 191ecdeb4..a3d369189 100644 --- a/rag/raptor.py +++ b/rag/raptor.py @@ -28,7 +28,7 @@ from graphrag.utils import ( set_llm_cache, chat_limiter, ) -from rag.utils import truncate +from common.token_utils import truncate class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index 8a0d464ba..dafaaf28f 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -65,7 +65,7 @@ from rag.app import laws, paper, presentation, manual, qa, table, book, resume, from rag.nlp import search, rag_tokenizer, add_positions from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor from rag.settings import DOC_MAXIMUM_SIZE, DOC_BULK_SIZE, EMBEDDING_BATCH_SIZE, SVR_CONSUMER_GROUP_NAME, get_svr_queue_name, get_svr_queue_names, print_rag_settings, TAG_FLD, PAGERANK_FLD -from rag.utils import num_tokens_from_string, truncate +from common.token_utils import num_tokens_from_string, truncate from rag.utils.redis_conn import REDIS_CONN, RedisDistributedLock from rag.utils.storage_factory import STORAGE_IMPL from graphrag.utils import chat_limiter diff --git a/rag/utils/__init__.py b/rag/utils/__init__.py index 64db543b1..86f2171be 100644 --- a/rag/utils/__init__.py +++ b/rag/utils/__init__.py @@ -14,59 +14,3 @@ # limitations under the License. # -import os -import tiktoken - -from common.file_utils import get_project_base_directory - -tiktoken_cache_dir = get_project_base_directory() -os.environ["TIKTOKEN_CACHE_DIR"] = tiktoken_cache_dir -# encoder = tiktoken.encoding_for_model("gpt-3.5-turbo") -encoder = tiktoken.get_encoding("cl100k_base") - - -def num_tokens_from_string(string: str) -> int: - """Returns the number of tokens in a text string.""" - try: - return len(encoder.encode(string)) - except Exception: - return 0 - -def total_token_count_from_response(resp): - if hasattr(resp, "usage") and hasattr(resp.usage, "total_tokens"): - try: - return resp.usage.total_tokens - except Exception: - pass - - if hasattr(resp, "usage_metadata") and hasattr(resp.usage_metadata, "total_tokens"): - try: - return resp.usage_metadata.total_tokens - except Exception: - pass - - if 'usage' in resp and 'total_tokens' in resp['usage']: - try: - return resp["usage"]["total_tokens"] - except Exception: - pass - - if 'usage' in resp and 'input_tokens' in resp['usage'] and 'output_tokens' in resp['usage']: - try: - return resp["usage"]["input_tokens"] + resp["usage"]["output_tokens"] - except Exception: - pass - - if 'meta' in resp and 'tokens' in resp['meta'] and 'input_tokens' in resp['meta']['tokens'] and 'output_tokens' in resp['meta']['tokens']: - try: - return resp["meta"]["tokens"]["input_tokens"] + resp["meta"]["tokens"]["output_tokens"] - except Exception: - pass - return 0 - - -def truncate(string: str, max_len: int) -> str: - """Returns truncated text if the length of text exceed max_len.""" - return encoder.decode(encoder.encode(string)[:max_len]) - - diff --git a/test/unit_test/common/test_token_utils.py b/test/unit_test/common/test_token_utils.py new file mode 100644 index 000000000..c16c4cb49 --- /dev/null +++ b/test/unit_test/common/test_token_utils.py @@ -0,0 +1,431 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from common.token_utils import num_tokens_from_string, total_token_count_from_response, truncate, encoder +import pytest + +class TestNumTokensFromString: + """Test cases for num_tokens_from_string function""" + + def test_empty_string(self): + """Test that empty string returns zero tokens""" + result = num_tokens_from_string("") + assert result == 0 + + def test_single_word(self): + """Test token count for a single word""" + # "hello" should be 1 token with cl100k_base encoding + result = num_tokens_from_string("hello") + assert result == 1 + + def test_multiple_words(self): + """Test token count for multiple words""" + # "hello world" typically becomes 2 tokens + result = num_tokens_from_string("hello world") + assert result == 2 + + def test_special_characters(self): + """Test token count with special characters""" + result = num_tokens_from_string("hello, world!") + # Special characters may be separate tokens + assert result == 4 + + def test_hanzi_characters(self): + """Test token count with special characters""" + result = num_tokens_from_string("δΈ–η•Œ") + # Special characters may be separate tokens + assert result > 0 + + def test_unicode_characters(self): + """Test token count with unicode characters""" + result = num_tokens_from_string("Hello δΈ–η•Œ 🌍") + # Unicode characters typically require multiple tokens + assert result > 0 + + def test_long_text(self): + """Test token count for longer text""" + long_text = "This is a longer piece of text that should contain multiple sentences. " \ + "It will help verify that the token counting works correctly for substantial input." + result = num_tokens_from_string(long_text) + assert result > 10 + + def test_whitespace_only(self): + """Test token count for whitespace-only strings""" + result = num_tokens_from_string(" \n\t ") + # Whitespace may or may not be tokens depending on the encoding + assert result >= 0 + + def test_numbers(self): + """Test token count with numerical values""" + result = num_tokens_from_string("12345 678.90") + assert result > 0 + + def test_mixed_content(self): + """Test token count with mixed content types""" + mixed_text = "Hello! 123 Main St. Price: $19.99 πŸŽ‰" + result = num_tokens_from_string(mixed_text) + assert result > 0 + + def test_encoding_error_handling(self): + """Test that function handles encoding errors gracefully""" + # This test verifies the exception handling in the function. + # The function should return 0 when encoding fails + # Note: We can't easily simulate encoding errors without mocking + pass + + +# Additional parameterized tests for efficiency +@pytest.mark.parametrize("input_string,expected_min_tokens", [ + ("a", 1), # Single character + ("test", 1), # Single word + ("hello world", 2), # Two words + ("This is a sentence.", 4), # Short sentence + # ("A" * 100, 100), # Repeated characters +]) +def test_token_count_ranges(input_string, expected_min_tokens): + """Parameterized test for various input strings""" + result = num_tokens_from_string(input_string) + assert result >= expected_min_tokens + + +def test_consistency(): + """Test that the same input produces consistent results""" + test_string = "Consistent token counting" + first_result = num_tokens_from_string(test_string) + second_result = num_tokens_from_string(test_string) + + assert first_result == second_result + assert first_result > 0 + + +from unittest.mock import Mock + +class TestTotalTokenCountFromResponse: + """Test cases for total_token_count_from_response function""" + + def test_dict_with_usage_total_tokens(self): + """Test dictionary response with usage['total_tokens']""" + resp_dict = { + 'usage': { + 'total_tokens': 175 + } + } + + result = total_token_count_from_response(resp_dict) + assert result == 175 + + def test_dict_with_usage_input_output_tokens(self): + """Test dictionary response with input_tokens and output_tokens in usage""" + resp_dict = { + 'usage': { + 'input_tokens': 100, + 'output_tokens': 50 + } + } + + result = total_token_count_from_response(resp_dict) + assert result == 150 + + def test_dict_with_meta_tokens_input_output(self): + """Test dictionary response with meta.tokens.input_tokens and output_tokens""" + resp_dict = { + 'meta': { + 'tokens': { + 'input_tokens': 80, + 'output_tokens': 40 + } + } + } + + result = total_token_count_from_response(resp_dict) + assert result == 120 + + def test_priority_order_usage_total_tokens_first(self): + """Test that resp.usage.total_tokens takes priority over other formats""" + # Create a response that matches multiple conditions + mock_usage = Mock() + mock_usage.total_tokens = 300 + + mock_usage_metadata = Mock() + mock_usage_metadata.total_tokens = 400 + + mock_resp = Mock() + mock_resp.usage = mock_usage + mock_resp.usage_metadata = mock_usage_metadata + + result = total_token_count_from_response(mock_resp) + assert result == 300 # Should use the first matching condition + + def test_priority_order_usage_metadata_second(self): + """Test that resp.usage_metadata.total_tokens is second in priority""" + # Create a response without resp.usage but with resp.usage_metadata + mock_usage_metadata = Mock() + mock_usage_metadata.total_tokens = 250 + + mock_resp = Mock() + delattr(mock_resp, 'usage') # Ensure no usage attribute + mock_resp.usage_metadata = mock_usage_metadata + + result = total_token_count_from_response(mock_resp) + assert result == 250 + + def test_priority_order_dict_usage_total_tokens_third(self): + """Test that dict['usage']['total_tokens'] is third in priority""" + resp_dict = { + 'usage': { + 'total_tokens': 180, + 'input_tokens': 100, + 'output_tokens': 80 + }, + 'meta': { + 'tokens': { + 'input_tokens': 200, + 'output_tokens': 100 + } + } + } + + result = total_token_count_from_response(resp_dict) + assert result == 180 # Should use total_tokens from usage + + def test_priority_order_dict_usage_input_output_fourth(self): + """Test that dict['usage']['input_tokens'] + output_tokens is fourth in priority""" + resp_dict = { + 'usage': { + 'input_tokens': 120, + 'output_tokens': 60 + }, + 'meta': { + 'tokens': { + 'input_tokens': 200, + 'output_tokens': 100 + } + } + } + + result = total_token_count_from_response(resp_dict) + assert result == 180 # Should sum input_tokens + output_tokens from usage + + def test_priority_order_meta_tokens_last(self): + """Test that meta.tokens is the last option in priority""" + resp_dict = { + 'meta': { + 'tokens': { + 'input_tokens': 90, + 'output_tokens': 30 + } + } + } + + result = total_token_count_from_response(resp_dict) + assert result == 120 + + def test_no_token_info_returns_zero(self): + """Test that function returns 0 when no token information is found""" + empty_resp = {} + result = total_token_count_from_response(empty_resp) + assert result == 0 + + def test_partial_dict_usage_missing_output_tokens(self): + """Test dictionary with usage but missing output_tokens""" + resp_dict = { + 'usage': { + 'input_tokens': 100 + # Missing output_tokens + } + } + + result = total_token_count_from_response(resp_dict) + assert result == 0 # Should not match the condition and return 0 + + def test_partial_meta_tokens_missing_input_tokens(self): + """Test dictionary with meta.tokens but missing input_tokens""" + resp_dict = { + 'meta': { + 'tokens': { + 'output_tokens': 50 + # Missing input_tokens + } + } + } + + result = total_token_count_from_response(resp_dict) + assert result == 0 # Should not match the condition and return 0 + + def test_none_response(self): + """Test that function handles None response gracefully""" + result = total_token_count_from_response(None) + assert result == 0 + + def test_invalid_response_type(self): + """Test that function handles invalid response types gracefully""" + result = total_token_count_from_response("invalid response") + assert result == 0 + + # result = total_token_count_from_response(123) + # assert result == 0 + + +# Parameterized tests for different response formats +@pytest.mark.parametrize("response_data,expected_tokens", [ + # Object with usage.total_tokens + ({"usage": Mock(total_tokens=150)}, 150), + # Dict with usage.total_tokens + ({"usage": {"total_tokens": 175}}, 175), + # Dict with usage.input_tokens + output_tokens + ({"usage": {"input_tokens": 100, "output_tokens": 50}}, 150), + # Dict with meta.tokens.input_tokens + output_tokens + ({"meta": {"tokens": {"input_tokens": 80, "output_tokens": 40}}}, 120), + # Empty dict + ({}, 0), +]) +def test_various_response_formats(response_data, expected_tokens): + """Test various response formats using parameterized tests""" + if isinstance(response_data, dict) and not any(isinstance(v, Mock) for v in response_data.values()): + # Regular dictionary + resp = response_data + else: + # Mock object + resp = Mock() + for key, value in response_data.items(): + setattr(resp, key, value) + + result = total_token_count_from_response(resp) + assert result == expected_tokens + + +class TestTruncate: + """Test cases for truncate function""" + + def test_empty_string(self): + """Test truncation of empty string""" + result = truncate("", 5) + assert result == "" + assert isinstance(result, str) + + def test_string_shorter_than_max_len(self): + """Test string that is shorter than max_len""" + original_string = "hello" + result = truncate(original_string, 10) + assert result == original_string + assert len(encoder.encode(result)) <= 10 + + def test_string_equal_to_max_len(self): + """Test string that exactly equals max_len in tokens""" + # Create a string that encodes to exactly 5 tokens + test_string = "hello world test" + encoded = encoder.encode(test_string) + exact_length = len(encoded) + + result = truncate(test_string, exact_length) + assert result == test_string + assert len(encoder.encode(result)) == exact_length + + def test_string_longer_than_max_len(self): + """Test string that is longer than max_len""" + long_string = "This is a longer string that will be truncated" + max_len = 5 + + result = truncate(long_string, max_len) + assert len(encoder.encode(result)) == max_len + assert result != long_string + + def test_truncation_preserves_beginning(self): + """Test that truncation preserves the beginning of the string""" + test_string = "The quick brown fox jumps over the lazy dog" + max_len = 3 + + result = truncate(test_string, max_len) + encoded_result = encoder.encode(result) + + # The truncated result should match the beginning of the original encoding + original_encoded = encoder.encode(test_string) + assert encoded_result == original_encoded[:max_len] + + def test_unicode_characters(self): + """Test truncation with unicode characters""" + unicode_string = "Hello δΈ–η•Œ 🌍 ζ΅‹θ―•" + max_len = 4 + + result = truncate(unicode_string, max_len) + assert len(encoder.encode(result)) == max_len + # Should be a valid string + assert isinstance(result, str) + + def test_special_characters(self): + """Test truncation with special characters""" + special_string = "Hello, world! @#$%^&*()" + max_len = 3 + + result = truncate(special_string, max_len) + assert len(encoder.encode(result)) == max_len + + def test_whitespace_string(self): + """Test truncation of whitespace-only string""" + whitespace_string = " \n\t " + max_len = 2 + + result = truncate(whitespace_string, max_len) + assert len(encoder.encode(result)) <= max_len + assert isinstance(result, str) + + def test_max_len_zero(self): + """Test truncation with max_len = 0""" + test_string = "hello world" + result = truncate(test_string, 0) + assert result == "" + assert len(encoder.encode(result)) == 0 + + def test_max_len_one(self): + """Test truncation with max_len = 1""" + test_string = "hello world" + result = truncate(test_string, 1) + assert len(encoder.encode(result)) == 1 + + def test_preserves_decoding_encoding_consistency(self): + """Test that truncation preserves encoding-decoding consistency""" + test_string = "This is a test string for encoding consistency" + max_len = 6 + + result = truncate(test_string, max_len) + # Re-encoding the result should give the same token count + re_encoded = encoder.encode(result) + assert len(re_encoded) == max_len + + def test_multibyte_characters_truncation(self): + """Test truncation with multibyte characters that span multiple tokens""" + # Some unicode characters may require multiple tokens + multibyte_string = "πŸš€πŸŒŸπŸŽ‰βœ¨πŸ”₯πŸ’«" + max_len = 3 + + result = truncate(multibyte_string, max_len) + assert len(encoder.encode(result)) == max_len + + def test_mixed_english_chinese_text(self): + """Test truncation with mixed English and Chinese text""" + mixed_string = "Hello δΈ–η•Œ, this is a test ζ΅‹θ―•" + max_len = 5 + + result = truncate(mixed_string, max_len) + assert len(encoder.encode(result)) == max_len + + def test_numbers_and_symbols(self): + """Test truncation with numbers and symbols""" + number_string = "12345 678.90 $100.00 @username #tag" + max_len = 4 + + result = truncate(number_string, max_len) + assert len(encoder.encode(result)) == max_len \ No newline at end of file