diff --git a/web/src/pages/agent/canvas/node/agent-node.tsx b/web/src/pages/agent/canvas/node/agent-node.tsx
index 2a4057e31..731f155a7 100644
--- a/web/src/pages/agent/canvas/node/agent-node.tsx
+++ b/web/src/pages/agent/canvas/node/agent-node.tsx
@@ -54,13 +54,13 @@ function InnerAgentNode({
type="target"
position={Position.Top}
isConnectable={false}
- id="f"
+ id={NodeHandleId.AgentTop}
>
state.deleteAgentToolNodeById,
+ );
const handleChange = useCallback(
(value: string[]) => {
@@ -29,11 +31,11 @@ export function ToolPopover({ children }: PropsWithChildren) {
nodeId: node?.id,
})();
} else {
- deleteToolNode(node.id); // TODO: The tool node should be derived from the agent tools data
+ deleteAgentToolNodeById(node.id); // TODO: The tool node should be derived from the agent tools data
}
}
},
- [addCanvasNode, deleteToolNode, node?.id, updateNodeTools],
+ [addCanvasNode, deleteAgentToolNodeById, node?.id, updateNodeTools],
);
return (
diff --git a/web/src/pages/agent/form/agent-form/tool-popover/use-update-tools.ts b/web/src/pages/agent/form/agent-form/tool-popover/use-update-tools.ts
index 3bcf84484..cca31e68d 100644
--- a/web/src/pages/agent/form/agent-form/tool-popover/use-update-tools.ts
+++ b/web/src/pages/agent/form/agent-form/tool-popover/use-update-tools.ts
@@ -3,7 +3,6 @@ import { AgentFormContext } from '@/pages/agent/context';
import useGraphStore from '@/pages/agent/store';
import { get } from 'lodash';
import { useCallback, useContext, useMemo } from 'react';
-import { useDeleteToolNode } from '../use-delete-tool-node';
export function useGetNodeTools() {
const node = useContext(AgentFormContext);
@@ -48,7 +47,9 @@ export function useDeleteAgentNodeTools() {
const { updateNodeForm } = useGraphStore((state) => state);
const tools = useGetNodeTools();
const node = useContext(AgentFormContext);
- const { deleteToolNode } = useDeleteToolNode();
+ const deleteAgentToolNodeById = useGraphStore(
+ (state) => state.deleteAgentToolNodeById,
+ );
const deleteNodeTool = useCallback(
(value: string) => () => {
@@ -56,11 +57,11 @@ export function useDeleteAgentNodeTools() {
if (node?.id) {
updateNodeForm(node?.id, nextTools, ['tools']);
if (nextTools.length === 0) {
- deleteToolNode(node?.id);
+ deleteAgentToolNodeById(node?.id);
}
}
},
- [deleteToolNode, node?.id, tools, updateNodeForm],
+ [deleteAgentToolNodeById, node?.id, tools, updateNodeForm],
);
return { deleteNodeTool };
diff --git a/web/src/pages/agent/form/agent-form/use-delete-tool-node.ts b/web/src/pages/agent/form/agent-form/use-delete-tool-node.ts
deleted file mode 100644
index b32274595..000000000
--- a/web/src/pages/agent/form/agent-form/use-delete-tool-node.ts
+++ /dev/null
@@ -1,24 +0,0 @@
-import { useCallback } from 'react';
-import { NodeHandleId } from '../../constant';
-import useGraphStore from '../../store';
-
-export function useDeleteToolNode() {
- const { edges, deleteEdgeById, deleteNodeById } = useGraphStore(
- (state) => state,
- );
- const deleteToolNode = useCallback(
- (agentNodeId: string) => {
- const edge = edges.find(
- (x) => x.source === agentNodeId && x.sourceHandle === NodeHandleId.Tool,
- );
-
- if (edge) {
- deleteEdgeById(edge.id);
- deleteNodeById(edge.target);
- }
- },
- [deleteEdgeById, deleteNodeById, edges],
- );
-
- return { deleteToolNode };
-}
diff --git a/web/src/pages/agent/hooks/use-add-node.ts b/web/src/pages/agent/hooks/use-add-node.ts
index 0bc58d025..2ae543dd0 100644
--- a/web/src/pages/agent/hooks/use-add-node.ts
+++ b/web/src/pages/agent/hooks/use-add-node.ts
@@ -315,7 +315,11 @@ export function useAddNode(reactFlowInstance?: ReactFlowInstance) {
if (agentNode) {
// Calculate the coordinates of child nodes to prevent newly added child nodes from covering other child nodes
const allChildAgentNodeIds = edges
- .filter((x) => x.source === nodeId && x.sourceHandle === 'e')
+ .filter(
+ (x) =>
+ x.source === nodeId &&
+ x.sourceHandle === NodeHandleId.AgentBottom,
+ )
.map((x) => x.target);
const xAxises = nodes
@@ -334,8 +338,8 @@ export function useAddNode(reactFlowInstance?: ReactFlowInstance) {
addEdge({
source: nodeId,
target: newNode.id,
- sourceHandle: 'e',
- targetHandle: 'f',
+ sourceHandle: NodeHandleId.AgentBottom,
+ targetHandle: NodeHandleId.AgentTop,
});
}
} else if (type === Operator.Tool) {
diff --git a/web/src/pages/agent/hooks/use-before-delete.tsx b/web/src/pages/agent/hooks/use-before-delete.tsx
index 14512ae96..d08333c86 100644
--- a/web/src/pages/agent/hooks/use-before-delete.tsx
+++ b/web/src/pages/agent/hooks/use-before-delete.tsx
@@ -1,14 +1,18 @@
import { RAGFlowNodeType } from '@/interfaces/database/flow';
-import { OnBeforeDelete } from '@xyflow/react';
+import { Node, OnBeforeDelete } from '@xyflow/react';
import { Operator } from '../constant';
import useGraphStore from '../store';
+import { deleteAllDownstreamAgentsAndTool } from '../utils/delete-node';
const UndeletableNodes = [Operator.Begin, Operator.IterationStart];
export function useBeforeDelete() {
- const getOperatorTypeFromId = useGraphStore(
- (state) => state.getOperatorTypeFromId,
- );
+ const { getOperatorTypeFromId, getNode } = useGraphStore((state) => state);
+
+ const agentPredicate = (node: Node) => {
+ return getOperatorTypeFromId(node.id) === Operator.Agent;
+ };
+
const handleBeforeDelete: OnBeforeDelete = async ({
nodes, // Nodes to be deleted
edges, // Edges to be deleted
@@ -47,6 +51,27 @@ export function useBeforeDelete() {
return true;
});
+ // Delete the agent and tool nodes downstream of the agent node
+ if (nodes.some(agentPredicate)) {
+ nodes.filter(agentPredicate).forEach((node) => {
+ const { downstreamAgentAndToolEdges, downstreamAgentAndToolNodeIds } =
+ deleteAllDownstreamAgentsAndTool(node.id, edges);
+
+ downstreamAgentAndToolNodeIds.forEach((nodeId) => {
+ const currentNode = getNode(nodeId);
+ if (toBeDeletedNodes.every((x) => x.id !== nodeId) && currentNode) {
+ toBeDeletedNodes.push(currentNode);
+ }
+ });
+
+ downstreamAgentAndToolEdges.forEach((edge) => {
+ if (toBeDeletedEdges.every((x) => x.id !== edge.id)) {
+ toBeDeletedEdges.push(edge);
+ }
+ });
+ }, []);
+ }
+
return {
nodes: toBeDeletedNodes,
edges: toBeDeletedEdges,
diff --git a/web/src/pages/agent/store.ts b/web/src/pages/agent/store.ts
index 26bb3750e..31ad0a46b 100644
--- a/web/src/pages/agent/store.ts
+++ b/web/src/pages/agent/store.ts
@@ -21,7 +21,7 @@ import lodashSet from 'lodash/set';
import { create } from 'zustand';
import { devtools } from 'zustand/middleware';
import { immer } from 'zustand/middleware/immer';
-import { Operator, SwitchElseTo } from './constant';
+import { NodeHandleId, Operator, SwitchElseTo } from './constant';
import {
duplicateNodeForm,
generateDuplicateNode,
@@ -30,6 +30,7 @@ import {
isEdgeEqual,
mapEdgeMouseEvent,
} from './utils';
+import { deleteAllDownstreamAgentsAndTool } from './utils/delete-node';
export type RFState = {
nodes: RAGFlowNodeType[];
@@ -70,6 +71,8 @@ export type RFState = {
deleteEdge: () => void;
deleteEdgeById: (id: string) => void;
deleteNodeById: (id: string) => void;
+ deleteAgentDownstreamNodesById: (id: string) => void;
+ deleteAgentToolNodeById: (id: string) => void;
deleteIterationNodeById: (id: string) => void;
deleteEdgeBySourceAndSourceHandle: (connection: Partial) => void;
findNodeByName: (operatorName: Operator) => RAGFlowNodeType | undefined;
@@ -370,7 +373,16 @@ const useGraphStore = create()(
});
},
deleteNodeById: (id: string) => {
- const { nodes, edges } = get();
+ const {
+ nodes,
+ edges,
+ getOperatorTypeFromId,
+ deleteAgentDownstreamNodesById,
+ } = get();
+ if (getOperatorTypeFromId(id) === Operator.Agent) {
+ deleteAgentDownstreamNodesById(id);
+ return;
+ }
set({
nodes: nodes.filter((node) => node.id !== id),
edges: edges
@@ -378,6 +390,38 @@ const useGraphStore = create()(
.filter((edge) => edge.target !== id),
});
},
+ deleteAgentDownstreamNodesById: (id) => {
+ const { edges, nodes } = get();
+
+ const { downstreamAgentAndToolNodeIds, downstreamAgentAndToolEdges } =
+ deleteAllDownstreamAgentsAndTool(id, edges);
+
+ set({
+ nodes: nodes.filter(
+ (node) =>
+ !downstreamAgentAndToolNodeIds.some((x) => x === node.id) &&
+ node.id !== id,
+ ),
+ edges: edges.filter(
+ (edge) =>
+ edge.source !== id &&
+ edge.target !== id &&
+ !downstreamAgentAndToolEdges.some((x) => x.id === edge.id),
+ ),
+ });
+ },
+ deleteAgentToolNodeById: (id) => {
+ const { edges, deleteEdgeById, deleteNodeById } = get();
+
+ const edge = edges.find(
+ (x) => x.source === id && x.sourceHandle === NodeHandleId.Tool,
+ );
+
+ if (edge) {
+ deleteEdgeById(edge.id);
+ deleteNodeById(edge.target);
+ }
+ },
deleteIterationNodeById: (id: string) => {
const { nodes, edges } = get();
const children = nodes.filter((node) => node.parentId === id);
diff --git a/web/src/pages/agent/utils/delete-node.ts b/web/src/pages/agent/utils/delete-node.ts
new file mode 100644
index 000000000..4f1c18ebc
--- /dev/null
+++ b/web/src/pages/agent/utils/delete-node.ts
@@ -0,0 +1,34 @@
+import { Edge } from '@xyflow/react';
+import { filterAllDownstreamAgentAndToolNodeIds } from './filter-downstream-nodes';
+
+// Delete all downstream agent and tool operators of the current agent operator
+export function deleteAllDownstreamAgentsAndTool(
+ nodeId: string,
+ edges: Edge[],
+) {
+ const downstreamAgentAndToolNodeIds = filterAllDownstreamAgentAndToolNodeIds(
+ edges,
+ [nodeId],
+ );
+
+ const downstreamAgentAndToolEdges = downstreamAgentAndToolNodeIds.reduce<
+ Edge[]
+ >((pre, cur) => {
+ const relatedEdges = edges.filter(
+ (x) => x.source === cur || x.target === cur,
+ );
+
+ relatedEdges.forEach((x) => {
+ if (!pre.some((y) => y.id !== x.id)) {
+ pre.push(x);
+ }
+ });
+
+ return pre;
+ }, []);
+
+ return {
+ downstreamAgentAndToolNodeIds,
+ downstreamAgentAndToolEdges,
+ };
+}
diff --git a/web/src/pages/agent/utils/filter-downstream-nodes.ts b/web/src/pages/agent/utils/filter-downstream-nodes.ts
new file mode 100644
index 000000000..be9b350a9
--- /dev/null
+++ b/web/src/pages/agent/utils/filter-downstream-nodes.ts
@@ -0,0 +1,31 @@
+import { Edge } from '@xyflow/react';
+import { NodeHandleId } from '../constant';
+
+// Get all downstream agent operators of the current agent operator
+export function filterAllDownstreamAgentAndToolNodeIds(
+ edges: Edge[],
+ nodeIds: string[],
+) {
+ return nodeIds.reduce((pre, nodeId) => {
+ const currentEdges = edges.filter(
+ (x) =>
+ x.source === nodeId &&
+ (x.sourceHandle === NodeHandleId.AgentBottom ||
+ x.sourceHandle === NodeHandleId.Tool),
+ );
+
+ const downstreamNodeIds: string[] = currentEdges.map((x) => x.target);
+
+ const ids = downstreamNodeIds.concat(
+ filterAllDownstreamAgentAndToolNodeIds(edges, downstreamNodeIds),
+ );
+
+ ids.forEach((x) => {
+ if (pre.every((y) => y !== x)) {
+ pre.push(x);
+ }
+ });
+
+ return pre;
+ }, []);
+}