diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index 61a09d0df..8ca38c893 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -167,7 +167,7 @@ class Base(ABC): ans = response.choices[0].message.content.strip() if response.choices[0].finish_reason == "length": ans = self._length_stop(ans) - return ans, self.total_token_count(response) + return ans, total_token_count_from_response(response) def _chat_streamly(self, history, gen_conf, **kwargs): logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4)) @@ -193,7 +193,7 @@ class Base(ABC): reasoning_start = False ans = resp.choices[0].delta.content - tol = self.total_token_count(resp) + tol = total_token_count_from_response(resp) if not tol: tol = num_tokens_from_string(resp.choices[0].delta.content) @@ -283,7 +283,7 @@ class Base(ABC): for _ in range(self.max_rounds + 1): logging.info(f"{self.tools=}") response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=self.tools, tool_choice="auto", **gen_conf) - tk_count += self.total_token_count(response) + tk_count += total_token_count_from_response(response) if any([not response.choices, not response.choices[0].message]): raise Exception(f"500 response structure error. Response: {response}") @@ -401,7 +401,7 @@ class Base(ABC): answer += resp.choices[0].delta.content yield resp.choices[0].delta.content - tol = self.total_token_count(resp) + tol = total_token_count_from_response(resp) if not tol: total_tokens += num_tokens_from_string(resp.choices[0].delta.content) else: @@ -437,7 +437,7 @@ class Base(ABC): if not resp.choices[0].delta.content: resp.choices[0].delta.content = "" continue - tol = self.total_token_count(resp) + tol = total_token_count_from_response(resp) if not tol: total_tokens += num_tokens_from_string(resp.choices[0].delta.content) else: @@ -472,9 +472,6 @@ class Base(ABC): yield total_tokens - def total_token_count(self, resp): - return total_token_count_from_response(resp) - def _calculate_dynamic_ctx(self, history): """Calculate dynamic context window size""" @@ -604,7 +601,7 @@ class BaiChuanChat(Base): ans += LENGTH_NOTIFICATION_CN else: ans += LENGTH_NOTIFICATION_EN - return ans, self.total_token_count(response) + return ans, total_token_count_from_response(response) def chat_streamly(self, system, history, gen_conf={}, **kwargs): if system and history and history[0].get("role") != "system": @@ -627,7 +624,7 @@ class BaiChuanChat(Base): if not resp.choices[0].delta.content: resp.choices[0].delta.content = "" ans = resp.choices[0].delta.content - tol = self.total_token_count(resp) + tol = total_token_count_from_response(resp) if not tol: total_tokens += num_tokens_from_string(resp.choices[0].delta.content) else: @@ -691,9 +688,9 @@ class ZhipuChat(Base): ans += LENGTH_NOTIFICATION_CN else: ans += LENGTH_NOTIFICATION_EN - tk_count = self.total_token_count(resp) + tk_count = total_token_count_from_response(resp) if resp.choices[0].finish_reason == "stop": - tk_count = self.total_token_count(resp) + tk_count = total_token_count_from_response(resp) yield ans except Exception as e: yield ans + "\n**ERROR**: " + str(e) @@ -812,7 +809,7 @@ class MiniMaxChat(Base): ans += LENGTH_NOTIFICATION_CN else: ans += LENGTH_NOTIFICATION_EN - return ans, self.total_token_count(response) + return ans, total_token_count_from_response(response) def chat_streamly(self, system, history, gen_conf): if system and history and history[0].get("role") != "system": @@ -847,7 +844,7 @@ class MiniMaxChat(Base): if "choices" in resp and "delta" in resp["choices"][0]: text = resp["choices"][0]["delta"]["content"] ans = text - tol = self.total_token_count(resp) + tol = total_token_count_from_response(resp) if not tol: total_tokens += num_tokens_from_string(text) else: @@ -886,7 +883,7 @@ class MistralChat(Base): ans += LENGTH_NOTIFICATION_CN else: ans += LENGTH_NOTIFICATION_EN - return ans, self.total_token_count(response) + return ans, total_token_count_from_response(response) def chat_streamly(self, system, history, gen_conf={}, **kwargs): if system and history and history[0].get("role") != "system": @@ -1110,7 +1107,7 @@ class BaiduYiyanChat(Base): system = history[0]["content"] if history and history[0]["role"] == "system" else "" response = self.client.do(model=self.model_name, messages=[h for h in history if h["role"] != "system"], system=system, **gen_conf).body ans = response["result"] - return ans, self.total_token_count(response) + return ans, total_token_count_from_response(response) def chat_streamly(self, system, history, gen_conf={}, **kwargs): gen_conf["penalty_score"] = ((gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty", 0)) / 2) + 1 @@ -1124,7 +1121,7 @@ class BaiduYiyanChat(Base): for resp in response: resp = resp.body ans = resp["result"] - total_tokens = self.total_token_count(resp) + total_tokens = total_token_count_from_response(resp) yield ans @@ -1478,7 +1475,7 @@ class LiteLLMBase(ABC): if response.choices[0].finish_reason == "length": ans = self._length_stop(ans) - return ans, self.total_token_count(response) + return ans, total_token_count_from_response(response) def _chat_streamly(self, history, gen_conf, **kwargs): logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4)) @@ -1512,7 +1509,7 @@ class LiteLLMBase(ABC): reasoning_start = False ans = delta.content - tol = self.total_token_count(resp) + tol = total_token_count_from_response(resp) if not tol: tol = num_tokens_from_string(delta.content) @@ -1665,7 +1662,7 @@ class LiteLLMBase(ABC): timeout=self.timeout, ) - tk_count += self.total_token_count(response) + tk_count += total_token_count_from_response(response) if not hasattr(response, "choices") or not response.choices or not response.choices[0].message: raise Exception(f"500 response structure error. Response: {response}") @@ -1797,7 +1794,7 @@ class LiteLLMBase(ABC): answer += delta.content yield delta.content - tol = self.total_token_count(resp) + tol = total_token_count_from_response(resp) if not tol: total_tokens += num_tokens_from_string(delta.content) else: @@ -1846,7 +1843,7 @@ class LiteLLMBase(ABC): delta = resp.choices[0].delta if not hasattr(delta, "content") or delta.content is None: continue - tol = self.total_token_count(resp) + tol = total_token_count_from_response(resp) if not tol: total_tokens += num_tokens_from_string(delta.content) else: @@ -1880,17 +1877,6 @@ class LiteLLMBase(ABC): yield total_tokens - def total_token_count(self, resp): - try: - return resp.usage.total_tokens - except Exception: - pass - try: - return resp["usage"]["total_tokens"] - except Exception: - pass - return 0 - def _calculate_dynamic_ctx(self, history): """Calculate dynamic context window size"""