Fix errors detected by Ruff (#3918)

### What problem does this PR solve?

Fix errors detected by Ruff

### Type of change

- [x] Refactoring
This commit is contained in:
Zhichang Yu
2024-12-08 14:21:12 +08:00
committed by GitHub
parent e267a026f3
commit 0d68a6cd1b
97 changed files with 2558 additions and 1976 deletions

View File

@ -69,7 +69,8 @@ class Base(ABC):
stream=True,
**gen_conf)
for resp in response:
if not resp.choices: continue
if not resp.choices:
continue
if not resp.choices[0].delta.content:
resp.choices[0].delta.content = ""
ans += resp.choices[0].delta.content
@ -81,7 +82,8 @@ class Base(ABC):
)
elif isinstance(resp.usage, dict):
total_tokens = resp.usage.get("total_tokens", total_tokens)
else: total_tokens = resp.usage.total_tokens
else:
total_tokens = resp.usage.total_tokens
if resp.choices[0].finish_reason == "length":
if is_chinese(ans):
@ -98,13 +100,15 @@ class Base(ABC):
class GptTurbo(Base):
def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"):
if not base_url: base_url = "https://api.openai.com/v1"
if not base_url:
base_url = "https://api.openai.com/v1"
super().__init__(key, model_name, base_url)
class MoonshotChat(Base):
def __init__(self, key, model_name="moonshot-v1-8k", base_url="https://api.moonshot.cn/v1"):
if not base_url: base_url = "https://api.moonshot.cn/v1"
if not base_url:
base_url = "https://api.moonshot.cn/v1"
super().__init__(key, model_name, base_url)
@ -128,7 +132,8 @@ class HuggingFaceChat(Base):
class DeepSeekChat(Base):
def __init__(self, key, model_name="deepseek-chat", base_url="https://api.deepseek.com/v1"):
if not base_url: base_url = "https://api.deepseek.com/v1"
if not base_url:
base_url = "https://api.deepseek.com/v1"
super().__init__(key, model_name, base_url)
@ -202,7 +207,8 @@ class BaiChuanChat(Base):
stream=True,
**self._format_params(gen_conf))
for resp in response:
if not resp.choices: continue
if not resp.choices:
continue
if not resp.choices[0].delta.content:
resp.choices[0].delta.content = ""
ans += resp.choices[0].delta.content
@ -313,8 +319,10 @@ class ZhipuChat(Base):
if system:
history.insert(0, {"role": "system", "content": system})
try:
if "presence_penalty" in gen_conf: del gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf: del gen_conf["frequency_penalty"]
if "presence_penalty" in gen_conf:
del gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf:
del gen_conf["frequency_penalty"]
response = self.client.chat.completions.create(
model=self.model_name,
messages=history,
@ -333,8 +341,10 @@ class ZhipuChat(Base):
def chat_streamly(self, system, history, gen_conf):
if system:
history.insert(0, {"role": "system", "content": system})
if "presence_penalty" in gen_conf: del gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf: del gen_conf["frequency_penalty"]
if "presence_penalty" in gen_conf:
del gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf:
del gen_conf["frequency_penalty"]
ans = ""
tk_count = 0
try:
@ -345,7 +355,8 @@ class ZhipuChat(Base):
**gen_conf
)
for resp in response:
if not resp.choices[0].delta.content: continue
if not resp.choices[0].delta.content:
continue
delta = resp.choices[0].delta.content
ans += delta
if resp.choices[0].finish_reason == "length":
@ -354,7 +365,8 @@ class ZhipuChat(Base):
else:
ans += LENGTH_NOTIFICATION_EN
tk_count = resp.usage.total_tokens
if resp.choices[0].finish_reason == "stop": tk_count = resp.usage.total_tokens
if resp.choices[0].finish_reason == "stop":
tk_count = resp.usage.total_tokens
yield ans
except Exception as e:
yield ans + "\n**ERROR**: " + str(e)
@ -372,11 +384,16 @@ class OllamaChat(Base):
history.insert(0, {"role": "system", "content": system})
try:
options = {}
if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"]
if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"]
if "top_p" in gen_conf: options["top_p"] = gen_conf["top_p"]
if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"]
if "temperature" in gen_conf:
options["temperature"] = gen_conf["temperature"]
if "max_tokens" in gen_conf:
options["num_predict"] = gen_conf["max_tokens"]
if "top_p" in gen_conf:
options["top_p"] = gen_conf["top_p"]
if "presence_penalty" in gen_conf:
options["presence_penalty"] = gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf:
options["frequency_penalty"] = gen_conf["frequency_penalty"]
response = self.client.chat(
model=self.model_name,
messages=history,
@ -392,11 +409,16 @@ class OllamaChat(Base):
if system:
history.insert(0, {"role": "system", "content": system})
options = {}
if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"]
if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"]
if "top_p" in gen_conf: options["top_p"] = gen_conf["top_p"]
if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"]
if "temperature" in gen_conf:
options["temperature"] = gen_conf["temperature"]
if "max_tokens" in gen_conf:
options["num_predict"] = gen_conf["max_tokens"]
if "top_p" in gen_conf:
options["top_p"] = gen_conf["top_p"]
if "presence_penalty" in gen_conf:
options["presence_penalty"] = gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf:
options["frequency_penalty"] = gen_conf["frequency_penalty"]
ans = ""
try:
response = self.client.chat(
@ -636,7 +658,8 @@ class MistralChat(Base):
messages=history,
**gen_conf)
for resp in response:
if not resp.choices or not resp.choices[0].delta.content: continue
if not resp.choices or not resp.choices[0].delta.content:
continue
ans += resp.choices[0].delta.content
total_tokens += 1
if resp.choices[0].finish_reason == "length":
@ -1196,7 +1219,8 @@ class SparkChat(Base):
assert model_name in model2version or model_name in version2model, f"The given model name is not supported yet. Support: {list(model2version.keys())}"
if model_name in model2version:
model_version = model2version[model_name]
else: model_version = model_name
else:
model_version = model_name
super().__init__(key, model_version, base_url)
@ -1281,8 +1305,10 @@ class AnthropicChat(Base):
self.system = system
if "max_tokens" not in gen_conf:
gen_conf["max_tokens"] = 4096
if "presence_penalty" in gen_conf: del gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf: del gen_conf["frequency_penalty"]
if "presence_penalty" in gen_conf:
del gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf:
del gen_conf["frequency_penalty"]
ans = ""
try:
@ -1312,8 +1338,10 @@ class AnthropicChat(Base):
self.system = system
if "max_tokens" not in gen_conf:
gen_conf["max_tokens"] = 4096
if "presence_penalty" in gen_conf: del gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf: del gen_conf["frequency_penalty"]
if "presence_penalty" in gen_conf:
del gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf:
del gen_conf["frequency_penalty"]
ans = ""
total_tokens = 0