diff --git a/web/src/pages/agent/canvas/index.tsx b/web/src/pages/agent/canvas/index.tsx
index fbc27b097..9fbd7177b 100644
--- a/web/src/pages/agent/canvas/index.tsx
+++ b/web/src/pages/agent/canvas/index.tsx
@@ -6,6 +6,7 @@ import {
} from '@xyflow/react';
import '@xyflow/react/dist/style.css';
import { ChatSheet } from '../chat/chat-sheet';
+import { AgentInstanceContext } from '../context';
import FormSheet from '../form-sheet/next';
import {
useHandleDrop,
@@ -13,6 +14,7 @@ import {
useValidateConnection,
useWatchNodeFormDataChange,
} from '../hooks';
+import { useAddNode } from '../hooks/use-add-node';
import { useBeforeDelete } from '../hooks/use-before-delete';
import { useShowDrawer } from '../hooks/use-show-drawer';
import RunSheet from '../run-sheet';
@@ -77,7 +79,8 @@ function AgentCanvas({ drawerVisible, hideDrawer }: IProps) {
} = useSelectCanvasData();
const isValidConnection = useValidateConnection();
- const { onDrop, onDragOver, setReactFlowInstance } = useHandleDrop();
+ const { onDrop, onDragOver, setReactFlowInstance, reactFlowInstance } =
+ useHandleDrop();
const {
onNodeClick,
@@ -101,6 +104,8 @@ function AgentCanvas({ drawerVisible, hideDrawer }: IProps) {
useWatchNodeFormDataChange();
+ const { addCanvasNode } = useAddNode(reactFlowInstance);
+
return (
{formDrawerVisible && (
-
+
+
+
)}
{chatVisible && (
+
+
);
diff --git a/web/src/pages/agent/context.ts b/web/src/pages/agent/context.ts
index 6dc1cd476..ddb2d6ddb 100644
--- a/web/src/pages/agent/context.ts
+++ b/web/src/pages/agent/context.ts
@@ -1,6 +1,16 @@
import { RAGFlowNodeType } from '@/interfaces/database/flow';
import { createContext } from 'react';
+import { useAddNode } from './hooks/use-add-node';
export const AgentFormContext = createContext(
undefined,
);
+
+type AgentInstanceContextType = Pick<
+ ReturnType,
+ 'addCanvasNode'
+>;
+
+export const AgentInstanceContext = createContext(
+ {} as AgentInstanceContextType,
+);
diff --git a/web/src/pages/agent/form/agent-form/index.tsx b/web/src/pages/agent/form/agent-form/index.tsx
index edffbd2a7..82eb2cd31 100644
--- a/web/src/pages/agent/form/agent-form/index.tsx
+++ b/web/src/pages/agent/form/agent-form/index.tsx
@@ -11,11 +11,12 @@ import {
FormLabel,
} from '@/components/ui/form';
import { zodResolver } from '@hookform/resolvers/zod';
-import { useMemo } from 'react';
+import { useContext, useMemo } from 'react';
import { useForm } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import { z } from 'zod';
-import { initialAgentValues } from '../../constant';
+import { Operator, initialAgentValues } from '../../constant';
+import { AgentInstanceContext } from '../../context';
import { INextOperatorForm } from '../../interface';
import { Output } from '../components/output';
import { PromptEditor } from '../components/prompt-editor';
@@ -62,6 +63,8 @@ const AgentForm = ({ node }: INextOperatorForm) => {
useWatchFormChange(node?.id, form);
+ const { addCanvasNode } = useContext(AgentInstanceContext);
+
return (
diff --git a/web/src/pages/agent/hooks.tsx b/web/src/pages/agent/hooks.tsx
index 0257e19dd..a868cfa14 100644
--- a/web/src/pages/agent/hooks.tsx
+++ b/web/src/pages/agent/hooks.tsx
@@ -263,7 +263,7 @@ export const useHandleDrop = () => {
[reactFlowInstance, getNodeName, nodes, initializeOperatorParams, addNode],
);
- return { onDrop, onDragOver, setReactFlowInstance };
+ return { onDrop, onDragOver, setReactFlowInstance, reactFlowInstance };
};
export const useHandleFormValuesChange = (
diff --git a/web/src/pages/agent/hooks/use-add-node.ts b/web/src/pages/agent/hooks/use-add-node.ts
new file mode 100644
index 000000000..b8166c92b
--- /dev/null
+++ b/web/src/pages/agent/hooks/use-add-node.ts
@@ -0,0 +1,224 @@
+import { useFetchModelId } from '@/hooks/logic-hooks';
+import { Node, Position, ReactFlowInstance } from '@xyflow/react';
+import humanId from 'human-id';
+import { lowerFirst } from 'lodash';
+import { useCallback, useMemo } from 'react';
+import { useTranslation } from 'react-i18next';
+import {
+ NodeMap,
+ Operator,
+ initialAgentValues,
+ initialAkShareValues,
+ initialArXivValues,
+ initialBaiduFanyiValues,
+ initialBaiduValues,
+ initialBeginValues,
+ initialBingValues,
+ initialCategorizeValues,
+ initialCodeValues,
+ initialConcentratorValues,
+ initialCrawlerValues,
+ initialDeepLValues,
+ initialDuckValues,
+ initialEmailValues,
+ initialExeSqlValues,
+ initialGenerateValues,
+ initialGithubValues,
+ initialGoogleScholarValues,
+ initialGoogleValues,
+ initialInvokeValues,
+ initialIterationValues,
+ initialJin10Values,
+ initialKeywordExtractValues,
+ initialMessageValues,
+ initialNoteValues,
+ initialPubMedValues,
+ initialQWeatherValues,
+ initialRelevantValues,
+ initialRetrievalValues,
+ initialRewriteQuestionValues,
+ initialSwitchValues,
+ initialTemplateValues,
+ initialTuShareValues,
+ initialWaitingDialogueValues,
+ initialWenCaiValues,
+ initialWikipediaValues,
+ initialYahooFinanceValues,
+} from '../constant';
+import useGraphStore from '../store';
+import {
+ generateNodeNamesWithIncreasingIndex,
+ getNodeDragHandle,
+ getRelativePositionToIterationNode,
+} from '../utils';
+
+export const useInitializeOperatorParams = () => {
+ const llmId = useFetchModelId();
+
+ const initialFormValuesMap = useMemo(() => {
+ return {
+ [Operator.Begin]: initialBeginValues,
+ [Operator.Retrieval]: initialRetrievalValues,
+ [Operator.Generate]: { ...initialGenerateValues, llm_id: llmId },
+ [Operator.Answer]: {},
+ [Operator.Categorize]: { ...initialCategorizeValues, llm_id: llmId },
+ [Operator.Relevant]: { ...initialRelevantValues, llm_id: llmId },
+ [Operator.RewriteQuestion]: {
+ ...initialRewriteQuestionValues,
+ llm_id: llmId,
+ },
+ [Operator.Message]: initialMessageValues,
+ [Operator.KeywordExtract]: {
+ ...initialKeywordExtractValues,
+ llm_id: llmId,
+ },
+ [Operator.DuckDuckGo]: initialDuckValues,
+ [Operator.Baidu]: initialBaiduValues,
+ [Operator.Wikipedia]: initialWikipediaValues,
+ [Operator.PubMed]: initialPubMedValues,
+ [Operator.ArXiv]: initialArXivValues,
+ [Operator.Google]: initialGoogleValues,
+ [Operator.Bing]: initialBingValues,
+ [Operator.GoogleScholar]: initialGoogleScholarValues,
+ [Operator.DeepL]: initialDeepLValues,
+ [Operator.GitHub]: initialGithubValues,
+ [Operator.BaiduFanyi]: initialBaiduFanyiValues,
+ [Operator.QWeather]: initialQWeatherValues,
+ [Operator.ExeSQL]: { ...initialExeSqlValues, llm_id: llmId },
+ [Operator.Switch]: initialSwitchValues,
+ [Operator.WenCai]: initialWenCaiValues,
+ [Operator.AkShare]: initialAkShareValues,
+ [Operator.YahooFinance]: initialYahooFinanceValues,
+ [Operator.Jin10]: initialJin10Values,
+ [Operator.Concentrator]: initialConcentratorValues,
+ [Operator.TuShare]: initialTuShareValues,
+ [Operator.Note]: initialNoteValues,
+ [Operator.Crawler]: initialCrawlerValues,
+ [Operator.Invoke]: initialInvokeValues,
+ [Operator.Template]: initialTemplateValues,
+ [Operator.Email]: initialEmailValues,
+ [Operator.Iteration]: initialIterationValues,
+ [Operator.IterationStart]: initialIterationValues,
+ [Operator.Code]: initialCodeValues,
+ [Operator.WaitingDialogue]: initialWaitingDialogueValues,
+ [Operator.Agent]: { ...initialAgentValues, llm_id: llmId },
+ };
+ }, [llmId]);
+
+ const initializeOperatorParams = useCallback(
+ (operatorName: Operator) => {
+ return initialFormValuesMap[operatorName];
+ },
+ [initialFormValuesMap],
+ );
+
+ return initializeOperatorParams;
+};
+
+export const useGetNodeName = () => {
+ const { t } = useTranslation();
+
+ return (type: string) => {
+ const name = t(`flow.${lowerFirst(type)}`);
+ return name;
+ };
+};
+
+export function useAddNode(reactFlowInstance?: ReactFlowInstance) {
+ const addNode = useGraphStore((state) => state.addNode);
+ const getNode = useGraphStore((state) => state.getNode);
+ const addEdge = useGraphStore((state) => state.addEdge);
+ const getNodeName = useGetNodeName();
+ const initializeOperatorParams = useInitializeOperatorParams();
+ const nodes = useGraphStore((state) => state.nodes);
+ // const [reactFlowInstance, setReactFlowInstance] =
+ // useState>();
+
+ const addCanvasNode = useCallback(
+ (type: string, id?: string) => (event: React.MouseEvent) => {
+ // reactFlowInstance.project was renamed to reactFlowInstance.screenToFlowPosition
+ // and you don't need to subtract the reactFlowBounds.left/top anymore
+ // details: https://@xyflow/react.dev/whats-new/2023-11-10
+ const position = reactFlowInstance?.screenToFlowPosition({
+ x: event.clientX,
+ y: event.clientY,
+ });
+
+ const newNode: Node = {
+ id: `${type}:${humanId()}`,
+ type: NodeMap[type as Operator] || 'ragNode',
+ position: position || {
+ x: 0,
+ y: 0,
+ },
+ data: {
+ label: `${type}`,
+ name: generateNodeNamesWithIncreasingIndex(getNodeName(type), nodes),
+ form: initializeOperatorParams(type as Operator),
+ },
+ sourcePosition: Position.Right,
+ targetPosition: Position.Left,
+ dragHandle: getNodeDragHandle(type),
+ };
+
+ if (type === Operator.Iteration) {
+ newNode.width = 500;
+ newNode.height = 250;
+ const iterationStartNode: Node = {
+ id: `${Operator.IterationStart}:${humanId()}`,
+ type: 'iterationStartNode',
+ position: { x: 50, y: 100 },
+ // draggable: false,
+ data: {
+ label: Operator.IterationStart,
+ name: Operator.IterationStart,
+ form: {},
+ },
+ parentId: newNode.id,
+ extent: 'parent',
+ };
+ addNode(newNode);
+ addNode(iterationStartNode);
+ } else if (type === Operator.Agent) {
+ const agentNode = getNode(id);
+ if (agentNode) {
+ newNode.position = {
+ x: agentNode.position.x + 82,
+ y: agentNode.position.y + 140,
+ };
+ }
+ addNode(newNode);
+ if (id) {
+ addEdge({
+ source: id,
+ target: newNode.id,
+ sourceHandle: 'e',
+ targetHandle: 'f',
+ });
+ }
+ } else {
+ const subNodeOfIteration = getRelativePositionToIterationNode(
+ nodes,
+ position,
+ );
+ if (subNodeOfIteration) {
+ newNode.parentId = subNodeOfIteration.parentId;
+ newNode.position = subNodeOfIteration.position;
+ newNode.extent = 'parent';
+ }
+ addNode(newNode);
+ }
+ },
+ [
+ addEdge,
+ addNode,
+ getNode,
+ getNodeName,
+ initializeOperatorParams,
+ nodes,
+ reactFlowInstance,
+ ],
+ );
+
+ return { addCanvasNode };
+}