diff --git a/common/token_utils.py b/common/token_utils.py index 4d5220fc9..5763dc97e 100644 --- a/common/token_utils.py +++ b/common/token_utils.py @@ -56,6 +56,12 @@ def total_token_count_from_response(resp): except Exception: pass + if hasattr(resp, "meta") and hasattr(resp.meta, "billed_units") and hasattr(resp.meta.billed_units, "input_tokens"): + try: + return resp.meta.billed_units.input_tokens + except Exception: + pass + if isinstance(resp, dict) and 'usage' in resp and 'total_tokens' in resp['usage']: try: return resp["usage"]["total_tokens"] diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index 344d3aac1..3d1ca5546 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -639,7 +639,7 @@ class CoHereEmbed(Base): ) try: ress.extend([d for d in res.embeddings.float]) - token_count += res.meta.billed_units.input_tokens + token_count += total_token_count_from_response(res) except Exception as _e: log_exception(_e, res) raise Exception(f"Error: {res}") @@ -653,7 +653,7 @@ class CoHereEmbed(Base): embedding_types=["float"], ) try: - return np.array(res.embeddings.float[0]), int(res.meta.billed_units.input_tokens) + return np.array(res.embeddings.float[0]), int(total_token_count_from_response(res)) except Exception as _e: log_exception(_e, res) raise Exception(f"Error: {res}")