mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Feat: Redesign and refactor agent module (#9113)
### What problem does this PR solve? #9082 #6365 <u> **WARNING: it's not compatible with the older version of `Agent` module, which means that `Agent` from older versions can not work anymore.**</u> ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -13,8 +13,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import numbers
|
||||
import os
|
||||
from abc import ABC
|
||||
from typing import Any
|
||||
|
||||
from agent.component.base import ComponentBase, ComponentParamBase
|
||||
from api.utils.api_utils import timeout
|
||||
|
||||
|
||||
class SwitchParam(ComponentParamBase):
|
||||
@ -34,7 +39,7 @@ class SwitchParam(ComponentParamBase):
|
||||
}
|
||||
"""
|
||||
self.conditions = []
|
||||
self.end_cpn_id = "answer:0"
|
||||
self.end_cpn_ids = []
|
||||
self.operators = ['contains', 'not contains', 'start with', 'end with', 'empty', 'not empty', '=', '≠', '>',
|
||||
'<', '≥', '≤']
|
||||
|
||||
@ -43,54 +48,46 @@ class SwitchParam(ComponentParamBase):
|
||||
for cond in self.conditions:
|
||||
if not cond["to"]:
|
||||
raise ValueError("[Switch] 'To' can not be empty!")
|
||||
self.check_empty(self.end_cpn_ids, "[Switch] the ELSE/Other destination can not be empty.")
|
||||
|
||||
def get_input_form(self) -> dict[str, dict]:
|
||||
return {
|
||||
"urls": {
|
||||
"name": "URLs",
|
||||
"type": "line"
|
||||
}
|
||||
}
|
||||
|
||||
class Switch(ComponentBase, ABC):
|
||||
component_name = "Switch"
|
||||
|
||||
def get_dependent_components(self):
|
||||
res = []
|
||||
for cond in self._param.conditions:
|
||||
for item in cond["items"]:
|
||||
if not item["cpn_id"]:
|
||||
continue
|
||||
if item["cpn_id"].lower().find("begin") >= 0 or item["cpn_id"].lower().find("answer") >= 0:
|
||||
continue
|
||||
cid = item["cpn_id"].split("@")[0]
|
||||
res.append(cid)
|
||||
|
||||
return list(set(res))
|
||||
|
||||
def _run(self, history, **kwargs):
|
||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 3))
|
||||
def _invoke(self, **kwargs):
|
||||
for cond in self._param.conditions:
|
||||
res = []
|
||||
for item in cond["items"]:
|
||||
if not item["cpn_id"]:
|
||||
continue
|
||||
cid = item["cpn_id"].split("@")[0]
|
||||
if item["cpn_id"].find("@") > 0:
|
||||
cpn_id, key = item["cpn_id"].split("@")
|
||||
for p in self._canvas.get_component(cid)["obj"]._param.query:
|
||||
if p["key"] == key:
|
||||
res.append(self.process_operator(p.get("value",""), item["operator"], item.get("value", "")))
|
||||
break
|
||||
else:
|
||||
out = self._canvas.get_component(cid)["obj"].output(allow_partial=False)[1]
|
||||
cpn_input = "" if "content" not in out.columns else " ".join([str(s) for s in out["content"]])
|
||||
res.append(self.process_operator(cpn_input, item["operator"], item.get("value", "")))
|
||||
|
||||
cpn_v = self._canvas.get_variable_value(item["cpn_id"])
|
||||
self.set_input_value(item["cpn_id"], cpn_v)
|
||||
operatee = item.get("value", "")
|
||||
if isinstance(cpn_v, numbers.Number):
|
||||
operatee = float(operatee)
|
||||
res.append(self.process_operator(cpn_v, item["operator"], operatee))
|
||||
if cond["logical_operator"] != "and" and any(res):
|
||||
return Switch.be_output(cond["to"])
|
||||
self.set_output("next", [self._canvas.get_component_name(cpn_id) for cpn_id in cond["to"]])
|
||||
self.set_output("_next", cond["to"])
|
||||
return
|
||||
|
||||
if all(res):
|
||||
return Switch.be_output(cond["to"])
|
||||
self.set_output("next", [self._canvas.get_component_name(cpn_id) for cpn_id in cond["to"]])
|
||||
self.set_output("_next", cond["to"])
|
||||
return
|
||||
|
||||
return Switch.be_output(self._param.end_cpn_id)
|
||||
|
||||
def process_operator(self, input: str, operator: str, value: str) -> bool:
|
||||
if not isinstance(input, str) or not isinstance(value, str):
|
||||
raise ValueError('Invalid input or value type: string')
|
||||
self.set_output("next", [self._canvas.get_component_name(cpn_id) for cpn_id in self._param.end_cpn_ids])
|
||||
self.set_output("_next", self._param.end_cpn_ids)
|
||||
|
||||
def process_operator(self, input: Any, operator: str, value: Any) -> bool:
|
||||
if operator == "contains":
|
||||
return True if value.lower() in input.lower() else False
|
||||
elif operator == "not contains":
|
||||
|
||||
Reference in New Issue
Block a user