mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-01-31 23:55:06 +08:00
Move token related functions to common (#10942)
### What problem does this PR solve? As title ### Type of change - [x] Refactoring Signed-off-by: Jin Hai <haijin.chn@gmail.com>
This commit is contained in:
@ -21,7 +21,7 @@ import re
|
||||
from api.db import ParserType
|
||||
from io import BytesIO
|
||||
from rag.nlp import rag_tokenizer, tokenize, tokenize_table, bullets_category, title_frequency, tokenize_chunks, docx_question_level
|
||||
from rag.utils import num_tokens_from_string
|
||||
from common.token_utils import num_tokens_from_string
|
||||
from deepdoc.parser import PdfParser, PlainParser, DocxParser
|
||||
from deepdoc.parser.figure_parser import vision_figure_parser_pdf_wrapper,vision_figure_parser_docx_wrapper
|
||||
from docx import Document
|
||||
|
||||
@ -29,7 +29,7 @@ from rag.flow.tokenizer.schema import TokenizerFromUpstream
|
||||
from rag.nlp import rag_tokenizer
|
||||
from rag.settings import EMBEDDING_BATCH_SIZE
|
||||
from rag.svr.task_executor import embed_limiter
|
||||
from rag.utils import truncate
|
||||
from common.token_utils import truncate
|
||||
|
||||
|
||||
class TokenizerParam(ProcessParamBase):
|
||||
|
||||
@ -36,7 +36,7 @@ from zhipuai import ZhipuAI
|
||||
|
||||
from rag.llm import FACTORY_DEFAULT_BASE_URL, LITELLM_PROVIDER_PREFIX, SupportedLiteLLMProvider
|
||||
from rag.nlp import is_chinese, is_english
|
||||
from rag.utils import num_tokens_from_string, total_token_count_from_response
|
||||
from common.token_utils import num_tokens_from_string, total_token_count_from_response
|
||||
|
||||
|
||||
# Error message constants
|
||||
|
||||
@ -30,7 +30,7 @@ from openai.lib.azure import AzureOpenAI
|
||||
from zhipuai import ZhipuAI
|
||||
from rag.nlp import is_english
|
||||
from rag.prompts.generator import vision_llm_describe_prompt
|
||||
from rag.utils import num_tokens_from_string, total_token_count_from_response
|
||||
from common.token_utils import num_tokens_from_string, total_token_count_from_response
|
||||
|
||||
|
||||
class Base(ABC):
|
||||
|
||||
@ -28,7 +28,7 @@ from openai import OpenAI
|
||||
from zhipuai import ZhipuAI
|
||||
|
||||
from api.utils.log_utils import log_exception
|
||||
from rag.utils import num_tokens_from_string, truncate
|
||||
from common.token_utils import num_tokens_from_string, truncate
|
||||
from api import settings
|
||||
import logging
|
||||
|
||||
|
||||
@ -23,7 +23,7 @@ import requests
|
||||
from yarl import URL
|
||||
|
||||
from api.utils.log_utils import log_exception
|
||||
from rag.utils import num_tokens_from_string, truncate, total_token_count_from_response
|
||||
from common.token_utils import num_tokens_from_string, truncate, total_token_count_from_response
|
||||
|
||||
class Base(ABC):
|
||||
def __init__(self, key, model_name, **kwargs):
|
||||
|
||||
@ -24,7 +24,7 @@ import requests
|
||||
from openai import OpenAI
|
||||
from openai.lib.azure import AzureOpenAI
|
||||
|
||||
from rag.utils import num_tokens_from_string
|
||||
from common.token_utils import num_tokens_from_string
|
||||
|
||||
|
||||
class Base(ABC):
|
||||
|
||||
@ -36,7 +36,7 @@ import requests
|
||||
import websocket
|
||||
from pydantic import BaseModel, conint
|
||||
|
||||
from rag.utils import num_tokens_from_string
|
||||
from common.token_utils import num_tokens_from_string
|
||||
|
||||
|
||||
class ServeReferenceAudio(BaseModel):
|
||||
|
||||
@ -18,7 +18,7 @@ import logging
|
||||
import random
|
||||
from collections import Counter
|
||||
|
||||
from rag.utils import num_tokens_from_string
|
||||
from common.token_utils import num_tokens_from_string
|
||||
from . import rag_tokenizer
|
||||
import re
|
||||
import copy
|
||||
|
||||
@ -26,7 +26,7 @@ from common.misc_utils import hash_str2int
|
||||
from rag.nlp import rag_tokenizer
|
||||
from rag.prompts.template import load_prompt
|
||||
from rag.settings import TAG_FLD
|
||||
from rag.utils import encoder, num_tokens_from_string
|
||||
from common.token_utils import encoder, num_tokens_from_string
|
||||
|
||||
|
||||
STOP_TOKEN="<|STOP|>"
|
||||
|
||||
@ -28,7 +28,7 @@ from graphrag.utils import (
|
||||
set_llm_cache,
|
||||
chat_limiter,
|
||||
)
|
||||
from rag.utils import truncate
|
||||
from common.token_utils import truncate
|
||||
|
||||
|
||||
class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
||||
|
||||
@ -65,7 +65,7 @@ from rag.app import laws, paper, presentation, manual, qa, table, book, resume,
|
||||
from rag.nlp import search, rag_tokenizer, add_positions
|
||||
from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
|
||||
from rag.settings import DOC_MAXIMUM_SIZE, DOC_BULK_SIZE, EMBEDDING_BATCH_SIZE, SVR_CONSUMER_GROUP_NAME, get_svr_queue_name, get_svr_queue_names, print_rag_settings, TAG_FLD, PAGERANK_FLD
|
||||
from rag.utils import num_tokens_from_string, truncate
|
||||
from common.token_utils import num_tokens_from_string, truncate
|
||||
from rag.utils.redis_conn import REDIS_CONN, RedisDistributedLock
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
from graphrag.utils import chat_limiter
|
||||
|
||||
@ -14,59 +14,3 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import os
|
||||
import tiktoken
|
||||
|
||||
from common.file_utils import get_project_base_directory
|
||||
|
||||
tiktoken_cache_dir = get_project_base_directory()
|
||||
os.environ["TIKTOKEN_CACHE_DIR"] = tiktoken_cache_dir
|
||||
# encoder = tiktoken.encoding_for_model("gpt-3.5-turbo")
|
||||
encoder = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
|
||||
def num_tokens_from_string(string: str) -> int:
|
||||
"""Returns the number of tokens in a text string."""
|
||||
try:
|
||||
return len(encoder.encode(string))
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
def total_token_count_from_response(resp):
|
||||
if hasattr(resp, "usage") and hasattr(resp.usage, "total_tokens"):
|
||||
try:
|
||||
return resp.usage.total_tokens
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if hasattr(resp, "usage_metadata") and hasattr(resp.usage_metadata, "total_tokens"):
|
||||
try:
|
||||
return resp.usage_metadata.total_tokens
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if 'usage' in resp and 'total_tokens' in resp['usage']:
|
||||
try:
|
||||
return resp["usage"]["total_tokens"]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if 'usage' in resp and 'input_tokens' in resp['usage'] and 'output_tokens' in resp['usage']:
|
||||
try:
|
||||
return resp["usage"]["input_tokens"] + resp["usage"]["output_tokens"]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if 'meta' in resp and 'tokens' in resp['meta'] and 'input_tokens' in resp['meta']['tokens'] and 'output_tokens' in resp['meta']['tokens']:
|
||||
try:
|
||||
return resp["meta"]["tokens"]["input_tokens"] + resp["meta"]["tokens"]["output_tokens"]
|
||||
except Exception:
|
||||
pass
|
||||
return 0
|
||||
|
||||
|
||||
def truncate(string: str, max_len: int) -> str:
|
||||
"""Returns truncated text if the length of text exceed max_len."""
|
||||
return encoder.decode(encoder.encode(string)[:max_len])
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user