From 8fb6b5d9451a7a92aae660bb0205ca66a0b32df8 Mon Sep 17 00:00:00 2001 From: balibabu Date: Mon, 9 Jun 2025 19:19:48 +0800 Subject: [PATCH] Feat: Add agent operator node from agent form #3221 (#8144) ### What problem does this PR solve? Feat: Add agent operator node from agent form #3221 ### Type of change - [x] New Feature (non-breaking change which adds functionality) --- web/src/pages/agent/canvas/index.tsx | 25 +- .../pages/agent/canvas/node/agent-node.tsx | 13 + web/src/pages/agent/context.ts | 10 + web/src/pages/agent/form/agent-form/index.tsx | 11 +- web/src/pages/agent/hooks.tsx | 2 +- web/src/pages/agent/hooks/use-add-node.ts | 224 ++++++++++++++++++ 6 files changed, 272 insertions(+), 13 deletions(-) create mode 100644 web/src/pages/agent/hooks/use-add-node.ts diff --git a/web/src/pages/agent/canvas/index.tsx b/web/src/pages/agent/canvas/index.tsx index fbc27b097..9fbd7177b 100644 --- a/web/src/pages/agent/canvas/index.tsx +++ b/web/src/pages/agent/canvas/index.tsx @@ -6,6 +6,7 @@ import { } from '@xyflow/react'; import '@xyflow/react/dist/style.css'; import { ChatSheet } from '../chat/chat-sheet'; +import { AgentInstanceContext } from '../context'; import FormSheet from '../form-sheet/next'; import { useHandleDrop, @@ -13,6 +14,7 @@ import { useValidateConnection, useWatchNodeFormDataChange, } from '../hooks'; +import { useAddNode } from '../hooks/use-add-node'; import { useBeforeDelete } from '../hooks/use-before-delete'; import { useShowDrawer } from '../hooks/use-show-drawer'; import RunSheet from '../run-sheet'; @@ -77,7 +79,8 @@ function AgentCanvas({ drawerVisible, hideDrawer }: IProps) { } = useSelectCanvasData(); const isValidConnection = useValidateConnection(); - const { onDrop, onDragOver, setReactFlowInstance } = useHandleDrop(); + const { onDrop, onDragOver, setReactFlowInstance, reactFlowInstance } = + useHandleDrop(); const { onNodeClick, @@ -101,6 +104,8 @@ function AgentCanvas({ drawerVisible, hideDrawer }: IProps) { useWatchNodeFormDataChange(); + const { addCanvasNode } = useAddNode(reactFlowInstance); + return (
{formDrawerVisible && ( - + + + )} {chatVisible && ( + + ); diff --git a/web/src/pages/agent/context.ts b/web/src/pages/agent/context.ts index 6dc1cd476..ddb2d6ddb 100644 --- a/web/src/pages/agent/context.ts +++ b/web/src/pages/agent/context.ts @@ -1,6 +1,16 @@ import { RAGFlowNodeType } from '@/interfaces/database/flow'; import { createContext } from 'react'; +import { useAddNode } from './hooks/use-add-node'; export const AgentFormContext = createContext( undefined, ); + +type AgentInstanceContextType = Pick< + ReturnType, + 'addCanvasNode' +>; + +export const AgentInstanceContext = createContext( + {} as AgentInstanceContextType, +); diff --git a/web/src/pages/agent/form/agent-form/index.tsx b/web/src/pages/agent/form/agent-form/index.tsx index edffbd2a7..82eb2cd31 100644 --- a/web/src/pages/agent/form/agent-form/index.tsx +++ b/web/src/pages/agent/form/agent-form/index.tsx @@ -11,11 +11,12 @@ import { FormLabel, } from '@/components/ui/form'; import { zodResolver } from '@hookform/resolvers/zod'; -import { useMemo } from 'react'; +import { useContext, useMemo } from 'react'; import { useForm } from 'react-hook-form'; import { useTranslation } from 'react-i18next'; import { z } from 'zod'; -import { initialAgentValues } from '../../constant'; +import { Operator, initialAgentValues } from '../../constant'; +import { AgentInstanceContext } from '../../context'; import { INextOperatorForm } from '../../interface'; import { Output } from '../components/output'; import { PromptEditor } from '../components/prompt-editor'; @@ -62,6 +63,8 @@ const AgentForm = ({ node }: INextOperatorForm) => { useWatchFormChange(node?.id, form); + const { addCanvasNode } = useContext(AgentInstanceContext); + return (
{ )} /> - Add Agent + + Add Agent +
diff --git a/web/src/pages/agent/hooks.tsx b/web/src/pages/agent/hooks.tsx index 0257e19dd..a868cfa14 100644 --- a/web/src/pages/agent/hooks.tsx +++ b/web/src/pages/agent/hooks.tsx @@ -263,7 +263,7 @@ export const useHandleDrop = () => { [reactFlowInstance, getNodeName, nodes, initializeOperatorParams, addNode], ); - return { onDrop, onDragOver, setReactFlowInstance }; + return { onDrop, onDragOver, setReactFlowInstance, reactFlowInstance }; }; export const useHandleFormValuesChange = ( diff --git a/web/src/pages/agent/hooks/use-add-node.ts b/web/src/pages/agent/hooks/use-add-node.ts new file mode 100644 index 000000000..b8166c92b --- /dev/null +++ b/web/src/pages/agent/hooks/use-add-node.ts @@ -0,0 +1,224 @@ +import { useFetchModelId } from '@/hooks/logic-hooks'; +import { Node, Position, ReactFlowInstance } from '@xyflow/react'; +import humanId from 'human-id'; +import { lowerFirst } from 'lodash'; +import { useCallback, useMemo } from 'react'; +import { useTranslation } from 'react-i18next'; +import { + NodeMap, + Operator, + initialAgentValues, + initialAkShareValues, + initialArXivValues, + initialBaiduFanyiValues, + initialBaiduValues, + initialBeginValues, + initialBingValues, + initialCategorizeValues, + initialCodeValues, + initialConcentratorValues, + initialCrawlerValues, + initialDeepLValues, + initialDuckValues, + initialEmailValues, + initialExeSqlValues, + initialGenerateValues, + initialGithubValues, + initialGoogleScholarValues, + initialGoogleValues, + initialInvokeValues, + initialIterationValues, + initialJin10Values, + initialKeywordExtractValues, + initialMessageValues, + initialNoteValues, + initialPubMedValues, + initialQWeatherValues, + initialRelevantValues, + initialRetrievalValues, + initialRewriteQuestionValues, + initialSwitchValues, + initialTemplateValues, + initialTuShareValues, + initialWaitingDialogueValues, + initialWenCaiValues, + initialWikipediaValues, + initialYahooFinanceValues, +} from '../constant'; +import useGraphStore from '../store'; +import { + generateNodeNamesWithIncreasingIndex, + getNodeDragHandle, + getRelativePositionToIterationNode, +} from '../utils'; + +export const useInitializeOperatorParams = () => { + const llmId = useFetchModelId(); + + const initialFormValuesMap = useMemo(() => { + return { + [Operator.Begin]: initialBeginValues, + [Operator.Retrieval]: initialRetrievalValues, + [Operator.Generate]: { ...initialGenerateValues, llm_id: llmId }, + [Operator.Answer]: {}, + [Operator.Categorize]: { ...initialCategorizeValues, llm_id: llmId }, + [Operator.Relevant]: { ...initialRelevantValues, llm_id: llmId }, + [Operator.RewriteQuestion]: { + ...initialRewriteQuestionValues, + llm_id: llmId, + }, + [Operator.Message]: initialMessageValues, + [Operator.KeywordExtract]: { + ...initialKeywordExtractValues, + llm_id: llmId, + }, + [Operator.DuckDuckGo]: initialDuckValues, + [Operator.Baidu]: initialBaiduValues, + [Operator.Wikipedia]: initialWikipediaValues, + [Operator.PubMed]: initialPubMedValues, + [Operator.ArXiv]: initialArXivValues, + [Operator.Google]: initialGoogleValues, + [Operator.Bing]: initialBingValues, + [Operator.GoogleScholar]: initialGoogleScholarValues, + [Operator.DeepL]: initialDeepLValues, + [Operator.GitHub]: initialGithubValues, + [Operator.BaiduFanyi]: initialBaiduFanyiValues, + [Operator.QWeather]: initialQWeatherValues, + [Operator.ExeSQL]: { ...initialExeSqlValues, llm_id: llmId }, + [Operator.Switch]: initialSwitchValues, + [Operator.WenCai]: initialWenCaiValues, + [Operator.AkShare]: initialAkShareValues, + [Operator.YahooFinance]: initialYahooFinanceValues, + [Operator.Jin10]: initialJin10Values, + [Operator.Concentrator]: initialConcentratorValues, + [Operator.TuShare]: initialTuShareValues, + [Operator.Note]: initialNoteValues, + [Operator.Crawler]: initialCrawlerValues, + [Operator.Invoke]: initialInvokeValues, + [Operator.Template]: initialTemplateValues, + [Operator.Email]: initialEmailValues, + [Operator.Iteration]: initialIterationValues, + [Operator.IterationStart]: initialIterationValues, + [Operator.Code]: initialCodeValues, + [Operator.WaitingDialogue]: initialWaitingDialogueValues, + [Operator.Agent]: { ...initialAgentValues, llm_id: llmId }, + }; + }, [llmId]); + + const initializeOperatorParams = useCallback( + (operatorName: Operator) => { + return initialFormValuesMap[operatorName]; + }, + [initialFormValuesMap], + ); + + return initializeOperatorParams; +}; + +export const useGetNodeName = () => { + const { t } = useTranslation(); + + return (type: string) => { + const name = t(`flow.${lowerFirst(type)}`); + return name; + }; +}; + +export function useAddNode(reactFlowInstance?: ReactFlowInstance) { + const addNode = useGraphStore((state) => state.addNode); + const getNode = useGraphStore((state) => state.getNode); + const addEdge = useGraphStore((state) => state.addEdge); + const getNodeName = useGetNodeName(); + const initializeOperatorParams = useInitializeOperatorParams(); + const nodes = useGraphStore((state) => state.nodes); + // const [reactFlowInstance, setReactFlowInstance] = + // useState>(); + + const addCanvasNode = useCallback( + (type: string, id?: string) => (event: React.MouseEvent) => { + // reactFlowInstance.project was renamed to reactFlowInstance.screenToFlowPosition + // and you don't need to subtract the reactFlowBounds.left/top anymore + // details: https://@xyflow/react.dev/whats-new/2023-11-10 + const position = reactFlowInstance?.screenToFlowPosition({ + x: event.clientX, + y: event.clientY, + }); + + const newNode: Node = { + id: `${type}:${humanId()}`, + type: NodeMap[type as Operator] || 'ragNode', + position: position || { + x: 0, + y: 0, + }, + data: { + label: `${type}`, + name: generateNodeNamesWithIncreasingIndex(getNodeName(type), nodes), + form: initializeOperatorParams(type as Operator), + }, + sourcePosition: Position.Right, + targetPosition: Position.Left, + dragHandle: getNodeDragHandle(type), + }; + + if (type === Operator.Iteration) { + newNode.width = 500; + newNode.height = 250; + const iterationStartNode: Node = { + id: `${Operator.IterationStart}:${humanId()}`, + type: 'iterationStartNode', + position: { x: 50, y: 100 }, + // draggable: false, + data: { + label: Operator.IterationStart, + name: Operator.IterationStart, + form: {}, + }, + parentId: newNode.id, + extent: 'parent', + }; + addNode(newNode); + addNode(iterationStartNode); + } else if (type === Operator.Agent) { + const agentNode = getNode(id); + if (agentNode) { + newNode.position = { + x: agentNode.position.x + 82, + y: agentNode.position.y + 140, + }; + } + addNode(newNode); + if (id) { + addEdge({ + source: id, + target: newNode.id, + sourceHandle: 'e', + targetHandle: 'f', + }); + } + } else { + const subNodeOfIteration = getRelativePositionToIterationNode( + nodes, + position, + ); + if (subNodeOfIteration) { + newNode.parentId = subNodeOfIteration.parentId; + newNode.position = subNodeOfIteration.position; + newNode.extent = 'parent'; + } + addNode(newNode); + } + }, + [ + addEdge, + addNode, + getNode, + getNodeName, + initializeOperatorParams, + nodes, + reactFlowInstance, + ], + ); + + return { addCanvasNode }; +}