add input variables to begin component (#3498)

### What problem does this PR solve?

#3355 

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Kevin Hu
2024-11-19 18:41:48 +08:00
committed by GitHub
parent 0cd5b64c3b
commit 361cff34fc
7 changed files with 128 additions and 78 deletions

View File

@ -13,17 +13,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
from abc import ABC
import builtins
import json
import os
from functools import partial
from typing import Tuple, Union
import pandas as pd
from agent import settings
from agent.settings import flow_logger, DEBUG
_FEEDED_DEPRECATED_PARAMS = "_feeded_deprecated_params"
_DEPRECATED_PARAMS = "_deprecated_params"
@ -82,7 +82,6 @@ class ComponentParamBase(ABC):
return {name: True for name in self.get_feeded_deprecated_params()}
def __str__(self):
return json.dumps(self.as_dict(), ensure_ascii=False)
def as_dict(self):
@ -398,8 +397,11 @@ class ComponentBase(ABC):
self._param.check()
def get_dependent_components(self):
cpnts = [para["component_id"] for para in self._param.query if para.get("component_id") and para["component_id"].lower().find("answer") < 0]
return cpnts
cpnts = set([para["component_id"].split("@")[0] for para in self._param.query \
if para.get("component_id") \
and para["component_id"].lower().find("answer") < 0 \
and para["component_id"].lower().find("begin") < 0])
return list(cpnts)
def run(self, history, **kwargs):
logging.debug("{}, history: {}, kwargs: {}".format(self, json.dumps(history, ensure_ascii=False),
@ -416,7 +418,7 @@ class ComponentBase(ABC):
def _run(self, history, **kwargs):
raise NotImplementedError()
def output(self, allow_partial=True) -> tuple[str, pd.DataFrame | partial]:
def output(self, allow_partial=True) -> Tuple[str, Union[pd.DataFrame, partial]]:
o = getattr(self._param, self._param.output_var_name)
if not isinstance(o, partial) and not isinstance(o, pd.DataFrame):
if not isinstance(o, list): o = [o]
@ -436,12 +438,19 @@ class ComponentBase(ABC):
def reset(self):
setattr(self._param, self._param.output_var_name, None)
self._param.inputs = []
def set_output(self, v: pd.DataFrame):
setattr(self._param, self._param.output_var_name, v)
def get_input(self):
reversed_cpnts = []
if len(self._canvas.path) > 1:
reversed_cpnts.extend(self._canvas.path[-2])
reversed_cpnts.extend(self._canvas.path[-1])
if self._param.query:
self._param.inputs = []
outs = []
for q in self._param.query:
if q["component_id"]:
@ -449,9 +458,9 @@ class ComponentBase(ABC):
cpn_id, key = q["component_id"].split("@")
for p in self._canvas.get_component(cpn_id)["obj"]._param.query:
if p["key"] == key:
outs.append(pd.DataFrame([{"content": p["value"]}]))
outs.append(pd.DataFrame([{"content": p.get("value", "")}]))
self._param.inputs.append({"component_id": q["component_id"],
"content": p["value"]})
"content": p.get("value", "")})
break
else:
assert False, f"Can't find parameter '{key}' for {cpn_id}"
@ -470,12 +479,8 @@ class ComponentBase(ABC):
return df
upstream_outs = []
reversed_cpnts = []
if len(self._canvas.path) > 1:
reversed_cpnts.extend(self._canvas.path[-2])
reversed_cpnts.extend(self._canvas.path[-1])
logging.debug(f"{self.component_name} {reversed_cpnts[::-1]}")
if DEBUG: print(self.component_name, reversed_cpnts[::-1])
for u in reversed_cpnts[::-1]:
if self.get_component_name(u) in ["switch", "concentrator"]: continue
if self.component_name.lower() == "generate" and self.get_component_name(u) == "retrieval":
@ -484,7 +489,7 @@ class ComponentBase(ABC):
o["component_id"] = u
upstream_outs.append(o)
continue
if self.component_name.lower()!="answer" and u not in self._canvas.get_component(self._id)["upstream"]: continue
#if self.component_name.lower()!="answer" and u not in self._canvas.get_component(self._id)["upstream"]: continue
if self.component_name.lower().find("switch") < 0 \
and self.get_component_name(u) in ["relevant", "categorize"]:
continue
@ -502,14 +507,14 @@ class ComponentBase(ABC):
upstream_outs.append(o)
break
assert upstream_outs, "Can't inference the where the component input is."
assert upstream_outs, "Can't inference the where the component input is. Please identify whose output is this component's input."
df = pd.concat(upstream_outs, ignore_index=True)
if "content" in df:
df = df.drop_duplicates(subset=['content']).reset_index(drop=True)
self._param.inputs = []
for _,r in df.iterrows():
for _, r in df.iterrows():
self._param.inputs.append({"component_id": r["component_id"], "content": r["content"]})
return df

View File

@ -63,9 +63,11 @@ class Generate(ComponentBase):
component_name = "Generate"
def get_dependent_components(self):
cpnts = [para["component_id"] for para in self._param.parameters if
para.get("component_id") and para["component_id"].lower().find("answer") < 0]
return cpnts
cpnts = set([para["component_id"].split("@")[0] for para in self._param.parameters \
if para.get("component_id") \
and para["component_id"].lower().find("answer") < 0 \
and para["component_id"].lower().find("begin") < 0])
return list(cpnts)
def set_cite(self, retrieval_res, answer):
retrieval_res = retrieval_res.dropna(subset=["vector", "content_ltks"]).reset_index(drop=True)
@ -107,11 +109,12 @@ class Generate(ComponentBase):
self._param.inputs = []
for para in self._param.parameters:
if not para.get("component_id"): continue
if para["component_id"].split("@")[0].lower().find("begin") > 0:
component_id = para["component_id"].split("@")[0]
if para["component_id"].lower().find("@") >= 0:
cpn_id, key = para["component_id"].split("@")
for p in self._canvas.get_component(cpn_id)["obj"]._param.query:
if p["key"] == key:
kwargs[para["key"]] = p["value"]
kwargs[para["key"]] = p.get("value", "")
self._param.inputs.append(
{"component_id": para["component_id"], "content": kwargs[para["key"]]})
break
@ -119,7 +122,7 @@ class Generate(ComponentBase):
assert False, f"Can't find parameter '{key}' for {cpn_id}"
continue
cpn = self._canvas.get_component(para["component_id"])["obj"]
cpn = self._canvas.get_component(component_id)["obj"]
if cpn.component_name.lower() == "answer":
kwargs[para["key"]] = self._canvas.get_history(1)[0]["content"]
continue
@ -129,14 +132,12 @@ class Generate(ComponentBase):
else:
if cpn.component_name.lower() == "retrieval":
retrieval_res.append(out)
kwargs[para["key"]] = " - " + "\n - ".join(
[o if isinstance(o, str) else str(o) for o in out["content"]])
kwargs[para["key"]] = " - "+"\n - ".join([o if isinstance(o, str) else str(o) for o in out["content"]])
self._param.inputs.append({"component_id": para["component_id"], "content": kwargs[para["key"]]})
if retrieval_res:
retrieval_res = pd.concat(retrieval_res, ignore_index=True)
else:
retrieval_res = pd.DataFrame([])
else: retrieval_res = pd.DataFrame([])
for n, v in kwargs.items():
prompt = re.sub(r"\{%s\}" % re.escape(n), re.escape(str(v)), prompt)
@ -158,6 +159,7 @@ class Generate(ComponentBase):
return pd.DataFrame([res])
msg = self._canvas.get_history(self._param.message_history_window_size)
if len(msg) < 1: msg.append({"role": "user", "content": ""})
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(chat_mdl.max_length * 0.97))
if len(msg) < 2: msg.append({"role": "user", "content": ""})
ans = chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf())
@ -178,6 +180,7 @@ class Generate(ComponentBase):
return
msg = self._canvas.get_history(self._param.message_history_window_size)
if len(msg) < 1: msg.append({"role": "user", "content": ""})
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(chat_mdl.max_length * 0.97))
if len(msg) < 2: msg.append({"role": "user", "content": ""})
answer = ""

