From 0ecccd27eb00699fbe44c9463dee447c95e3c367 Mon Sep 17 00:00:00 2001 From: Stephen Hu <812791840@qq.com> Date: Fri, 31 Oct 2025 09:46:16 +0800 Subject: [PATCH] Refactor:improve the logic for rerank models to cal the total token count (#10882) ### What problem does this PR solve? improve the logic for rerank models to cal the total token count ### Type of change - [x] Refactoring --- rag/llm/rerank_model.py | 11 ++++------- rag/utils/__init__.py | 6 ++++++ 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/rag/llm/rerank_model.py b/rag/llm/rerank_model.py index 96113e9b1..7a4207d1e 100644 --- a/rag/llm/rerank_model.py +++ b/rag/llm/rerank_model.py @@ -36,9 +36,6 @@ class Base(ABC): def similarity(self, query: str, texts: list): raise NotImplementedError("Please implement encode method!") - def total_token_count(self, resp): - return total_token_count_from_response(resp) - class JinaRerank(Base): _FACTORY_NAME = "Jina" @@ -58,7 +55,7 @@ class JinaRerank(Base): rank[d["index"]] = d["relevance_score"] except Exception as _e: log_exception(_e, res) - return rank, self.total_token_count(res) + return rank, total_token_count_from_response(res) class XInferenceRerank(Base): @@ -301,7 +298,7 @@ class SILICONFLOWRerank(Base): log_exception(_e, response) return ( rank, - response["meta"]["tokens"]["input_tokens"] + response["meta"]["tokens"]["output_tokens"], + total_token_count_from_response(response), ) @@ -330,7 +327,7 @@ class BaiduYiyanRerank(Base): rank[d["index"]] = d["relevance_score"] except Exception as _e: log_exception(_e, res) - return rank, self.total_token_count(res) + return rank, total_token_count_from_response(res) class VoyageRerank(Base): @@ -378,7 +375,7 @@ class QWenRerank(Base): rank[r.index] = r.relevance_score except Exception as _e: log_exception(_e, resp) - return rank, resp.usage.total_tokens + return rank, total_token_count_from_response(resp) else: raise ValueError(f"Error calling QWenRerank model {self.model_name}: {resp.status_code} - {resp.text}") diff --git a/rag/utils/__init__.py b/rag/utils/__init__.py index 16acfd98e..19d952d58 100644 --- a/rag/utils/__init__.py +++ b/rag/utils/__init__.py @@ -69,6 +69,12 @@ def total_token_count_from_response(resp): return resp["usage"]["input_tokens"] + resp["usage"]["output_tokens"] except Exception: pass + + if 'meta' in resp and 'tokens' in resp['meta'] and 'input_tokens' in resp['meta']['tokens'] and 'output_tokens' in resp['meta']['tokens']: + try: + return resp["meta"]["tokens"]["input_tokens"] + resp["meta"]["tokens"]["output_tokens"] + except Exception: + pass return 0