Fix errors detected by Ruff (#3918)

### What problem does this PR solve?

Fix errors detected by Ruff

### Type of change

- [x] Refactoring
This commit is contained in:
Zhichang Yu
2024-12-08 14:21:12 +08:00
committed by GitHub
parent e267a026f3
commit 0d68a6cd1b
97 changed files with 2558 additions and 1976 deletions

View File

@ -133,7 +133,8 @@ class Canvas(ABC):
"components": {}
}
for k in self.dsl.keys():
if k in ["components"]:continue
if k in ["components"]:
continue
dsl[k] = deepcopy(self.dsl[k])
for k, cpn in self.components.items():
@ -158,7 +159,8 @@ class Canvas(ABC):
def get_compnent_name(self, cid):
for n in self.dsl["graph"]["nodes"]:
if cid == n["id"]: return n["data"]["name"]
if cid == n["id"]:
return n["data"]["name"]
return ""
def run(self, **kwargs):
@ -173,7 +175,8 @@ class Canvas(ABC):
if kwargs.get("stream"):
for an in ans():
yield an
else: yield ans
else:
yield ans
return
if not self.path:
@ -188,7 +191,8 @@ class Canvas(ABC):
def prepare2run(cpns):
nonlocal ran, ans
for c in cpns:
if self.path[-1] and c == self.path[-1][-1]: continue
if self.path[-1] and c == self.path[-1][-1]:
continue
cpn = self.components[c]["obj"]
if cpn.component_name == "Answer":
self.answer.append(c)
@ -197,7 +201,8 @@ class Canvas(ABC):
if c not in without_dependent_checking:
cpids = cpn.get_dependent_components()
if any([cc not in self.path[-1] for cc in cpids]):
if c not in waiting: waiting.append(c)
if c not in waiting:
waiting.append(c)
continue
yield "*'{}'* is running...🕞".format(self.get_compnent_name(c))
ans = cpn.run(self.history, **kwargs)
@ -211,10 +216,12 @@ class Canvas(ABC):
logging.debug(f"Canvas.run: {ran} {self.path}")
cpn_id = self.path[-1][ran]
cpn = self.get_component(cpn_id)
if not cpn["downstream"]: break
if not cpn["downstream"]:
break
loop = self._find_loop()
if loop: raise OverflowError(f"Too much loops: {loop}")
if loop:
raise OverflowError(f"Too much loops: {loop}")
if cpn["obj"].component_name.lower() in ["switch", "categorize", "relevant"]:
switch_out = cpn["obj"].output()[1].iloc[0, 0]
@ -283,19 +290,22 @@ class Canvas(ABC):
def _find_loop(self, max_loops=6):
path = self.path[-1][::-1]
if len(path) < 2: return False
if len(path) < 2:
return False
for i in range(len(path)):
if path[i].lower().find("answer") >= 0:
path = path[:i]
break
if len(path) < 2: return False
if len(path) < 2:
return False
for l in range(2, len(path) // 2):
pat = ",".join(path[0:l])
for loc in range(2, len(path) // 2):
pat = ",".join(path[0:loc])
path_str = ",".join(path)
if len(pat) >= len(path_str): return False
if len(pat) >= len(path_str):
return False
loop = max_loops
while path_str.find(pat) == 0 and loop >= 0:
loop -= 1
@ -303,7 +313,7 @@ class Canvas(ABC):
return False
path_str = path_str[len(pat)+1:]
if loop < 0:
pat = " => ".join([p.split(":")[0] for p in path[0:l]])
pat = " => ".join([p.split(":")[0] for p in path[0:loc]])
return pat + " => " + pat
return False

View File

@ -39,3 +39,73 @@ def component_class(class_name):
m = importlib.import_module("agent.component")
c = getattr(m, class_name)
return c
__all__ = [
"Begin",
"BeginParam",
"Generate",
"GenerateParam",
"Retrieval",
"RetrievalParam",
"Answer",
"AnswerParam",
"Categorize",
"CategorizeParam",
"Switch",
"SwitchParam",
"Relevant",
"RelevantParam",
"Message",
"MessageParam",
"RewriteQuestion",
"RewriteQuestionParam",
"KeywordExtract",
"KeywordExtractParam",
"Concentrator",
"ConcentratorParam",
"Baidu",
"BaiduParam",
"DuckDuckGo",
"DuckDuckGoParam",
"Wikipedia",
"WikipediaParam",
"PubMed",
"PubMedParam",
"ArXiv",
"ArXivParam",
"Google",
"GoogleParam",
"Bing",
"BingParam",
"GoogleScholar",
"GoogleScholarParam",
"DeepL",
"DeepLParam",
"GitHub",
"GitHubParam",
"BaiduFanyi",
"BaiduFanyiParam",
"QWeather",
"QWeatherParam",
"ExeSQL",
"ExeSQLParam",
"YahooFinance",
"YahooFinanceParam",
"WenCai",
"WenCaiParam",
"Jin10",
"Jin10Param",
"TuShare",
"TuShareParam",
"AkShare",
"AkShareParam",
"Crawler",
"CrawlerParam",
"Invoke",
"InvokeParam",
"Template",
"TemplateParam",
"Email",
"EmailParam",
"component_class"
]

View File

@ -428,7 +428,8 @@ class ComponentBase(ABC):
def output(self, allow_partial=True) -> Tuple[str, Union[pd.DataFrame, partial]]:
o = getattr(self._param, self._param.output_var_name)
if not isinstance(o, partial) and not isinstance(o, pd.DataFrame):
if not isinstance(o, list): o = [o]
if not isinstance(o, list):
o = [o]
o = pd.DataFrame(o)
if allow_partial or not isinstance(o, partial):
@ -440,7 +441,8 @@ class ComponentBase(ABC):
for oo in o():
if not isinstance(oo, pd.DataFrame):
outs = pd.DataFrame(oo if isinstance(oo, list) else [oo])
else: outs = oo
else:
outs = oo
return self._param.output_var_name, outs
def reset(self):
@ -482,13 +484,15 @@ class ComponentBase(ABC):
outs.append(pd.DataFrame([{"content": q["value"]}]))
if outs:
df = pd.concat(outs, ignore_index=True)
if "content" in df: df = df.drop_duplicates(subset=['content']).reset_index(drop=True)
if "content" in df:
df = df.drop_duplicates(subset=['content']).reset_index(drop=True)
return df
upstream_outs = []
for u in reversed_cpnts[::-1]:
if self.get_component_name(u) in ["switch", "concentrator"]: continue
if self.get_component_name(u) in ["switch", "concentrator"]:
continue
if self.component_name.lower() == "generate" and self.get_component_name(u) == "retrieval":
o = self._canvas.get_component(u)["obj"].output(allow_partial=False)[1]
if o is not None:
@ -532,7 +536,8 @@ class ComponentBase(ABC):
reversed_cpnts.extend(self._canvas.path[-1])
for u in reversed_cpnts[::-1]:
if self.get_component_name(u) in ["switch", "answer"]: continue
if self.get_component_name(u) in ["switch", "answer"]:
continue
return self._canvas.get_component(u)["obj"].output()[1]
@staticmethod

View File

@ -34,15 +34,18 @@ class CategorizeParam(GenerateParam):
super().check()
self.check_empty(self.category_description, "[Categorize] Category examples")
for k, v in self.category_description.items():
if not k: raise ValueError("[Categorize] Category name can not be empty!")
if not v.get("to"): raise ValueError(f"[Categorize] 'To' of category {k} can not be empty!")
if not k:
raise ValueError("[Categorize] Category name can not be empty!")
if not v.get("to"):
raise ValueError(f"[Categorize] 'To' of category {k} can not be empty!")
def get_prompt(self):
cate_lines = []
for c, desc in self.category_description.items():
for l in desc.get("examples", "").split("\n"):
if not l: continue
cate_lines.append("Question: {}\tCategory: {}".format(l, c))
for line in desc.get("examples", "").split("\n"):
if not line:
continue
cate_lines.append("Question: {}\tCategory: {}".format(line, c))
descriptions = []
for c, desc in self.category_description.items():
if desc.get("description"):

View File

@ -14,7 +14,6 @@
# limitations under the License.
#
from abc import ABC
import re
from agent.component.base import ComponentBase, ComponentParamBase
import deepl

View File

@ -46,8 +46,10 @@ class ExeSQLParam(ComponentParamBase):
self.check_empty(self.password, "Database password")
self.check_positive_integer(self.top_n, "Number of records")
if self.database == "rag_flow":
if self.host == "ragflow-mysql": raise ValueError("The host is not accessible.")
if self.password == "infini_rag_flow": raise ValueError("The host is not accessible.")
if self.host == "ragflow-mysql":
raise ValueError("The host is not accessible.")
if self.password == "infini_rag_flow":
raise ValueError("The host is not accessible.")
class ExeSQL(ComponentBase, ABC):

View File

@ -51,11 +51,16 @@ class GenerateParam(ComponentParamBase):
def gen_conf(self):
conf = {}
if self.max_tokens > 0: conf["max_tokens"] = self.max_tokens
if self.temperature > 0: conf["temperature"] = self.temperature
if self.top_p > 0: conf["top_p"] = self.top_p
if self.presence_penalty > 0: conf["presence_penalty"] = self.presence_penalty
if self.frequency_penalty > 0: conf["frequency_penalty"] = self.frequency_penalty
if self.max_tokens > 0:
conf["max_tokens"] = self.max_tokens
if self.temperature > 0:
conf["temperature"] = self.temperature
if self.top_p > 0:
conf["top_p"] = self.top_p
if self.presence_penalty > 0:
conf["presence_penalty"] = self.presence_penalty
if self.frequency_penalty > 0:
conf["frequency_penalty"] = self.frequency_penalty
return conf
@ -83,7 +88,8 @@ class Generate(ComponentBase):
recall_docs = []
for i in idx:
did = retrieval_res.loc[int(i), "doc_id"]
if did in doc_ids: continue
if did in doc_ids:
continue
doc_ids.add(did)
recall_docs.append({"doc_id": did, "doc_name": retrieval_res.loc[int(i), "docnm_kwd"]})
@ -108,7 +114,8 @@ class Generate(ComponentBase):
retrieval_res = []
self._param.inputs = []
for para in self._param.parameters:
if not para.get("component_id"): continue
if not para.get("component_id"):
continue
component_id = para["component_id"].split("@")[0]
if para["component_id"].lower().find("@") >= 0:
cpn_id, key = para["component_id"].split("@")
@ -142,7 +149,8 @@ class Generate(ComponentBase):
if retrieval_res:
retrieval_res = pd.concat(retrieval_res, ignore_index=True)
else: retrieval_res = pd.DataFrame([])
else:
retrieval_res = pd.DataFrame([])
for n, v in kwargs.items():
prompt = re.sub(r"\{%s\}" % re.escape(n), str(v).replace("\\", " "), prompt)
@ -164,9 +172,11 @@ class Generate(ComponentBase):
return pd.DataFrame([res])
msg = self._canvas.get_history(self._param.message_history_window_size)
if len(msg) < 1: msg.append({"role": "user", "content": ""})
if len(msg) < 1:
msg.append({"role": "user", "content": ""})
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(chat_mdl.max_length * 0.97))
if len(msg) < 2: msg.append({"role": "user", "content": ""})
if len(msg) < 2:
msg.append({"role": "user", "content": ""})
ans = chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf())
if self._param.cite and "content_ltks" in retrieval_res.columns and "vector" in retrieval_res.columns:
@ -185,9 +195,11 @@ class Generate(ComponentBase):
return
msg = self._canvas.get_history(self._param.message_history_window_size)
if len(msg) < 1: msg.append({"role": "user", "content": ""})
if len(msg) < 1:
msg.append({"role": "user", "content": ""})
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(chat_mdl.max_length * 0.97))
if len(msg) < 2: msg.append({"role": "user", "content": ""})
if len(msg) < 2:
msg.append({"role": "user", "content": ""})
answer = ""
for ans in chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf()):
res = {"content": ans, "reference": []}

