diff --git a/web/src/pages/agent/form/agent-form/index.tsx b/web/src/pages/agent/form/agent-form/index.tsx index fe7b8af86..dabbffc72 100644 --- a/web/src/pages/agent/form/agent-form/index.tsx +++ b/web/src/pages/agent/form/agent-form/index.tsx @@ -17,10 +17,11 @@ import { useContext, useMemo } from 'react'; import { useForm } from 'react-hook-form'; import { useTranslation } from 'react-i18next'; import { z } from 'zod'; -import { NodeHandleId, Operator, initialAgentValues } from '../../constant'; +import { Operator, initialAgentValues } from '../../constant'; import { AgentInstanceContext } from '../../context'; import { INextOperatorForm } from '../../interface'; import useGraphStore from '../../store'; +import { isBottomSubAgent } from '../../utils'; import { Output } from '../components/output'; import { PromptEditor } from '../components/prompt-editor'; import { AgentTools } from './agent-tools'; @@ -57,10 +58,7 @@ const AgentForm = ({ node }: INextOperatorForm) => { const defaultValues = useValues(node); const isSubAgent = useMemo(() => { - const edge = edges.find( - (x) => x.target === node?.id && x.targetHandle === NodeHandleId.AgentTop, - ); - return !!edge; + return isBottomSubAgent(edges, node?.id); }, [edges, node?.id]); const outputList = useMemo(() => { diff --git a/web/src/pages/agent/utils.ts b/web/src/pages/agent/utils.ts index f67042285..c4c0d0021 100644 --- a/web/src/pages/agent/utils.ts +++ b/web/src/pages/agent/utils.ts @@ -15,6 +15,7 @@ import { v4 as uuidv4 } from 'uuid'; import { CategorizeAnchorPointPositions, NoDebugOperatorsList, + NodeHandleId, NodeMap, Operator, } from './constant'; @@ -100,12 +101,20 @@ const buildComponentDownstreamOrUpstream = ( .filter((y) => { const node = nodes.find((x) => x.id === nodeId); let isNotUpstreamTool = true; + let isNotUpstreamAgent = true; if (isBuildDownstream && node?.data.label === Operator.Agent) { - isNotUpstreamTool = !y.target.startsWith(Operator.Tool); // Exclude the tool operator downstream of the agent operator + // Exclude the tool operator downstream of the agent operator + isNotUpstreamTool = !y.target.startsWith(Operator.Tool); + // Exclude the agent operator downstream of the agent operator + isNotUpstreamAgent = !( + y.target.startsWith(Operator.Agent) && + y.targetHandle === NodeHandleId.AgentTop + ); } return ( y[isBuildDownstream ? 'source' : 'target'] === nodeId && - isNotUpstreamTool + isNotUpstreamTool && + isNotUpstreamAgent ); }) .map((y) => y[isBuildDownstream ? 'target' : 'source']); @@ -130,6 +139,25 @@ const removeUselessDataInTheOperator = curry( // return values; // }); +function buildAgentTools(edges: Edge[], nodes: Node[], nodeId: string) { + const node = nodes.find((x) => x.id === nodeId); + const params = { ...(node?.data.form ?? {}) }; + if (node && node.data.label === Operator.Agent) { + const bottomSubAgentEdges = edges.filter( + (x) => x.source === nodeId && x.sourceHandle === NodeHandleId.AgentBottom, + ); + + (params as IAgentForm).tools = (params as IAgentForm).tools.concat( + bottomSubAgentEdges.map((x) => { + const formData = buildAgentTools(edges, nodes, x.target); + + return { component_name: Operator.Agent, params: { ...formData } }; + }), + ); + } + return params; +} + const buildOperatorParams = (operatorName: string) => pipe( removeUselessDataInTheOperator(operatorName), @@ -138,6 +166,13 @@ const buildOperatorParams = (operatorName: string) => const ExcludeOperators = [Operator.Note, Operator.Tool]; +export function isBottomSubAgent(edges: Edge[], nodeId?: string) { + const edge = edges.find( + (x) => x.target === nodeId && x.targetHandle === NodeHandleId.AgentTop, + ); + return !!edge; +} + // construct a dsl based on the node information of the graph export const buildDslComponentsByGraph = ( nodes: RAGFlowNodeType[], @@ -147,18 +182,21 @@ export const buildDslComponentsByGraph = ( const components: DSLComponents = {}; nodes - ?.filter((x) => !ExcludeOperators.some((y) => y === x.data.label)) + ?.filter( + (x) => + !ExcludeOperators.some((y) => y === x.data.label) && + !isBottomSubAgent(edges, x.id), + ) .forEach((x) => { const id = x.id; const operatorName = x.data.label; + + const params = buildAgentTools(edges, nodes, id); components[id] = { obj: { ...(oldDslComponents[id]?.obj ?? {}), component_name: operatorName, - params: - buildOperatorParams(operatorName)( - x.data.form as Record, - ) ?? {}, + params: buildOperatorParams(operatorName)(params) ?? {}, }, downstream: buildComponentDownstreamOrUpstream(edges, id, true, nodes), upstream: buildComponentDownstreamOrUpstream(edges, id, false, nodes),