From 47005ebe10e30bdf30d756cb1671600d0dfa4b3b Mon Sep 17 00:00:00 2001 From: Jimmy Ben Klieve Date: Mon, 22 Dec 2025 09:35:34 +0800 Subject: [PATCH] feat: supports multiple retrieval tool under an agent (#12046) ### What problem does this PR solve? Add support for multiple Retrieval tools under an agent ### Type of change - [x] New Feature (non-breaking change which adds functionality) --- web/src/components/large-model-form-field.tsx | 2 +- web/src/interfaces/database/agent.ts | 1 + web/src/pages/agent/canvas/node/tool-node.tsx | 52 ++++---- web/src/pages/agent/canvas/node/toolbar.tsx | 30 ++--- web/src/pages/agent/constant/index.tsx | 1 + web/src/pages/agent/form-sheet/next.tsx | 51 ++++--- .../pages/agent/form-sheet/title-input.tsx | 55 ++++++-- .../agent/form/agent-form/agent-tools.tsx | 43 ++++-- .../form/agent-form/tool-popover/index.tsx | 20 +-- .../agent-form/tool-popover/tool-command.tsx | 45 ++++--- .../tool-popover/use-update-tools.ts | 70 +++++++--- .../agent/form/agent-form/use-get-tools.ts | 5 + web/src/pages/agent/form/tool-form/index.tsx | 4 +- .../pages/agent/form/tool-form/use-values.ts | 21 +-- .../agent/form/tool-form/use-watch-change.ts | 49 ++++--- .../pages/agent/hooks/use-change-node-name.ts | 74 ++++++----- web/src/pages/agent/hooks/use-is-mcp.ts | 6 +- web/src/pages/agent/hooks/use-show-drawer.tsx | 7 +- web/src/pages/agent/store.ts | 124 +++++++++++++++++- web/src/pages/agent/utils.ts | 8 +- 20 files changed, 442 insertions(+), 226 deletions(-) diff --git a/web/src/components/large-model-form-field.tsx b/web/src/components/large-model-form-field.tsx index e805eb541..0e266258c 100644 --- a/web/src/components/large-model-form-field.tsx +++ b/web/src/components/large-model-form-field.tsx @@ -68,7 +68,7 @@ export function LargeModelFormField({ - + diff --git a/web/src/interfaces/database/agent.ts b/web/src/interfaces/database/agent.ts index ebdf65836..7caa7eee8 100644 --- a/web/src/interfaces/database/agent.ts +++ b/web/src/interfaces/database/agent.ts @@ -170,6 +170,7 @@ export interface IAgentForm { tools: Array<{ name: string; component_name: string; + id: string; params: Record; }>; mcp: Array<{ diff --git a/web/src/pages/agent/canvas/node/tool-node.tsx b/web/src/pages/agent/canvas/node/tool-node.tsx index f1a8ca1da..5fc53b1aa 100644 --- a/web/src/pages/agent/canvas/node/tool-node.tsx +++ b/web/src/pages/agent/canvas/node/tool-node.tsx @@ -2,7 +2,7 @@ import { NodeCollapsible } from '@/components/collapse'; import { IAgentForm, IToolNode } from '@/interfaces/database/agent'; import { Handle, NodeProps, Position } from '@xyflow/react'; import { get } from 'lodash'; -import { MouseEventHandler, memo, useCallback } from 'react'; +import { memo } from 'react'; import { NodeHandleId, Operator } from '../../constant'; import { ToolCard } from '../../form/agent-form/agent-tools'; import { useFindMcpById } from '../../hooks/use-find-mcp-by-id'; @@ -15,22 +15,11 @@ function InnerToolNode({ isConnectable = true, selected, }: NodeProps) { - const { edges, getNode } = useGraphStore((state) => state); + const { edges, getNode, setClickedToolId } = useGraphStore(); const upstreamAgentNodeId = edges.find((x) => x.target === id)?.source; const upstreamAgentNode = getNode(upstreamAgentNodeId); const { findMcpById } = useFindMcpById(); - const handleClick = useCallback( - (operator: string): MouseEventHandler => - (e) => { - if (operator === Operator.Code) { - e.preventDefault(); - e.stopPropagation(); - } - }, - [], - ); - const tools: IAgentForm['tools'] = get( upstreamAgentNode, 'data.form.tools', @@ -51,17 +40,24 @@ function InnerToolNode({ position={Position.Top} isConnectable={isConnectable} className="!bg-accent-primary !size-2" - > + /> + {(x) => { - if ('mcp_id' in x) { + if (Reflect.has(x, 'mcp_id')) { const mcp = x as unknown as IAgentForm['mcp'][number]; + return ( { + if (mcp.mcp_id === Operator.Code) { + e.preventDefault(); + e.stopPropagation(); + } + }} className="cursor-pointer" - data-tool={x.mcp_id} + data-tool={mcp.mcp_id} > {findMcpById(mcp.mcp_id)?.name} @@ -69,18 +65,28 @@ function InnerToolNode({ } const tool = x as unknown as IAgentForm['tools'][number]; + return ( { + if (tool.component_name === Operator.Code) { + e.preventDefault(); + e.stopPropagation(); + } + + setClickedToolId(tool.id || tool.component_name); + }} className="cursor-pointer" data-tool={tool.component_name} + data-tool-id={tool.id} >
- - {tool.component_name} + + + {tool.component_name === Operator.Retrieval + ? tool.name + : tool.component_name}
); diff --git a/web/src/pages/agent/canvas/node/toolbar.tsx b/web/src/pages/agent/canvas/node/toolbar.tsx index 775ba228d..f7c103e29 100644 --- a/web/src/pages/agent/canvas/node/toolbar.tsx +++ b/web/src/pages/agent/canvas/node/toolbar.tsx @@ -1,3 +1,4 @@ +import { Button, ButtonProps } from '@/components/ui/button'; import { TooltipContent, TooltipNode, @@ -6,31 +7,24 @@ import { import { cn } from '@/lib/utils'; import { Position } from '@xyflow/react'; import { Copy, Play, Trash2 } from 'lucide-react'; -import { - HTMLAttributes, - MouseEventHandler, - PropsWithChildren, - useCallback, -} from 'react'; +import { MouseEventHandler, PropsWithChildren, useCallback } from 'react'; import { Operator } from '../../constant'; import { useDuplicateNode } from '../../hooks'; import useGraphStore from '../../store'; -function IconWrapper({ - children, - className, - ...props -}: HTMLAttributes) { +function IconWrapper({ children, className, ...props }: ButtonProps) { return ( -
{children} -
+ ); } @@ -55,7 +49,7 @@ export function ToolBar({ (store) => store.deleteIterationNodeById, ); - const deleteNode: MouseEventHandler = useCallback( + const deleteNode: MouseEventHandler = useCallback( (e) => { e.stopPropagation(); if ([Operator.Iteration, Operator.Loop].includes(label as Operator)) { @@ -69,7 +63,7 @@ export function ToolBar({ const duplicateNode = useDuplicateNode(); - const handleDuplicate: MouseEventHandler = useCallback( + const handleDuplicate: MouseEventHandler = useCallback( (e) => { e.stopPropagation(); duplicateNode(id, label); @@ -82,7 +76,7 @@ export function ToolBar({ {children} -
+
{showRun && ( @@ -94,8 +88,8 @@ export function ToolBar({ )} diff --git a/web/src/pages/agent/constant/index.tsx b/web/src/pages/agent/constant/index.tsx index 9731fbf3d..c548b896a 100644 --- a/web/src/pages/agent/constant/index.tsx +++ b/web/src/pages/agent/constant/index.tsx @@ -778,6 +778,7 @@ export const NoDebugOperatorsList = [ Operator.Splitter, Operator.HierarchicalMerger, Operator.Extractor, + Operator.Tool, ]; export const NoCopyOperatorsList = [ diff --git a/web/src/pages/agent/form-sheet/next.tsx b/web/src/pages/agent/form-sheet/next.tsx index b9ebeac51..1b8b0a65e 100644 --- a/web/src/pages/agent/form-sheet/next.tsx +++ b/web/src/pages/agent/form-sheet/next.tsx @@ -1,3 +1,4 @@ +import { Button } from '@/components/ui/button'; import { Sheet, SheetContent, @@ -41,49 +42,69 @@ const FormSheet = ({ showSingleDebugDrawer, }: IModalProps & IProps) => { const operatorName: Operator = node?.data.label as Operator; - const clickedToolId = useGraphStore((state) => state.clickedToolId); + const { clickedToolId, getAgentToolById } = useGraphStore(); const currentFormMap = FormConfigMap[operatorName]; - const OperatorForm = currentFormMap?.component ?? EmptyContent; - const isMcp = useIsMcp(operatorName); - const { t } = useTranslate('flow'); + const { component_name: toolComponentName } = (getAgentToolById( + clickedToolId, + ) ?? {}) as { + component_name: Operator; + name: string; + id: string; + }; return ( -
+
- + {needsSingleStepDebugging(operatorName) && ( - + > + + )} - + +
- {isMcp || ( - + + {!isMcp && ( +

{t( - `${lowerFirst(operatorName === Operator.Tool ? clickedToolId : operatorName)}Description`, + `${lowerFirst(operatorName === Operator.Tool ? toolComponentName : operatorName)}Description`, )} - +

)}
+
{visible && ( diff --git a/web/src/pages/agent/form-sheet/title-input.tsx b/web/src/pages/agent/form-sheet/title-input.tsx index e876e97ba..0cc520c86 100644 --- a/web/src/pages/agent/form-sheet/title-input.tsx +++ b/web/src/pages/agent/form-sheet/title-input.tsx @@ -1,7 +1,8 @@ +import { Button } from '@/components/ui/button'; import { Input } from '@/components/ui/input'; import { RAGFlowNodeType } from '@/interfaces/database/agent'; import { PenLine } from 'lucide-react'; -import { useCallback, useState } from 'react'; +import { useCallback, useLayoutEffect, useRef, useState } from 'react'; import { useTranslation } from 'react-i18next'; import { BeginId, Operator } from '../constant'; import { useHandleNodeNameChange } from '../hooks/use-change-node-name'; @@ -13,47 +14,75 @@ type TitleInputProps = { export function TitleInput({ node }: TitleInputProps) { const { t } = useTranslation(); + const inputRef = useRef(null); const { name, handleNameBlur, handleNameChange } = useHandleNodeNameChange({ id: node?.id, data: node?.data, }); const operatorName: Operator = node?.data.label as Operator; - const isMcp = useIsMcp(operatorName); - const [isEditingMode, setIsEditingMode] = useState(false); const switchIsEditingMode = useCallback(() => { setIsEditingMode((prev) => !prev); }, []); - const handleBlur = useCallback(() => { - handleNameBlur(); - setIsEditingMode(false); - }, [handleNameBlur]); + const handleBlur = useCallback( + (e: React.FocusEvent) => { + if (handleNameBlur()) { + setIsEditingMode(false); + } else { + // Re-focus the input if name doesn't change successfully + e.target.focus(); + e.target.select(); + } + }, + [handleNameBlur], + ); + + useLayoutEffect(() => { + if (isEditingMode && inputRef.current) { + inputRef.current.focus(); + inputRef.current.select(); + } + }, [isEditingMode]); if (isMcp) { return
MCP Config
; } return ( -
+ // Give a fixed height to prevent layout shift when switching between edit and view modes +
{node?.id === BeginId ? ( - {t(BeginId)} + // Begin node is not editable + {t(`flow.${BeginId}`)} ) : isEditingMode ? ( { + // Support committing the value changes by pressing Enter + if (e.key === 'Enter') { + handleBlur(e as unknown as React.FocusEvent); + } + }} onChange={handleNameChange} - > + /> ) : (
{name} - + > + +
)}
diff --git a/web/src/pages/agent/form/agent-form/agent-tools.tsx b/web/src/pages/agent/form/agent-form/agent-tools.tsx index 2ac473686..347abbfd2 100644 --- a/web/src/pages/agent/form/agent-form/agent-tools.tsx +++ b/web/src/pages/agent/form/agent-form/agent-tools.tsx @@ -1,4 +1,4 @@ -import { BlockButton } from '@/components/ui/button'; +import { BlockButton, Button } from '@/components/ui/button'; import { Tooltip, TooltipContent, @@ -26,7 +26,7 @@ import { filterDownstreamAgentNodeIds } from '../../utils/filter-downstream-node import { ToolPopover } from './tool-popover'; import { useDeleteAgentNodeMCP } from './tool-popover/use-update-mcp'; import { useDeleteAgentNodeTools } from './tool-popover/use-update-tools'; -import { useGetAgentMCPIds, useGetAgentToolNames } from './use-get-tools'; +import { useGetAgentMCPIds, useGetNodeTools } from './use-get-tools'; type ToolCardProps = React.HTMLAttributes & PropsWithChildren & { @@ -79,20 +79,33 @@ function ActionButton({ deleteRecord, record, edit }: ActionButtonProps) { deleteRecord(record); }, [deleteRecord, record]); + // Wrapping into buttons to solve the issue that clicking icon occasionally not jumping to corresponding form return (
- - + > + + + +
); } export function AgentTools() { - const { toolNames } = useGetAgentToolNames(); + const tools = useGetNodeTools(); const { deleteNodeTool } = useDeleteAgentNodeTools(); const { mcpIds } = useGetAgentMCPIds(); const { findMcpById } = useFindMcpById(); @@ -105,6 +118,7 @@ export function AgentTools() { const handleEdit: MouseEventHandler = useCallback( (e) => { const toolNodeId = findAgentToolNodeById(clickedNodeId); + if (toolNodeId) { selectNodeIds([toolNodeId]); showFormDrawer(e, toolNodeId); @@ -117,19 +131,20 @@ export function AgentTools() {
{t('flow.tools')}
    - {toolNames.map((x) => ( - + {tools.map(({ id, component_name, name }) => ( +
    - - {x} + + {component_name === Operator.Retrieval ? name : component_name}
    + />
    ))} + {mcpIds.map((id) => ( {findMcpById(id)?.name} diff --git a/web/src/pages/agent/form/agent-form/tool-popover/index.tsx b/web/src/pages/agent/form/agent-form/tool-popover/index.tsx index ce294f091..b6c032c1e 100644 --- a/web/src/pages/agent/form/agent-form/tool-popover/index.tsx +++ b/web/src/pages/agent/form/agent-form/tool-popover/index.tsx @@ -9,21 +9,19 @@ import { AgentFormContext, AgentInstanceContext } from '@/pages/agent/context'; import useGraphStore from '@/pages/agent/store'; import { Position } from '@xyflow/react'; import { t } from 'i18next'; -import { PropsWithChildren, useCallback, useContext, useEffect } from 'react'; +import { useContext, useEffect } from 'react'; import { useGetAgentMCPIds, useGetAgentToolNames } from '../use-get-tools'; import { MCPCommand, ToolCommand } from './tool-command'; import { useUpdateAgentNodeMCP } from './use-update-mcp'; -import { useUpdateAgentNodeTools } from './use-update-tools'; enum ToolType { Common = 'common', MCP = 'mcp', } -export function ToolPopover({ children }: PropsWithChildren) { +export function ToolPopover({ children }: React.PropsWithChildren) { const { addCanvasNode } = useContext(AgentInstanceContext); const node = useContext(AgentFormContext); - const { updateNodeTools } = useUpdateAgentNodeTools(); const { toolNames } = useGetAgentToolNames(); const deleteAgentToolNodeById = useGraphStore( (state) => state.deleteAgentToolNodeById, @@ -31,15 +29,6 @@ export function ToolPopover({ children }: PropsWithChildren) { const { mcpIds } = useGetAgentMCPIds(); const { updateNodeMCP } = useUpdateAgentNodeMCP(); - const handleChange = useCallback( - (value: string[]) => { - if (Array.isArray(value) && node?.id) { - updateNodeTools(value); - } - }, - [node?.id, updateNodeTools], - ); - useEffect(() => { const total = toolNames.length + mcpIds.length; if (node?.id) { @@ -72,10 +61,7 @@ export function ToolPopover({ children }: PropsWithChildren) { MCP - + diff --git a/web/src/pages/agent/form/agent-form/tool-popover/tool-command.tsx b/web/src/pages/agent/form/agent-form/tool-popover/tool-command.tsx index 7ade01008..b5edf6d89 100644 --- a/web/src/pages/agent/form/agent-form/tool-popover/tool-command.tsx +++ b/web/src/pages/agent/form/agent-form/tool-popover/tool-command.tsx @@ -12,8 +12,10 @@ import { Operator } from '@/pages/agent/constant'; import OperatorIcon from '@/pages/agent/operator-icon'; import { t } from 'i18next'; import { lowerFirst } from 'lodash'; +import { LucidePlus } from 'lucide-react'; import { PropsWithChildren, useCallback, useEffect, useState } from 'react'; import { useTranslation } from 'react-i18next'; +import { useGetNodeTools, useUpdateAgentNodeTools } from './use-update-tools'; const Menus = [ { @@ -66,7 +68,13 @@ function ToolCommandItem({ }: ToolCommandItemProps & PropsWithChildren) { return ( toggleOption(id)}> - + {id === Operator.Retrieval ? ( + + + + ) : ( + + )} {children} ); @@ -98,12 +106,12 @@ function useHandleSelectChange({ onChange, value }: ToolCommandProps) { }; } +// eslint-disable-next-line export function ToolCommand({ value, onChange }: ToolCommandProps) { const { t } = useTranslation(); - const { toggleOption, currentValue } = useHandleSelectChange({ - onChange, - value, - }); + + const currentValue = useGetNodeTools(); + const { updateNodeTools } = useUpdateAgentNodeTools(); return ( @@ -112,22 +120,17 @@ export function ToolCommand({ value, onChange }: ToolCommandProps) { No results found. {Menus.map((x) => ( - {x.list.map((y) => { - const isSelected = currentValue.includes(y); - return ( - - <> - - {t(`flow.${lowerFirst(y)}`)} - - - ); - })} + {x.list.map((y) => ( + x.component_name === y)} + > + + {t(`flow.${lowerFirst(y)}`)} + + ))} ))} diff --git a/web/src/pages/agent/form/agent-form/tool-popover/use-update-tools.ts b/web/src/pages/agent/form/agent-form/tool-popover/use-update-tools.ts index db579561a..bc91f2275 100644 --- a/web/src/pages/agent/form/agent-form/tool-popover/use-update-tools.ts +++ b/web/src/pages/agent/form/agent-form/tool-popover/use-update-tools.ts @@ -16,32 +16,59 @@ export function useGetNodeTools() { } export function useUpdateAgentNodeTools() { - const { updateNodeForm } = useGraphStore((state) => state); - const node = useContext(AgentFormContext); + const { generateAgentToolName, generateAgentToolId, updateNodeForm } = + useGraphStore((state) => state); + const node = useContext(AgentFormContext)!; const tools = useGetNodeTools(); const { initializeAgentToolValues } = useAgentToolInitialValues(); const updateNodeTools = useCallback( - (value: string[]) => { - if (node?.id) { - const nextValue = value.reduce((pre, cur) => { - const tool = tools.find((x) => x.component_name === cur); - pre.push( - tool - ? tool - : { - component_name: cur, - name: cur, - params: initializeAgentToolValues(cur as Operator), - }, - ); - return pre; - }, []); + (value: string) => { + if (!node?.id) return; - updateNodeForm(node?.id, nextValue, ['tools']); + // Append + if (value === Operator.Retrieval) { + updateNodeForm( + node.id, + [ + ...tools, + { + component_name: value, + name: generateAgentToolName(node.id, value), + params: initializeAgentToolValues(value as Operator), + id: generateAgentToolId(value), + }, + ], + ['tools'], + ); + } + // Toggle + else { + updateNodeForm( + node.id, + tools.some((x) => x.component_name === value) + ? tools.filter((x) => x.component_name !== value) + : [ + ...tools, + { + component_name: value, + name: value, + params: initializeAgentToolValues(value as Operator), + id: generateAgentToolId(value), + }, + ], + ['tools'], + ); } }, - [initializeAgentToolValues, node?.id, tools, updateNodeForm], + [ + generateAgentToolName, + generateAgentToolId, + initializeAgentToolValues, + node?.id, + tools, + updateNodeForm, + ], ); return { updateNodeTools }; @@ -53,8 +80,9 @@ export function useDeleteAgentNodeTools() { const node = useContext(AgentFormContext); const deleteNodeTool = useCallback( - (value: string) => () => { - const nextTools = tools.filter((x) => x.component_name !== value); + (toolId: string) => () => { + const nextTools = tools.filter((x) => x.id !== toolId); + if (node?.id) { updateNodeForm(node?.id, nextTools, ['tools']); } diff --git a/web/src/pages/agent/form/agent-form/use-get-tools.ts b/web/src/pages/agent/form/agent-form/use-get-tools.ts index 32bf3f0ef..9cb8decdb 100644 --- a/web/src/pages/agent/form/agent-form/use-get-tools.ts +++ b/web/src/pages/agent/form/agent-form/use-get-tools.ts @@ -3,6 +3,11 @@ import { get } from 'lodash'; import { useContext, useMemo } from 'react'; import { AgentFormContext } from '../../context'; +export function useGetNodeTools() { + const node = useContext(AgentFormContext); + return get(node, 'data.form.tools', []) as IAgentForm['tools']; +} + export function useGetAgentToolNames() { const node = useContext(AgentFormContext); diff --git a/web/src/pages/agent/form/tool-form/index.tsx b/web/src/pages/agent/form/tool-form/index.tsx index 9c03870b4..639c55c86 100644 --- a/web/src/pages/agent/form/tool-form/index.tsx +++ b/web/src/pages/agent/form/tool-form/index.tsx @@ -7,9 +7,11 @@ const EmptyContent = () =>
    ; function ToolForm() { const clickedToolId = useGraphStore((state) => state.clickedToolId); + const { getAgentToolById } = useGraphStore(); + const tool = getAgentToolById(clickedToolId); const ToolForm = - ToolFormConfigMap[clickedToolId as keyof typeof ToolFormConfigMap] ?? + ToolFormConfigMap[tool?.component_name as keyof typeof ToolFormConfigMap] ?? MCPForm ?? EmptyContent; diff --git a/web/src/pages/agent/form/tool-form/use-values.ts b/web/src/pages/agent/form/tool-form/use-values.ts index 59a2e090f..6000b6c07 100644 --- a/web/src/pages/agent/form/tool-form/use-values.ts +++ b/web/src/pages/agent/form/tool-form/use-values.ts @@ -3,7 +3,6 @@ import { useMemo } from 'react'; import { Operator } from '../../constant'; import { useAgentToolInitialValues } from '../../hooks/use-agent-tool-initial-values'; import useGraphStore from '../../store'; -import { getAgentNodeTools } from '../../utils'; export enum SearchDepth { Basic = 'basic', @@ -16,22 +15,23 @@ export enum Topic { } export function useValues() { - const { clickedToolId, clickedNodeId, findUpstreamNodeById } = useGraphStore( - (state) => state, - ); + const { + clickedToolId, + clickedNodeId, + findUpstreamNodeById, + getAgentToolById, + } = useGraphStore(); + const { initializeAgentToolValues } = useAgentToolInitialValues(); const values = useMemo(() => { const agentNode = findUpstreamNodeById(clickedNodeId); - const tools = getAgentNodeTools(agentNode); - - const formData = tools.find( - (x) => x.component_name === clickedToolId, - )?.params; + const tool = getAgentToolById(clickedToolId, agentNode!); + const formData = tool?.params; if (isEmpty(formData)) { const defaultValues = initializeAgentToolValues( - clickedNodeId as Operator, + (tool?.component_name || clickedNodeId) as Operator, ); return defaultValues; @@ -44,6 +44,7 @@ export function useValues() { clickedNodeId, clickedToolId, findUpstreamNodeById, + getAgentToolById, initializeAgentToolValues, ]); diff --git a/web/src/pages/agent/form/tool-form/use-watch-change.ts b/web/src/pages/agent/form/tool-form/use-watch-change.ts index 81a70a235..3807592c5 100644 --- a/web/src/pages/agent/form/tool-form/use-watch-change.ts +++ b/web/src/pages/agent/form/tool-form/use-watch-change.ts @@ -1,39 +1,38 @@ import { useEffect } from 'react'; import { UseFormReturn, useWatch } from 'react-hook-form'; import useGraphStore from '../../store'; -import { getAgentNodeTools } from '../../utils'; export function useWatchFormChange(form?: UseFormReturn) { let values = useWatch({ control: form?.control }); - const { clickedToolId, clickedNodeId, findUpstreamNodeById, updateNodeForm } = - useGraphStore((state) => state); + + const { + clickedToolId, + clickedNodeId, + findUpstreamNodeById, + getAgentToolById, + updateAgentToolById, + updateNodeForm, + } = useGraphStore(); useEffect(() => { const agentNode = findUpstreamNodeById(clickedNodeId); // Manually triggered form updates are synchronized to the canvas if (agentNode && form?.formState.isDirty) { - const agentNodeId = agentNode?.id; - const tools = getAgentNodeTools(agentNode); - - values = form?.getValues(); - const nextTools = tools.map((x) => { - if (x.component_name === clickedToolId) { - return { - ...x, - params: { - ...values, - }, - }; - } - return x; + updateAgentToolById(agentNode, clickedToolId, { + params: { + ...(values ?? {}), + }, }); - - const nextValues = { - ...(agentNode?.data?.form ?? {}), - tools: nextTools, - }; - - updateNodeForm(agentNodeId, nextValues); } - }, [form?.formState.isDirty, updateNodeForm, values]); + }, [ + clickedNodeId, + clickedToolId, + findUpstreamNodeById, + form, + form?.formState.isDirty, + getAgentToolById, + updateAgentToolById, + updateNodeForm, + values, + ]); } diff --git a/web/src/pages/agent/hooks/use-change-node-name.ts b/web/src/pages/agent/hooks/use-change-node-name.ts index 61a5653d7..9d6112c2d 100644 --- a/web/src/pages/agent/hooks/use-change-node-name.ts +++ b/web/src/pages/agent/hooks/use-change-node-name.ts @@ -6,14 +6,13 @@ import { SetStateAction, useCallback, useEffect, - useMemo, useState, } from 'react'; import { Operator } from '../constant'; import useGraphStore from '../store'; import { getAgentNodeTools } from '../utils'; -export function useHandleTooNodeNameChange({ +export function useHandleToolNodeNameChange({ id, name, setName, @@ -22,48 +21,44 @@ export function useHandleTooNodeNameChange({ name?: string; setName: Dispatch>; }) { - const { clickedToolId, findUpstreamNodeById, updateNodeForm } = useGraphStore( - (state) => state, - ); - const agentNode = findUpstreamNodeById(id); + const { + clickedToolId, + findUpstreamNodeById, + getAgentToolById, + updateAgentToolById, + } = useGraphStore((state) => state); + const agentNode = findUpstreamNodeById(id)!; const tools = getAgentNodeTools(agentNode); - - const previousName = useMemo(() => { - const tool = tools.find((x) => x.component_name === clickedToolId); - return tool?.name || tool?.component_name; - }, [clickedToolId, tools]); + const previousName = getAgentToolById(clickedToolId, agentNode)?.name; const handleToolNameBlur = useCallback(() => { const trimmedName = trim(name); const existsSameName = tools.some((x) => x.name === trimmedName); - if (trimmedName === '' || existsSameName) { - if (existsSameName && previousName !== name) { - message.error('The name cannot be repeated'); - } + + // Not changed + if (trimmedName === '') { setName(previousName || ''); - return; + return true; + } + + if (existsSameName && previousName !== name) { + message.error('The name cannot be repeated'); + return false; } if (agentNode?.id) { - const nextTools = tools.map((x) => { - if (x.component_name === clickedToolId) { - return { - ...x, - name, - }; - } - return x; - }); - updateNodeForm(agentNode?.id, nextTools, ['tools']); + updateAgentToolById(agentNode, clickedToolId, { name }); } + + return true; }, [ - agentNode?.id, + agentNode, clickedToolId, name, previousName, setName, tools, - updateNodeForm, + updateAgentToolById, ]); return { handleToolNameBlur, previousToolName: previousName }; @@ -83,28 +78,35 @@ export const useHandleNodeNameChange = ({ const previousName = data?.name; const isToolNode = getOperatorTypeFromId(id) === Operator.Tool; - const { handleToolNameBlur, previousToolName } = useHandleTooNodeNameChange({ + const { handleToolNameBlur, previousToolName } = useHandleToolNodeNameChange({ id, name, setName, }); const handleNameBlur = useCallback(() => { + const trimmedName = trim(name); const existsSameName = nodes.some((x) => x.data.name === name); - if (trim(name) === '' || existsSameName) { - if (existsSameName && previousName !== name) { - message.error('The name cannot be repeated'); - } - setName(previousName); - return; + + // Not changed + if (!trimmedName) { + setName(previousName || ''); + return true; + } + + if (existsSameName && previousName !== name) { + message.error('The name cannot be repeated'); + return false; } if (id) { updateNodeName(id, name); } + + return true; }, [name, id, updateNodeName, previousName, nodes]); - const handleNameChange = useCallback((e: ChangeEvent) => { + const handleNameChange = useCallback((e: ChangeEvent) => { setName(e.target.value); }, []); diff --git a/web/src/pages/agent/hooks/use-is-mcp.ts b/web/src/pages/agent/hooks/use-is-mcp.ts index 5cec74487..bd76b62f4 100644 --- a/web/src/pages/agent/hooks/use-is-mcp.ts +++ b/web/src/pages/agent/hooks/use-is-mcp.ts @@ -2,10 +2,12 @@ import { Operator } from '../constant'; import useGraphStore from '../store'; export function useIsMcp(operatorName: Operator) { - const clickedToolId = useGraphStore((state) => state.clickedToolId); + const { clickedToolId, getAgentToolById } = useGraphStore(); + + const { component_name: toolName } = getAgentToolById(clickedToolId) ?? {}; return ( operatorName === Operator.Tool && - Object.values(Operator).every((x) => x !== clickedToolId) + Object.values(Operator).every((x) => x !== toolName) ); } diff --git a/web/src/pages/agent/hooks/use-show-drawer.tsx b/web/src/pages/agent/hooks/use-show-drawer.tsx index 8804c6a21..3a15b29b8 100644 --- a/web/src/pages/agent/hooks/use-show-drawer.tsx +++ b/web/src/pages/agent/hooks/use-show-drawer.tsx @@ -24,7 +24,9 @@ export const useShowFormDrawer = () => { const handleShow = useCallback( (e: React.MouseEvent, nodeId: string) => { - const tool = get(e.target, 'dataset.tool'); + const toolId = (e.target as HTMLElement).dataset.toolId; + const tool = (e.target as HTMLElement).dataset.tool; + // TODO: Operator type judgment should be used const operatorType = getOperatorTypeFromId(nodeId); if ( @@ -36,7 +38,8 @@ export const useShowFormDrawer = () => { return; } setClickedNodeId(nodeId); - setClickedToolId(tool); + // Guess this could gracefully handle the case where the tool id is not provided? + setClickedToolId(toolId || tool); showFormDrawer(); }, [getOperatorTypeFromId, setClickedNodeId, setClickedToolId, showFormDrawer], diff --git a/web/src/pages/agent/store.ts b/web/src/pages/agent/store.ts index 42ce97444..98b4b9da2 100644 --- a/web/src/pages/agent/store.ts +++ b/web/src/pages/agent/store.ts @@ -1,4 +1,5 @@ -import { RAGFlowNodeType } from '@/interfaces/database/flow'; +import type { IAgentForm } from '@/interfaces/database/agent'; +import { IAgentNode, RAGFlowNodeType } from '@/interfaces/database/flow'; import type {} from '@redux-devtools/extension'; import { Connection, @@ -14,10 +15,15 @@ import { applyEdgeChanges, applyNodeChanges, } from '@xyflow/react'; -import { cloneDeep, omit } from 'lodash'; -import differenceWith from 'lodash/differenceWith'; -import intersectionWith from 'lodash/intersectionWith'; -import lodashSet from 'lodash/set'; +import humanId from 'human-id'; +import { + cloneDeep, + differenceWith, + intersectionWith, + get as lodashGet, + set as lodashSet, + omit, +} from 'lodash'; import { create } from 'zustand'; import { devtools } from 'zustand/middleware'; import { immer } from 'zustand/middleware/immer'; @@ -26,12 +32,26 @@ import { duplicateNodeForm, generateDuplicateNode, generateNodeNamesWithIncreasingIndex, + getAgentNodeTools, getOperatorIndex, isEdgeEqual, mapEdgeMouseEvent, } from './utils'; import { deleteAllDownstreamAgentsAndTool } from './utils/delete-node'; +type IAgentTool = IAgentForm['tools'][number]; + +interface GetAgentToolByIdFunc { + (id: string): IAgentTool | undefined; + (id: string, agentNode: RAGFlowNodeType): IAgentTool | undefined; + (id: string, agentNodeId: string): IAgentTool | undefined; +} + +interface UpdateAgentToolByIdFunc { + (agentNode: RAGFlowNodeType, id: string, value?: Partial): void; + (agentNodeId: string, id: string, value?: Partial): void; +} + export type RFState = { nodes: RAGFlowNodeType[]; edges: Edge[]; @@ -81,6 +101,11 @@ export type RFState = { getParentIdById: (id?: string | null) => string | undefined; updateNodeName: (id: string, name: string) => void; generateNodeName: (name: string) => string; + generateAgentToolName: (id: string, name: string) => string; + generateAgentToolId: (prefix: string) => string; + getAllAgentTools: () => IAgentTool[]; + getAgentToolById: GetAgentToolByIdFunc; + updateAgentToolById: UpdateAgentToolByIdFunc; setClickedNodeId: (id?: string) => void; setClickedToolId: (id?: string) => void; findUpstreamNodeById: (id?: string | null) => RAGFlowNodeType | undefined; @@ -501,6 +526,95 @@ const useGraphStore = create()( return generateNodeNamesWithIncreasingIndex(name, nodes); }, + generateAgentToolName: (id: string, name: string) => { + const node = get().nodes.find( + (x) => x.id === id, + ) as IAgentNode; + + if (!node) { + return ''; + } + + const tools = (node.data.form!.tools as any[]).filter( + (x) => x.component_name === name, + ); + const lastIndex = tools.length + ? tools + .map((x) => { + const idx = x.name.match(/(\d+)$/)?.[1]; + return idx && isNaN(idx) ? -1 : Number(idx); + }) + .sort((a, b) => a - b) + .at(-1) ?? -1 + : -1; + + return `${name}_${lastIndex + 1}`; + }, + generateAgentToolId: (prefix: string) => { + const allAgentToolIds = get() + .getAllAgentTools() + .map((t) => t.id || t.component_name); + + let id: string; + + // Loop for avoiding id collisions + do { + id = `${prefix}:${humanId()}`; + } while (allAgentToolIds.includes(id)); + + return id; + }, + getAllAgentTools: () => { + return get() + .nodes.filter((n) => n?.data?.label === Operator.Agent) + .flatMap((n) => n?.data?.form?.tools); + }, + getAgentToolById: ( + id: string, + nodeOrNodeId?: RAGFlowNodeType | string, + ) => { + // eslint-disable-next-line eqeqeq + const tools = + nodeOrNodeId != null + ? getAgentNodeTools( + typeof nodeOrNodeId === 'string' + ? get().getNode(nodeOrNodeId) + : nodeOrNodeId, + ) + : get().getAllAgentTools(); + + // For backward compatibility + return tools.find((t) => (t.id || t.component_name) === id); + }, + updateAgentToolById: ( + nodeOrNodeId: RAGFlowNodeType | string, + id: string, + value?: Partial, + ) => { + const { getNode, updateNodeForm } = get(); + + const agentNode = + typeof nodeOrNodeId === 'string' + ? getNode(nodeOrNodeId) + : nodeOrNodeId; + + if (!agentNode) { + return; + } + + const toolIndex = getAgentNodeTools(agentNode).findIndex( + (t) => (t.id || t.component_name) === id, + ); + + updateNodeForm( + agentNode.id, + { + ...lodashGet(agentNode.data.form, ['tools', toolIndex], {}), + ...(value ?? {}), + }, + ['tools', toolIndex], + ); + }, setClickedToolId: (id?: string) => { set({ clickedToolId: id }); }, diff --git a/web/src/pages/agent/utils.ts b/web/src/pages/agent/utils.ts index b6397754d..cea553a9d 100644 --- a/web/src/pages/agent/utils.ts +++ b/web/src/pages/agent/utils.ts @@ -120,13 +120,17 @@ function buildAgentTools(edges: Edge[], nodes: Node[], nodeId: string) { return { component_name: Operator.Agent, id, - name: name as string, // Cast name to string and provide fallback + name, params: { ...formData }, }; }), ); } - return { params, name: node?.data.name, id: node?.id }; + return { params, name: node?.data.name, id: node?.id } as { + params: IAgentForm; + name: string; + id: string; + }; } function filterTargetsBySourceHandleId(edges: Edge[], handleId: string) {