fix sequence2txt error and usage total token issue (#2961)

### What problem does this PR solve?

#1363

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
Kevin Hu
2024-10-22 11:38:37 +08:00
committed by GitHub
parent 6a4858a7ee
commit b2524eec49
5 changed files with 16 additions and 11 deletions

View File

@ -67,14 +67,16 @@ class Base(ABC):
if not resp.choices[0].delta.content:
resp.choices[0].delta.content = ""
ans += resp.choices[0].delta.content
total_tokens = (
(
total_tokens
+ num_tokens_from_string(resp.choices[0].delta.content)
)
if not hasattr(resp, "usage") or not resp.usage
else resp.usage.get("total_tokens", total_tokens)
)
total_tokens += 1
if not hasattr(resp, "usage") or not resp.usage:
total_tokens = (
total_tokens
+ num_tokens_from_string(resp.choices[0].delta.content)
)
elif isinstance(resp.usage, dict):
total_tokens = resp.usage.get("total_tokens", total_tokens)
else: total_tokens = resp.usage.total_tokens
if resp.choices[0].finish_reason == "length":
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"

View File

@ -87,7 +87,7 @@ class AzureSeq2txt(Base):
class XinferenceSeq2txt(Base):
def __init__(self,key,model_name="whisper-small",**kwargs):
def __init__(self, key, model_name="whisper-small", **kwargs):
self.base_url = kwargs.get('base_url', None)
self.model_name = model_name
self.key = key