Fix errors detected by Ruff (#3918)

### What problem does this PR solve?

Fix errors detected by Ruff

### Type of change

- [x] Refactoring
This commit is contained in:
Zhichang Yu
2024-12-08 14:21:12 +08:00
committed by GitHub
parent e267a026f3
commit 0d68a6cd1b
97 changed files with 2558 additions and 1976 deletions

View File

@ -51,11 +51,16 @@ class GenerateParam(ComponentParamBase):
def gen_conf(self):
conf = {}
if self.max_tokens > 0: conf["max_tokens"] = self.max_tokens
if self.temperature > 0: conf["temperature"] = self.temperature
if self.top_p > 0: conf["top_p"] = self.top_p
if self.presence_penalty > 0: conf["presence_penalty"] = self.presence_penalty
if self.frequency_penalty > 0: conf["frequency_penalty"] = self.frequency_penalty
if self.max_tokens > 0:
conf["max_tokens"] = self.max_tokens
if self.temperature > 0:
conf["temperature"] = self.temperature
if self.top_p > 0:
conf["top_p"] = self.top_p
if self.presence_penalty > 0:
conf["presence_penalty"] = self.presence_penalty
if self.frequency_penalty > 0:
conf["frequency_penalty"] = self.frequency_penalty
return conf
@ -83,7 +88,8 @@ class Generate(ComponentBase):
recall_docs = []
for i in idx:
did = retrieval_res.loc[int(i), "doc_id"]
if did in doc_ids: continue
if did in doc_ids:
continue
doc_ids.add(did)
recall_docs.append({"doc_id": did, "doc_name": retrieval_res.loc[int(i), "docnm_kwd"]})
@ -108,7 +114,8 @@ class Generate(ComponentBase):
retrieval_res = []
self._param.inputs = []
for para in self._param.parameters:
if not para.get("component_id"): continue
if not para.get("component_id"):
continue
component_id = para["component_id"].split("@")[0]
if para["component_id"].lower().find("@") >= 0:
cpn_id, key = para["component_id"].split("@")
@ -142,7 +149,8 @@ class Generate(ComponentBase):
if retrieval_res:
retrieval_res = pd.concat(retrieval_res, ignore_index=True)
else: retrieval_res = pd.DataFrame([])
else:
retrieval_res = pd.DataFrame([])
for n, v in kwargs.items():
prompt = re.sub(r"\{%s\}" % re.escape(n), str(v).replace("\\", " "), prompt)
@ -164,9 +172,11 @@ class Generate(ComponentBase):
return pd.DataFrame([res])
msg = self._canvas.get_history(self._param.message_history_window_size)
if len(msg) < 1: msg.append({"role": "user", "content": ""})
if len(msg) < 1:
msg.append({"role": "user", "content": ""})
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(chat_mdl.max_length * 0.97))
if len(msg) < 2: msg.append({"role": "user", "content": ""})
if len(msg) < 2:
msg.append({"role": "user", "content": ""})
ans = chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf())
if self._param.cite and "content_ltks" in retrieval_res.columns and "vector" in retrieval_res.columns:
@ -185,9 +195,11 @@ class Generate(ComponentBase):
return
msg = self._canvas.get_history(self._param.message_history_window_size)
if len(msg) < 1: msg.append({"role": "user", "content": ""})
if len(msg) < 1:
msg.append({"role": "user", "content": ""})
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(chat_mdl.max_length * 0.97))
if len(msg) < 2: msg.append({"role": "user", "content": ""})
if len(msg) < 2:
msg.append({"role": "user", "content": ""})
answer = ""
for ans in chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf()):
res = {"content": ans, "reference": []}