Refactor:improve the logic for rerank models to cal the total token count (#10882)

### What problem does this PR solve?

improve the logic for rerank models to cal the total token count

### Type of change

- [x] Refactoring
This commit is contained in:
Stephen Hu
2025-10-31 09:46:16 +08:00
committed by GitHub
parent 5a830ea68b
commit 0ecccd27eb
2 changed files with 10 additions and 7 deletions

View File

@ -36,9 +36,6 @@ class Base(ABC):
def similarity(self, query: str, texts: list):
raise NotImplementedError("Please implement encode method!")
def total_token_count(self, resp):
return total_token_count_from_response(resp)
class JinaRerank(Base):
_FACTORY_NAME = "Jina"
@ -58,7 +55,7 @@ class JinaRerank(Base):
rank[d["index"]] = d["relevance_score"]
except Exception as _e:
log_exception(_e, res)
return rank, self.total_token_count(res)
return rank, total_token_count_from_response(res)
class XInferenceRerank(Base):
@ -301,7 +298,7 @@ class SILICONFLOWRerank(Base):
log_exception(_e, response)
return (
rank,
response["meta"]["tokens"]["input_tokens"] + response["meta"]["tokens"]["output_tokens"],
total_token_count_from_response(response),
)
@ -330,7 +327,7 @@ class BaiduYiyanRerank(Base):
rank[d["index"]] = d["relevance_score"]
except Exception as _e:
log_exception(_e, res)
return rank, self.total_token_count(res)
return rank, total_token_count_from_response(res)
class VoyageRerank(Base):
@ -378,7 +375,7 @@ class QWenRerank(Base):
rank[r.index] = r.relevance_score
except Exception as _e:
log_exception(_e, resp)
return rank, resp.usage.total_tokens
return rank, total_token_count_from_response(resp)
else:
raise ValueError(f"Error calling QWenRerank model {self.model_name}: {resp.status_code} - {resp.text}")

View File

@ -69,6 +69,12 @@ def total_token_count_from_response(resp):
return resp["usage"]["input_tokens"] + resp["usage"]["output_tokens"]
except Exception:
pass
if 'meta' in resp and 'tokens' in resp['meta'] and 'input_tokens' in resp['meta']['tokens'] and 'output_tokens' in resp['meta']['tokens']:
try:
return resp["meta"]["tokens"]["input_tokens"] + resp["meta"]["tokens"]["output_tokens"]
except Exception:
pass
return 0