mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Refactor: How LiteLLMBase Calculate total count (#10532)
### What problem does this PR solve? How LiteLLMBase Calculate total count ### Type of change - [x] Refactoring
This commit is contained in:
@ -167,7 +167,7 @@ class Base(ABC):
|
|||||||
ans = response.choices[0].message.content.strip()
|
ans = response.choices[0].message.content.strip()
|
||||||
if response.choices[0].finish_reason == "length":
|
if response.choices[0].finish_reason == "length":
|
||||||
ans = self._length_stop(ans)
|
ans = self._length_stop(ans)
|
||||||
return ans, self.total_token_count(response)
|
return ans, total_token_count_from_response(response)
|
||||||
|
|
||||||
def _chat_streamly(self, history, gen_conf, **kwargs):
|
def _chat_streamly(self, history, gen_conf, **kwargs):
|
||||||
logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4))
|
logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4))
|
||||||
@ -193,7 +193,7 @@ class Base(ABC):
|
|||||||
reasoning_start = False
|
reasoning_start = False
|
||||||
ans = resp.choices[0].delta.content
|
ans = resp.choices[0].delta.content
|
||||||
|
|
||||||
tol = self.total_token_count(resp)
|
tol = total_token_count_from_response(resp)
|
||||||
if not tol:
|
if not tol:
|
||||||
tol = num_tokens_from_string(resp.choices[0].delta.content)
|
tol = num_tokens_from_string(resp.choices[0].delta.content)
|
||||||
|
|
||||||
@ -283,7 +283,7 @@ class Base(ABC):
|
|||||||
for _ in range(self.max_rounds + 1):
|
for _ in range(self.max_rounds + 1):
|
||||||
logging.info(f"{self.tools=}")
|
logging.info(f"{self.tools=}")
|
||||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=self.tools, tool_choice="auto", **gen_conf)
|
response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=self.tools, tool_choice="auto", **gen_conf)
|
||||||
tk_count += self.total_token_count(response)
|
tk_count += total_token_count_from_response(response)
|
||||||
if any([not response.choices, not response.choices[0].message]):
|
if any([not response.choices, not response.choices[0].message]):
|
||||||
raise Exception(f"500 response structure error. Response: {response}")
|
raise Exception(f"500 response structure error. Response: {response}")
|
||||||
|
|
||||||
@ -401,7 +401,7 @@ class Base(ABC):
|
|||||||
answer += resp.choices[0].delta.content
|
answer += resp.choices[0].delta.content
|
||||||
yield resp.choices[0].delta.content
|
yield resp.choices[0].delta.content
|
||||||
|
|
||||||
tol = self.total_token_count(resp)
|
tol = total_token_count_from_response(resp)
|
||||||
if not tol:
|
if not tol:
|
||||||
total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
|
total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
|
||||||
else:
|
else:
|
||||||
@ -437,7 +437,7 @@ class Base(ABC):
|
|||||||
if not resp.choices[0].delta.content:
|
if not resp.choices[0].delta.content:
|
||||||
resp.choices[0].delta.content = ""
|
resp.choices[0].delta.content = ""
|
||||||
continue
|
continue
|
||||||
tol = self.total_token_count(resp)
|
tol = total_token_count_from_response(resp)
|
||||||
if not tol:
|
if not tol:
|
||||||
total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
|
total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
|
||||||
else:
|
else:
|
||||||
@ -472,9 +472,6 @@ class Base(ABC):
|
|||||||
|
|
||||||
yield total_tokens
|
yield total_tokens
|
||||||
|
|
||||||
def total_token_count(self, resp):
|
|
||||||
return total_token_count_from_response(resp)
|
|
||||||
|
|
||||||
def _calculate_dynamic_ctx(self, history):
|
def _calculate_dynamic_ctx(self, history):
|
||||||
"""Calculate dynamic context window size"""
|
"""Calculate dynamic context window size"""
|
||||||
|
|
||||||
@ -604,7 +601,7 @@ class BaiChuanChat(Base):
|
|||||||
ans += LENGTH_NOTIFICATION_CN
|
ans += LENGTH_NOTIFICATION_CN
|
||||||
else:
|
else:
|
||||||
ans += LENGTH_NOTIFICATION_EN
|
ans += LENGTH_NOTIFICATION_EN
|
||||||
return ans, self.total_token_count(response)
|
return ans, total_token_count_from_response(response)
|
||||||
|
|
||||||
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
|
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
|
||||||
if system and history and history[0].get("role") != "system":
|
if system and history and history[0].get("role") != "system":
|
||||||
@ -627,7 +624,7 @@ class BaiChuanChat(Base):
|
|||||||
if not resp.choices[0].delta.content:
|
if not resp.choices[0].delta.content:
|
||||||
resp.choices[0].delta.content = ""
|
resp.choices[0].delta.content = ""
|
||||||
ans = resp.choices[0].delta.content
|
ans = resp.choices[0].delta.content
|
||||||
tol = self.total_token_count(resp)
|
tol = total_token_count_from_response(resp)
|
||||||
if not tol:
|
if not tol:
|
||||||
total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
|
total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
|
||||||
else:
|
else:
|
||||||
@ -691,9 +688,9 @@ class ZhipuChat(Base):
|
|||||||
ans += LENGTH_NOTIFICATION_CN
|
ans += LENGTH_NOTIFICATION_CN
|
||||||
else:
|
else:
|
||||||
ans += LENGTH_NOTIFICATION_EN
|
ans += LENGTH_NOTIFICATION_EN
|
||||||
tk_count = self.total_token_count(resp)
|
tk_count = total_token_count_from_response(resp)
|
||||||
if resp.choices[0].finish_reason == "stop":
|
if resp.choices[0].finish_reason == "stop":
|
||||||
tk_count = self.total_token_count(resp)
|
tk_count = total_token_count_from_response(resp)
|
||||||
yield ans
|
yield ans
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
yield ans + "\n**ERROR**: " + str(e)
|
yield ans + "\n**ERROR**: " + str(e)
|
||||||
@ -812,7 +809,7 @@ class MiniMaxChat(Base):
|
|||||||
ans += LENGTH_NOTIFICATION_CN
|
ans += LENGTH_NOTIFICATION_CN
|
||||||
else:
|
else:
|
||||||
ans += LENGTH_NOTIFICATION_EN
|
ans += LENGTH_NOTIFICATION_EN
|
||||||
return ans, self.total_token_count(response)
|
return ans, total_token_count_from_response(response)
|
||||||
|
|
||||||
def chat_streamly(self, system, history, gen_conf):
|
def chat_streamly(self, system, history, gen_conf):
|
||||||
if system and history and history[0].get("role") != "system":
|
if system and history and history[0].get("role") != "system":
|
||||||
@ -847,7 +844,7 @@ class MiniMaxChat(Base):
|
|||||||
if "choices" in resp and "delta" in resp["choices"][0]:
|
if "choices" in resp and "delta" in resp["choices"][0]:
|
||||||
text = resp["choices"][0]["delta"]["content"]
|
text = resp["choices"][0]["delta"]["content"]
|
||||||
ans = text
|
ans = text
|
||||||
tol = self.total_token_count(resp)
|
tol = total_token_count_from_response(resp)
|
||||||
if not tol:
|
if not tol:
|
||||||
total_tokens += num_tokens_from_string(text)
|
total_tokens += num_tokens_from_string(text)
|
||||||
else:
|
else:
|
||||||
@ -886,7 +883,7 @@ class MistralChat(Base):
|
|||||||
ans += LENGTH_NOTIFICATION_CN
|
ans += LENGTH_NOTIFICATION_CN
|
||||||
else:
|
else:
|
||||||
ans += LENGTH_NOTIFICATION_EN
|
ans += LENGTH_NOTIFICATION_EN
|
||||||
return ans, self.total_token_count(response)
|
return ans, total_token_count_from_response(response)
|
||||||
|
|
||||||
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
|
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
|
||||||
if system and history and history[0].get("role") != "system":
|
if system and history and history[0].get("role") != "system":
|
||||||
@ -1110,7 +1107,7 @@ class BaiduYiyanChat(Base):
|
|||||||
system = history[0]["content"] if history and history[0]["role"] == "system" else ""
|
system = history[0]["content"] if history and history[0]["role"] == "system" else ""
|
||||||
response = self.client.do(model=self.model_name, messages=[h for h in history if h["role"] != "system"], system=system, **gen_conf).body
|
response = self.client.do(model=self.model_name, messages=[h for h in history if h["role"] != "system"], system=system, **gen_conf).body
|
||||||
ans = response["result"]
|
ans = response["result"]
|
||||||
return ans, self.total_token_count(response)
|
return ans, total_token_count_from_response(response)
|
||||||
|
|
||||||
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
|
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
|
||||||
gen_conf["penalty_score"] = ((gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty", 0)) / 2) + 1
|
gen_conf["penalty_score"] = ((gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty", 0)) / 2) + 1
|
||||||
@ -1124,7 +1121,7 @@ class BaiduYiyanChat(Base):
|
|||||||
for resp in response:
|
for resp in response:
|
||||||
resp = resp.body
|
resp = resp.body
|
||||||
ans = resp["result"]
|
ans = resp["result"]
|
||||||
total_tokens = self.total_token_count(resp)
|
total_tokens = total_token_count_from_response(resp)
|
||||||
|
|
||||||
yield ans
|
yield ans
|
||||||
|
|
||||||
@ -1478,7 +1475,7 @@ class LiteLLMBase(ABC):
|
|||||||
if response.choices[0].finish_reason == "length":
|
if response.choices[0].finish_reason == "length":
|
||||||
ans = self._length_stop(ans)
|
ans = self._length_stop(ans)
|
||||||
|
|
||||||
return ans, self.total_token_count(response)
|
return ans, total_token_count_from_response(response)
|
||||||
|
|
||||||
def _chat_streamly(self, history, gen_conf, **kwargs):
|
def _chat_streamly(self, history, gen_conf, **kwargs):
|
||||||
logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4))
|
logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4))
|
||||||
@ -1512,7 +1509,7 @@ class LiteLLMBase(ABC):
|
|||||||
reasoning_start = False
|
reasoning_start = False
|
||||||
ans = delta.content
|
ans = delta.content
|
||||||
|
|
||||||
tol = self.total_token_count(resp)
|
tol = total_token_count_from_response(resp)
|
||||||
if not tol:
|
if not tol:
|
||||||
tol = num_tokens_from_string(delta.content)
|
tol = num_tokens_from_string(delta.content)
|
||||||
|
|
||||||
@ -1665,7 +1662,7 @@ class LiteLLMBase(ABC):
|
|||||||
timeout=self.timeout,
|
timeout=self.timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
tk_count += self.total_token_count(response)
|
tk_count += total_token_count_from_response(response)
|
||||||
|
|
||||||
if not hasattr(response, "choices") or not response.choices or not response.choices[0].message:
|
if not hasattr(response, "choices") or not response.choices or not response.choices[0].message:
|
||||||
raise Exception(f"500 response structure error. Response: {response}")
|
raise Exception(f"500 response structure error. Response: {response}")
|
||||||
@ -1797,7 +1794,7 @@ class LiteLLMBase(ABC):
|
|||||||
answer += delta.content
|
answer += delta.content
|
||||||
yield delta.content
|
yield delta.content
|
||||||
|
|
||||||
tol = self.total_token_count(resp)
|
tol = total_token_count_from_response(resp)
|
||||||
if not tol:
|
if not tol:
|
||||||
total_tokens += num_tokens_from_string(delta.content)
|
total_tokens += num_tokens_from_string(delta.content)
|
||||||
else:
|
else:
|
||||||
@ -1846,7 +1843,7 @@ class LiteLLMBase(ABC):
|
|||||||
delta = resp.choices[0].delta
|
delta = resp.choices[0].delta
|
||||||
if not hasattr(delta, "content") or delta.content is None:
|
if not hasattr(delta, "content") or delta.content is None:
|
||||||
continue
|
continue
|
||||||
tol = self.total_token_count(resp)
|
tol = total_token_count_from_response(resp)
|
||||||
if not tol:
|
if not tol:
|
||||||
total_tokens += num_tokens_from_string(delta.content)
|
total_tokens += num_tokens_from_string(delta.content)
|
||||||
else:
|
else:
|
||||||
@ -1880,17 +1877,6 @@ class LiteLLMBase(ABC):
|
|||||||
|
|
||||||
yield total_tokens
|
yield total_tokens
|
||||||
|
|
||||||
def total_token_count(self, resp):
|
|
||||||
try:
|
|
||||||
return resp.usage.total_tokens
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
try:
|
|
||||||
return resp["usage"]["total_tokens"]
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
return 0
|
|
||||||
|
|
||||||
def _calculate_dynamic_ctx(self, history):
|
def _calculate_dynamic_ctx(self, history):
|
||||||
"""Calculate dynamic context window size"""
|
"""Calculate dynamic context window size"""
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user