From b3ebc66b131a29b027016bd390954f6c56997c14 Mon Sep 17 00:00:00 2001 From: KevinHuSh Date: Mon, 8 Jul 2024 09:32:44 +0800 Subject: [PATCH] be more specific for error message (#1409) ### What problem does this PR solve? #918 ### Type of change - [x] Refactoring --- api/apps/canvas_app.py | 8 ++-- graph/canvas.py | 46 ++++++++++++++++-- graph/component/base.py | 21 +++++++-- graph/component/categorize.py | 5 +- graph/component/generate.py | 89 ++++++++++++++++++----------------- graph/component/message.py | 2 +- graph/component/relevant.py | 2 + graph/component/retrieval.py | 8 ++-- graph/component/switch.py | 6 ++- 9 files changed, 126 insertions(+), 61 deletions(-) diff --git a/api/apps/canvas_app.py b/api/apps/canvas_app.py index 31ac4e753..1f30ab06d 100644 --- a/api/apps/canvas_app.py +++ b/api/apps/canvas_app.py @@ -95,14 +95,16 @@ def run(): final_ans = {"reference": [], "content": ""} try: canvas = Canvas(cvs.dsl, current_user.id) - print(canvas) if "message" in req: canvas.messages.append({"role": "user", "content": req["message"]}) canvas.add_user_input(req["message"]) answer = canvas.run(stream=stream) + print(canvas) except Exception as e: return server_error_response(e) + assert answer, "Nothing. Is it over?" + if stream: assert isinstance(answer, partial) @@ -116,7 +118,7 @@ def run(): yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": ans}, ensure_ascii=False) + "\n\n" canvas.messages.append({"role": "assistant", "content": final_ans["content"]}) - if "reference" in final_ans: + if final_ans.get("reference"): canvas.reference.append(final_ans["reference"]) cvs.dsl = json.loads(str(canvas)) UserCanvasService.update_by_id(req["id"], cvs.to_dict()) @@ -134,7 +136,7 @@ def run(): return resp canvas.messages.append({"role": "assistant", "content": final_ans["content"]}) - if "reference" in final_ans: + if final_ans.get("reference"): canvas.reference.append(final_ans["reference"]) cvs.dsl = json.loads(str(canvas)) UserCanvasService.update_by_id(req["id"], cvs.to_dict()) diff --git a/graph/canvas.py b/graph/canvas.py index 34c1332d7..7309680b9 100644 --- a/graph/canvas.py +++ b/graph/canvas.py @@ -121,7 +121,6 @@ class Canvas(ABC): if desc["to"] not in cpn["downstream"]: cpn["downstream"].append(desc["to"]) - self.path = self.dsl["path"] self.history = self.dsl["history"] self.messages = self.dsl["messages"] @@ -136,9 +135,21 @@ class Canvas(ABC): self.dsl["answer"] = self.answer self.dsl["reference"] = self.reference self.dsl["embed_id"] = self._embed_id - dsl = deepcopy(self.dsl) + dsl = { + "components": {} + } + for k in self.dsl.keys(): + if k in ["components"]:continue + dsl[k] = deepcopy(self.dsl[k]) + for k, cpn in self.components.items(): - dsl["components"][k]["obj"] = json.loads(str(cpn["obj"])) + if k not in dsl["components"]: + dsl["components"][k] = {} + for c in cpn.keys(): + if c == "obj": + dsl["components"][k][c] = json.loads(str(cpn["obj"])) + continue + dsl["components"][k][c] = deepcopy(cpn[c]) return json.dumps(dsl, ensure_ascii=False) def reset(self): @@ -161,6 +172,9 @@ class Canvas(ABC): except Exception as e: ans = ComponentBase.be_output(str(e)) self.path[-1].append(cpn_id) + if kwargs.get("stream"): + assert isinstance(ans, partial) + return ans self.history.append(("assistant", ans.to_dict("records"))) return ans @@ -190,6 +204,8 @@ class Canvas(ABC): cpn = self.get_component(cpn_id) if not cpn["downstream"]: break + if self._find_loop(): raise OverflowError("Too much loops!") + if cpn["obj"].component_name.lower() in ["switch", "categorize", "relevant"]: switch_out = cpn["obj"].output()[1].iloc[0, 0] assert switch_out in self.components, \ @@ -249,3 +265,27 @@ class Canvas(ABC): def get_embedding_model(self): return self._embed_id + + def _find_loop(self, max_loops=2): + path = self.path[-1][::-1] + if len(path) < 2: return False + + for i in range(len(path)): + if path[i].lower().find("answer") >= 0: + path = path[:i] + break + + if len(path) < 2: return False + + for l in range(1, len(path) // 2): + pat = ",".join(path[0:l]) + path_str = ",".join(path) + if len(pat) >= len(path_str): return False + path_str = path_str[len(pat):] + loop = max_loops + while path_str.find(pat) >= 0 and loop >= 0: + loop -= 1 + path_str = path_str[len(pat):] + if loop < 0: return True + + return False diff --git a/graph/component/base.py b/graph/component/base.py index f758f0192..13364fde7 100644 --- a/graph/component/base.py +++ b/graph/component/base.py @@ -19,7 +19,7 @@ import json import os from copy import deepcopy from functools import partial -from typing import List, Dict +from typing import List, Dict, Tuple, Union import pandas as pd @@ -246,7 +246,7 @@ class ComponentParamBase(ABC): def check_empty(param, descr): if not param: raise ValueError( - descr + " {} not supported empty value." + descr + " does not support empty value." ) @staticmethod @@ -411,12 +411,23 @@ class ComponentBase(ABC): def _run(self, history, **kwargs): raise NotImplementedError() - def output(self) -> pd.DataFrame: + def output(self, allow_partial=True) -> Tuple[str, Union[pd.DataFrame, partial]]: o = getattr(self._param, self._param.output_var_name) if not isinstance(o, partial) and not isinstance(o, pd.DataFrame): if not isinstance(o, list): o = [o] o = pd.DataFrame(o) - return self._param.output_var_name, o + + if allow_partial or not isinstance(o, partial): + if not isinstance(o, partial) and not isinstance(o, pd.DataFrame): + return pd.DataFrame(o if isinstance(o, list) else [o]) + return self._param.output_var_name, o + + outs = None + for oo in o(): + if not isinstance(oo, pd.DataFrame): + outs = pd.DataFrame(oo if isinstance(oo, list) else [oo]) + else: outs = oo + return self._param.output_var_name, outs def reset(self): setattr(self._param, self._param.output_var_name, None) @@ -446,7 +457,7 @@ class ComponentBase(ABC): if self.component_name.lower().find("answer") >= 0: if self.get_component_name(u) in ["relevant"]: continue - upstream_outs.append(self._canvas.get_component(u)["obj"].output()[1]) + else: upstream_outs.append(self._canvas.get_component(u)["obj"].output(allow_partial=False)[1]) break return pd.concat(upstream_outs, ignore_index=False) diff --git a/graph/component/categorize.py b/graph/component/categorize.py index 8a7f90e70..a5c530e94 100644 --- a/graph/component/categorize.py +++ b/graph/component/categorize.py @@ -35,7 +35,10 @@ class CategorizeParam(GenerateParam): def check(self): super().check() - self.check_empty(self.category_description, "Category examples") + self.check_empty(self.category_description, "[Categorize] Category examples") + for k, v in self.category_description.items(): + if not k: raise ValueError(f"[Categorize] Category name can not be empty!") + if not v["to"]: raise ValueError(f"[Categorize] 'To' of category {k} can not be empty!") def get_prompt(self): cate_lines = [] diff --git a/graph/component/generate.py b/graph/component/generate.py index b49122758..81d4d36bf 100644 --- a/graph/component/generate.py +++ b/graph/component/generate.py @@ -33,31 +33,31 @@ class GenerateParam(ComponentParamBase): super().__init__() self.llm_id = "" self.prompt = "" - self.max_tokens = 256 - self.temperature = 0.1 - self.top_p = 0.3 - self.presence_penalty = 0.4 - self.frequency_penalty = 0.7 + self.max_tokens = 0 + self.temperature = 0 + self.top_p = 0 + self.presence_penalty = 0 + self.frequency_penalty = 0 self.cite = True - #self.parameters = [] + self.parameters = [] def check(self): - self.check_decimal_float(self.temperature, "Temperature") - self.check_decimal_float(self.presence_penalty, "Presence penalty") - self.check_decimal_float(self.frequency_penalty, "Frequency penalty") - self.check_positive_number(self.max_tokens, "Max tokens") - self.check_decimal_float(self.top_p, "Top P") - self.check_empty(self.llm_id, "LLM") - #self.check_defined_type(self.parameters, "Parameters", ["list"]) + self.check_decimal_float(self.temperature, "[Generate] Temperature") + self.check_decimal_float(self.presence_penalty, "[Generate] Presence penalty") + self.check_decimal_float(self.frequency_penalty, "[Generate] Frequency penalty") + self.check_nonnegative_number(self.max_tokens, "[Generate] Max tokens") + self.check_decimal_float(self.top_p, "[Generate] Top P") + self.check_empty(self.llm_id, "[Generate] LLM") + # self.check_defined_type(self.parameters, "Parameters", ["list"]) def gen_conf(self): - return { - "max_tokens": self.max_tokens, - "temperature": self.temperature, - "top_p": self.top_p, - "presence_penalty": self.presence_penalty, - "frequency_penalty": self.frequency_penalty, - } + conf = {} + if self.max_tokens > 0: conf["max_tokens"] = self.max_tokens + if self.temperature > 0: conf["temperature"] = self.temperature + if self.top_p > 0: conf["top_p"] = self.top_p + if self.presence_penalty > 0: conf["presence_penalty"] = self.presence_penalty + if self.frequency_penalty > 0: conf["frequency_penalty"] = self.frequency_penalty + return conf class Generate(ComponentBase): @@ -69,12 +69,15 @@ class Generate(ComponentBase): retrieval_res = self.get_input() input = "\n- ".join(retrieval_res["content"]) - + for para in self._param.parameters: + cpn = self._canvas.get_component(para["component_id"])["obj"] + _, out = cpn.output(allow_partial=False) + kwargs[para["key"]] = "\n - ".join(out["content"]) kwargs["input"] = input for n, v in kwargs.items(): - #prompt = re.sub(r"\{%s\}"%n, re.escape(str(v)), prompt) - prompt = re.sub(r"\{%s\}"%n, str(v), prompt) + # prompt = re.sub(r"\{%s\}"%n, re.escape(str(v)), prompt) + prompt = re.sub(r"\{%s\}" % n, str(v), prompt) if kwargs.get("stream"): return partial(self.stream_output, chat_mdl, prompt, retrieval_res) @@ -82,23 +85,25 @@ class Generate(ComponentBase): if "empty_response" in retrieval_res.columns: return Generate.be_output(input) - ans = chat_mdl.chat(prompt, self._canvas.get_history(self._param.message_history_window_size), self._param.gen_conf()) + ans = chat_mdl.chat(prompt, self._canvas.get_history(self._param.message_history_window_size), + self._param.gen_conf()) if self._param.cite and "content_ltks" in retrieval_res.columns and "vector" in retrieval_res.columns: ans, idx = retrievaler.insert_citations(ans, - [ck["content_ltks"] - for _, ck in retrieval_res.iterrows()], - [ck["vector"] - for _,ck in retrieval_res.iterrows()], - LLMBundle(self._canvas.get_tenant_id(), LLMType.EMBEDDING, self._canvas.get_embedding_model()), - tkweight=0.7, - vtweight=0.3) + [ck["content_ltks"] + for _, ck in retrieval_res.iterrows()], + [ck["vector"] + for _, ck in retrieval_res.iterrows()], + LLMBundle(self._canvas.get_tenant_id(), LLMType.EMBEDDING, + self._canvas.get_embedding_model()), + tkweight=0.7, + vtweight=0.3) del retrieval_res["vector"] retrieval_res = retrieval_res.to_dict("records") df = [] for i in idx: df.append(retrieval_res[int(i)]) - r = re.search(r"^((.|[\r\n])*? ##%s\$\$)"%str(i), ans) + r = re.search(r"^((.|[\r\n])*? ##%s\$\$)" % str(i), ans) assert r, f"{i} => {ans}" df[-1]["content"] = r.group(1) ans = re.sub(r"^((.|[\r\n])*? ##%s\$\$)" % str(i), "", ans) @@ -116,20 +121,22 @@ class Generate(ComponentBase): return answer = "" - for ans in chat_mdl.chat_streamly(prompt, self._canvas.get_history(self._param.message_history_window_size), self._param.gen_conf()): + for ans in chat_mdl.chat_streamly(prompt, self._canvas.get_history(self._param.message_history_window_size), + self._param.gen_conf()): res = {"content": ans, "reference": []} answer = ans yield res if self._param.cite and "content_ltks" in retrieval_res.columns and "vector" in retrieval_res.columns: answer, idx = retrievaler.insert_citations(answer, - [ck["content_ltks"] - for _, ck in retrieval_res.iterrows()], - [ck["vector"] - for _, ck in retrieval_res.iterrows()], - LLMBundle(self._canvas.get_tenant_id(), LLMType.EMBEDDING, self._canvas.get_embedding_model()), - tkweight=0.7, - vtweight=0.3) + [ck["content_ltks"] + for _, ck in retrieval_res.iterrows()], + [ck["vector"] + for _, ck in retrieval_res.iterrows()], + LLMBundle(self._canvas.get_tenant_id(), LLMType.EMBEDDING, + self._canvas.get_embedding_model()), + tkweight=0.7, + vtweight=0.3) doc_ids = set([]) recall_docs = [] for i in idx: @@ -152,5 +159,3 @@ class Generate(ComponentBase): yield res self.set_output(res) - - diff --git a/graph/component/message.py b/graph/component/message.py index 19c026c92..adbf4c8ad 100644 --- a/graph/component/message.py +++ b/graph/component/message.py @@ -32,7 +32,7 @@ class MessageParam(ComponentParamBase): self.messages = [] def check(self): - self.check_empty(self.messages, "Message") + self.check_empty(self.messages, "[Message]") return True diff --git a/graph/component/relevant.py b/graph/component/relevant.py index 20c6eb222..ab2fa318f 100644 --- a/graph/component/relevant.py +++ b/graph/component/relevant.py @@ -33,6 +33,8 @@ class RelevantParam(GenerateParam): def check(self): super().check() + self.check_empty(self.yes, "[Relevant] 'Yes'") + self.check_empty(self.no, "[Relevant] 'No'") def get_prompt(self): self.prompt = """ diff --git a/graph/component/retrieval.py b/graph/component/retrieval.py index 7daee8073..a765556ce 100644 --- a/graph/component/retrieval.py +++ b/graph/component/retrieval.py @@ -40,10 +40,10 @@ class RetrievalParam(ComponentParamBase): self.empty_response = "" def check(self): - self.check_decimal_float(self.similarity_threshold, "Similarity threshold") - self.check_decimal_float(self.keywords_similarity_weight, "Keywords similarity weight") - self.check_positive_number(self.top_n, "Top N") - self.check_empty(self.kb_ids, "Knowledge bases") + self.check_decimal_float(self.similarity_threshold, "[Retrieval] Similarity threshold") + self.check_decimal_float(self.keywords_similarity_weight, "[Retrieval] Keywords similarity weight") + self.check_positive_number(self.top_n, "[Retrieval] Top N") + self.check_empty(self.kb_ids, "[Retrieval] Knowledge bases") class Retrieval(ComponentBase, ABC): diff --git a/graph/component/switch.py b/graph/component/switch.py index 790d16bb4..a431143d3 100644 --- a/graph/component/switch.py +++ b/graph/component/switch.py @@ -44,8 +44,10 @@ class SwitchParam(ComponentParamBase): self.default = "" def check(self): - self.check_empty(self.conditions, "Switch conditions") - self.check_empty(self.default, "Default path") + self.check_empty(self.conditions, "[Switch] conditions") + self.check_empty(self.default, "[Switch] Default path") + for cond in self.conditions: + if not cond["to"]: raise ValueError(f"[Switch] 'To' can not be empty!") def operators(self, field, op, value): if op == "gt":