Refactor embedding batch_size (#3825)

### What problem does this PR solve?

Refactor embedding batch_size. Close #3657

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
- [x] Refactoring
This commit is contained in:
Zhichang Yu
2024-12-03 16:22:39 +08:00
committed by GitHub
parent 934dbc2e2b
commit 92ab7ef659
3 changed files with 160 additions and 109 deletions

View File

@ -232,13 +232,13 @@ class LLMBundle(object):
self.max_length = lm.max_tokens
break
def encode(self, texts: list, batch_size=32):
emd, used_tokens = self.mdl.encode(texts, batch_size)
def encode(self, texts: list):
embeddings, used_tokens = self.mdl.encode(texts)
if not TenantLLMService.increase_usage(
self.tenant_id, self.llm_type, used_tokens):
logging.error(
"LLMBundle.encode can't update token usage for {}/EMBEDDING used_tokens: {}".format(self.tenant_id, used_tokens))
return emd, used_tokens
return embeddings, used_tokens
def encode_queries(self, query: str):
emd, used_tokens = self.mdl.encode_queries(query)
@ -280,7 +280,7 @@ class LLMBundle(object):
logging.error(
"LLMBundle.tts can't update token usage for {}/TTS".format(self.tenant_id))
return
yield chunk
yield chunk
def chat(self, system, history, gen_conf):
txt, used_tokens = self.mdl.chat(system, history, gen_conf)