add input variables to begin component (#3498)

### What problem does this PR solve?

#3355 

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Kevin Hu
2024-11-19 18:41:48 +08:00
committed by GitHub
parent 0cd5b64c3b
commit 361cff34fc
7 changed files with 128 additions and 78 deletions

View File

@ -63,9 +63,11 @@ class Generate(ComponentBase):
component_name = "Generate"
def get_dependent_components(self):
cpnts = [para["component_id"] for para in self._param.parameters if
para.get("component_id") and para["component_id"].lower().find("answer") < 0]
return cpnts
cpnts = set([para["component_id"].split("@")[0] for para in self._param.parameters \
if para.get("component_id") \
and para["component_id"].lower().find("answer") < 0 \
and para["component_id"].lower().find("begin") < 0])
return list(cpnts)
def set_cite(self, retrieval_res, answer):
retrieval_res = retrieval_res.dropna(subset=["vector", "content_ltks"]).reset_index(drop=True)
@ -107,11 +109,12 @@ class Generate(ComponentBase):
self._param.inputs = []
for para in self._param.parameters:
if not para.get("component_id"): continue
if para["component_id"].split("@")[0].lower().find("begin") > 0:
component_id = para["component_id"].split("@")[0]
if para["component_id"].lower().find("@") >= 0:
cpn_id, key = para["component_id"].split("@")
for p in self._canvas.get_component(cpn_id)["obj"]._param.query:
if p["key"] == key:
kwargs[para["key"]] = p["value"]
kwargs[para["key"]] = p.get("value", "")
self._param.inputs.append(
{"component_id": para["component_id"], "content": kwargs[para["key"]]})
break
@ -119,7 +122,7 @@ class Generate(ComponentBase):
assert False, f"Can't find parameter '{key}' for {cpn_id}"
continue
cpn = self._canvas.get_component(para["component_id"])["obj"]
cpn = self._canvas.get_component(component_id)["obj"]
if cpn.component_name.lower() == "answer":
kwargs[para["key"]] = self._canvas.get_history(1)[0]["content"]
continue
@ -129,14 +132,12 @@ class Generate(ComponentBase):
else:
if cpn.component_name.lower() == "retrieval":
retrieval_res.append(out)
kwargs[para["key"]] = " - " + "\n - ".join(
[o if isinstance(o, str) else str(o) for o in out["content"]])
kwargs[para["key"]] = " - "+"\n - ".join([o if isinstance(o, str) else str(o) for o in out["content"]])
self._param.inputs.append({"component_id": para["component_id"], "content": kwargs[para["key"]]})
if retrieval_res:
retrieval_res = pd.concat(retrieval_res, ignore_index=True)
else:
retrieval_res = pd.DataFrame([])
else: retrieval_res = pd.DataFrame([])
for n, v in kwargs.items():
prompt = re.sub(r"\{%s\}" % re.escape(n), re.escape(str(v)), prompt)
@ -158,6 +159,7 @@ class Generate(ComponentBase):
return pd.DataFrame([res])
msg = self._canvas.get_history(self._param.message_history_window_size)
if len(msg) < 1: msg.append({"role": "user", "content": ""})
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(chat_mdl.max_length * 0.97))
if len(msg) < 2: msg.append({"role": "user", "content": ""})
ans = chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf())
@ -178,6 +180,7 @@ class Generate(ComponentBase):
return
msg = self._canvas.get_history(self._param.message_history_window_size)
if len(msg) < 1: msg.append({"role": "user", "content": ""})
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(chat_mdl.max_length * 0.97))
if len(msg) < 2: msg.append({"role": "user", "content": ""})
answer = ""