View File

@ -47,13 +47,35 @@ class SwitchParam(ComponentParamBase):
class Switch(ComponentBase, ABC):
component_name = "Switch"
def get_dependent_components(self):
res = []
for cond in self._param.conditions:
for item in cond["items"]:
if not item["cpn_id"]: continue
if item["cpn_id"].find("begin") >= 0:
continue
cid = item["cpn_id"].split("@")[0]
res.append(cid)
return list(set(res))
def _run(self, history, **kwargs):
for cond in self._param.conditions:
res = []
for item in cond["items"]:
out = self._canvas.get_component(item["cpn_id"])["obj"].output()[1]
cpn_input = "" if "content" not in out.columns else " ".join([str(s) for s in out["content"]])
res.append(self.process_operator(cpn_input, item["operator"], item["value"]))
if not item["cpn_id"]:continue
cid = item["cpn_id"].split("@")[0]
if item["cpn_id"].find("@") > 0:
cpn_id, key = item["cpn_id"].split("@")
for p in self._canvas.get_component(cid)["obj"]._param.query:
if p["key"] == key:
res.append(self.process_operator(p.get("value",""), item["operator"], item.get("value", "")))
break
else:
out = self._canvas.get_component(cid)["obj"].output()[1]
cpn_input = "" if "content" not in out.columns else " ".join([str(s) for s in out["content"]])
res.append(self.process_operator(cpn_input, item["operator"], item.get("value", "")))
if cond["logical_operator"] != "and" and any(res):
return Switch.be_output(cond["to"])