Move token related functions to common (#10942)

### What problem does this PR solve?

As title

### Type of change

- [x] Refactoring

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
This commit is contained in:
Jin Hai
2025-11-03 08:50:05 +08:00
committed by GitHub
parent 44f2d6f5da
commit 360f5c1179
25 changed files with 529 additions and 78 deletions

View File

@ -21,7 +21,7 @@ import re
from api.db import ParserType
from io import BytesIO
from rag.nlp import rag_tokenizer, tokenize, tokenize_table, bullets_category, title_frequency, tokenize_chunks, docx_question_level
from rag.utils import num_tokens_from_string
from common.token_utils import num_tokens_from_string
from deepdoc.parser import PdfParser, PlainParser, DocxParser
from deepdoc.parser.figure_parser import vision_figure_parser_pdf_wrapper,vision_figure_parser_docx_wrapper
from docx import Document

View File

@ -29,7 +29,7 @@ from rag.flow.tokenizer.schema import TokenizerFromUpstream
from rag.nlp import rag_tokenizer
from rag.settings import EMBEDDING_BATCH_SIZE
from rag.svr.task_executor import embed_limiter
from rag.utils import truncate
from common.token_utils import truncate
class TokenizerParam(ProcessParamBase):

View File

@ -36,7 +36,7 @@ from zhipuai import ZhipuAI
from rag.llm import FACTORY_DEFAULT_BASE_URL, LITELLM_PROVIDER_PREFIX, SupportedLiteLLMProvider
from rag.nlp import is_chinese, is_english
from rag.utils import num_tokens_from_string, total_token_count_from_response
from common.token_utils import num_tokens_from_string, total_token_count_from_response
# Error message constants

View File

@ -30,7 +30,7 @@ from openai.lib.azure import AzureOpenAI
from zhipuai import ZhipuAI
from rag.nlp import is_english
from rag.prompts.generator import vision_llm_describe_prompt
from rag.utils import num_tokens_from_string, total_token_count_from_response
from common.token_utils import num_tokens_from_string, total_token_count_from_response
class Base(ABC):

View File

@ -28,7 +28,7 @@ from openai import OpenAI
from zhipuai import ZhipuAI
from api.utils.log_utils import log_exception
from rag.utils import num_tokens_from_string, truncate
from common.token_utils import num_tokens_from_string, truncate
from api import settings
import logging

View File

@ -23,7 +23,7 @@ import requests
from yarl import URL
from api.utils.log_utils import log_exception
from rag.utils import num_tokens_from_string, truncate, total_token_count_from_response
from common.token_utils import num_tokens_from_string, truncate, total_token_count_from_response
class Base(ABC):
def __init__(self, key, model_name, **kwargs):

View File

@ -24,7 +24,7 @@ import requests
from openai import OpenAI
from openai.lib.azure import AzureOpenAI
from rag.utils import num_tokens_from_string
from common.token_utils import num_tokens_from_string
class Base(ABC):

View File

@ -36,7 +36,7 @@ import requests
import websocket
from pydantic import BaseModel, conint
from rag.utils import num_tokens_from_string
from common.token_utils import num_tokens_from_string
class ServeReferenceAudio(BaseModel):

View File

@ -18,7 +18,7 @@ import logging
import random
from collections import Counter
from rag.utils import num_tokens_from_string
from common.token_utils import num_tokens_from_string
from . import rag_tokenizer
import re
import copy

View File

@ -26,7 +26,7 @@ from common.misc_utils import hash_str2int
from rag.nlp import rag_tokenizer
from rag.prompts.template import load_prompt
from rag.settings import TAG_FLD
from rag.utils import encoder, num_tokens_from_string
from common.token_utils import encoder, num_tokens_from_string
STOP_TOKEN="<|STOP|>"

View File

@ -28,7 +28,7 @@ from graphrag.utils import (
set_llm_cache,
chat_limiter,
)
from rag.utils import truncate
from common.token_utils import truncate
class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:

View File

@ -65,7 +65,7 @@ from rag.app import laws, paper, presentation, manual, qa, table, book, resume,
from rag.nlp import search, rag_tokenizer, add_positions
from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
from rag.settings import DOC_MAXIMUM_SIZE, DOC_BULK_SIZE, EMBEDDING_BATCH_SIZE, SVR_CONSUMER_GROUP_NAME, get_svr_queue_name, get_svr_queue_names, print_rag_settings, TAG_FLD, PAGERANK_FLD
from rag.utils import num_tokens_from_string, truncate
from common.token_utils import num_tokens_from_string, truncate
from rag.utils.redis_conn import REDIS_CONN, RedisDistributedLock
from rag.utils.storage_factory import STORAGE_IMPL
from graphrag.utils import chat_limiter

View File

@ -14,59 +14,3 @@
# limitations under the License.
#
import os
import tiktoken
from common.file_utils import get_project_base_directory
tiktoken_cache_dir = get_project_base_directory()
os.environ["TIKTOKEN_CACHE_DIR"] = tiktoken_cache_dir
# encoder = tiktoken.encoding_for_model("gpt-3.5-turbo")
encoder = tiktoken.get_encoding("cl100k_base")
def num_tokens_from_string(string: str) -> int:
"""Returns the number of tokens in a text string."""
try:
return len(encoder.encode(string))
except Exception:
return 0
def total_token_count_from_response(resp):
if hasattr(resp, "usage") and hasattr(resp.usage, "total_tokens"):
try:
return resp.usage.total_tokens
except Exception:
pass
if hasattr(resp, "usage_metadata") and hasattr(resp.usage_metadata, "total_tokens"):
try:
return resp.usage_metadata.total_tokens
except Exception:
pass
if 'usage' in resp and 'total_tokens' in resp['usage']:
try:
return resp["usage"]["total_tokens"]
except Exception:
pass
if 'usage' in resp and 'input_tokens' in resp['usage'] and 'output_tokens' in resp['usage']:
try:
return resp["usage"]["input_tokens"] + resp["usage"]["output_tokens"]
except Exception:
pass
if 'meta' in resp and 'tokens' in resp['meta'] and 'input_tokens' in resp['meta']['tokens'] and 'output_tokens' in resp['meta']['tokens']:
try:
return resp["meta"]["tokens"]["input_tokens"] + resp["meta"]["tokens"]["output_tokens"]
except Exception:
pass
return 0
def truncate(string: str, max_len: int) -> str:
"""Returns truncated text if the length of text exceed max_len."""
return encoder.decode(encoder.encode(string)[:max_len])