Refactor for total_tokens. (#4652)

### What problem does this PR solve?

#4567
### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
Kevin Hu
2025-01-26 13:54:26 +08:00
committed by GitHub
parent c24137bd11
commit 4776fa5e4e
3 changed files with 79 additions and 52 deletions

View File

@ -42,6 +42,17 @@ class Base(ABC):
def similarity(self, query: str, texts: list):
raise NotImplementedError("Please implement encode method!")
def total_token_count(self, resp):
try:
return resp.usage.total_tokens
except Exception:
pass
try:
return resp["usage"]["total_tokens"]
except Exception:
pass
return 0
class DefaultRerank(Base):
_model = None
@ -115,7 +126,7 @@ class JinaRerank(Base):
rank = np.zeros(len(texts), dtype=float)
for d in res["results"]:
rank[d["index"]] = d["relevance_score"]
return rank, res["usage"]["total_tokens"]
return rank, self.total_token_count(res)
class YoudaoRerank(DefaultRerank):
@ -417,7 +428,7 @@ class BaiduYiyanRerank(Base):
rank = np.zeros(len(texts), dtype=float)
for d in res["results"]:
rank[d["index"]] = d["relevance_score"]
return rank, res["usage"]["total_tokens"]
return rank, self.total_token_count(res)
class VoyageRerank(Base):