mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
add input variables to begin component (#3498)
### What problem does this PR solve? #3355 ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -13,17 +13,17 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import logging
|
||||
from abc import ABC
|
||||
import builtins
|
||||
import json
|
||||
import os
|
||||
from functools import partial
|
||||
from typing import Tuple, Union
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from agent import settings
|
||||
|
||||
from agent.settings import flow_logger, DEBUG
|
||||
|
||||
_FEEDED_DEPRECATED_PARAMS = "_feeded_deprecated_params"
|
||||
_DEPRECATED_PARAMS = "_deprecated_params"
|
||||
@ -82,7 +82,6 @@ class ComponentParamBase(ABC):
|
||||
return {name: True for name in self.get_feeded_deprecated_params()}
|
||||
|
||||
def __str__(self):
|
||||
|
||||
return json.dumps(self.as_dict(), ensure_ascii=False)
|
||||
|
||||
def as_dict(self):
|
||||
@ -398,8 +397,11 @@ class ComponentBase(ABC):
|
||||
self._param.check()
|
||||
|
||||
def get_dependent_components(self):
|
||||
cpnts = [para["component_id"] for para in self._param.query if para.get("component_id") and para["component_id"].lower().find("answer") < 0]
|
||||
return cpnts
|
||||
cpnts = set([para["component_id"].split("@")[0] for para in self._param.query \
|
||||
if para.get("component_id") \
|
||||
and para["component_id"].lower().find("answer") < 0 \
|
||||
and para["component_id"].lower().find("begin") < 0])
|
||||
return list(cpnts)
|
||||
|
||||
def run(self, history, **kwargs):
|
||||
logging.debug("{}, history: {}, kwargs: {}".format(self, json.dumps(history, ensure_ascii=False),
|
||||
@ -416,7 +418,7 @@ class ComponentBase(ABC):
|
||||
def _run(self, history, **kwargs):
|
||||
raise NotImplementedError()
|
||||
|
||||
def output(self, allow_partial=True) -> tuple[str, pd.DataFrame | partial]:
|
||||
def output(self, allow_partial=True) -> Tuple[str, Union[pd.DataFrame, partial]]:
|
||||
o = getattr(self._param, self._param.output_var_name)
|
||||
if not isinstance(o, partial) and not isinstance(o, pd.DataFrame):
|
||||
if not isinstance(o, list): o = [o]
|
||||
@ -436,12 +438,19 @@ class ComponentBase(ABC):
|
||||
|
||||
def reset(self):
|
||||
setattr(self._param, self._param.output_var_name, None)
|
||||
self._param.inputs = []
|
||||
|
||||
def set_output(self, v: pd.DataFrame):
|
||||
setattr(self._param, self._param.output_var_name, v)
|
||||
|
||||
def get_input(self):
|
||||
reversed_cpnts = []
|
||||
if len(self._canvas.path) > 1:
|
||||
reversed_cpnts.extend(self._canvas.path[-2])
|
||||
reversed_cpnts.extend(self._canvas.path[-1])
|
||||
|
||||
if self._param.query:
|
||||
self._param.inputs = []
|
||||
outs = []
|
||||
for q in self._param.query:
|
||||
if q["component_id"]:
|
||||
@ -449,9 +458,9 @@ class ComponentBase(ABC):
|
||||
cpn_id, key = q["component_id"].split("@")
|
||||
for p in self._canvas.get_component(cpn_id)["obj"]._param.query:
|
||||
if p["key"] == key:
|
||||
outs.append(pd.DataFrame([{"content": p["value"]}]))
|
||||
outs.append(pd.DataFrame([{"content": p.get("value", "")}]))
|
||||
self._param.inputs.append({"component_id": q["component_id"],
|
||||
"content": p["value"]})
|
||||
"content": p.get("value", "")})
|
||||
break
|
||||
else:
|
||||
assert False, f"Can't find parameter '{key}' for {cpn_id}"
|
||||
@ -470,12 +479,8 @@ class ComponentBase(ABC):
|
||||
return df
|
||||
|
||||
upstream_outs = []
|
||||
reversed_cpnts = []
|
||||
if len(self._canvas.path) > 1:
|
||||
reversed_cpnts.extend(self._canvas.path[-2])
|
||||
reversed_cpnts.extend(self._canvas.path[-1])
|
||||
|
||||
logging.debug(f"{self.component_name} {reversed_cpnts[::-1]}")
|
||||
if DEBUG: print(self.component_name, reversed_cpnts[::-1])
|
||||
for u in reversed_cpnts[::-1]:
|
||||
if self.get_component_name(u) in ["switch", "concentrator"]: continue
|
||||
if self.component_name.lower() == "generate" and self.get_component_name(u) == "retrieval":
|
||||
@ -484,7 +489,7 @@ class ComponentBase(ABC):
|
||||
o["component_id"] = u
|
||||
upstream_outs.append(o)
|
||||
continue
|
||||
if self.component_name.lower()!="answer" and u not in self._canvas.get_component(self._id)["upstream"]: continue
|
||||
#if self.component_name.lower()!="answer" and u not in self._canvas.get_component(self._id)["upstream"]: continue
|
||||
if self.component_name.lower().find("switch") < 0 \
|
||||
and self.get_component_name(u) in ["relevant", "categorize"]:
|
||||
continue
|
||||
@ -502,14 +507,14 @@ class ComponentBase(ABC):
|
||||
upstream_outs.append(o)
|
||||
break
|
||||
|
||||
assert upstream_outs, "Can't inference the where the component input is."
|
||||
assert upstream_outs, "Can't inference the where the component input is. Please identify whose output is this component's input."
|
||||
|
||||
df = pd.concat(upstream_outs, ignore_index=True)
|
||||
if "content" in df:
|
||||
df = df.drop_duplicates(subset=['content']).reset_index(drop=True)
|
||||
|
||||
self._param.inputs = []
|
||||
for _,r in df.iterrows():
|
||||
for _, r in df.iterrows():
|
||||
self._param.inputs.append({"component_id": r["component_id"], "content": r["content"]})
|
||||
|
||||
return df
|
||||
|
||||
Reference in New Issue
Block a user