mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Refactor: How LiteLLMBase Calculate total count (#10532)
### What problem does this PR solve? How LiteLLMBase Calculate total count ### Type of change - [x] Refactoring
This commit is contained in:
@ -167,7 +167,7 @@ class Base(ABC):
|
||||
ans = response.choices[0].message.content.strip()
|
||||
if response.choices[0].finish_reason == "length":
|
||||
ans = self._length_stop(ans)
|
||||
return ans, self.total_token_count(response)
|
||||
return ans, total_token_count_from_response(response)
|
||||
|
||||
def _chat_streamly(self, history, gen_conf, **kwargs):
|
||||
logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4))
|
||||
@ -193,7 +193,7 @@ class Base(ABC):
|
||||
reasoning_start = False
|
||||
ans = resp.choices[0].delta.content
|
||||
|
||||
tol = self.total_token_count(resp)
|
||||
tol = total_token_count_from_response(resp)
|
||||
if not tol:
|
||||
tol = num_tokens_from_string(resp.choices[0].delta.content)
|
||||
|
||||
@ -283,7 +283,7 @@ class Base(ABC):
|
||||
for _ in range(self.max_rounds + 1):
|
||||
logging.info(f"{self.tools=}")
|
||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=self.tools, tool_choice="auto", **gen_conf)
|
||||
tk_count += self.total_token_count(response)
|
||||
tk_count += total_token_count_from_response(response)
|
||||
if any([not response.choices, not response.choices[0].message]):
|
||||
raise Exception(f"500 response structure error. Response: {response}")
|
||||
|
||||
@ -401,7 +401,7 @@ class Base(ABC):
|
||||
answer += resp.choices[0].delta.content
|
||||
yield resp.choices[0].delta.content
|
||||
|
||||
tol = self.total_token_count(resp)
|
||||
tol = total_token_count_from_response(resp)
|
||||
if not tol:
|
||||
total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
|
||||
else:
|
||||
@ -437,7 +437,7 @@ class Base(ABC):
|
||||
if not resp.choices[0].delta.content:
|
||||
resp.choices[0].delta.content = ""
|
||||
continue
|
||||
tol = self.total_token_count(resp)
|
||||
tol = total_token_count_from_response(resp)
|
||||
if not tol:
|
||||
total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
|
||||
else:
|
||||
@ -472,9 +472,6 @@ class Base(ABC):
|
||||
|
||||
yield total_tokens
|
||||
|
||||
def total_token_count(self, resp):
|
||||
return total_token_count_from_response(resp)
|
||||
|
||||
def _calculate_dynamic_ctx(self, history):
|
||||
"""Calculate dynamic context window size"""
|
||||
|
||||
@ -604,7 +601,7 @@ class BaiChuanChat(Base):
|
||||
ans += LENGTH_NOTIFICATION_CN
|
||||
else:
|
||||
ans += LENGTH_NOTIFICATION_EN
|
||||
return ans, self.total_token_count(response)
|
||||
return ans, total_token_count_from_response(response)
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
|
||||
if system and history and history[0].get("role") != "system":
|
||||
@ -627,7 +624,7 @@ class BaiChuanChat(Base):
|
||||
if not resp.choices[0].delta.content:
|
||||
resp.choices[0].delta.content = ""
|
||||
ans = resp.choices[0].delta.content
|
||||
tol = self.total_token_count(resp)
|
||||
tol = total_token_count_from_response(resp)
|
||||
if not tol:
|
||||
total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
|
||||
else:
|
||||
@ -691,9 +688,9 @@ class ZhipuChat(Base):
|
||||
ans += LENGTH_NOTIFICATION_CN
|
||||
else:
|
||||
ans += LENGTH_NOTIFICATION_EN
|
||||
tk_count = self.total_token_count(resp)
|
||||
tk_count = total_token_count_from_response(resp)
|
||||
if resp.choices[0].finish_reason == "stop":
|
||||
tk_count = self.total_token_count(resp)
|
||||
tk_count = total_token_count_from_response(resp)
|
||||
yield ans
|
||||
except Exception as e:
|
||||
yield ans + "\n**ERROR**: " + str(e)
|
||||
@ -812,7 +809,7 @@ class MiniMaxChat(Base):
|
||||
ans += LENGTH_NOTIFICATION_CN
|
||||
else:
|
||||
ans += LENGTH_NOTIFICATION_EN
|
||||
return ans, self.total_token_count(response)
|
||||
return ans, total_token_count_from_response(response)
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf):
|
||||
if system and history and history[0].get("role") != "system":
|
||||
@ -847,7 +844,7 @@ class MiniMaxChat(Base):
|
||||
if "choices" in resp and "delta" in resp["choices"][0]:
|
||||
text = resp["choices"][0]["delta"]["content"]
|
||||
ans = text
|
||||
tol = self.total_token_count(resp)
|
||||
tol = total_token_count_from_response(resp)
|
||||
if not tol:
|
||||
total_tokens += num_tokens_from_string(text)
|
||||
else:
|
||||
@ -886,7 +883,7 @@ class MistralChat(Base):
|
||||
ans += LENGTH_NOTIFICATION_CN
|
||||
else:
|
||||
ans += LENGTH_NOTIFICATION_EN
|
||||
return ans, self.total_token_count(response)
|
||||
return ans, total_token_count_from_response(response)
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
|
||||
if system and history and history[0].get("role") != "system":
|
||||
@ -1110,7 +1107,7 @@ class BaiduYiyanChat(Base):
|
||||
system = history[0]["content"] if history and history[0]["role"] == "system" else ""
|
||||
response = self.client.do(model=self.model_name, messages=[h for h in history if h["role"] != "system"], system=system, **gen_conf).body
|
||||
ans = response["result"]
|
||||
return ans, self.total_token_count(response)
|
||||
return ans, total_token_count_from_response(response)
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
|
||||
gen_conf["penalty_score"] = ((gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty", 0)) / 2) + 1
|
||||
@ -1124,7 +1121,7 @@ class BaiduYiyanChat(Base):
|
||||
for resp in response:
|
||||
resp = resp.body
|
||||
ans = resp["result"]
|
||||
total_tokens = self.total_token_count(resp)
|
||||
total_tokens = total_token_count_from_response(resp)
|
||||
|
||||
yield ans
|
||||
|
||||
@ -1478,7 +1475,7 @@ class LiteLLMBase(ABC):
|
||||
if response.choices[0].finish_reason == "length":
|
||||
ans = self._length_stop(ans)
|
||||
|
||||
return ans, self.total_token_count(response)
|
||||
return ans, total_token_count_from_response(response)
|
||||
|
||||
def _chat_streamly(self, history, gen_conf, **kwargs):
|
||||
logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4))
|
||||
@ -1512,7 +1509,7 @@ class LiteLLMBase(ABC):
|
||||
reasoning_start = False
|
||||
ans = delta.content
|
||||
|
||||
tol = self.total_token_count(resp)
|
||||
tol = total_token_count_from_response(resp)
|
||||
if not tol:
|
||||
tol = num_tokens_from_string(delta.content)
|
||||
|
||||
@ -1665,7 +1662,7 @@ class LiteLLMBase(ABC):
|
||||
timeout=self.timeout,
|
||||
)
|
||||
|
||||
tk_count += self.total_token_count(response)
|
||||
tk_count += total_token_count_from_response(response)
|
||||
|
||||
if not hasattr(response, "choices") or not response.choices or not response.choices[0].message:
|
||||
raise Exception(f"500 response structure error. Response: {response}")
|
||||
@ -1797,7 +1794,7 @@ class LiteLLMBase(ABC):
|
||||
answer += delta.content
|
||||
yield delta.content
|
||||
|
||||
tol = self.total_token_count(resp)
|
||||
tol = total_token_count_from_response(resp)
|
||||
if not tol:
|
||||
total_tokens += num_tokens_from_string(delta.content)
|
||||
else:
|
||||
@ -1846,7 +1843,7 @@ class LiteLLMBase(ABC):
|
||||
delta = resp.choices[0].delta
|
||||
if not hasattr(delta, "content") or delta.content is None:
|
||||
continue
|
||||
tol = self.total_token_count(resp)
|
||||
tol = total_token_count_from_response(resp)
|
||||
if not tol:
|
||||
total_tokens += num_tokens_from_string(delta.content)
|
||||
else:
|
||||
@ -1880,17 +1877,6 @@ class LiteLLMBase(ABC):
|
||||
|
||||
yield total_tokens
|
||||
|
||||
def total_token_count(self, resp):
|
||||
try:
|
||||
return resp.usage.total_tokens
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
return resp["usage"]["total_tokens"]
|
||||
except Exception:
|
||||
pass
|
||||
return 0
|
||||
|
||||
def _calculate_dynamic_ctx(self, history):
|
||||
"""Calculate dynamic context window size"""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user