Feat: Redesign and refactor agent module (#9113)

### What problem does this PR solve?

#9082 #6365

<u> **WARNING: it's not compatible with the older version of `Agent`
module, which means that `Agent` from older versions can not work
anymore.**</u>

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Kevin Hu
2025-07-30 19:41:09 +08:00
committed by GitHub
parent 07e37560fc
commit d9fe279dde
124 changed files with 7744 additions and 18226 deletions

View File

@ -13,17 +13,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import re
import time
from abc import ABC
import builtins
import json
import logging
import os
from abc import ABC
from functools import partial
from typing import Any, Tuple, Union
import logging
from typing import Any, List, Union
import pandas as pd
import trio
from agent import settings
from api.utils.api_utils import timeout
_FEEDED_DEPRECATED_PARAMS = "_feeded_deprecated_params"
_DEPRECATED_PARAMS = "_deprecated_params"
@ -33,12 +35,17 @@ _IS_RAW_CONF = "_is_raw_conf"
class ComponentParamBase(ABC):
def __init__(self):
self.output_var_name = "output"
self.infor_var_name = "infor"
self.message_history_window_size = 22
self.query = []
self.inputs = []
self.debug_inputs = []
self.inputs = {}
self.outputs = {}
self.description = ""
self.max_retries = 0
self.delay_after_error = 2.0
self.exception_method = None
self.exception_default_value = None
self.exception_comment = None
self.exception_goto = None
self.debug_inputs = {}
def set_name(self, name: str):
self._name = name
@ -110,11 +117,15 @@ class ComponentParamBase(ABC):
update_from_raw_conf = conf.get(_IS_RAW_CONF, True)
if update_from_raw_conf:
deprecated_params_set = self._get_or_init_deprecated_params_set()
feeded_deprecated_params_set = self._get_or_init_feeded_deprecated_params_set()
feeded_deprecated_params_set = (
self._get_or_init_feeded_deprecated_params_set()
)
user_feeded_params_set = self._get_or_init_user_feeded_params_set()
setattr(self, _IS_RAW_CONF, False)
else:
feeded_deprecated_params_set = self._get_or_init_feeded_deprecated_params_set(conf)
feeded_deprecated_params_set = (
self._get_or_init_feeded_deprecated_params_set(conf)
)
user_feeded_params_set = self._get_or_init_user_feeded_params_set(conf)
def _recursive_update_param(param, config, depth, prefix):
@ -150,11 +161,15 @@ class ComponentParamBase(ABC):
else:
# recursive set obj attr
sub_params = _recursive_update_param(attr, config_value, depth + 1, prefix=f"{prefix}{config_key}.")
sub_params = _recursive_update_param(
attr, config_value, depth + 1, prefix=f"{prefix}{config_key}."
)
setattr(param, config_key, sub_params)
if not allow_redundant and redundant_attrs:
raise ValueError(f"cpn `{getattr(self, '_name', type(self))}` has redundant parameters: `{[redundant_attrs]}`")
raise ValueError(
f"cpn `{getattr(self, '_name', type(self))}` has redundant parameters: `{[redundant_attrs]}`"
)
return param
@ -185,7 +200,9 @@ class ComponentParamBase(ABC):
param_validation_path_prefix = home_dir + "/param_validation/"
param_name = type(self).__name__
param_validation_path = "/".join([param_validation_path_prefix, param_name + ".json"])
param_validation_path = "/".join(
[param_validation_path_prefix, param_name + ".json"]
)
validation_json = None
@ -218,7 +235,11 @@ class ComponentParamBase(ABC):
break
if not value_legal:
raise ValueError("Plase check runtime conf, {} = {} does not match user-parameter restriction".format(variable, value))
raise ValueError(
"Plase check runtime conf, {} = {} does not match user-parameter restriction".format(
variable, value
)
)
elif variable in validation_json:
self._validate_param(attr, validation_json)
@ -226,63 +247,94 @@ class ComponentParamBase(ABC):
@staticmethod
def check_string(param, descr):
if type(param).__name__ not in ["str"]:
raise ValueError(descr + " {} not supported, should be string type".format(param))
raise ValueError(
descr + " {} not supported, should be string type".format(param)
)
@staticmethod
def check_empty(param, descr):
if not param:
raise ValueError(descr + " does not support empty value.")
raise ValueError(
descr + " does not support empty value."
)
@staticmethod
def check_positive_integer(param, descr):
if type(param).__name__ not in ["int", "long"] or param <= 0:
raise ValueError(descr + " {} not supported, should be positive integer".format(param))
raise ValueError(
descr + " {} not supported, should be positive integer".format(param)
)
@staticmethod
def check_positive_number(param, descr):
if type(param).__name__ not in ["float", "int", "long"] or param <= 0:
raise ValueError(descr + " {} not supported, should be positive numeric".format(param))
raise ValueError(
descr + " {} not supported, should be positive numeric".format(param)
)
@staticmethod
def check_nonnegative_number(param, descr):
if type(param).__name__ not in ["float", "int", "long"] or param < 0:
raise ValueError(descr + " {} not supported, should be non-negative numeric".format(param))
raise ValueError(
descr
+ " {} not supported, should be non-negative numeric".format(param)
)
@staticmethod
def check_decimal_float(param, descr):
if type(param).__name__ not in ["float", "int"] or param < 0 or param > 1:
raise ValueError(descr + " {} not supported, should be a float number in range [0, 1]".format(param))
raise ValueError(
descr
+ " {} not supported, should be a float number in range [0, 1]".format(
param
)
)
@staticmethod
def check_boolean(param, descr):
if type(param).__name__ != "bool":
raise ValueError(descr + " {} not supported, should be bool type".format(param))
raise ValueError(
descr + " {} not supported, should be bool type".format(param)
)
@staticmethod
def check_open_unit_interval(param, descr):
if type(param).__name__ not in ["float"] or param <= 0 or param >= 1:
raise ValueError(descr + " should be a numeric number between 0 and 1 exclusively")
raise ValueError(
descr + " should be a numeric number between 0 and 1 exclusively"
)
@staticmethod
def check_valid_value(param, descr, valid_values):
if param not in valid_values:
raise ValueError(descr + " {} is not supported, it should be in {}".format(param, valid_values))
raise ValueError(
descr
+ " {} is not supported, it should be in {}".format(param, valid_values)
)
@staticmethod
def check_defined_type(param, descr, types):
if type(param).__name__ not in types:
raise ValueError(descr + " {} not supported, should be one of {}".format(param, types))
raise ValueError(
descr + " {} not supported, should be one of {}".format(param, types)
)
@staticmethod
def check_and_change_lower(param, valid_list, descr=""):
if type(param).__name__ != "str":
raise ValueError(descr + " {} not supported, should be one of {}".format(param, valid_list))
raise ValueError(
descr
+ " {} not supported, should be one of {}".format(param, valid_list)
)
lower_param = param.lower()
if lower_param in valid_list:
return lower_param
else:
raise ValueError(descr + " {} not supported, should be one of {}".format(param, valid_list))
raise ValueError(
descr
+ " {} not supported, should be one of {}".format(param, valid_list)
)
@staticmethod
def _greater_equal_than(value, limit):
@ -296,7 +348,11 @@ class ComponentParamBase(ABC):
def _range(value, ranges):
in_range = False
for left_limit, right_limit in ranges:
if left_limit - settings.FLOAT_ZERO <= value <= right_limit + settings.FLOAT_ZERO:
if (
left_limit - settings.FLOAT_ZERO
<= value
<= right_limit + settings.FLOAT_ZERO
):
in_range = True
break
@ -312,17 +368,24 @@ class ComponentParamBase(ABC):
def _warn_deprecated_param(self, param_name, descr):
if self._deprecated_params_set.get(param_name):
logging.warning(f"{descr} {param_name} is deprecated and ignored in this version.")
logging.warning(
f"{descr} {param_name} is deprecated and ignored in this version."
)
def _warn_to_deprecate_param(self, param_name, descr, new_param):
if self._deprecated_params_set.get(param_name):
logging.warning(f"{descr} {param_name} will be deprecated in future release; please use {new_param} instead.")
logging.warning(
f"{descr} {param_name} will be deprecated in future release; "
f"please use {new_param} instead."
)
return True
return False
class ComponentBase(ABC):
component_name: str
thread_limiter = trio.CapacityLimiter(int(os.environ.get('MAX_CONCURRENT_CHATS', 10)))
variable_ref_patt = r"\{* *\{([a-zA-Z:0-9]+@[A-Za-z:0-9_.-]+|sys\.[a-z_]+)\} *\}*"
def __str__(self):
"""
@ -331,232 +394,144 @@ class ComponentBase(ABC):
"params": {}
}
"""
out = getattr(self._param, self._param.output_var_name)
if isinstance(out, pd.DataFrame) and "chunks" in out:
del out["chunks"]
setattr(self._param, self._param.output_var_name, out)
return """{{
"component_name": "{}",
"params": {},
"output": {},
"inputs": {}
}}""".format(
self.component_name,
self._param,
json.dumps(json.loads(str(self._param)).get("output", {}), ensure_ascii=False),
json.dumps(json.loads(str(self._param)).get("inputs", []), ensure_ascii=False),
"params": {}
}}""".format(self.component_name,
self._param
)
def __init__(self, canvas, id, param: ComponentParamBase):
from agent.canvas import Canvas # Local import to avoid cyclic dependency
assert isinstance(canvas, Canvas), "canvas must be an instance of Canvas"
self._canvas = canvas
self._id = id
self._param = param
self._param.check()
def get_dependent_components(self):
cpnts = set(
[
para["component_id"].split("@")[0]
for para in self._param.query
if para.get("component_id") and para["component_id"].lower().find("answer") < 0 and para["component_id"].lower().find("begin") < 0
]
)
return list(cpnts)
def run(self, history, **kwargs):
logging.debug("{}, history: {}, kwargs: {}".format(self, json.dumps(history, ensure_ascii=False), json.dumps(kwargs, ensure_ascii=False)))
self._param.debug_inputs = []
def invoke(self, **kwargs) -> dict[str, Any]:
self.set_output("_created_time", time.perf_counter())
try:
res = self._run(history, **kwargs)
self.set_output(res)
self._invoke(**kwargs)
except Exception as e:
self.set_output(pd.DataFrame([{"content": str(e)}]))
raise e
self._param.outputs["_ERROR"] = {"value": str(e)}
logging.exception(e)
self._param.debug_inputs = {}
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
return self.output()
return res
def _run(self, history, **kwargs):
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))
def _invoke(self, **kwargs):
raise NotImplementedError()
def output(self, allow_partial=True) -> Tuple[str, Union[pd.DataFrame, partial]]:
o = getattr(self._param, self._param.output_var_name)
if not isinstance(o, partial):
if not isinstance(o, pd.DataFrame):
if isinstance(o, list):
return self._param.output_var_name, pd.DataFrame(o).dropna()
if o is None:
return self._param.output_var_name, pd.DataFrame()
return self._param.output_var_name, pd.DataFrame([{"content": str(o)}])
return self._param.output_var_name, o
def output(self, var_nm: str=None) -> Union[dict[str, Any], Any]:
if var_nm:
return self._param.outputs.get(var_nm, {}).get("value")
return {k: o.get("value") for k,o in self._param.outputs.items()}
if allow_partial or not isinstance(o, partial):
if not isinstance(o, partial) and not isinstance(o, pd.DataFrame):
return pd.DataFrame(o if isinstance(o, list) else [o]).dropna()
return self._param.output_var_name, o
def set_output(self, key: str, value: Any):
if key not in self._param.outputs:
self._param.outputs[key] = {"value": None, "type": str(type(value))}
self._param.outputs[key]["value"] = value
outs = None
for oo in o():
if not isinstance(oo, pd.DataFrame):
outs = pd.DataFrame(oo if isinstance(oo, list) else [oo]).dropna()
else:
outs = oo.dropna()
return self._param.output_var_name, outs
def error(self):
return self._param.outputs.get("_ERROR", {}).get("value")
def reset(self):
setattr(self._param, self._param.output_var_name, None)
self._param.inputs = []
for k in self._param.outputs.keys():
self._param.outputs[k]["value"] = None
for k in self._param.inputs.keys():
self._param.inputs[k]["value"] = None
self._param.debug_inputs = {}
def set_output(self, v):
setattr(self._param, self._param.output_var_name, v)
def get_input(self, key: str=None) -> Union[Any, dict[str, Any]]:
if key:
return self._param.inputs.get(key, {}).get("value")
def set_infor(self, v):
setattr(self._param, self._param.infor_var_name, v)
def _fetch_outputs_from(self, sources: list[dict[str, Any]]) -> list[pd.DataFrame]:
outs = []
for q in sources:
if q.get("component_id"):
if "@" in q["component_id"] and q["component_id"].split("@")[0].lower().find("begin") >= 0:
cpn_id, key = q["component_id"].split("@")
for p in self._canvas.get_component(cpn_id)["obj"]._param.query:
if p["key"] == key:
outs.append(pd.DataFrame([{"content": p.get("value", "")}]))
break
else:
assert False, f"Can't find parameter '{key}' for {cpn_id}"
continue
if q["component_id"].lower().find("answer") == 0:
txt = []
for r, c in self._canvas.history[::-1][: self._param.message_history_window_size][::-1]:
txt.append(f"{r.upper()}:{c}")
txt = "\n".join(txt)
outs.append(pd.DataFrame([{"content": txt}]))
continue
outs.append(self._canvas.get_component(q["component_id"])["obj"].output(allow_partial=False)[1])
elif q.get("value"):
outs.append(pd.DataFrame([{"content": q["value"]}]))
return outs
def get_input(self):
if self._param.debug_inputs:
return pd.DataFrame([{"content": v["value"]} for v in self._param.debug_inputs if v.get("value")])
reversed_cpnts = []
if len(self._canvas.path) > 1:
reversed_cpnts.extend(self._canvas.path[-2])
reversed_cpnts.extend(self._canvas.path[-1])
up_cpns = self.get_upstream()
reversed_up_cpnts = [cpn for cpn in reversed_cpnts if cpn in up_cpns]
if self._param.query:
self._param.inputs = []
outs = self._fetch_outputs_from(self._param.query)
for out in outs:
records = out.to_dict("records")
content: str
if len(records) > 1:
content = "\n".join([str(d["content"]) for d in records])
else:
content = records[0]["content"]
self._param.inputs.append({"component_id": records[0].get("component_id"), "content": content})
if outs:
df = pd.concat(outs, ignore_index=True)
if "content" in df:
df = df.drop_duplicates(subset=["content"]).reset_index(drop=True)
return df
upstream_outs = []
for u in reversed_up_cpnts[::-1]:
if self.get_component_name(u) in ["switch", "concentrator"]:
res = {}
for var, o in self.get_input_elements().items():
v = self.get_param(var)
if v is None:
continue
if self.component_name.lower() == "generate" and self.get_component_name(u) == "retrieval":
o = self._canvas.get_component(u)["obj"].output(allow_partial=False)[1]
if o is not None:
o["component_id"] = u
upstream_outs.append(o)
continue
# if self.component_name.lower()!="answer" and u not in self._canvas.get_component(self._id)["upstream"]: continue
if self.component_name.lower().find("switch") < 0 and self.get_component_name(u) in ["relevant", "categorize"]:
continue
if u.lower().find("answer") >= 0:
for r, c in self._canvas.history[::-1]:
if r == "user":
upstream_outs.append(pd.DataFrame([{"content": c, "component_id": u}]))
break
break
if self.component_name.lower().find("answer") >= 0 and self.get_component_name(u) in ["relevant"]:
continue
o = self._canvas.get_component(u)["obj"].output(allow_partial=False)[1]
if o is not None:
o["component_id"] = u
upstream_outs.append(o)
break
assert upstream_outs, "Can't inference the where the component input is. Please identify whose output is this component's input."
df = pd.concat(upstream_outs, ignore_index=True)
if "content" in df:
df = df.drop_duplicates(subset=["content"]).reset_index(drop=True)
self._param.inputs = []
for _, r in df.iterrows():
self._param.inputs.append({"component_id": r["component_id"], "content": r["content"]})
return df
def get_input_elements(self):
assert self._param.query, "Please verify the input parameters first."
eles = []
for q in self._param.query:
if q.get("component_id"):
cpn_id = q["component_id"]
if cpn_id.split("@")[0].lower().find("begin") >= 0:
cpn_id, key = cpn_id.split("@")
eles.extend(self._canvas.get_component(cpn_id)["obj"]._param.query)
continue
eles.append({"name": self._canvas.get_component_name(cpn_id), "key": cpn_id})
if isinstance(v, str) and self._canvas.is_reff(v):
self.set_input_value(var, self._canvas.get_variable_value(v))
else:
eles.append({"key": q["value"], "name": q["value"], "value": q["value"]})
return eles
self.set_input_value(var, v)
res[var] = self.get_input_value(var)
return res
def get_stream_input(self):
reversed_cpnts = []
if len(self._canvas.path) > 1:
reversed_cpnts.extend(self._canvas.path[-2])
reversed_cpnts.extend(self._canvas.path[-1])
up_cpns = self.get_upstream()
reversed_up_cpnts = [cpn for cpn in reversed_cpnts if cpn in up_cpns]
def get_input_values(self) -> Union[Any, dict[str, Any]]:
if self._param.debug_inputs:
return self._param.debug_inputs
for u in reversed_up_cpnts[::-1]:
if self.get_component_name(u) in ["switch", "answer"]:
continue
return self._canvas.get_component(u)["obj"].output()[1]
return {var: self.get_input_value(var) for var, o in self.get_input_elements().items()}
@staticmethod
def be_output(v):
return pd.DataFrame([{"content": v}])
def get_input_elements_from_text(self, txt: str) -> dict[str, dict[str, str]]:
res = {}
for r in re.finditer(self.variable_ref_patt, txt, flags=re.IGNORECASE):
exp = r.group(1)
cpn_id, var_nm = exp.split("@") if exp.find("@")>0 else ("", exp)
res[exp] = {
"name": (self._canvas.get_component_name(cpn_id) +f"@{var_nm}") if cpn_id else exp,
"value": self._canvas.get_variable_value(exp),
"_retrival": self._canvas.get_variable_value(f"{cpn_id}@_references") if cpn_id else None,
"_cpn_id": cpn_id
}
return res
def get_component_name(self, cpn_id):
def get_input_elements(self) -> dict[str, Any]:
return self._param.inputs
def get_input_form(self) -> dict[str, dict]:
return self._param.get_input_form()
def set_input_value(self, key: str, value: Any) -> None:
if key not in self._param.inputs:
self._param.inputs[key] = {"value": None}
self._param.inputs[key]["value"] = value
def get_input_value(self, key: str) -> Any:
if key not in self._param.inputs:
return None
return self._param.inputs[key].get("value")
def get_component_name(self, cpn_id) -> str:
return self._canvas.get_component(cpn_id)["obj"].component_name.lower()
def debug(self, **kwargs):
return self._run([], **kwargs)
def get_param(self, name):
if hasattr(self._param, name):
return getattr(self._param, name)
def get_parent(self):
pid = self._canvas.get_component(self._id)["parent_id"]
def debug(self, **kwargs):
return self._invoke(**kwargs)
def get_parent(self) -> Union[object, None]:
pid = self._canvas.get_component(self._id).get("parent_id")
if not pid:
return
return self._canvas.get_component(pid)["obj"]
def get_upstream(self):
cpn_nms = self._canvas.get_component(self._id)["upstream"]
def get_upstream(self) -> List[str]:
cpn_nms = self._canvas.get_component(self._id)['upstream']
return cpn_nms
@staticmethod
def string_format(content: str, kv: dict[str, str]) -> str:
for n, v in kv.items():
content = re.sub(
r"\{%s\}" % re.escape(n), re.escape(v), content
)
return content
def exception_handler(self):
if not self._param.exception_method:
return
return {
"goto": self._param.exception_goto,
"comment": self._param.exception_comment,
"default_value": self._param.exception_default_value
}
def get_exception_default_value(self):
return self._param.exception_default_value