mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Refactor for total_tokens. (#4652)
### What problem does this PR solve? #4567 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
@ -42,6 +42,17 @@ class Base(ABC):
|
||||
def similarity(self, query: str, texts: list):
|
||||
raise NotImplementedError("Please implement encode method!")
|
||||
|
||||
def total_token_count(self, resp):
|
||||
try:
|
||||
return resp.usage.total_tokens
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
return resp["usage"]["total_tokens"]
|
||||
except Exception:
|
||||
pass
|
||||
return 0
|
||||
|
||||
|
||||
class DefaultRerank(Base):
|
||||
_model = None
|
||||
@ -115,7 +126,7 @@ class JinaRerank(Base):
|
||||
rank = np.zeros(len(texts), dtype=float)
|
||||
for d in res["results"]:
|
||||
rank[d["index"]] = d["relevance_score"]
|
||||
return rank, res["usage"]["total_tokens"]
|
||||
return rank, self.total_token_count(res)
|
||||
|
||||
|
||||
class YoudaoRerank(DefaultRerank):
|
||||
@ -417,7 +428,7 @@ class BaiduYiyanRerank(Base):
|
||||
rank = np.zeros(len(texts), dtype=float)
|
||||
for d in res["results"]:
|
||||
rank[d["index"]] = d["relevance_score"]
|
||||
return rank, res["usage"]["total_tokens"]
|
||||
return rank, self.total_token_count(res)
|
||||
|
||||
|
||||
class VoyageRerank(Base):
|
||||
|
||||
Reference in New Issue
Block a user