From 07e37560fc59cf6e5feead51e1134ea8f32d1cad Mon Sep 17 00:00:00 2001 From: balibabu Date: Wed, 30 Jul 2025 19:14:33 +0800 Subject: [PATCH] Feat: Handling abnormal anchor points of agent operators #3221 (#9121) ### What problem does this PR solve? Feat: Handling abnormal anchor points of agent operators #3221 ### Type of change - [x] New Feature (non-breaking change which adds functionality) --- .../pages/agent/canvas/node/agent-node.tsx | 39 ++++++- web/src/pages/agent/constant.tsx | 22 ++-- web/src/pages/agent/form-hooks.ts | 34 ------ web/src/pages/agent/form/agent-form/index.tsx | 62 ++++++---- .../categorize-form/dynamic-categorize.tsx | 2 +- .../agent/form/tavily-extract-form/index.tsx | 8 +- .../pages/agent/form/tavily-form/index.tsx | 17 +-- web/src/pages/agent/store.ts | 21 +--- web/src/pages/agent/utils.test.ts | 106 ------------------ web/src/pages/agent/utils.ts | 91 +++------------ 10 files changed, 125 insertions(+), 277 deletions(-) delete mode 100644 web/src/pages/agent/utils.test.ts diff --git a/web/src/pages/agent/canvas/node/agent-node.tsx b/web/src/pages/agent/canvas/node/agent-node.tsx index 1c9116ec3..dd61d2711 100644 --- a/web/src/pages/agent/canvas/node/agent-node.tsx +++ b/web/src/pages/agent/canvas/node/agent-node.tsx @@ -1,7 +1,9 @@ +import LLMLabel from '@/components/llm-select/llm-label'; import { IAgentNode } from '@/interfaces/database/flow'; import { Handle, NodeProps, Position } from '@xyflow/react'; +import { get } from 'lodash'; import { memo, useMemo } from 'react'; -import { NodeHandleId } from '../../constant'; +import { AgentExceptionMethod, NodeHandleId } from '../../constant'; import useGraphStore from '../../store'; import { isBottomSubAgent } from '../../utils'; import { CommonHandle } from './handle'; @@ -23,6 +25,14 @@ function InnerAgentNode({ return !isBottomSubAgent(edges, id); }, [edges, id]); + const exceptionMethod = useMemo(() => { + return get(data, 'form.exception_method'); + }, [data]); + + const isGotoMethod = useMemo(() => { + return exceptionMethod === AgentExceptionMethod.Goto; + }, [exceptionMethod]); + return ( @@ -48,6 +58,7 @@ function InnerAgentNode({ > )} + +
+
+ +
+ {(isGotoMethod || + exceptionMethod === AgentExceptionMethod.Comment) && ( +
+ Abnormal + + {isGotoMethod ? 'Exception branch' : 'Output default value'} + +
+ )} +
+ {isGotoMethod && ( + + )}
); diff --git a/web/src/pages/agent/constant.tsx b/web/src/pages/agent/constant.tsx index 8318f0d46..8e6af99fe 100644 --- a/web/src/pages/agent/constant.tsx +++ b/web/src/pages/agent/constant.tsx @@ -644,19 +644,20 @@ export const initialAgentValues = { max_rounds: 5, exception_method: null, exception_comment: '', - exception_goto: '', + exception_goto: [], + exception_default_value: '', tools: [], mcp: [], outputs: { - structured_output: { - // topic: { - // type: 'string', - // description: - // 'default:general. The category of the search.news is useful for retrieving real-time updates, particularly about politics, sports, and major current events covered by mainstream media sources. general is for broader, more general-purpose searches that may include a wide range of sources.', - // enum: ['general', 'news'], - // default: 'general', - // }, - }, + // structured_output: { + // topic: { + // type: 'string', + // description: + // 'default:general. The category of the search.news is useful for retrieving real-time updates, particularly about politics, sports, and major current events covered by mainstream media sources. general is for broader, more general-purpose searches that may include a wide range of sources.', + // enum: ['general', 'news'], + // default: 'general', + // }, + // }, content: { type: 'string', value: '', @@ -932,6 +933,7 @@ export enum NodeHandleId { Tool = 'tool', AgentTop = 'agentTop', AgentBottom = 'agentBottom', + AgentException = 'agentException', } export enum VariableType { diff --git a/web/src/pages/agent/form-hooks.ts b/web/src/pages/agent/form-hooks.ts index 372288e6f..fef7d37b9 100644 --- a/web/src/pages/agent/form-hooks.ts +++ b/web/src/pages/agent/form-hooks.ts @@ -30,40 +30,6 @@ export const useBuildFormSelectOptions = ( return buildCategorizeToOptions; }; -/** - * dumped - * @param nodeId - * @returns - */ -export const useHandleFormSelectChange = (nodeId?: string) => { - const { addEdge, deleteEdgeBySourceAndSourceHandle } = useGraphStore( - (state) => state, - ); - const handleSelectChange = useCallback( - (name?: string) => (value?: string) => { - if (nodeId && name) { - if (value) { - addEdge({ - source: nodeId, - target: value, - sourceHandle: name, - targetHandle: null, - }); - } else { - // clear selected value - deleteEdgeBySourceAndSourceHandle({ - source: nodeId, - sourceHandle: name, - }); - } - } - }, - [addEdge, nodeId, deleteEdgeBySourceAndSourceHandle], - ); - - return { handleSelectChange }; -}; - export const useBuildSortOptions = () => { const { t } = useTranslate('flow'); diff --git a/web/src/pages/agent/form/agent-form/index.tsx b/web/src/pages/agent/form/agent-form/index.tsx index 080f19902..c12b992c9 100644 --- a/web/src/pages/agent/form/agent-form/index.tsx +++ b/web/src/pages/agent/form/agent-form/index.tsx @@ -19,19 +19,22 @@ import { LlmModelType } from '@/constants/knowledge'; import { useFindLlmByUuid } from '@/hooks/use-llm-request'; import { buildOptions } from '@/utils/form'; import { zodResolver } from '@hookform/resolvers/zod'; -import { memo, useMemo } from 'react'; +import { memo, useEffect, useMemo } from 'react'; import { useForm, useWatch } from 'react-hook-form'; import { useTranslation } from 'react-i18next'; import { z } from 'zod'; import { AgentExceptionMethod, + NodeHandleId, VariableType, initialAgentValues, } from '../../constant'; import { INextOperatorForm } from '../../interface'; import useGraphStore from '../../store'; import { isBottomSubAgent } from '../../utils'; +import { buildOutputList } from '../../utils/build-output-list'; import { DescriptionField } from '../components/description-field'; +import { FormWrapper } from '../components/form-wrapper'; import { Output } from '../components/output'; import { PromptEditor } from '../components/prompt-editor'; import { QueryVariable } from '../components/query-variable'; @@ -69,13 +72,18 @@ const FormSchema = z.object({ max_rounds: z.coerce.number().optional(), exception_method: z.string().nullable(), exception_comment: z.string().optional(), - exception_goto: z.string().optional(), + exception_goto: z.array(z.string()).optional(), + exception_default_value: z.string().optional(), ...LargeModelFilterFormSchema, }); +const outputList = buildOutputList(initialAgentValues.outputs); + function AgentForm({ node }: INextOperatorForm) { const { t } = useTranslation(); - const { edges } = useGraphStore((state) => state); + const { edges, deleteEdgesBySourceAndSourceHandle } = useGraphStore( + (state) => state, + ); const defaultValues = useValues(node); @@ -83,12 +91,6 @@ function AgentForm({ node }: INextOperatorForm) { return isBottomSubAgent(edges, node?.id); }, [edges, node?.id]); - const outputList = useMemo(() => { - return [ - { title: 'content', type: initialAgentValues.outputs.content.type }, - ]; - }, []); - const form = useForm>({ defaultValues: defaultValues, resolver: zodResolver(FormSchema), @@ -98,16 +100,27 @@ function AgentForm({ node }: INextOperatorForm) { const findLlmByUuid = useFindLlmByUuid(); + const exceptionMethod = useWatch({ + control: form.control, + name: 'exception_method', + }); + + useEffect(() => { + if (exceptionMethod !== AgentExceptionMethod.Goto) { + if (node?.id) { + deleteEdgesBySourceAndSourceHandle( + node?.id, + NodeHandleId.AgentException, + ); + } + } + }, [deleteEdgesBySourceAndSourceHandle, exceptionMethod, node?.id]); + useWatchFormChange(node?.id, form); return (
- { - e.preventDefault(); - }} - > + {isSubAgent && } @@ -219,6 +232,18 @@ function AgentForm({ node }: INextOperatorForm) { )} /> + ( + + Exception default value + + + + + )} + /> )} /> - - +
); } diff --git a/web/src/pages/agent/form/categorize-form/dynamic-categorize.tsx b/web/src/pages/agent/form/categorize-form/dynamic-categorize.tsx index 51979853e..0807f7bfa 100644 --- a/web/src/pages/agent/form/categorize-form/dynamic-categorize.tsx +++ b/web/src/pages/agent/form/categorize-form/dynamic-categorize.tsx @@ -177,7 +177,7 @@ const DynamicCategorize = ({ nodeId }: IProps) => { const FormSchema = useCreateCategorizeFormSchema(); const deleteCategorizeCaseEdges = useGraphStore( - (state) => state.deleteCategorizeCaseEdges, + (state) => state.deleteEdgesBySourceAndSourceHandle, ); const form = useFormContext>(); const { t } = useTranslate('flow'); diff --git a/web/src/pages/agent/form/tavily-extract-form/index.tsx b/web/src/pages/agent/form/tavily-extract-form/index.tsx index 23cadcefc..87c8ae75e 100644 --- a/web/src/pages/agent/form/tavily-extract-form/index.tsx +++ b/web/src/pages/agent/form/tavily-extract-form/index.tsx @@ -7,7 +7,6 @@ import { FormLabel, FormMessage, } from '@/components/ui/form'; -import { Input } from '@/components/ui/input'; import { RAGFlowSelect } from '@/components/ui/select'; import { buildOptions } from '@/utils/form'; import { zodResolver } from '@hookform/resolvers/zod'; @@ -26,6 +25,7 @@ import { buildOutputList } from '../../utils/build-output-list'; import { ApiKeyField } from '../components/api-key-field'; import { FormWrapper } from '../components/form-wrapper'; import { Output } from '../components/output'; +import { PromptEditor } from '../components/prompt-editor'; import { TavilyFormSchema } from '../tavily-form'; const outputList = buildOutputList(initialTavilyExtractValues.outputs); @@ -64,7 +64,11 @@ function TavilyExtractForm({ node }: INextOperatorForm) { URL - + diff --git a/web/src/pages/agent/form/tavily-form/index.tsx b/web/src/pages/agent/form/tavily-form/index.tsx index f74abb2e0..e1acae98b 100644 --- a/web/src/pages/agent/form/tavily-form/index.tsx +++ b/web/src/pages/agent/form/tavily-form/index.tsx @@ -12,7 +12,7 @@ import { RAGFlowSelect } from '@/components/ui/select'; import { Switch } from '@/components/ui/switch'; import { buildOptions } from '@/utils/form'; import { zodResolver } from '@hookform/resolvers/zod'; -import { memo, useMemo } from 'react'; +import { memo } from 'react'; import { useForm } from 'react-hook-form'; import { z } from 'zod'; import { @@ -21,9 +21,10 @@ import { initialTavilyValues, } from '../../constant'; import { INextOperatorForm } from '../../interface'; +import { buildOutputList } from '../../utils/build-output-list'; import { ApiKeyField } from '../components/api-key-field'; import { FormWrapper } from '../components/form-wrapper'; -import { Output, OutputType } from '../components/output'; +import { Output } from '../components/output'; import { QueryVariable } from '../components/query-variable'; import { DynamicDomain } from './dynamic-domain'; import { useValues } from './use-values'; @@ -33,6 +34,8 @@ export const TavilyFormSchema = { api_key: z.string(), }; +const outputList = buildOutputList(initialTavilyValues.outputs); + function TavilyForm({ node }: INextOperatorForm) { const values = useValues(node); @@ -56,16 +59,6 @@ function TavilyForm({ node }: INextOperatorForm) { resolver: zodResolver(FormSchema), }); - const outputList = useMemo(() => { - return Object.entries(initialTavilyValues.outputs).reduce( - (pre, [key, val]) => { - pre.push({ title: key, type: val.type }); - return pre; - }, - [], - ); - }, []); - useWatchFormChange(node?.id, form); return ( diff --git a/web/src/pages/agent/store.ts b/web/src/pages/agent/store.ts index 66d29edbe..46745f995 100644 --- a/web/src/pages/agent/store.ts +++ b/web/src/pages/agent/store.ts @@ -74,7 +74,6 @@ export type RFState = { deleteAgentDownstreamNodesById: (id: string) => void; deleteAgentToolNodeById: (id: string) => void; deleteIterationNodeById: (id: string) => void; - deleteEdgeBySourceAndSourceHandle: (connection: Partial) => void; findNodeByName: (operatorName: Operator) => RAGFlowNodeType | undefined; updateMutableNodeFormItem: (id: string, field: string, value: any) => void; getOperatorTypeFromId: (id?: string | null) => string | undefined; @@ -84,7 +83,10 @@ export type RFState = { setClickedNodeId: (id?: string) => void; setClickedToolId: (id?: string) => void; findUpstreamNodeById: (id?: string | null) => RAGFlowNodeType | undefined; - deleteCategorizeCaseEdges: (source: string, sourceHandle: string) => void; // Deleting a condition of a classification operator will delete the related edge + deleteEdgesBySourceAndSourceHandle: ( + source: string, + sourceHandle: string, + ) => void; // Deleting a condition of a classification operator will delete the related edge findAgentToolNodeById: (id: string | null) => string | undefined; selectNodeIds: (nodeIds: string[]) => void; }; @@ -330,19 +332,6 @@ const useGraphStore = create()( edges: edges.filter((edge) => edge.id !== id), }); }, - deleteEdgeBySourceAndSourceHandle: ({ - source, - sourceHandle, - }: Partial) => { - const { edges } = get(); - const nextEdges = edges.filter( - (edge) => - edge.source !== source || edge.sourceHandle !== sourceHandle, - ); - set({ - edges: nextEdges, - }); - }, deleteNodeById: (id: string) => { const { nodes, @@ -511,7 +500,7 @@ const useGraphStore = create()( const edge = edges.find((x) => x.target === id); return getNode(edge?.source); }, - deleteCategorizeCaseEdges: (source, sourceHandle) => { + deleteEdgesBySourceAndSourceHandle: (source, sourceHandle) => { const { edges, setEdges } = get(); setEdges( edges.filter( diff --git a/web/src/pages/agent/utils.test.ts b/web/src/pages/agent/utils.test.ts deleted file mode 100644 index dbb89ce7e..000000000 --- a/web/src/pages/agent/utils.test.ts +++ /dev/null @@ -1,106 +0,0 @@ -import fs from 'fs'; -import path from 'path'; -import customer_service from '../../../../graph/test/dsl_examples/customer_service.json'; -import headhunter_zh from '../../../../graph/test/dsl_examples/headhunter_zh.json'; -import interpreter from '../../../../graph/test/dsl_examples/interpreter.json'; -import retrievalRelevantRewriteAndGenerate from '../../../../graph/test/dsl_examples/retrieval_relevant_rewrite_and_generate.json'; -import { dsl } from './mock'; -import { buildNodesAndEdgesFromDSLComponents } from './utils'; - -test('buildNodesAndEdgesFromDSLComponents', () => { - const { edges, nodes } = buildNodesAndEdgesFromDSLComponents(dsl.components); - - expect(nodes.length).toEqual(4); - expect(edges.length).toEqual(4); - - expect(edges).toEqual( - expect.arrayContaining([ - expect.objectContaining({ - source: 'begin', - target: 'Answer:China', - }), - expect.objectContaining({ - source: 'Answer:China', - target: 'Retrieval:China', - }), - expect.objectContaining({ - source: 'Retrieval:China', - target: 'Generate:China', - }), - expect.objectContaining({ - source: 'Generate:China', - target: 'Answer:China', - }), - ]), - ); -}); - -test('build nodes and edges from headhunter_zh dsl', () => { - const { edges, nodes } = buildNodesAndEdgesFromDSLComponents( - headhunter_zh.components, - ); - console.info('node length', nodes.length); - console.info('edge length', edges.length); - try { - fs.writeFileSync( - path.join(__dirname, 'headhunter_zh.json'), - JSON.stringify({ edges, nodes }, null, 4), - ); - console.log('JSON data is saved.'); - } catch (error) { - console.warn(error); - } - expect(nodes.length).toEqual(12); -}); - -test('build nodes and edges from customer_service dsl', () => { - const { edges, nodes } = buildNodesAndEdgesFromDSLComponents( - customer_service.components, - ); - console.info('node length', nodes.length); - console.info('edge length', edges.length); - try { - fs.writeFileSync( - path.join(__dirname, 'customer_service.json'), - JSON.stringify({ edges, nodes }, null, 4), - ); - console.log('JSON data is saved.'); - } catch (error) { - console.warn(error); - } - expect(nodes.length).toEqual(12); -}); - -test('build nodes and edges from interpreter dsl', () => { - const { edges, nodes } = buildNodesAndEdgesFromDSLComponents( - interpreter.components, - ); - console.info('node length', nodes.length); - console.info('edge length', edges.length); - try { - fs.writeFileSync( - path.join(__dirname, 'interpreter.json'), - JSON.stringify({ edges, nodes }, null, 4), - ); - console.log('JSON data is saved.'); - } catch (error) { - console.warn(error); - } - expect(nodes.length).toEqual(12); -}); - -test('build nodes and edges from chat bot dsl', () => { - const { edges, nodes } = buildNodesAndEdgesFromDSLComponents( - retrievalRelevantRewriteAndGenerate.components, - ); - try { - fs.writeFileSync( - path.join(__dirname, 'retrieval_relevant_rewrite_and_generate.json'), - JSON.stringify({ edges, nodes }, null, 4), - ); - console.log('JSON data is saved.'); - } catch (error) { - console.warn(error); - } - expect(nodes.length).toEqual(12); -}); diff --git a/web/src/pages/agent/utils.ts b/web/src/pages/agent/utils.ts index 184cfdc83..c7d60dfd2 100644 --- a/web/src/pages/agent/utils.ts +++ b/web/src/pages/agent/utils.ts @@ -6,91 +6,28 @@ import { } from '@/interfaces/database/agent'; import { DSLComponents, RAGFlowNodeType } from '@/interfaces/database/flow'; import { removeUselessFieldsFromValues } from '@/utils/form'; -import { Edge, Node, Position, XYPosition } from '@xyflow/react'; +import { Edge, Node, XYPosition } from '@xyflow/react'; import { FormInstance, FormListFieldData } from 'antd'; import { humanId } from 'human-id'; import { curry, get, intersectionWith, isEqual, omit, sample } from 'lodash'; import pipe from 'lodash/fp/pipe'; import isObject from 'lodash/isObject'; -import { v4 as uuidv4 } from 'uuid'; import { CategorizeAnchorPointPositions, NoDebugOperatorsList, NodeHandleId, - NodeMap, Operator, } from './constant'; import { BeginQuery, IPosition } from './interface'; -const buildEdges = ( - operatorIds: string[], - currentId: string, - allEdges: Edge[], - isUpstream = false, - componentName: string, - nodeParams: Record, -) => { - operatorIds.forEach((cur) => { - const source = isUpstream ? cur : currentId; - const target = isUpstream ? currentId : cur; - if (!allEdges.some((e) => e.source === source && e.target === target)) { - const edge: Edge = { - id: uuidv4(), - label: '', - // type: 'step', - source: source, - target: target, - // markerEnd: { - // type: MarkerType.ArrowClosed, - // color: 'rgb(157 149 225)', - // width: 20, - // height: 20, - // }, - }; - if (componentName === Operator.Categorize && !isUpstream) { - const categoryDescription = - nodeParams.category_description as ICategorizeItemResult; +function buildAgentExceptionGoto(edges: Edge[], nodeId: string) { + const exceptionEdges = edges.filter( + (x) => + x.source === nodeId && x.sourceHandle === NodeHandleId.AgentException, + ); - const name = Object.keys(categoryDescription).find( - (x) => categoryDescription[x].to === target, - ); - - if (name) { - edge.sourceHandle = name; - } - } - allEdges.push(edge); - } - }); -}; - -export const buildNodesAndEdgesFromDSLComponents = (data: DSLComponents) => { - const nodes: Node[] = []; - let edges: Edge[] = []; - - Object.entries(data).forEach(([key, value]) => { - const downstream = [...value.downstream]; - const upstream = [...value.upstream]; - const { component_name: componentName, params } = value.obj; - nodes.push({ - id: key, - type: NodeMap[value.obj.component_name as Operator] || 'ragNode', - position: { x: 0, y: 0 }, - data: { - label: componentName, - name: humanId(), - form: params, - }, - sourcePosition: Position.Left, - targetPosition: Position.Right, - }); - - buildEdges(upstream, key, edges, true, componentName, params); - buildEdges(downstream, key, edges, false, componentName, params); - }); - - return { nodes, edges }; -}; + return exceptionEdges.map((x) => x.target); +} const buildComponentDownstreamOrUpstream = ( edges: Edge[], @@ -103,7 +40,9 @@ const buildComponentDownstreamOrUpstream = ( const node = nodes.find((x) => x.id === nodeId); let isNotUpstreamTool = true; let isNotUpstreamAgent = true; + let isNotExceptionGoto = true; if (isBuildDownstream && node?.data.label === Operator.Agent) { + isNotExceptionGoto = y.sourceHandle !== NodeHandleId.AgentException; // Exclude the tool operator downstream of the agent operator isNotUpstreamTool = !y.target.startsWith(Operator.Tool); // Exclude the agent operator downstream of the agent operator @@ -115,7 +54,8 @@ const buildComponentDownstreamOrUpstream = ( return ( y[isBuildDownstream ? 'source' : 'target'] === nodeId && isNotUpstreamTool && - isNotUpstreamAgent + isNotUpstreamAgent && + isNotExceptionGoto ); }) .map((y) => y[isBuildDownstream ? 'target' : 'source']); @@ -234,7 +174,10 @@ export const buildDslComponentsByGraph = ( switch (operatorName) { case Operator.Agent: { const { params: formData } = buildAgentTools(edges, nodes, id); - params = formData; + params = { + ...formData, + exception_goto: buildAgentExceptionGoto(edges, id), + }; break; } case Operator.Categorize: @@ -559,7 +502,7 @@ export const buildCategorizeObjectFromList = (list: Array) => { if (cur?.name) { pre[cur.name] = { ...omit(cur, 'name', 'examples'), - examples: convertToStringArray(cur.examples), + examples: convertToStringArray(cur.examples) as string[], }; } return pre;