Refactor: How LiteLLMBase Calculate total count (#10532)

### What problem does this PR solve?

How LiteLLMBase Calculate total count

### Type of change

- [x] Refactoring
This commit is contained in:
Stephen Hu
2025-10-22 12:25:31 +08:00
committed by GitHub
parent a82e9b3d91
commit b30f0be858

View File

@ -167,7 +167,7 @@ class Base(ABC):
ans = response.choices[0].message.content.strip()
if response.choices[0].finish_reason == "length":
ans = self._length_stop(ans)
return ans, self.total_token_count(response)
return ans, total_token_count_from_response(response)
def _chat_streamly(self, history, gen_conf, **kwargs):
logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4))
@ -193,7 +193,7 @@ class Base(ABC):
reasoning_start = False
ans = resp.choices[0].delta.content
tol = self.total_token_count(resp)
tol = total_token_count_from_response(resp)
if not tol:
tol = num_tokens_from_string(resp.choices[0].delta.content)
@ -283,7 +283,7 @@ class Base(ABC):
for _ in range(self.max_rounds + 1):
logging.info(f"{self.tools=}")
response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=self.tools, tool_choice="auto", **gen_conf)
tk_count += self.total_token_count(response)
tk_count += total_token_count_from_response(response)
if any([not response.choices, not response.choices[0].message]):
raise Exception(f"500 response structure error. Response: {response}")
@ -401,7 +401,7 @@ class Base(ABC):
answer += resp.choices[0].delta.content
yield resp.choices[0].delta.content
tol = self.total_token_count(resp)
tol = total_token_count_from_response(resp)
if not tol:
total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
else:
@ -437,7 +437,7 @@ class Base(ABC):
if not resp.choices[0].delta.content:
resp.choices[0].delta.content = ""
continue
tol = self.total_token_count(resp)
tol = total_token_count_from_response(resp)
if not tol:
total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
else:
@ -472,9 +472,6 @@ class Base(ABC):
yield total_tokens
def total_token_count(self, resp):
return total_token_count_from_response(resp)
def _calculate_dynamic_ctx(self, history):
"""Calculate dynamic context window size"""
@ -604,7 +601,7 @@ class BaiChuanChat(Base):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
return ans, self.total_token_count(response)
return ans, total_token_count_from_response(response)
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
if system and history and history[0].get("role") != "system":
@ -627,7 +624,7 @@ class BaiChuanChat(Base):
if not resp.choices[0].delta.content:
resp.choices[0].delta.content = ""
ans = resp.choices[0].delta.content
tol = self.total_token_count(resp)
tol = total_token_count_from_response(resp)
if not tol:
total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
else:
@ -691,9 +688,9 @@ class ZhipuChat(Base):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
tk_count = self.total_token_count(resp)
tk_count = total_token_count_from_response(resp)
if resp.choices[0].finish_reason == "stop":
tk_count = self.total_token_count(resp)
tk_count = total_token_count_from_response(resp)
yield ans
except Exception as e:
yield ans + "\n**ERROR**: " + str(e)
@ -812,7 +809,7 @@ class MiniMaxChat(Base):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
return ans, self.total_token_count(response)
return ans, total_token_count_from_response(response)
def chat_streamly(self, system, history, gen_conf):
if system and history and history[0].get("role") != "system":
@ -847,7 +844,7 @@ class MiniMaxChat(Base):
if "choices" in resp and "delta" in resp["choices"][0]:
text = resp["choices"][0]["delta"]["content"]
ans = text
tol = self.total_token_count(resp)
tol = total_token_count_from_response(resp)
if not tol:
total_tokens += num_tokens_from_string(text)
else:
@ -886,7 +883,7 @@ class MistralChat(Base):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
return ans, self.total_token_count(response)
return ans, total_token_count_from_response(response)
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
if system and history and history[0].get("role") != "system":
@ -1110,7 +1107,7 @@ class BaiduYiyanChat(Base):
system = history[0]["content"] if history and history[0]["role"] == "system" else ""
response = self.client.do(model=self.model_name, messages=[h for h in history if h["role"] != "system"], system=system, **gen_conf).body
ans = response["result"]
return ans, self.total_token_count(response)
return ans, total_token_count_from_response(response)
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
gen_conf["penalty_score"] = ((gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty", 0)) / 2) + 1
@ -1124,7 +1121,7 @@ class BaiduYiyanChat(Base):
for resp in response:
resp = resp.body
ans = resp["result"]
total_tokens = self.total_token_count(resp)
total_tokens = total_token_count_from_response(resp)
yield ans
@ -1478,7 +1475,7 @@ class LiteLLMBase(ABC):
if response.choices[0].finish_reason == "length":
ans = self._length_stop(ans)
return ans, self.total_token_count(response)
return ans, total_token_count_from_response(response)
def _chat_streamly(self, history, gen_conf, **kwargs):
logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4))
@ -1512,7 +1509,7 @@ class LiteLLMBase(ABC):
reasoning_start = False
ans = delta.content
tol = self.total_token_count(resp)
tol = total_token_count_from_response(resp)
if not tol:
tol = num_tokens_from_string(delta.content)
@ -1665,7 +1662,7 @@ class LiteLLMBase(ABC):
timeout=self.timeout,
)
tk_count += self.total_token_count(response)
tk_count += total_token_count_from_response(response)
if not hasattr(response, "choices") or not response.choices or not response.choices[0].message:
raise Exception(f"500 response structure error. Response: {response}")
@ -1797,7 +1794,7 @@ class LiteLLMBase(ABC):
answer += delta.content
yield delta.content
tol = self.total_token_count(resp)
tol = total_token_count_from_response(resp)
if not tol:
total_tokens += num_tokens_from_string(delta.content)
else:
@ -1846,7 +1843,7 @@ class LiteLLMBase(ABC):
delta = resp.choices[0].delta
if not hasattr(delta, "content") or delta.content is None:
continue
tol = self.total_token_count(resp)
tol = total_token_count_from_response(resp)
if not tol:
total_tokens += num_tokens_from_string(delta.content)
else:
@ -1880,17 +1877,6 @@ class LiteLLMBase(ABC):
yield total_tokens
def total_token_count(self, resp):
try:
return resp.usage.total_tokens
except Exception:
pass
try:
return resp["usage"]["total_tokens"]
except Exception:
pass
return 0
def _calculate_dynamic_ctx(self, history):
"""Calculate dynamic context window size"""