diff --git a/web/src/pages/agent/canvas/node/agent-node.tsx b/web/src/pages/agent/canvas/node/agent-node.tsx index 2a4057e31..731f155a7 100644 --- a/web/src/pages/agent/canvas/node/agent-node.tsx +++ b/web/src/pages/agent/canvas/node/agent-node.tsx @@ -54,13 +54,13 @@ function InnerAgentNode({ type="target" position={Position.Top} isConnectable={false} - id="f" + id={NodeHandleId.AgentTop} > state.deleteAgentToolNodeById, + ); const handleChange = useCallback( (value: string[]) => { @@ -29,11 +31,11 @@ export function ToolPopover({ children }: PropsWithChildren) { nodeId: node?.id, })(); } else { - deleteToolNode(node.id); // TODO: The tool node should be derived from the agent tools data + deleteAgentToolNodeById(node.id); // TODO: The tool node should be derived from the agent tools data } } }, - [addCanvasNode, deleteToolNode, node?.id, updateNodeTools], + [addCanvasNode, deleteAgentToolNodeById, node?.id, updateNodeTools], ); return ( diff --git a/web/src/pages/agent/form/agent-form/tool-popover/use-update-tools.ts b/web/src/pages/agent/form/agent-form/tool-popover/use-update-tools.ts index 3bcf84484..cca31e68d 100644 --- a/web/src/pages/agent/form/agent-form/tool-popover/use-update-tools.ts +++ b/web/src/pages/agent/form/agent-form/tool-popover/use-update-tools.ts @@ -3,7 +3,6 @@ import { AgentFormContext } from '@/pages/agent/context'; import useGraphStore from '@/pages/agent/store'; import { get } from 'lodash'; import { useCallback, useContext, useMemo } from 'react'; -import { useDeleteToolNode } from '../use-delete-tool-node'; export function useGetNodeTools() { const node = useContext(AgentFormContext); @@ -48,7 +47,9 @@ export function useDeleteAgentNodeTools() { const { updateNodeForm } = useGraphStore((state) => state); const tools = useGetNodeTools(); const node = useContext(AgentFormContext); - const { deleteToolNode } = useDeleteToolNode(); + const deleteAgentToolNodeById = useGraphStore( + (state) => state.deleteAgentToolNodeById, + ); const deleteNodeTool = useCallback( (value: string) => () => { @@ -56,11 +57,11 @@ export function useDeleteAgentNodeTools() { if (node?.id) { updateNodeForm(node?.id, nextTools, ['tools']); if (nextTools.length === 0) { - deleteToolNode(node?.id); + deleteAgentToolNodeById(node?.id); } } }, - [deleteToolNode, node?.id, tools, updateNodeForm], + [deleteAgentToolNodeById, node?.id, tools, updateNodeForm], ); return { deleteNodeTool }; diff --git a/web/src/pages/agent/form/agent-form/use-delete-tool-node.ts b/web/src/pages/agent/form/agent-form/use-delete-tool-node.ts deleted file mode 100644 index b32274595..000000000 --- a/web/src/pages/agent/form/agent-form/use-delete-tool-node.ts +++ /dev/null @@ -1,24 +0,0 @@ -import { useCallback } from 'react'; -import { NodeHandleId } from '../../constant'; -import useGraphStore from '../../store'; - -export function useDeleteToolNode() { - const { edges, deleteEdgeById, deleteNodeById } = useGraphStore( - (state) => state, - ); - const deleteToolNode = useCallback( - (agentNodeId: string) => { - const edge = edges.find( - (x) => x.source === agentNodeId && x.sourceHandle === NodeHandleId.Tool, - ); - - if (edge) { - deleteEdgeById(edge.id); - deleteNodeById(edge.target); - } - }, - [deleteEdgeById, deleteNodeById, edges], - ); - - return { deleteToolNode }; -} diff --git a/web/src/pages/agent/hooks/use-add-node.ts b/web/src/pages/agent/hooks/use-add-node.ts index 0bc58d025..2ae543dd0 100644 --- a/web/src/pages/agent/hooks/use-add-node.ts +++ b/web/src/pages/agent/hooks/use-add-node.ts @@ -315,7 +315,11 @@ export function useAddNode(reactFlowInstance?: ReactFlowInstance) { if (agentNode) { // Calculate the coordinates of child nodes to prevent newly added child nodes from covering other child nodes const allChildAgentNodeIds = edges - .filter((x) => x.source === nodeId && x.sourceHandle === 'e') + .filter( + (x) => + x.source === nodeId && + x.sourceHandle === NodeHandleId.AgentBottom, + ) .map((x) => x.target); const xAxises = nodes @@ -334,8 +338,8 @@ export function useAddNode(reactFlowInstance?: ReactFlowInstance) { addEdge({ source: nodeId, target: newNode.id, - sourceHandle: 'e', - targetHandle: 'f', + sourceHandle: NodeHandleId.AgentBottom, + targetHandle: NodeHandleId.AgentTop, }); } } else if (type === Operator.Tool) { diff --git a/web/src/pages/agent/hooks/use-before-delete.tsx b/web/src/pages/agent/hooks/use-before-delete.tsx index 14512ae96..d08333c86 100644 --- a/web/src/pages/agent/hooks/use-before-delete.tsx +++ b/web/src/pages/agent/hooks/use-before-delete.tsx @@ -1,14 +1,18 @@ import { RAGFlowNodeType } from '@/interfaces/database/flow'; -import { OnBeforeDelete } from '@xyflow/react'; +import { Node, OnBeforeDelete } from '@xyflow/react'; import { Operator } from '../constant'; import useGraphStore from '../store'; +import { deleteAllDownstreamAgentsAndTool } from '../utils/delete-node'; const UndeletableNodes = [Operator.Begin, Operator.IterationStart]; export function useBeforeDelete() { - const getOperatorTypeFromId = useGraphStore( - (state) => state.getOperatorTypeFromId, - ); + const { getOperatorTypeFromId, getNode } = useGraphStore((state) => state); + + const agentPredicate = (node: Node) => { + return getOperatorTypeFromId(node.id) === Operator.Agent; + }; + const handleBeforeDelete: OnBeforeDelete = async ({ nodes, // Nodes to be deleted edges, // Edges to be deleted @@ -47,6 +51,27 @@ export function useBeforeDelete() { return true; }); + // Delete the agent and tool nodes downstream of the agent node + if (nodes.some(agentPredicate)) { + nodes.filter(agentPredicate).forEach((node) => { + const { downstreamAgentAndToolEdges, downstreamAgentAndToolNodeIds } = + deleteAllDownstreamAgentsAndTool(node.id, edges); + + downstreamAgentAndToolNodeIds.forEach((nodeId) => { + const currentNode = getNode(nodeId); + if (toBeDeletedNodes.every((x) => x.id !== nodeId) && currentNode) { + toBeDeletedNodes.push(currentNode); + } + }); + + downstreamAgentAndToolEdges.forEach((edge) => { + if (toBeDeletedEdges.every((x) => x.id !== edge.id)) { + toBeDeletedEdges.push(edge); + } + }); + }, []); + } + return { nodes: toBeDeletedNodes, edges: toBeDeletedEdges, diff --git a/web/src/pages/agent/store.ts b/web/src/pages/agent/store.ts index 26bb3750e..31ad0a46b 100644 --- a/web/src/pages/agent/store.ts +++ b/web/src/pages/agent/store.ts @@ -21,7 +21,7 @@ import lodashSet from 'lodash/set'; import { create } from 'zustand'; import { devtools } from 'zustand/middleware'; import { immer } from 'zustand/middleware/immer'; -import { Operator, SwitchElseTo } from './constant'; +import { NodeHandleId, Operator, SwitchElseTo } from './constant'; import { duplicateNodeForm, generateDuplicateNode, @@ -30,6 +30,7 @@ import { isEdgeEqual, mapEdgeMouseEvent, } from './utils'; +import { deleteAllDownstreamAgentsAndTool } from './utils/delete-node'; export type RFState = { nodes: RAGFlowNodeType[]; @@ -70,6 +71,8 @@ export type RFState = { deleteEdge: () => void; deleteEdgeById: (id: string) => void; deleteNodeById: (id: string) => void; + deleteAgentDownstreamNodesById: (id: string) => void; + deleteAgentToolNodeById: (id: string) => void; deleteIterationNodeById: (id: string) => void; deleteEdgeBySourceAndSourceHandle: (connection: Partial) => void; findNodeByName: (operatorName: Operator) => RAGFlowNodeType | undefined; @@ -370,7 +373,16 @@ const useGraphStore = create()( }); }, deleteNodeById: (id: string) => { - const { nodes, edges } = get(); + const { + nodes, + edges, + getOperatorTypeFromId, + deleteAgentDownstreamNodesById, + } = get(); + if (getOperatorTypeFromId(id) === Operator.Agent) { + deleteAgentDownstreamNodesById(id); + return; + } set({ nodes: nodes.filter((node) => node.id !== id), edges: edges @@ -378,6 +390,38 @@ const useGraphStore = create()( .filter((edge) => edge.target !== id), }); }, + deleteAgentDownstreamNodesById: (id) => { + const { edges, nodes } = get(); + + const { downstreamAgentAndToolNodeIds, downstreamAgentAndToolEdges } = + deleteAllDownstreamAgentsAndTool(id, edges); + + set({ + nodes: nodes.filter( + (node) => + !downstreamAgentAndToolNodeIds.some((x) => x === node.id) && + node.id !== id, + ), + edges: edges.filter( + (edge) => + edge.source !== id && + edge.target !== id && + !downstreamAgentAndToolEdges.some((x) => x.id === edge.id), + ), + }); + }, + deleteAgentToolNodeById: (id) => { + const { edges, deleteEdgeById, deleteNodeById } = get(); + + const edge = edges.find( + (x) => x.source === id && x.sourceHandle === NodeHandleId.Tool, + ); + + if (edge) { + deleteEdgeById(edge.id); + deleteNodeById(edge.target); + } + }, deleteIterationNodeById: (id: string) => { const { nodes, edges } = get(); const children = nodes.filter((node) => node.parentId === id); diff --git a/web/src/pages/agent/utils/delete-node.ts b/web/src/pages/agent/utils/delete-node.ts new file mode 100644 index 000000000..4f1c18ebc --- /dev/null +++ b/web/src/pages/agent/utils/delete-node.ts @@ -0,0 +1,34 @@ +import { Edge } from '@xyflow/react'; +import { filterAllDownstreamAgentAndToolNodeIds } from './filter-downstream-nodes'; + +// Delete all downstream agent and tool operators of the current agent operator +export function deleteAllDownstreamAgentsAndTool( + nodeId: string, + edges: Edge[], +) { + const downstreamAgentAndToolNodeIds = filterAllDownstreamAgentAndToolNodeIds( + edges, + [nodeId], + ); + + const downstreamAgentAndToolEdges = downstreamAgentAndToolNodeIds.reduce< + Edge[] + >((pre, cur) => { + const relatedEdges = edges.filter( + (x) => x.source === cur || x.target === cur, + ); + + relatedEdges.forEach((x) => { + if (!pre.some((y) => y.id !== x.id)) { + pre.push(x); + } + }); + + return pre; + }, []); + + return { + downstreamAgentAndToolNodeIds, + downstreamAgentAndToolEdges, + }; +} diff --git a/web/src/pages/agent/utils/filter-downstream-nodes.ts b/web/src/pages/agent/utils/filter-downstream-nodes.ts new file mode 100644 index 000000000..be9b350a9 --- /dev/null +++ b/web/src/pages/agent/utils/filter-downstream-nodes.ts @@ -0,0 +1,31 @@ +import { Edge } from '@xyflow/react'; +import { NodeHandleId } from '../constant'; + +// Get all downstream agent operators of the current agent operator +export function filterAllDownstreamAgentAndToolNodeIds( + edges: Edge[], + nodeIds: string[], +) { + return nodeIds.reduce((pre, nodeId) => { + const currentEdges = edges.filter( + (x) => + x.source === nodeId && + (x.sourceHandle === NodeHandleId.AgentBottom || + x.sourceHandle === NodeHandleId.Tool), + ); + + const downstreamNodeIds: string[] = currentEdges.map((x) => x.target); + + const ids = downstreamNodeIds.concat( + filterAllDownstreamAgentAndToolNodeIds(edges, downstreamNodeIds), + ); + + ids.forEach((x) => { + if (pre.every((y) => y !== x)) { + pre.push(x); + } + }); + + return pre; + }, []); +}