diff --git a/agent/component/base.py b/agent/component/base.py index 4e9122dae..4138ba9b4 100644 --- a/agent/component/base.py +++ b/agent/component/base.py @@ -463,6 +463,8 @@ class ComponentBase(ABC): if len(self._canvas.path) > 1: reversed_cpnts.extend(self._canvas.path[-2]) reversed_cpnts.extend(self._canvas.path[-1]) + up_cpns = self.get_upstream() + reversed_up_cpnts = [cpn for cpn in reversed_cpnts if cpn in up_cpns] if self._param.query: self._param.inputs = [] @@ -505,7 +507,7 @@ class ComponentBase(ABC): upstream_outs = [] - for u in reversed_cpnts[::-1]: + for u in reversed_up_cpnts[::-1]: if self.get_component_name(u) in ["switch", "concentrator"]: continue if self.component_name.lower() == "generate" and self.get_component_name(u) == "retrieval": @@ -565,8 +567,10 @@ class ComponentBase(ABC): if len(self._canvas.path) > 1: reversed_cpnts.extend(self._canvas.path[-2]) reversed_cpnts.extend(self._canvas.path[-1]) + up_cpns = self.get_upstream() + reversed_up_cpnts = [cpn for cpn in reversed_cpnts if cpn in up_cpns] - for u in reversed_cpnts[::-1]: + for u in reversed_up_cpnts[::-1]: if self.get_component_name(u) in ["switch", "answer"]: continue return self._canvas.get_component(u)["obj"].output()[1] @@ -584,3 +588,7 @@ class ComponentBase(ABC): def get_parent(self): pid = self._canvas.get_component(self._id)["parent_id"] return self._canvas.get_component(pid)["obj"] + + def get_upstream(self): + cpn_nms = self._canvas.get_component(self._id)['upstream'] + return cpn_nms