From 2a647162a8c335c0a54b378321236764a89bd2e8 Mon Sep 17 00:00:00 2001 From: Kevin Hu Date: Tue, 16 Jul 2024 09:28:13 +0800 Subject: [PATCH] fix bugs about multi input for generate (#1525) ### What problem does this PR solve? ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- graph/canvas.py | 6 +++++- graph/component/base.py | 1 + graph/component/generate.py | 4 ++++ 3 files changed, 10 insertions(+), 1 deletion(-) diff --git a/graph/canvas.py b/graph/canvas.py index 530bb6fbe..9b8f4c62c 100644 --- a/graph/canvas.py +++ b/graph/canvas.py @@ -193,9 +193,13 @@ class Canvas(ABC): self.answer.append(c) else: if DEBUG: print("RUN: ", c) + if cpn.component_name == "Generate": + cpids = cpn.get_dependent_components() + if any([c not in self.path[-1] for c in cpids]): + continue ans = cpn.run(self.history, **kwargs) self.path[-1].append(c) - ran += 1 + ran += 1 prepare2run(self.components[self.path[-2][-1]]["downstream"]) while 0 <= ran < len(self.path[-1]): diff --git a/graph/component/base.py b/graph/component/base.py index 7c0b42e1e..82e2f18c7 100644 --- a/graph/component/base.py +++ b/graph/component/base.py @@ -445,6 +445,7 @@ class ComponentBase(ABC): if DEBUG: print(self.component_name, reversed_cpnts[::-1]) for u in reversed_cpnts[::-1]: if self.get_component_name(u) in ["switch"]: continue + if u not in self._canvas.get_component(self._id)["upstream"]: continue if self.component_name.lower().find("switch") < 0 \ and self.get_component_name(u) in ["relevant", "categorize"]: continue diff --git a/graph/component/generate.py b/graph/component/generate.py index 8b6525112..693bdba7a 100644 --- a/graph/component/generate.py +++ b/graph/component/generate.py @@ -63,6 +63,10 @@ class GenerateParam(ComponentParamBase): class Generate(ComponentBase): component_name = "Generate" + def get_dependent_components(self): + cpnts = [para["component_id"] for para in self._param.parameters] + return cpnts + def _run(self, history, **kwargs): chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id) prompt = self._param.prompt