feat: supports multiple retrieval tool under an agent (#12046)

### What problem does this PR solve?

Add support for multiple Retrieval tools under an agent

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Jimmy Ben Klieve
2025-12-22 09:35:34 +08:00
committed by GitHub
parent 3ee47e4af7
commit 47005ebe10
20 changed files with 442 additions and 226 deletions

View File

@ -68,7 +68,7 @@ export function LargeModelFormField({
<FormItem> <FormItem>
<FormControl> <FormControl>
<DropdownMenu> <DropdownMenu>
<DropdownMenuTrigger> <DropdownMenuTrigger asChild>
<Button variant={'ghost'}> <Button variant={'ghost'}>
<Funnel className="text-text-disabled" /> <Funnel className="text-text-disabled" />
</Button> </Button>

View File

@ -170,6 +170,7 @@ export interface IAgentForm {
tools: Array<{ tools: Array<{
name: string; name: string;
component_name: string; component_name: string;
id: string;
params: Record<string, any>; params: Record<string, any>;
}>; }>;
mcp: Array<{ mcp: Array<{

View File

@ -2,7 +2,7 @@ import { NodeCollapsible } from '@/components/collapse';
import { IAgentForm, IToolNode } from '@/interfaces/database/agent'; import { IAgentForm, IToolNode } from '@/interfaces/database/agent';
import { Handle, NodeProps, Position } from '@xyflow/react'; import { Handle, NodeProps, Position } from '@xyflow/react';
import { get } from 'lodash'; import { get } from 'lodash';
import { MouseEventHandler, memo, useCallback } from 'react'; import { memo } from 'react';
import { NodeHandleId, Operator } from '../../constant'; import { NodeHandleId, Operator } from '../../constant';
import { ToolCard } from '../../form/agent-form/agent-tools'; import { ToolCard } from '../../form/agent-form/agent-tools';
import { useFindMcpById } from '../../hooks/use-find-mcp-by-id'; import { useFindMcpById } from '../../hooks/use-find-mcp-by-id';
@ -15,22 +15,11 @@ function InnerToolNode({
isConnectable = true, isConnectable = true,
selected, selected,
}: NodeProps<IToolNode>) { }: NodeProps<IToolNode>) {
const { edges, getNode } = useGraphStore((state) => state); const { edges, getNode, setClickedToolId } = useGraphStore();
const upstreamAgentNodeId = edges.find((x) => x.target === id)?.source; const upstreamAgentNodeId = edges.find((x) => x.target === id)?.source;
const upstreamAgentNode = getNode(upstreamAgentNodeId); const upstreamAgentNode = getNode(upstreamAgentNodeId);
const { findMcpById } = useFindMcpById(); const { findMcpById } = useFindMcpById();
const handleClick = useCallback(
(operator: string): MouseEventHandler<HTMLLIElement> =>
(e) => {
if (operator === Operator.Code) {
e.preventDefault();
e.stopPropagation();
}
},
[],
);
const tools: IAgentForm['tools'] = get( const tools: IAgentForm['tools'] = get(
upstreamAgentNode, upstreamAgentNode,
'data.form.tools', 'data.form.tools',
@ -51,17 +40,24 @@ function InnerToolNode({
position={Position.Top} position={Position.Top}
isConnectable={isConnectable} isConnectable={isConnectable}
className="!bg-accent-primary !size-2" className="!bg-accent-primary !size-2"
></Handle> />
<NodeCollapsible items={[tools, mcpList]}> <NodeCollapsible items={[tools, mcpList]}>
{(x) => { {(x) => {
if ('mcp_id' in x) { if (Reflect.has(x, 'mcp_id')) {
const mcp = x as unknown as IAgentForm['mcp'][number]; const mcp = x as unknown as IAgentForm['mcp'][number];
return ( return (
<ToolCard <ToolCard
key={mcp.mcp_id} key={mcp.mcp_id}
onClick={handleClick(mcp.mcp_id)} onClick={(e) => {
if (mcp.mcp_id === Operator.Code) {
e.preventDefault();
e.stopPropagation();
}
}}
className="cursor-pointer" className="cursor-pointer"
data-tool={x.mcp_id} data-tool={mcp.mcp_id}
> >
{findMcpById(mcp.mcp_id)?.name} {findMcpById(mcp.mcp_id)?.name}
</ToolCard> </ToolCard>
@ -69,18 +65,28 @@ function InnerToolNode({
} }
const tool = x as unknown as IAgentForm['tools'][number]; const tool = x as unknown as IAgentForm['tools'][number];
return ( return (
<ToolCard <ToolCard
key={tool.component_name} key={tool.id}
onClick={handleClick(tool.component_name)} onClick={(e) => {
if (tool.component_name === Operator.Code) {
e.preventDefault();
e.stopPropagation();
}
setClickedToolId(tool.id || tool.component_name);
}}
className="cursor-pointer" className="cursor-pointer"
data-tool={tool.component_name} data-tool={tool.component_name}
data-tool-id={tool.id}
> >
<div className="flex gap-1 items-center pointer-events-none"> <div className="flex gap-1 items-center pointer-events-none">
<OperatorIcon <OperatorIcon name={tool.component_name as Operator} />
name={tool.component_name as Operator}
></OperatorIcon> {tool.component_name === Operator.Retrieval
{tool.component_name} ? tool.name
: tool.component_name}
</div> </div>
</ToolCard> </ToolCard>
); );

View File

@ -1,3 +1,4 @@
import { Button, ButtonProps } from '@/components/ui/button';
import { import {
TooltipContent, TooltipContent,
TooltipNode, TooltipNode,
@ -6,31 +7,24 @@ import {
import { cn } from '@/lib/utils'; import { cn } from '@/lib/utils';
import { Position } from '@xyflow/react'; import { Position } from '@xyflow/react';
import { Copy, Play, Trash2 } from 'lucide-react'; import { Copy, Play, Trash2 } from 'lucide-react';
import { import { MouseEventHandler, PropsWithChildren, useCallback } from 'react';
HTMLAttributes,
MouseEventHandler,
PropsWithChildren,
useCallback,
} from 'react';
import { Operator } from '../../constant'; import { Operator } from '../../constant';
import { useDuplicateNode } from '../../hooks'; import { useDuplicateNode } from '../../hooks';
import useGraphStore from '../../store'; import useGraphStore from '../../store';
function IconWrapper({ function IconWrapper({ children, className, ...props }: ButtonProps) {
children,
className,
...props
}: HTMLAttributes<HTMLDivElement>) {
return ( return (
<div <Button
variant="secondary"
size="icon"
className={cn( className={cn(
'p-1.5 bg-bg-component border border-border-button rounded-sm cursor-pointer hover:text-text-primary', 'size-7 p-0 bg-bg-component text-current hover:text-text-primary focus-visible:text-text-primary',
className, className,
)} )}
{...props} {...props}
> >
{children} {children}
</div> </Button>
); );
} }
@ -55,7 +49,7 @@ export function ToolBar({
(store) => store.deleteIterationNodeById, (store) => store.deleteIterationNodeById,
); );
const deleteNode: MouseEventHandler<HTMLDivElement> = useCallback( const deleteNode: MouseEventHandler<HTMLButtonElement> = useCallback(
(e) => { (e) => {
e.stopPropagation(); e.stopPropagation();
if ([Operator.Iteration, Operator.Loop].includes(label as Operator)) { if ([Operator.Iteration, Operator.Loop].includes(label as Operator)) {
@ -69,7 +63,7 @@ export function ToolBar({
const duplicateNode = useDuplicateNode(); const duplicateNode = useDuplicateNode();
const handleDuplicate: MouseEventHandler<HTMLDivElement> = useCallback( const handleDuplicate: MouseEventHandler<HTMLButtonElement> = useCallback(
(e) => { (e) => {
e.stopPropagation(); e.stopPropagation();
duplicateNode(id, label); duplicateNode(id, label);
@ -82,7 +76,7 @@ export function ToolBar({
<TooltipTrigger className="h-full">{children}</TooltipTrigger> <TooltipTrigger className="h-full">{children}</TooltipTrigger>
<TooltipContent position={Position.Top}> <TooltipContent position={Position.Top}>
<section className="flex gap-2 items-center text-text-secondary"> <section className="flex gap-2 items-center text-text-secondary pb-2">
{showRun && ( {showRun && (
<IconWrapper> <IconWrapper>
<Play className="size-3.5" data-play /> <Play className="size-3.5" data-play />
@ -94,8 +88,8 @@ export function ToolBar({
</IconWrapper> </IconWrapper>
)} )}
<IconWrapper <IconWrapper
onClick={deleteNode}
className="hover:text-state-error hover:border-state-error" className="hover:text-state-error hover:border-state-error"
onClick={deleteNode}
> >
<Trash2 className="size-3.5" /> <Trash2 className="size-3.5" />
</IconWrapper> </IconWrapper>

View File

@ -778,6 +778,7 @@ export const NoDebugOperatorsList = [
Operator.Splitter, Operator.Splitter,
Operator.HierarchicalMerger, Operator.HierarchicalMerger,
Operator.Extractor, Operator.Extractor,
Operator.Tool,
]; ];
export const NoCopyOperatorsList = [ export const NoCopyOperatorsList = [

View File

@ -1,3 +1,4 @@
import { Button } from '@/components/ui/button';
import { import {
Sheet, Sheet,
SheetContent, SheetContent,
@ -41,49 +42,69 @@ const FormSheet = ({
showSingleDebugDrawer, showSingleDebugDrawer,
}: IModalProps<any> & IProps) => { }: IModalProps<any> & IProps) => {
const operatorName: Operator = node?.data.label as Operator; const operatorName: Operator = node?.data.label as Operator;
const clickedToolId = useGraphStore((state) => state.clickedToolId); const { clickedToolId, getAgentToolById } = useGraphStore();
const currentFormMap = FormConfigMap[operatorName]; const currentFormMap = FormConfigMap[operatorName];
const OperatorForm = currentFormMap?.component ?? EmptyContent; const OperatorForm = currentFormMap?.component ?? EmptyContent;
const isMcp = useIsMcp(operatorName); const isMcp = useIsMcp(operatorName);
const { t } = useTranslate('flow'); const { t } = useTranslate('flow');
const { component_name: toolComponentName } = (getAgentToolById(
clickedToolId,
) ?? {}) as {
component_name: Operator;
name: string;
id: string;
};
return ( return (
<Sheet open={visible} modal={false}> <Sheet open={visible} modal={false}>
<SheetContent <SheetContent
className={cn('top-20 p-0 flex flex-col pb-20', { className={cn('top-20 p-0 flex flex-col pb-20 gap-0', {
'right-[clamp(0px,34%,620px)]': chatVisible, 'right-[clamp(0px,34%,620px)]': chatVisible,
})} })}
closeIcon={false} closeIcon={false}
> >
<SheetHeader> <SheetHeader>
<SheetTitle className="hidden"></SheetTitle> <SheetTitle className="hidden"></SheetTitle>
<section className="flex-col border-b py-2 px-5"> <section className="flex-col border-b pt-2 pb-4 px-5">
<div className="flex items-center gap-2 pb-3"> <div className="flex items-center gap-2 pb-3">
<OperatorIcon name={operatorName}></OperatorIcon> <OperatorIcon
name={toolComponentName || operatorName}
></OperatorIcon>
<TitleInput node={node}></TitleInput> <TitleInput node={node}></TitleInput>
{needsSingleStepDebugging(operatorName) && ( {needsSingleStepDebugging(operatorName) && (
<RunTooltip> <RunTooltip>
<CirclePlay <Button
className="size-3.5 cursor-pointer" variant="ghost"
size="icon"
className="size-6 !p-0 bg-transparent"
onClick={showSingleDebugDrawer} onClick={showSingleDebugDrawer}
/> >
<CirclePlay className="size-3.5 cursor-pointer" />
</Button>
</RunTooltip> </RunTooltip>
)} )}
<X onClick={hideModal} className="size-3.5 cursor-pointer" />
<Button
variant="ghost"
size="icon"
className="size-6 !p-0 bg-transparent"
onClick={hideModal}
>
<X className="size-3.5 cursor-pointer" />
</Button>
</div> </div>
{isMcp || (
<span className="text-text-secondary"> {!isMcp && (
<p className="text-text-secondary">
{t( {t(
`${lowerFirst(operatorName === Operator.Tool ? clickedToolId : operatorName)}Description`, `${lowerFirst(operatorName === Operator.Tool ? toolComponentName : operatorName)}Description`,
)} )}
</span> </p>
)} )}
</section> </section>
</SheetHeader> </SheetHeader>
<section className="pt-4 overflow-auto flex-1"> <section className="pt-4 overflow-auto flex-1">
{visible && ( {visible && (
<AgentFormContext.Provider value={node}> <AgentFormContext.Provider value={node}>

View File

@ -1,7 +1,8 @@
import { Button } from '@/components/ui/button';
import { Input } from '@/components/ui/input'; import { Input } from '@/components/ui/input';
import { RAGFlowNodeType } from '@/interfaces/database/agent'; import { RAGFlowNodeType } from '@/interfaces/database/agent';
import { PenLine } from 'lucide-react'; import { PenLine } from 'lucide-react';
import { useCallback, useState } from 'react'; import { useCallback, useLayoutEffect, useRef, useState } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { BeginId, Operator } from '../constant'; import { BeginId, Operator } from '../constant';
import { useHandleNodeNameChange } from '../hooks/use-change-node-name'; import { useHandleNodeNameChange } from '../hooks/use-change-node-name';
@ -13,47 +14,75 @@ type TitleInputProps = {
export function TitleInput({ node }: TitleInputProps) { export function TitleInput({ node }: TitleInputProps) {
const { t } = useTranslation(); const { t } = useTranslation();
const inputRef = useRef<HTMLInputElement>(null);
const { name, handleNameBlur, handleNameChange } = useHandleNodeNameChange({ const { name, handleNameBlur, handleNameChange } = useHandleNodeNameChange({
id: node?.id, id: node?.id,
data: node?.data, data: node?.data,
}); });
const operatorName: Operator = node?.data.label as Operator; const operatorName: Operator = node?.data.label as Operator;
const isMcp = useIsMcp(operatorName); const isMcp = useIsMcp(operatorName);
const [isEditingMode, setIsEditingMode] = useState(false); const [isEditingMode, setIsEditingMode] = useState(false);
const switchIsEditingMode = useCallback(() => { const switchIsEditingMode = useCallback(() => {
setIsEditingMode((prev) => !prev); setIsEditingMode((prev) => !prev);
}, []); }, []);
const handleBlur = useCallback(() => { const handleBlur = useCallback(
handleNameBlur(); (e: React.FocusEvent<HTMLInputElement>) => {
setIsEditingMode(false); if (handleNameBlur()) {
}, [handleNameBlur]); setIsEditingMode(false);
} else {
// Re-focus the input if name doesn't change successfully
e.target.focus();
e.target.select();
}
},
[handleNameBlur],
);
useLayoutEffect(() => {
if (isEditingMode && inputRef.current) {
inputRef.current.focus();
inputRef.current.select();
}
}, [isEditingMode]);
if (isMcp) { if (isMcp) {
return <div className="flex-1 text-base">MCP Config</div>; return <div className="flex-1 text-base">MCP Config</div>;
} }
return ( return (
<div className="flex items-center gap-1 flex-1"> // Give a fixed height to prevent layout shift when switching between edit and view modes
<div className="flex items-center gap-1 flex-1 h-8 mr-2">
{node?.id === BeginId ? ( {node?.id === BeginId ? (
<span>{t(BeginId)}</span> // Begin node is not editable
<span>{t(`flow.${BeginId}`)}</span>
) : isEditingMode ? ( ) : isEditingMode ? (
<Input <Input
ref={inputRef}
value={name} value={name}
onBlur={handleBlur} onBlur={handleBlur}
onKeyDown={(e) => {
// Support committing the value changes by pressing Enter
if (e.key === 'Enter') {
handleBlur(e as unknown as React.FocusEvent<HTMLInputElement>);
}
}}
onChange={handleNameChange} onChange={handleNameChange}
></Input> />
) : ( ) : (
<div className="flex items-center gap-2.5 text-base"> <div className="flex items-center gap-2.5 text-base">
{name} {name}
<PenLine
<Button
variant="transparent"
size="icon"
className="size-6 !p-0 border-0 bg-transparent"
onClick={switchIsEditingMode} onClick={switchIsEditingMode}
className="size-3.5 text-text-secondary cursor-pointer" >
/> <PenLine className="size-3.5 text-text-secondary cursor-pointer" />
</Button>
</div> </div>
)} )}
</div> </div>

View File

@ -1,4 +1,4 @@
import { BlockButton } from '@/components/ui/button'; import { BlockButton, Button } from '@/components/ui/button';
import { import {
Tooltip, Tooltip,
TooltipContent, TooltipContent,
@ -26,7 +26,7 @@ import { filterDownstreamAgentNodeIds } from '../../utils/filter-downstream-node
import { ToolPopover } from './tool-popover'; import { ToolPopover } from './tool-popover';
import { useDeleteAgentNodeMCP } from './tool-popover/use-update-mcp'; import { useDeleteAgentNodeMCP } from './tool-popover/use-update-mcp';
import { useDeleteAgentNodeTools } from './tool-popover/use-update-tools'; import { useDeleteAgentNodeTools } from './tool-popover/use-update-tools';
import { useGetAgentMCPIds, useGetAgentToolNames } from './use-get-tools'; import { useGetAgentMCPIds, useGetNodeTools } from './use-get-tools';
type ToolCardProps = React.HTMLAttributes<HTMLLIElement> & type ToolCardProps = React.HTMLAttributes<HTMLLIElement> &
PropsWithChildren & { PropsWithChildren & {
@ -79,20 +79,33 @@ function ActionButton<T>({ deleteRecord, record, edit }: ActionButtonProps<T>) {
deleteRecord(record); deleteRecord(record);
}, [deleteRecord, record]); }, [deleteRecord, record]);
// Wrapping into buttons to solve the issue that clicking icon occasionally not jumping to corresponding form
return ( return (
<div className="flex items-center gap-4 text-text-secondary"> <div className="flex items-center gap-4 text-text-secondary">
<PencilLine <Button
className="size-3.5 cursor-pointer" variant="transparent"
size="icon"
className="size-3.5 !bg-transparent !border-none"
data-tool={record} data-tool={record}
onClick={edit} onClick={edit}
/> >
<X className="size-3.5 cursor-pointer" onClick={handleDelete} /> <PencilLine className="size-full" />
</Button>
<Button
variant="transparent"
size="icon"
className="size-3.5 !bg-transparent !border-none"
onClick={handleDelete}
>
<X className="size-full" />
</Button>
</div> </div>
); );
} }
export function AgentTools() { export function AgentTools() {
const { toolNames } = useGetAgentToolNames(); const tools = useGetNodeTools();
const { deleteNodeTool } = useDeleteAgentNodeTools(); const { deleteNodeTool } = useDeleteAgentNodeTools();
const { mcpIds } = useGetAgentMCPIds(); const { mcpIds } = useGetAgentMCPIds();
const { findMcpById } = useFindMcpById(); const { findMcpById } = useFindMcpById();
@ -105,6 +118,7 @@ export function AgentTools() {
const handleEdit: MouseEventHandler<SVGSVGElement> = useCallback( const handleEdit: MouseEventHandler<SVGSVGElement> = useCallback(
(e) => { (e) => {
const toolNodeId = findAgentToolNodeById(clickedNodeId); const toolNodeId = findAgentToolNodeById(clickedNodeId);
if (toolNodeId) { if (toolNodeId) {
selectNodeIds([toolNodeId]); selectNodeIds([toolNodeId]);
showFormDrawer(e, toolNodeId); showFormDrawer(e, toolNodeId);
@ -117,19 +131,20 @@ export function AgentTools() {
<section className="space-y-2.5"> <section className="space-y-2.5">
<span className="text-text-secondary text-sm">{t('flow.tools')}</span> <span className="text-text-secondary text-sm">{t('flow.tools')}</span>
<ul className="space-y-2.5"> <ul className="space-y-2.5">
{toolNames.map((x) => ( {tools.map(({ id, component_name, name }) => (
<ToolCard key={x} isNodeTool={false}> <ToolCard key={id} isNodeTool={false}>
<div className="flex gap-2 items-center"> <div className="flex gap-2 items-center">
<OperatorIcon name={x as Operator}></OperatorIcon> <OperatorIcon name={component_name as Operator}></OperatorIcon>
{x} {component_name === Operator.Retrieval ? name : component_name}
</div> </div>
<ActionButton <ActionButton
record={x} record={id}
deleteRecord={deleteNodeTool(x)} deleteRecord={deleteNodeTool(id)}
edit={handleEdit} edit={handleEdit}
></ActionButton> />
</ToolCard> </ToolCard>
))} ))}
{mcpIds.map((id) => ( {mcpIds.map((id) => (
<ToolCard key={id} isNodeTool={false}> <ToolCard key={id} isNodeTool={false}>
{findMcpById(id)?.name} {findMcpById(id)?.name}

View File

@ -9,21 +9,19 @@ import { AgentFormContext, AgentInstanceContext } from '@/pages/agent/context';
import useGraphStore from '@/pages/agent/store'; import useGraphStore from '@/pages/agent/store';
import { Position } from '@xyflow/react'; import { Position } from '@xyflow/react';
import { t } from 'i18next'; import { t } from 'i18next';
import { PropsWithChildren, useCallback, useContext, useEffect } from 'react'; import { useContext, useEffect } from 'react';
import { useGetAgentMCPIds, useGetAgentToolNames } from '../use-get-tools'; import { useGetAgentMCPIds, useGetAgentToolNames } from '../use-get-tools';
import { MCPCommand, ToolCommand } from './tool-command'; import { MCPCommand, ToolCommand } from './tool-command';
import { useUpdateAgentNodeMCP } from './use-update-mcp'; import { useUpdateAgentNodeMCP } from './use-update-mcp';
import { useUpdateAgentNodeTools } from './use-update-tools';
enum ToolType { enum ToolType {
Common = 'common', Common = 'common',
MCP = 'mcp', MCP = 'mcp',
} }
export function ToolPopover({ children }: PropsWithChildren) { export function ToolPopover({ children }: React.PropsWithChildren) {
const { addCanvasNode } = useContext(AgentInstanceContext); const { addCanvasNode } = useContext(AgentInstanceContext);
const node = useContext(AgentFormContext); const node = useContext(AgentFormContext);
const { updateNodeTools } = useUpdateAgentNodeTools();
const { toolNames } = useGetAgentToolNames(); const { toolNames } = useGetAgentToolNames();
const deleteAgentToolNodeById = useGraphStore( const deleteAgentToolNodeById = useGraphStore(
(state) => state.deleteAgentToolNodeById, (state) => state.deleteAgentToolNodeById,
@ -31,15 +29,6 @@ export function ToolPopover({ children }: PropsWithChildren) {
const { mcpIds } = useGetAgentMCPIds(); const { mcpIds } = useGetAgentMCPIds();
const { updateNodeMCP } = useUpdateAgentNodeMCP(); const { updateNodeMCP } = useUpdateAgentNodeMCP();
const handleChange = useCallback(
(value: string[]) => {
if (Array.isArray(value) && node?.id) {
updateNodeTools(value);
}
},
[node?.id, updateNodeTools],
);
useEffect(() => { useEffect(() => {
const total = toolNames.length + mcpIds.length; const total = toolNames.length + mcpIds.length;
if (node?.id) { if (node?.id) {
@ -72,10 +61,7 @@ export function ToolPopover({ children }: PropsWithChildren) {
<TabsTrigger value={ToolType.MCP}>MCP</TabsTrigger> <TabsTrigger value={ToolType.MCP}>MCP</TabsTrigger>
</TabsList> </TabsList>
<TabsContent value={ToolType.Common}> <TabsContent value={ToolType.Common}>
<ToolCommand <ToolCommand />
onChange={handleChange}
value={toolNames}
></ToolCommand>
</TabsContent> </TabsContent>
<TabsContent value={ToolType.MCP}> <TabsContent value={ToolType.MCP}>
<MCPCommand value={mcpIds} onChange={updateNodeMCP}></MCPCommand> <MCPCommand value={mcpIds} onChange={updateNodeMCP}></MCPCommand>

View File

@ -12,8 +12,10 @@ import { Operator } from '@/pages/agent/constant';
import OperatorIcon from '@/pages/agent/operator-icon'; import OperatorIcon from '@/pages/agent/operator-icon';
import { t } from 'i18next'; import { t } from 'i18next';
import { lowerFirst } from 'lodash'; import { lowerFirst } from 'lodash';
import { LucidePlus } from 'lucide-react';
import { PropsWithChildren, useCallback, useEffect, useState } from 'react'; import { PropsWithChildren, useCallback, useEffect, useState } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { useGetNodeTools, useUpdateAgentNodeTools } from './use-update-tools';
const Menus = [ const Menus = [
{ {
@ -66,7 +68,13 @@ function ToolCommandItem({
}: ToolCommandItemProps & PropsWithChildren) { }: ToolCommandItemProps & PropsWithChildren) {
return ( return (
<CommandItem className="cursor-pointer" onSelect={() => toggleOption(id)}> <CommandItem className="cursor-pointer" onSelect={() => toggleOption(id)}>
<Checkbox checked={isSelected} /> {id === Operator.Retrieval ? (
<span>
<LucidePlus className="size-4" />
</span>
) : (
<Checkbox checked={isSelected} />
)}
{children} {children}
</CommandItem> </CommandItem>
); );
@ -98,12 +106,12 @@ function useHandleSelectChange({ onChange, value }: ToolCommandProps) {
}; };
} }
// eslint-disable-next-line
export function ToolCommand({ value, onChange }: ToolCommandProps) { export function ToolCommand({ value, onChange }: ToolCommandProps) {
const { t } = useTranslation(); const { t } = useTranslation();
const { toggleOption, currentValue } = useHandleSelectChange({
onChange, const currentValue = useGetNodeTools();
value, const { updateNodeTools } = useUpdateAgentNodeTools();
});
return ( return (
<Command> <Command>
@ -112,22 +120,17 @@ export function ToolCommand({ value, onChange }: ToolCommandProps) {
<CommandEmpty>No results found.</CommandEmpty> <CommandEmpty>No results found.</CommandEmpty>
{Menus.map((x) => ( {Menus.map((x) => (
<CommandGroup heading={x.label} key={x.label}> <CommandGroup heading={x.label} key={x.label}>
{x.list.map((y) => { {x.list.map((y) => (
const isSelected = currentValue.includes(y); <ToolCommandItem
return ( key={y}
<ToolCommandItem id={y}
key={y} toggleOption={updateNodeTools}
id={y} isSelected={currentValue.some((x) => x.component_name === y)}
toggleOption={toggleOption} >
isSelected={isSelected} <OperatorIcon name={y as Operator}></OperatorIcon>
> <span>{t(`flow.${lowerFirst(y)}`)}</span>
<> </ToolCommandItem>
<OperatorIcon name={y as Operator}></OperatorIcon> ))}
<span>{t(`flow.${lowerFirst(y)}`)}</span>
</>
</ToolCommandItem>
);
})}
</CommandGroup> </CommandGroup>
))} ))}
</CommandList> </CommandList>

View File

@ -16,32 +16,59 @@ export function useGetNodeTools() {
} }
export function useUpdateAgentNodeTools() { export function useUpdateAgentNodeTools() {
const { updateNodeForm } = useGraphStore((state) => state); const { generateAgentToolName, generateAgentToolId, updateNodeForm } =
const node = useContext(AgentFormContext); useGraphStore((state) => state);
const node = useContext(AgentFormContext)!;
const tools = useGetNodeTools(); const tools = useGetNodeTools();
const { initializeAgentToolValues } = useAgentToolInitialValues(); const { initializeAgentToolValues } = useAgentToolInitialValues();
const updateNodeTools = useCallback( const updateNodeTools = useCallback(
(value: string[]) => { (value: string) => {
if (node?.id) { if (!node?.id) return;
const nextValue = value.reduce<IAgentForm['tools']>((pre, cur) => {
const tool = tools.find((x) => x.component_name === cur);
pre.push(
tool
? tool
: {
component_name: cur,
name: cur,
params: initializeAgentToolValues(cur as Operator),
},
);
return pre;
}, []);
updateNodeForm(node?.id, nextValue, ['tools']); // Append
if (value === Operator.Retrieval) {
updateNodeForm(
node.id,
[
...tools,
{
component_name: value,
name: generateAgentToolName(node.id, value),
params: initializeAgentToolValues(value as Operator),
id: generateAgentToolId(value),
},
],
['tools'],
);
}
// Toggle
else {
updateNodeForm(
node.id,
tools.some((x) => x.component_name === value)
? tools.filter((x) => x.component_name !== value)
: [
...tools,
{
component_name: value,
name: value,
params: initializeAgentToolValues(value as Operator),
id: generateAgentToolId(value),
},
],
['tools'],
);
} }
}, },
[initializeAgentToolValues, node?.id, tools, updateNodeForm], [
generateAgentToolName,
generateAgentToolId,
initializeAgentToolValues,
node?.id,
tools,
updateNodeForm,
],
); );
return { updateNodeTools }; return { updateNodeTools };
@ -53,8 +80,9 @@ export function useDeleteAgentNodeTools() {
const node = useContext(AgentFormContext); const node = useContext(AgentFormContext);
const deleteNodeTool = useCallback( const deleteNodeTool = useCallback(
(value: string) => () => { (toolId: string) => () => {
const nextTools = tools.filter((x) => x.component_name !== value); const nextTools = tools.filter((x) => x.id !== toolId);
if (node?.id) { if (node?.id) {
updateNodeForm(node?.id, nextTools, ['tools']); updateNodeForm(node?.id, nextTools, ['tools']);
} }

View File

@ -3,6 +3,11 @@ import { get } from 'lodash';
import { useContext, useMemo } from 'react'; import { useContext, useMemo } from 'react';
import { AgentFormContext } from '../../context'; import { AgentFormContext } from '../../context';
export function useGetNodeTools() {
const node = useContext(AgentFormContext);
return get(node, 'data.form.tools', []) as IAgentForm['tools'];
}
export function useGetAgentToolNames() { export function useGetAgentToolNames() {
const node = useContext(AgentFormContext); const node = useContext(AgentFormContext);

View File

@ -7,9 +7,11 @@ const EmptyContent = () => <div></div>;
function ToolForm() { function ToolForm() {
const clickedToolId = useGraphStore((state) => state.clickedToolId); const clickedToolId = useGraphStore((state) => state.clickedToolId);
const { getAgentToolById } = useGraphStore();
const tool = getAgentToolById(clickedToolId);
const ToolForm = const ToolForm =
ToolFormConfigMap[clickedToolId as keyof typeof ToolFormConfigMap] ?? ToolFormConfigMap[tool?.component_name as keyof typeof ToolFormConfigMap] ??
MCPForm ?? MCPForm ??
EmptyContent; EmptyContent;

View File

@ -3,7 +3,6 @@ import { useMemo } from 'react';
import { Operator } from '../../constant'; import { Operator } from '../../constant';
import { useAgentToolInitialValues } from '../../hooks/use-agent-tool-initial-values'; import { useAgentToolInitialValues } from '../../hooks/use-agent-tool-initial-values';
import useGraphStore from '../../store'; import useGraphStore from '../../store';
import { getAgentNodeTools } from '../../utils';
export enum SearchDepth { export enum SearchDepth {
Basic = 'basic', Basic = 'basic',
@ -16,22 +15,23 @@ export enum Topic {
} }
export function useValues() { export function useValues() {
const { clickedToolId, clickedNodeId, findUpstreamNodeById } = useGraphStore( const {
(state) => state, clickedToolId,
); clickedNodeId,
findUpstreamNodeById,
getAgentToolById,
} = useGraphStore();
const { initializeAgentToolValues } = useAgentToolInitialValues(); const { initializeAgentToolValues } = useAgentToolInitialValues();
const values = useMemo(() => { const values = useMemo(() => {
const agentNode = findUpstreamNodeById(clickedNodeId); const agentNode = findUpstreamNodeById(clickedNodeId);
const tools = getAgentNodeTools(agentNode); const tool = getAgentToolById(clickedToolId, agentNode!);
const formData = tool?.params;
const formData = tools.find(
(x) => x.component_name === clickedToolId,
)?.params;
if (isEmpty(formData)) { if (isEmpty(formData)) {
const defaultValues = initializeAgentToolValues( const defaultValues = initializeAgentToolValues(
clickedNodeId as Operator, (tool?.component_name || clickedNodeId) as Operator,
); );
return defaultValues; return defaultValues;
@ -44,6 +44,7 @@ export function useValues() {
clickedNodeId, clickedNodeId,
clickedToolId, clickedToolId,
findUpstreamNodeById, findUpstreamNodeById,
getAgentToolById,
initializeAgentToolValues, initializeAgentToolValues,
]); ]);

View File

@ -1,39 +1,38 @@
import { useEffect } from 'react'; import { useEffect } from 'react';
import { UseFormReturn, useWatch } from 'react-hook-form'; import { UseFormReturn, useWatch } from 'react-hook-form';
import useGraphStore from '../../store'; import useGraphStore from '../../store';
import { getAgentNodeTools } from '../../utils';
export function useWatchFormChange(form?: UseFormReturn<any>) { export function useWatchFormChange(form?: UseFormReturn<any>) {
let values = useWatch({ control: form?.control }); let values = useWatch({ control: form?.control });
const { clickedToolId, clickedNodeId, findUpstreamNodeById, updateNodeForm } =
useGraphStore((state) => state); const {
clickedToolId,
clickedNodeId,
findUpstreamNodeById,
getAgentToolById,
updateAgentToolById,
updateNodeForm,
} = useGraphStore();
useEffect(() => { useEffect(() => {
const agentNode = findUpstreamNodeById(clickedNodeId); const agentNode = findUpstreamNodeById(clickedNodeId);
// Manually triggered form updates are synchronized to the canvas // Manually triggered form updates are synchronized to the canvas
if (agentNode && form?.formState.isDirty) { if (agentNode && form?.formState.isDirty) {
const agentNodeId = agentNode?.id; updateAgentToolById(agentNode, clickedToolId, {
const tools = getAgentNodeTools(agentNode); params: {
...(values ?? {}),
values = form?.getValues(); },
const nextTools = tools.map((x) => {
if (x.component_name === clickedToolId) {
return {
...x,
params: {
...values,
},
};
}
return x;
}); });
const nextValues = {
...(agentNode?.data?.form ?? {}),
tools: nextTools,
};
updateNodeForm(agentNodeId, nextValues);
} }
}, [form?.formState.isDirty, updateNodeForm, values]); }, [
clickedNodeId,
clickedToolId,
findUpstreamNodeById,
form,
form?.formState.isDirty,
getAgentToolById,
updateAgentToolById,
updateNodeForm,
values,
]);
} }

View File

@ -6,14 +6,13 @@ import {
SetStateAction, SetStateAction,
useCallback, useCallback,
useEffect, useEffect,
useMemo,
useState, useState,
} from 'react'; } from 'react';
import { Operator } from '../constant'; import { Operator } from '../constant';
import useGraphStore from '../store'; import useGraphStore from '../store';
import { getAgentNodeTools } from '../utils'; import { getAgentNodeTools } from '../utils';
export function useHandleTooNodeNameChange({ export function useHandleToolNodeNameChange({
id, id,
name, name,
setName, setName,
@ -22,48 +21,44 @@ export function useHandleTooNodeNameChange({
name?: string; name?: string;
setName: Dispatch<SetStateAction<string>>; setName: Dispatch<SetStateAction<string>>;
}) { }) {
const { clickedToolId, findUpstreamNodeById, updateNodeForm } = useGraphStore( const {
(state) => state, clickedToolId,
); findUpstreamNodeById,
const agentNode = findUpstreamNodeById(id); getAgentToolById,
updateAgentToolById,
} = useGraphStore((state) => state);
const agentNode = findUpstreamNodeById(id)!;
const tools = getAgentNodeTools(agentNode); const tools = getAgentNodeTools(agentNode);
const previousName = getAgentToolById(clickedToolId, agentNode)?.name;
const previousName = useMemo(() => {
const tool = tools.find((x) => x.component_name === clickedToolId);
return tool?.name || tool?.component_name;
}, [clickedToolId, tools]);
const handleToolNameBlur = useCallback(() => { const handleToolNameBlur = useCallback(() => {
const trimmedName = trim(name); const trimmedName = trim(name);
const existsSameName = tools.some((x) => x.name === trimmedName); const existsSameName = tools.some((x) => x.name === trimmedName);
if (trimmedName === '' || existsSameName) {
if (existsSameName && previousName !== name) { // Not changed
message.error('The name cannot be repeated'); if (trimmedName === '') {
}
setName(previousName || ''); setName(previousName || '');
return; return true;
}
if (existsSameName && previousName !== name) {
message.error('The name cannot be repeated');
return false;
} }
if (agentNode?.id) { if (agentNode?.id) {
const nextTools = tools.map((x) => { updateAgentToolById(agentNode, clickedToolId, { name });
if (x.component_name === clickedToolId) {
return {
...x,
name,
};
}
return x;
});
updateNodeForm(agentNode?.id, nextTools, ['tools']);
} }
return true;
}, [ }, [
agentNode?.id, agentNode,
clickedToolId, clickedToolId,
name, name,
previousName, previousName,
setName, setName,
tools, tools,
updateNodeForm, updateAgentToolById,
]); ]);
return { handleToolNameBlur, previousToolName: previousName }; return { handleToolNameBlur, previousToolName: previousName };
@ -83,28 +78,35 @@ export const useHandleNodeNameChange = ({
const previousName = data?.name; const previousName = data?.name;
const isToolNode = getOperatorTypeFromId(id) === Operator.Tool; const isToolNode = getOperatorTypeFromId(id) === Operator.Tool;
const { handleToolNameBlur, previousToolName } = useHandleTooNodeNameChange({ const { handleToolNameBlur, previousToolName } = useHandleToolNodeNameChange({
id, id,
name, name,
setName, setName,
}); });
const handleNameBlur = useCallback(() => { const handleNameBlur = useCallback(() => {
const trimmedName = trim(name);
const existsSameName = nodes.some((x) => x.data.name === name); const existsSameName = nodes.some((x) => x.data.name === name);
if (trim(name) === '' || existsSameName) {
if (existsSameName && previousName !== name) { // Not changed
message.error('The name cannot be repeated'); if (!trimmedName) {
} setName(previousName || '');
setName(previousName); return true;
return; }
if (existsSameName && previousName !== name) {
message.error('The name cannot be repeated');
return false;
} }
if (id) { if (id) {
updateNodeName(id, name); updateNodeName(id, name);
} }
return true;
}, [name, id, updateNodeName, previousName, nodes]); }, [name, id, updateNodeName, previousName, nodes]);
const handleNameChange = useCallback((e: ChangeEvent<any>) => { const handleNameChange = useCallback((e: ChangeEvent<HTMLInputElement>) => {
setName(e.target.value); setName(e.target.value);
}, []); }, []);

View File

@ -2,10 +2,12 @@ import { Operator } from '../constant';
import useGraphStore from '../store'; import useGraphStore from '../store';
export function useIsMcp(operatorName: Operator) { export function useIsMcp(operatorName: Operator) {
const clickedToolId = useGraphStore((state) => state.clickedToolId); const { clickedToolId, getAgentToolById } = useGraphStore();
const { component_name: toolName } = getAgentToolById(clickedToolId) ?? {};
return ( return (
operatorName === Operator.Tool && operatorName === Operator.Tool &&
Object.values(Operator).every((x) => x !== clickedToolId) Object.values(Operator).every((x) => x !== toolName)
); );
} }

View File

@ -24,7 +24,9 @@ export const useShowFormDrawer = () => {
const handleShow = useCallback( const handleShow = useCallback(
(e: React.MouseEvent<Element>, nodeId: string) => { (e: React.MouseEvent<Element>, nodeId: string) => {
const tool = get(e.target, 'dataset.tool'); const toolId = (e.target as HTMLElement).dataset.toolId;
const tool = (e.target as HTMLElement).dataset.tool;
// TODO: Operator type judgment should be used // TODO: Operator type judgment should be used
const operatorType = getOperatorTypeFromId(nodeId); const operatorType = getOperatorTypeFromId(nodeId);
if ( if (
@ -36,7 +38,8 @@ export const useShowFormDrawer = () => {
return; return;
} }
setClickedNodeId(nodeId); setClickedNodeId(nodeId);
setClickedToolId(tool); // Guess this could gracefully handle the case where the tool id is not provided?
setClickedToolId(toolId || tool);
showFormDrawer(); showFormDrawer();
}, },
[getOperatorTypeFromId, setClickedNodeId, setClickedToolId, showFormDrawer], [getOperatorTypeFromId, setClickedNodeId, setClickedToolId, showFormDrawer],

View File

@ -1,4 +1,5 @@
import { RAGFlowNodeType } from '@/interfaces/database/flow'; import type { IAgentForm } from '@/interfaces/database/agent';
import { IAgentNode, RAGFlowNodeType } from '@/interfaces/database/flow';
import type {} from '@redux-devtools/extension'; import type {} from '@redux-devtools/extension';
import { import {
Connection, Connection,
@ -14,10 +15,15 @@ import {
applyEdgeChanges, applyEdgeChanges,
applyNodeChanges, applyNodeChanges,
} from '@xyflow/react'; } from '@xyflow/react';
import { cloneDeep, omit } from 'lodash'; import humanId from 'human-id';
import differenceWith from 'lodash/differenceWith'; import {
import intersectionWith from 'lodash/intersectionWith'; cloneDeep,
import lodashSet from 'lodash/set'; differenceWith,
intersectionWith,
get as lodashGet,
set as lodashSet,
omit,
} from 'lodash';
import { create } from 'zustand'; import { create } from 'zustand';
import { devtools } from 'zustand/middleware'; import { devtools } from 'zustand/middleware';
import { immer } from 'zustand/middleware/immer'; import { immer } from 'zustand/middleware/immer';
@ -26,12 +32,26 @@ import {
duplicateNodeForm, duplicateNodeForm,
generateDuplicateNode, generateDuplicateNode,
generateNodeNamesWithIncreasingIndex, generateNodeNamesWithIncreasingIndex,
getAgentNodeTools,
getOperatorIndex, getOperatorIndex,
isEdgeEqual, isEdgeEqual,
mapEdgeMouseEvent, mapEdgeMouseEvent,
} from './utils'; } from './utils';
import { deleteAllDownstreamAgentsAndTool } from './utils/delete-node'; import { deleteAllDownstreamAgentsAndTool } from './utils/delete-node';
type IAgentTool = IAgentForm['tools'][number];
interface GetAgentToolByIdFunc {
(id: string): IAgentTool | undefined;
(id: string, agentNode: RAGFlowNodeType): IAgentTool | undefined;
(id: string, agentNodeId: string): IAgentTool | undefined;
}
interface UpdateAgentToolByIdFunc {
(agentNode: RAGFlowNodeType, id: string, value?: Partial<IAgentTool>): void;
(agentNodeId: string, id: string, value?: Partial<IAgentTool>): void;
}
export type RFState = { export type RFState = {
nodes: RAGFlowNodeType[]; nodes: RAGFlowNodeType[];
edges: Edge[]; edges: Edge[];
@ -81,6 +101,11 @@ export type RFState = {
getParentIdById: (id?: string | null) => string | undefined; getParentIdById: (id?: string | null) => string | undefined;
updateNodeName: (id: string, name: string) => void; updateNodeName: (id: string, name: string) => void;
generateNodeName: (name: string) => string; generateNodeName: (name: string) => string;
generateAgentToolName: (id: string, name: string) => string;
generateAgentToolId: (prefix: string) => string;
getAllAgentTools: () => IAgentTool[];
getAgentToolById: GetAgentToolByIdFunc;
updateAgentToolById: UpdateAgentToolByIdFunc;
setClickedNodeId: (id?: string) => void; setClickedNodeId: (id?: string) => void;
setClickedToolId: (id?: string) => void; setClickedToolId: (id?: string) => void;
findUpstreamNodeById: (id?: string | null) => RAGFlowNodeType | undefined; findUpstreamNodeById: (id?: string | null) => RAGFlowNodeType | undefined;
@ -501,6 +526,95 @@ const useGraphStore = create<RFState>()(
return generateNodeNamesWithIncreasingIndex(name, nodes); return generateNodeNamesWithIncreasingIndex(name, nodes);
}, },
generateAgentToolName: (id: string, name: string) => {
const node = get().nodes.find(
(x) => x.id === id,
) as IAgentNode<IAgentForm>;
if (!node) {
return '';
}
const tools = (node.data.form!.tools as any[]).filter(
(x) => x.component_name === name,
);
const lastIndex = tools.length
? tools
.map((x) => {
const idx = x.name.match(/(\d+)$/)?.[1];
return idx && isNaN(idx) ? -1 : Number(idx);
})
.sort((a, b) => a - b)
.at(-1) ?? -1
: -1;
return `${name}_${lastIndex + 1}`;
},
generateAgentToolId: (prefix: string) => {
const allAgentToolIds = get()
.getAllAgentTools()
.map((t) => t.id || t.component_name);
let id: string;
// Loop for avoiding id collisions
do {
id = `${prefix}:${humanId()}`;
} while (allAgentToolIds.includes(id));
return id;
},
getAllAgentTools: () => {
return get()
.nodes.filter((n) => n?.data?.label === Operator.Agent)
.flatMap((n) => n?.data?.form?.tools);
},
getAgentToolById: (
id: string,
nodeOrNodeId?: RAGFlowNodeType | string,
) => {
// eslint-disable-next-line eqeqeq
const tools =
nodeOrNodeId != null
? getAgentNodeTools(
typeof nodeOrNodeId === 'string'
? get().getNode(nodeOrNodeId)
: nodeOrNodeId,
)
: get().getAllAgentTools();
// For backward compatibility
return tools.find((t) => (t.id || t.component_name) === id);
},
updateAgentToolById: (
nodeOrNodeId: RAGFlowNodeType | string,
id: string,
value?: Partial<IAgentTool>,
) => {
const { getNode, updateNodeForm } = get();
const agentNode =
typeof nodeOrNodeId === 'string'
? getNode(nodeOrNodeId)
: nodeOrNodeId;
if (!agentNode) {
return;
}
const toolIndex = getAgentNodeTools(agentNode).findIndex(
(t) => (t.id || t.component_name) === id,
);
updateNodeForm(
agentNode.id,
{
...lodashGet(agentNode.data.form, ['tools', toolIndex], {}),
...(value ?? {}),
},
['tools', toolIndex],
);
},
setClickedToolId: (id?: string) => { setClickedToolId: (id?: string) => {
set({ clickedToolId: id }); set({ clickedToolId: id });
}, },

View File

@ -120,13 +120,17 @@ function buildAgentTools(edges: Edge[], nodes: Node[], nodeId: string) {
return { return {
component_name: Operator.Agent, component_name: Operator.Agent,
id, id,
name: name as string, // Cast name to string and provide fallback name,
params: { ...formData }, params: { ...formData },
}; };
}), }),
); );
} }
return { params, name: node?.data.name, id: node?.id }; return { params, name: node?.data.name, id: node?.id } as {
params: IAgentForm;
name: string;
id: string;
};
} }
function filterTargetsBySourceHandleId(edges: Edge[], handleId: string) { function filterTargetsBySourceHandleId(edges: Edge[], handleId: string) {