diff --git a/api/db/services/llm_service.py b/api/db/services/llm_service.py index ded4f7f3f..c368807f9 100644 --- a/api/db/services/llm_service.py +++ b/api/db/services/llm_service.py @@ -159,6 +159,12 @@ class TenantLLMService(CommonService): @classmethod @DB.connection_context() def increase_usage(cls, tenant_id, llm_type, used_tokens, llm_name=None): + try: + if not DB.is_connection_usable(): + DB.connect() + except Exception: + DB.close() + DB.connect() e, tenant = TenantService.get_by_id(tenant_id) if not e: logging.error(f"Tenant not found: {tenant_id}") @@ -356,21 +362,22 @@ class LLMBundle: ans = "" chat_streamly = self.mdl.chat_streamly - + total_tokens = 0 if self.is_tools and self.mdl.is_tools: chat_streamly = self.mdl.chat_streamly_with_tools for txt in chat_streamly(system, history, gen_conf): if isinstance(txt, int): + total_tokens = txt if self.langfuse: generation.end(output={"output": ans}) - - if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, txt, self.llm_name): - logging.error("LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, content: {}".format(self.tenant_id, self.llm_name, txt)) - return ans + break if txt.endswith(""): ans = ans.rstrip("") ans += txt yield ans + if total_tokens > 0: + if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, txt, self.llm_name): + logging.error("LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, content: {}".format(self.tenant_id, self.llm_name, txt))