mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Feat: Redesign and refactor agent module (#9113)
### What problem does this PR solve? #9082 #6365 <u> **WARNING: it's not compatible with the older version of `Agent` module, which means that `Agent` from older versions can not work anymore.**</u> ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -13,17 +13,19 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import re
|
||||
import time
|
||||
from abc import ABC
|
||||
import builtins
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from abc import ABC
|
||||
from functools import partial
|
||||
from typing import Any, Tuple, Union
|
||||
|
||||
import logging
|
||||
from typing import Any, List, Union
|
||||
import pandas as pd
|
||||
|
||||
import trio
|
||||
from agent import settings
|
||||
from api.utils.api_utils import timeout
|
||||
|
||||
|
||||
_FEEDED_DEPRECATED_PARAMS = "_feeded_deprecated_params"
|
||||
_DEPRECATED_PARAMS = "_deprecated_params"
|
||||
@ -33,12 +35,17 @@ _IS_RAW_CONF = "_is_raw_conf"
|
||||
|
||||
class ComponentParamBase(ABC):
|
||||
def __init__(self):
|
||||
self.output_var_name = "output"
|
||||
self.infor_var_name = "infor"
|
||||
self.message_history_window_size = 22
|
||||
self.query = []
|
||||
self.inputs = []
|
||||
self.debug_inputs = []
|
||||
self.inputs = {}
|
||||
self.outputs = {}
|
||||
self.description = ""
|
||||
self.max_retries = 0
|
||||
self.delay_after_error = 2.0
|
||||
self.exception_method = None
|
||||
self.exception_default_value = None
|
||||
self.exception_comment = None
|
||||
self.exception_goto = None
|
||||
self.debug_inputs = {}
|
||||
|
||||
def set_name(self, name: str):
|
||||
self._name = name
|
||||
@ -110,11 +117,15 @@ class ComponentParamBase(ABC):
|
||||
update_from_raw_conf = conf.get(_IS_RAW_CONF, True)
|
||||
if update_from_raw_conf:
|
||||
deprecated_params_set = self._get_or_init_deprecated_params_set()
|
||||
feeded_deprecated_params_set = self._get_or_init_feeded_deprecated_params_set()
|
||||
feeded_deprecated_params_set = (
|
||||
self._get_or_init_feeded_deprecated_params_set()
|
||||
)
|
||||
user_feeded_params_set = self._get_or_init_user_feeded_params_set()
|
||||
setattr(self, _IS_RAW_CONF, False)
|
||||
else:
|
||||
feeded_deprecated_params_set = self._get_or_init_feeded_deprecated_params_set(conf)
|
||||
feeded_deprecated_params_set = (
|
||||
self._get_or_init_feeded_deprecated_params_set(conf)
|
||||
)
|
||||
user_feeded_params_set = self._get_or_init_user_feeded_params_set(conf)
|
||||
|
||||
def _recursive_update_param(param, config, depth, prefix):
|
||||
@ -150,11 +161,15 @@ class ComponentParamBase(ABC):
|
||||
|
||||
else:
|
||||
# recursive set obj attr
|
||||
sub_params = _recursive_update_param(attr, config_value, depth + 1, prefix=f"{prefix}{config_key}.")
|
||||
sub_params = _recursive_update_param(
|
||||
attr, config_value, depth + 1, prefix=f"{prefix}{config_key}."
|
||||
)
|
||||
setattr(param, config_key, sub_params)
|
||||
|
||||
if not allow_redundant and redundant_attrs:
|
||||
raise ValueError(f"cpn `{getattr(self, '_name', type(self))}` has redundant parameters: `{[redundant_attrs]}`")
|
||||
raise ValueError(
|
||||
f"cpn `{getattr(self, '_name', type(self))}` has redundant parameters: `{[redundant_attrs]}`"
|
||||
)
|
||||
|
||||
return param
|
||||
|
||||
@ -185,7 +200,9 @@ class ComponentParamBase(ABC):
|
||||
param_validation_path_prefix = home_dir + "/param_validation/"
|
||||
|
||||
param_name = type(self).__name__
|
||||
param_validation_path = "/".join([param_validation_path_prefix, param_name + ".json"])
|
||||
param_validation_path = "/".join(
|
||||
[param_validation_path_prefix, param_name + ".json"]
|
||||
)
|
||||
|
||||
validation_json = None
|
||||
|
||||
@ -218,7 +235,11 @@ class ComponentParamBase(ABC):
|
||||
break
|
||||
|
||||
if not value_legal:
|
||||
raise ValueError("Plase check runtime conf, {} = {} does not match user-parameter restriction".format(variable, value))
|
||||
raise ValueError(
|
||||
"Plase check runtime conf, {} = {} does not match user-parameter restriction".format(
|
||||
variable, value
|
||||
)
|
||||
)
|
||||
|
||||
elif variable in validation_json:
|
||||
self._validate_param(attr, validation_json)
|
||||
@ -226,63 +247,94 @@ class ComponentParamBase(ABC):
|
||||
@staticmethod
|
||||
def check_string(param, descr):
|
||||
if type(param).__name__ not in ["str"]:
|
||||
raise ValueError(descr + " {} not supported, should be string type".format(param))
|
||||
raise ValueError(
|
||||
descr + " {} not supported, should be string type".format(param)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def check_empty(param, descr):
|
||||
if not param:
|
||||
raise ValueError(descr + " does not support empty value.")
|
||||
raise ValueError(
|
||||
descr + " does not support empty value."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def check_positive_integer(param, descr):
|
||||
if type(param).__name__ not in ["int", "long"] or param <= 0:
|
||||
raise ValueError(descr + " {} not supported, should be positive integer".format(param))
|
||||
raise ValueError(
|
||||
descr + " {} not supported, should be positive integer".format(param)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def check_positive_number(param, descr):
|
||||
if type(param).__name__ not in ["float", "int", "long"] or param <= 0:
|
||||
raise ValueError(descr + " {} not supported, should be positive numeric".format(param))
|
||||
raise ValueError(
|
||||
descr + " {} not supported, should be positive numeric".format(param)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def check_nonnegative_number(param, descr):
|
||||
if type(param).__name__ not in ["float", "int", "long"] or param < 0:
|
||||
raise ValueError(descr + " {} not supported, should be non-negative numeric".format(param))
|
||||
raise ValueError(
|
||||
descr
|
||||
+ " {} not supported, should be non-negative numeric".format(param)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def check_decimal_float(param, descr):
|
||||
if type(param).__name__ not in ["float", "int"] or param < 0 or param > 1:
|
||||
raise ValueError(descr + " {} not supported, should be a float number in range [0, 1]".format(param))
|
||||
raise ValueError(
|
||||
descr
|
||||
+ " {} not supported, should be a float number in range [0, 1]".format(
|
||||
param
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def check_boolean(param, descr):
|
||||
if type(param).__name__ != "bool":
|
||||
raise ValueError(descr + " {} not supported, should be bool type".format(param))
|
||||
raise ValueError(
|
||||
descr + " {} not supported, should be bool type".format(param)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def check_open_unit_interval(param, descr):
|
||||
if type(param).__name__ not in ["float"] or param <= 0 or param >= 1:
|
||||
raise ValueError(descr + " should be a numeric number between 0 and 1 exclusively")
|
||||
raise ValueError(
|
||||
descr + " should be a numeric number between 0 and 1 exclusively"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def check_valid_value(param, descr, valid_values):
|
||||
if param not in valid_values:
|
||||
raise ValueError(descr + " {} is not supported, it should be in {}".format(param, valid_values))
|
||||
raise ValueError(
|
||||
descr
|
||||
+ " {} is not supported, it should be in {}".format(param, valid_values)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def check_defined_type(param, descr, types):
|
||||
if type(param).__name__ not in types:
|
||||
raise ValueError(descr + " {} not supported, should be one of {}".format(param, types))
|
||||
raise ValueError(
|
||||
descr + " {} not supported, should be one of {}".format(param, types)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def check_and_change_lower(param, valid_list, descr=""):
|
||||
if type(param).__name__ != "str":
|
||||
raise ValueError(descr + " {} not supported, should be one of {}".format(param, valid_list))
|
||||
raise ValueError(
|
||||
descr
|
||||
+ " {} not supported, should be one of {}".format(param, valid_list)
|
||||
)
|
||||
|
||||
lower_param = param.lower()
|
||||
if lower_param in valid_list:
|
||||
return lower_param
|
||||
else:
|
||||
raise ValueError(descr + " {} not supported, should be one of {}".format(param, valid_list))
|
||||
raise ValueError(
|
||||
descr
|
||||
+ " {} not supported, should be one of {}".format(param, valid_list)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _greater_equal_than(value, limit):
|
||||
@ -296,7 +348,11 @@ class ComponentParamBase(ABC):
|
||||
def _range(value, ranges):
|
||||
in_range = False
|
||||
for left_limit, right_limit in ranges:
|
||||
if left_limit - settings.FLOAT_ZERO <= value <= right_limit + settings.FLOAT_ZERO:
|
||||
if (
|
||||
left_limit - settings.FLOAT_ZERO
|
||||
<= value
|
||||
<= right_limit + settings.FLOAT_ZERO
|
||||
):
|
||||
in_range = True
|
||||
break
|
||||
|
||||
@ -312,17 +368,24 @@ class ComponentParamBase(ABC):
|
||||
|
||||
def _warn_deprecated_param(self, param_name, descr):
|
||||
if self._deprecated_params_set.get(param_name):
|
||||
logging.warning(f"{descr} {param_name} is deprecated and ignored in this version.")
|
||||
logging.warning(
|
||||
f"{descr} {param_name} is deprecated and ignored in this version."
|
||||
)
|
||||
|
||||
def _warn_to_deprecate_param(self, param_name, descr, new_param):
|
||||
if self._deprecated_params_set.get(param_name):
|
||||
logging.warning(f"{descr} {param_name} will be deprecated in future release; please use {new_param} instead.")
|
||||
logging.warning(
|
||||
f"{descr} {param_name} will be deprecated in future release; "
|
||||
f"please use {new_param} instead."
|
||||
)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class ComponentBase(ABC):
|
||||
component_name: str
|
||||
thread_limiter = trio.CapacityLimiter(int(os.environ.get('MAX_CONCURRENT_CHATS', 10)))
|
||||
variable_ref_patt = r"\{* *\{([a-zA-Z:0-9]+@[A-Za-z:0-9_.-]+|sys\.[a-z_]+)\} *\}*"
|
||||
|
||||
def __str__(self):
|
||||
"""
|
||||
@ -331,232 +394,144 @@ class ComponentBase(ABC):
|
||||
"params": {}
|
||||
}
|
||||
"""
|
||||
out = getattr(self._param, self._param.output_var_name)
|
||||
if isinstance(out, pd.DataFrame) and "chunks" in out:
|
||||
del out["chunks"]
|
||||
setattr(self._param, self._param.output_var_name, out)
|
||||
|
||||
return """{{
|
||||
"component_name": "{}",
|
||||
"params": {},
|
||||
"output": {},
|
||||
"inputs": {}
|
||||
}}""".format(
|
||||
self.component_name,
|
||||
self._param,
|
||||
json.dumps(json.loads(str(self._param)).get("output", {}), ensure_ascii=False),
|
||||
json.dumps(json.loads(str(self._param)).get("inputs", []), ensure_ascii=False),
|
||||
"params": {}
|
||||
}}""".format(self.component_name,
|
||||
self._param
|
||||
)
|
||||
|
||||
def __init__(self, canvas, id, param: ComponentParamBase):
|
||||
from agent.canvas import Canvas # Local import to avoid cyclic dependency
|
||||
|
||||
assert isinstance(canvas, Canvas), "canvas must be an instance of Canvas"
|
||||
self._canvas = canvas
|
||||
self._id = id
|
||||
self._param = param
|
||||
self._param.check()
|
||||
|
||||
def get_dependent_components(self):
|
||||
cpnts = set(
|
||||
[
|
||||
para["component_id"].split("@")[0]
|
||||
for para in self._param.query
|
||||
if para.get("component_id") and para["component_id"].lower().find("answer") < 0 and para["component_id"].lower().find("begin") < 0
|
||||
]
|
||||
)
|
||||
return list(cpnts)
|
||||
|
||||
def run(self, history, **kwargs):
|
||||
logging.debug("{}, history: {}, kwargs: {}".format(self, json.dumps(history, ensure_ascii=False), json.dumps(kwargs, ensure_ascii=False)))
|
||||
self._param.debug_inputs = []
|
||||
def invoke(self, **kwargs) -> dict[str, Any]:
|
||||
self.set_output("_created_time", time.perf_counter())
|
||||
try:
|
||||
res = self._run(history, **kwargs)
|
||||
self.set_output(res)
|
||||
self._invoke(**kwargs)
|
||||
except Exception as e:
|
||||
self.set_output(pd.DataFrame([{"content": str(e)}]))
|
||||
raise e
|
||||
self._param.outputs["_ERROR"] = {"value": str(e)}
|
||||
logging.exception(e)
|
||||
self._param.debug_inputs = {}
|
||||
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
|
||||
return self.output()
|
||||
|
||||
return res
|
||||
|
||||
def _run(self, history, **kwargs):
|
||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))
|
||||
def _invoke(self, **kwargs):
|
||||
raise NotImplementedError()
|
||||
|
||||
def output(self, allow_partial=True) -> Tuple[str, Union[pd.DataFrame, partial]]:
|
||||
o = getattr(self._param, self._param.output_var_name)
|
||||
if not isinstance(o, partial):
|
||||
if not isinstance(o, pd.DataFrame):
|
||||
if isinstance(o, list):
|
||||
return self._param.output_var_name, pd.DataFrame(o).dropna()
|
||||
if o is None:
|
||||
return self._param.output_var_name, pd.DataFrame()
|
||||
return self._param.output_var_name, pd.DataFrame([{"content": str(o)}])
|
||||
return self._param.output_var_name, o
|
||||
def output(self, var_nm: str=None) -> Union[dict[str, Any], Any]:
|
||||
if var_nm:
|
||||
return self._param.outputs.get(var_nm, {}).get("value")
|
||||
return {k: o.get("value") for k,o in self._param.outputs.items()}
|
||||
|
||||
if allow_partial or not isinstance(o, partial):
|
||||
if not isinstance(o, partial) and not isinstance(o, pd.DataFrame):
|
||||
return pd.DataFrame(o if isinstance(o, list) else [o]).dropna()
|
||||
return self._param.output_var_name, o
|
||||
def set_output(self, key: str, value: Any):
|
||||
if key not in self._param.outputs:
|
||||
self._param.outputs[key] = {"value": None, "type": str(type(value))}
|
||||
self._param.outputs[key]["value"] = value
|
||||
|
||||
outs = None
|
||||
for oo in o():
|
||||
if not isinstance(oo, pd.DataFrame):
|
||||
outs = pd.DataFrame(oo if isinstance(oo, list) else [oo]).dropna()
|
||||
else:
|
||||
outs = oo.dropna()
|
||||
return self._param.output_var_name, outs
|
||||
def error(self):
|
||||
return self._param.outputs.get("_ERROR", {}).get("value")
|
||||
|
||||
def reset(self):
|
||||
setattr(self._param, self._param.output_var_name, None)
|
||||
self._param.inputs = []
|
||||
for k in self._param.outputs.keys():
|
||||
self._param.outputs[k]["value"] = None
|
||||
for k in self._param.inputs.keys():
|
||||
self._param.inputs[k]["value"] = None
|
||||
self._param.debug_inputs = {}
|
||||
|
||||
def set_output(self, v):
|
||||
setattr(self._param, self._param.output_var_name, v)
|
||||
def get_input(self, key: str=None) -> Union[Any, dict[str, Any]]:
|
||||
if key:
|
||||
return self._param.inputs.get(key, {}).get("value")
|
||||
|
||||
def set_infor(self, v):
|
||||
setattr(self._param, self._param.infor_var_name, v)
|
||||
|
||||
def _fetch_outputs_from(self, sources: list[dict[str, Any]]) -> list[pd.DataFrame]:
|
||||
outs = []
|
||||
for q in sources:
|
||||
if q.get("component_id"):
|
||||
if "@" in q["component_id"] and q["component_id"].split("@")[0].lower().find("begin") >= 0:
|
||||
cpn_id, key = q["component_id"].split("@")
|
||||
for p in self._canvas.get_component(cpn_id)["obj"]._param.query:
|
||||
if p["key"] == key:
|
||||
outs.append(pd.DataFrame([{"content": p.get("value", "")}]))
|
||||
break
|
||||
else:
|
||||
assert False, f"Can't find parameter '{key}' for {cpn_id}"
|
||||
continue
|
||||
|
||||
if q["component_id"].lower().find("answer") == 0:
|
||||
txt = []
|
||||
for r, c in self._canvas.history[::-1][: self._param.message_history_window_size][::-1]:
|
||||
txt.append(f"{r.upper()}:{c}")
|
||||
txt = "\n".join(txt)
|
||||
outs.append(pd.DataFrame([{"content": txt}]))
|
||||
continue
|
||||
|
||||
outs.append(self._canvas.get_component(q["component_id"])["obj"].output(allow_partial=False)[1])
|
||||
elif q.get("value"):
|
||||
outs.append(pd.DataFrame([{"content": q["value"]}]))
|
||||
return outs
|
||||
def get_input(self):
|
||||
if self._param.debug_inputs:
|
||||
return pd.DataFrame([{"content": v["value"]} for v in self._param.debug_inputs if v.get("value")])
|
||||
|
||||
reversed_cpnts = []
|
||||
if len(self._canvas.path) > 1:
|
||||
reversed_cpnts.extend(self._canvas.path[-2])
|
||||
reversed_cpnts.extend(self._canvas.path[-1])
|
||||
up_cpns = self.get_upstream()
|
||||
reversed_up_cpnts = [cpn for cpn in reversed_cpnts if cpn in up_cpns]
|
||||
|
||||
if self._param.query:
|
||||
self._param.inputs = []
|
||||
outs = self._fetch_outputs_from(self._param.query)
|
||||
|
||||
for out in outs:
|
||||
records = out.to_dict("records")
|
||||
content: str
|
||||
|
||||
if len(records) > 1:
|
||||
content = "\n".join([str(d["content"]) for d in records])
|
||||
else:
|
||||
content = records[0]["content"]
|
||||
|
||||
self._param.inputs.append({"component_id": records[0].get("component_id"), "content": content})
|
||||
|
||||
if outs:
|
||||
df = pd.concat(outs, ignore_index=True)
|
||||
if "content" in df:
|
||||
df = df.drop_duplicates(subset=["content"]).reset_index(drop=True)
|
||||
return df
|
||||
|
||||
upstream_outs = []
|
||||
|
||||
for u in reversed_up_cpnts[::-1]:
|
||||
if self.get_component_name(u) in ["switch", "concentrator"]:
|
||||
res = {}
|
||||
for var, o in self.get_input_elements().items():
|
||||
v = self.get_param(var)
|
||||
if v is None:
|
||||
continue
|
||||
if self.component_name.lower() == "generate" and self.get_component_name(u) == "retrieval":
|
||||
o = self._canvas.get_component(u)["obj"].output(allow_partial=False)[1]
|
||||
if o is not None:
|
||||
o["component_id"] = u
|
||||
upstream_outs.append(o)
|
||||
continue
|
||||
# if self.component_name.lower()!="answer" and u not in self._canvas.get_component(self._id)["upstream"]: continue
|
||||
if self.component_name.lower().find("switch") < 0 and self.get_component_name(u) in ["relevant", "categorize"]:
|
||||
continue
|
||||
if u.lower().find("answer") >= 0:
|
||||
for r, c in self._canvas.history[::-1]:
|
||||
if r == "user":
|
||||
upstream_outs.append(pd.DataFrame([{"content": c, "component_id": u}]))
|
||||
break
|
||||
break
|
||||
if self.component_name.lower().find("answer") >= 0 and self.get_component_name(u) in ["relevant"]:
|
||||
continue
|
||||
o = self._canvas.get_component(u)["obj"].output(allow_partial=False)[1]
|
||||
if o is not None:
|
||||
o["component_id"] = u
|
||||
upstream_outs.append(o)
|
||||
break
|
||||
|
||||
assert upstream_outs, "Can't inference the where the component input is. Please identify whose output is this component's input."
|
||||
|
||||
df = pd.concat(upstream_outs, ignore_index=True)
|
||||
if "content" in df:
|
||||
df = df.drop_duplicates(subset=["content"]).reset_index(drop=True)
|
||||
|
||||
self._param.inputs = []
|
||||
for _, r in df.iterrows():
|
||||
self._param.inputs.append({"component_id": r["component_id"], "content": r["content"]})
|
||||
|
||||
return df
|
||||
|
||||
def get_input_elements(self):
|
||||
assert self._param.query, "Please verify the input parameters first."
|
||||
eles = []
|
||||
for q in self._param.query:
|
||||
if q.get("component_id"):
|
||||
cpn_id = q["component_id"]
|
||||
if cpn_id.split("@")[0].lower().find("begin") >= 0:
|
||||
cpn_id, key = cpn_id.split("@")
|
||||
eles.extend(self._canvas.get_component(cpn_id)["obj"]._param.query)
|
||||
continue
|
||||
|
||||
eles.append({"name": self._canvas.get_component_name(cpn_id), "key": cpn_id})
|
||||
if isinstance(v, str) and self._canvas.is_reff(v):
|
||||
self.set_input_value(var, self._canvas.get_variable_value(v))
|
||||
else:
|
||||
eles.append({"key": q["value"], "name": q["value"], "value": q["value"]})
|
||||
return eles
|
||||
self.set_input_value(var, v)
|
||||
res[var] = self.get_input_value(var)
|
||||
return res
|
||||
|
||||
def get_stream_input(self):
|
||||
reversed_cpnts = []
|
||||
if len(self._canvas.path) > 1:
|
||||
reversed_cpnts.extend(self._canvas.path[-2])
|
||||
reversed_cpnts.extend(self._canvas.path[-1])
|
||||
up_cpns = self.get_upstream()
|
||||
reversed_up_cpnts = [cpn for cpn in reversed_cpnts if cpn in up_cpns]
|
||||
def get_input_values(self) -> Union[Any, dict[str, Any]]:
|
||||
if self._param.debug_inputs:
|
||||
return self._param.debug_inputs
|
||||
|
||||
for u in reversed_up_cpnts[::-1]:
|
||||
if self.get_component_name(u) in ["switch", "answer"]:
|
||||
continue
|
||||
return self._canvas.get_component(u)["obj"].output()[1]
|
||||
return {var: self.get_input_value(var) for var, o in self.get_input_elements().items()}
|
||||
|
||||
@staticmethod
|
||||
def be_output(v):
|
||||
return pd.DataFrame([{"content": v}])
|
||||
def get_input_elements_from_text(self, txt: str) -> dict[str, dict[str, str]]:
|
||||
res = {}
|
||||
for r in re.finditer(self.variable_ref_patt, txt, flags=re.IGNORECASE):
|
||||
exp = r.group(1)
|
||||
cpn_id, var_nm = exp.split("@") if exp.find("@")>0 else ("", exp)
|
||||
res[exp] = {
|
||||
"name": (self._canvas.get_component_name(cpn_id) +f"@{var_nm}") if cpn_id else exp,
|
||||
"value": self._canvas.get_variable_value(exp),
|
||||
"_retrival": self._canvas.get_variable_value(f"{cpn_id}@_references") if cpn_id else None,
|
||||
"_cpn_id": cpn_id
|
||||
}
|
||||
return res
|
||||
|
||||
def get_component_name(self, cpn_id):
|
||||
def get_input_elements(self) -> dict[str, Any]:
|
||||
return self._param.inputs
|
||||
|
||||
def get_input_form(self) -> dict[str, dict]:
|
||||
return self._param.get_input_form()
|
||||
|
||||
def set_input_value(self, key: str, value: Any) -> None:
|
||||
if key not in self._param.inputs:
|
||||
self._param.inputs[key] = {"value": None}
|
||||
self._param.inputs[key]["value"] = value
|
||||
|
||||
def get_input_value(self, key: str) -> Any:
|
||||
if key not in self._param.inputs:
|
||||
return None
|
||||
return self._param.inputs[key].get("value")
|
||||
|
||||
def get_component_name(self, cpn_id) -> str:
|
||||
return self._canvas.get_component(cpn_id)["obj"].component_name.lower()
|
||||
|
||||
def debug(self, **kwargs):
|
||||
return self._run([], **kwargs)
|
||||
def get_param(self, name):
|
||||
if hasattr(self._param, name):
|
||||
return getattr(self._param, name)
|
||||
|
||||
def get_parent(self):
|
||||
pid = self._canvas.get_component(self._id)["parent_id"]
|
||||
def debug(self, **kwargs):
|
||||
return self._invoke(**kwargs)
|
||||
|
||||
def get_parent(self) -> Union[object, None]:
|
||||
pid = self._canvas.get_component(self._id).get("parent_id")
|
||||
if not pid:
|
||||
return
|
||||
return self._canvas.get_component(pid)["obj"]
|
||||
|
||||
def get_upstream(self):
|
||||
cpn_nms = self._canvas.get_component(self._id)["upstream"]
|
||||
def get_upstream(self) -> List[str]:
|
||||
cpn_nms = self._canvas.get_component(self._id)['upstream']
|
||||
return cpn_nms
|
||||
|
||||
@staticmethod
|
||||
def string_format(content: str, kv: dict[str, str]) -> str:
|
||||
for n, v in kv.items():
|
||||
content = re.sub(
|
||||
r"\{%s\}" % re.escape(n), re.escape(v), content
|
||||
)
|
||||
return content
|
||||
|
||||
def exception_handler(self):
|
||||
if not self._param.exception_method:
|
||||
return
|
||||
return {
|
||||
"goto": self._param.exception_goto,
|
||||
"comment": self._param.exception_comment,
|
||||
"default_value": self._param.exception_default_value
|
||||
}
|
||||
|
||||
def get_exception_default_value(self):
|
||||
return self._param.exception_default_value
|
||||
|
||||
|
||||
Reference in New Issue
Block a user