Fix: 'AzureEmbed' object has no attribute 'total_token_count_from_response' (#11962)

### What problem does this PR solve?

https://github.com/infiniflow/ragflow/issues/11956

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
Stephen Hu
2025-12-16 11:29:07 +08:00
committed by GitHub
parent a98887d4ca
commit ef5d1d4b74

View File

@ -49,17 +49,6 @@ class Base(ABC):
def encode_queries(self, text: str):
raise NotImplementedError("Please implement encode method!")
def total_token_count(self, resp):
try:
return resp.usage.total_tokens
except Exception:
pass
try:
return resp["usage"]["total_tokens"]
except Exception:
pass
return 0
class BuiltinEmbed(Base):
_FACTORY_NAME = "Builtin"
@ -127,7 +116,7 @@ class OpenAIEmbed(Base):
def encode_queries(self, text):
res = self.client.embeddings.create(input=[truncate(text, 8191)], model=self.model_name, encoding_format="float",extra_body={"drop_params": True})
try:
return np.array(res.data[0].embedding), self.total_token_count_from_response(res)
return np.array(res.data[0].embedding), total_token_count_from_response(res)
except Exception as _e:
log_exception(_e, res)
raise Exception(f"Error: {res}")
@ -216,7 +205,7 @@ class QWenEmbed(Base):
for e in resp["output"]["embeddings"]:
embds[e["text_index"]] = e["embedding"]
res.extend(embds)
token_count += self.total_token_count_from_response(resp)
token_count += total_token_count_from_response(resp)
except Exception as _e:
log_exception(_e, resp)
raise
@ -225,7 +214,7 @@ class QWenEmbed(Base):
def encode_queries(self, text):
resp = dashscope.TextEmbedding.call(model=self.model_name, input=text[:2048], api_key=self.key, text_type="query")
try:
return np.array(resp["output"]["embeddings"][0]["embedding"]), self.total_token_count_from_response(resp)
return np.array(resp["output"]["embeddings"][0]["embedding"]), total_token_count_from_response(resp)
except Exception as _e:
log_exception(_e, resp)
raise Exception(f"Error: {resp}")
@ -253,7 +242,7 @@ class ZhipuEmbed(Base):
res = self.client.embeddings.create(input=txt, model=self.model_name)
try:
arr.append(res.data[0].embedding)
tks_num += self.total_token_count_from_response(res)
tks_num += total_token_count_from_response(res)
except Exception as _e:
log_exception(_e, res)
raise Exception(f"Error: {res}")
@ -262,7 +251,7 @@ class ZhipuEmbed(Base):
def encode_queries(self, text):
res = self.client.embeddings.create(input=text, model=self.model_name)
try:
return np.array(res.data[0].embedding), self.total_token_count_from_response(res)
return np.array(res.data[0].embedding), total_token_count_from_response(res)
except Exception as _e:
log_exception(_e, res)
raise Exception(f"Error: {res}")
@ -323,7 +312,7 @@ class XinferenceEmbed(Base):
try:
res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name)
ress.extend([d.embedding for d in res.data])
total_tokens += self.total_token_count_from_response(res)
total_tokens += total_token_count_from_response(res)
except Exception as _e:
log_exception(_e, res)
raise Exception(f"Error: {res}")
@ -333,7 +322,7 @@ class XinferenceEmbed(Base):
res = None
try:
res = self.client.embeddings.create(input=[text], model=self.model_name)
return np.array(res.data[0].embedding), self.total_token_count_from_response(res)
return np.array(res.data[0].embedding), total_token_count_from_response(res)
except Exception as _e:
log_exception(_e, res)
raise Exception(f"Error: {res}")
@ -409,7 +398,7 @@ class JinaMultiVecEmbed(Base):
ress.append(chunk_emb)
token_count += self.total_token_count_from_response(res)
token_count +=total_token_count_from_response(res)
except Exception as _e:
log_exception(_e, response)
raise Exception(f"Error: {response}")
@ -443,7 +432,7 @@ class MistralEmbed(Base):
try:
res = self.client.embeddings(input=texts[i : i + batch_size], model=self.model_name)
ress.extend([d.embedding for d in res.data])
token_count += self.total_token_count_from_response(res)
token_count += total_token_count_from_response(res)
break
except Exception as _e:
if retry_max == 1:
@ -460,7 +449,7 @@ class MistralEmbed(Base):
while retry_max > 0:
try:
res = self.client.embeddings(input=[truncate(text, 8196)], model=self.model_name)
return np.array(res.data[0].embedding), self.total_token_count_from_response(res)
return np.array(res.data[0].embedding), total_token_count_from_response(res)
except Exception as _e:
if retry_max == 1:
log_exception(_e)
@ -595,7 +584,7 @@ class NvidiaEmbed(Base):
try:
res = response.json()
ress.extend([d["embedding"] for d in res["data"]])
token_count += self.total_token_count_from_response(res)
token_count += total_token_count_from_response(res)
except Exception as _e:
log_exception(_e, response)
raise Exception(f"Error: {response}")
@ -732,7 +721,7 @@ class SILICONFLOWEmbed(Base):
try:
res = response.json()
ress.extend([d["embedding"] for d in res["data"]])
token_count += self.total_token_count_from_response(res)
token_count += total_token_count_from_response(res)
except Exception as _e:
log_exception(_e, response)
raise Exception(f"Error: {response}")
@ -748,7 +737,7 @@ class SILICONFLOWEmbed(Base):
response = requests.post(self.base_url, json=payload, headers=self.headers)
try:
res = response.json()
return np.array(res["data"][0]["embedding"]), self.total_token_count_from_response(res)
return np.array(res["data"][0]["embedding"]), total_token_count_from_response(res)
except Exception as _e:
log_exception(_e, response)
raise Exception(f"Error: {response}")
@ -794,7 +783,7 @@ class BaiduYiyanEmbed(Base):
try:
return (
np.array([r["embedding"] for r in res["data"]]),
self.total_token_count_from_response(res),
total_token_count_from_response(res),
)
except Exception as _e:
log_exception(_e, res)
@ -805,7 +794,7 @@ class BaiduYiyanEmbed(Base):
try:
return (
np.array([r["embedding"] for r in res["data"]]),
self.total_token_count_from_response(res),
total_token_count_from_response(res),
)
except Exception as _e:
log_exception(_e, res)