View File

@ -95,7 +95,8 @@ class RewriteQuestion(Generate, ABC):
hist = self._canvas.get_history(4)
conv = []
for m in hist:
if m["role"] not in ["user", "assistant"]: continue
if m["role"] not in ["user", "assistant"]:
continue
conv.append("{}: {}".format(m["role"].upper(), m["content"]))
conv = "\n".join(conv)

View File

@ -41,7 +41,8 @@ class SwitchParam(ComponentParamBase):
def check(self):
self.check_empty(self.conditions, "[Switch] conditions")
for cond in self.conditions:
if not cond["to"]: raise ValueError(f"[Switch] 'To' can not be empty!")
if not cond["to"]:
raise ValueError("[Switch] 'To' can not be empty!")
class Switch(ComponentBase, ABC):
@ -51,7 +52,8 @@ class Switch(ComponentBase, ABC):
res = []
for cond in self._param.conditions:
for item in cond["items"]:
if not item["cpn_id"]: continue
if not item["cpn_id"]:
continue
if item["cpn_id"].find("begin") >= 0:
continue
cid = item["cpn_id"].split("@")[0]
@ -63,7 +65,8 @@ class Switch(ComponentBase, ABC):
for cond in self._param.conditions:
res = []
for item in cond["items"]:
if not item["cpn_id"]:continue
if not item["cpn_id"]:
continue
cid = item["cpn_id"].split("@")[0]
if item["cpn_id"].find("@") > 0:
cpn_id, key = item["cpn_id"].split("@")
@ -107,22 +110,22 @@ class Switch(ComponentBase, ABC):
elif operator == ">":
try:
return True if float(input) > float(value) else False
except Exception as e:
except Exception:
return True if input > value else False
elif operator == "<":
try:
return True if float(input) < float(value) else False
except Exception as e:
except Exception:
return True if input < value else False
elif operator == "":
try:
return True if float(input) >= float(value) else False
except Exception as e:
except Exception:
return True if input >= value else False
elif operator == "":
try:
return True if float(input) <= float(value) else False
except Exception as e:
except Exception:
return True if input <= value else False
raise ValueError('Not supported operator' + operator)

View File

@ -47,7 +47,8 @@ class Template(ComponentBase):
self._param.inputs = []
for para in self._param.parameters:
if not para.get("component_id"): continue
if not para.get("component_id"):
continue
component_id = para["component_id"].split("@")[0]
if para["component_id"].lower().find("@") >= 0:
cpn_id, key = para["component_id"].split("@")

View File

@ -43,6 +43,7 @@ if __name__ == '__main__':
else:
print(ans["content"])
if DEBUG: print(canvas.path)
if DEBUG:
print(canvas.path)
question = input("\n==================== User =====================\n> ")
canvas.add_user_input(question)