diff --git a/web/src/interfaces/database/agent.ts b/web/src/interfaces/database/agent.ts index b10df254d..2b6f23926 100644 --- a/web/src/interfaces/database/agent.ts +++ b/web/src/interfaces/database/agent.ts @@ -9,3 +9,21 @@ export type ICategorizeItemResult = Record< string, Omit & { examples: string[] } >; + +export interface ISwitchCondition { + items: ISwitchItem[]; + logical_operator: string; + to: string[]; +} + +export interface ISwitchItem { + cpn_id: string; + operator: string; + value: string; +} + +export interface ISwitchForm { + conditions: ISwitchCondition[]; + end_cpn_ids: string[]; + no: string; +} diff --git a/web/src/interfaces/database/flow.ts b/web/src/interfaces/database/flow.ts index 4ad847d16..2d4aa4cbd 100644 --- a/web/src/interfaces/database/flow.ts +++ b/web/src/interfaces/database/flow.ts @@ -92,7 +92,7 @@ export interface IRelevantForm extends IGenerateForm { export interface ISwitchCondition { items: ISwitchItem[]; logical_operator: string; - to: string; + to: string[] | string; } export interface ISwitchItem { diff --git a/web/src/pages/agent/canvas/index.tsx b/web/src/pages/agent/canvas/index.tsx index 27fb0eb6d..21279da0b 100644 --- a/web/src/pages/agent/canvas/index.tsx +++ b/web/src/pages/agent/canvas/index.tsx @@ -17,7 +17,6 @@ import { useHandleDrop, useSelectCanvasData, useValidateConnection, - useWatchNodeFormDataChange, } from '../hooks'; import { useAddNode } from '../hooks/use-add-node'; import { useBeforeDelete } from '../hooks/use-before-delete'; @@ -120,8 +119,6 @@ function AgentCanvas({ drawerVisible, hideDrawer }: IProps) { const { handleBeforeDelete } = useBeforeDelete(); - useWatchNodeFormDataChange(); - const { addCanvasNode } = useAddNode(reactFlowInstance); useEffect(() => { diff --git a/web/src/pages/agent/canvas/node/switch-node.tsx b/web/src/pages/agent/canvas/node/switch-node.tsx index 7b82574f6..97b4ce79f 100644 --- a/web/src/pages/agent/canvas/node/switch-node.tsx +++ b/web/src/pages/agent/canvas/node/switch-node.tsx @@ -1,9 +1,12 @@ +import { IconFont } from '@/components/icon-font'; import { useTheme } from '@/components/theme-provider'; +import { Card, CardContent } from '@/components/ui/card'; import { ISwitchCondition, ISwitchNode } from '@/interfaces/database/flow'; import { Handle, NodeProps, Position } from '@xyflow/react'; -import { Divider, Flex } from 'antd'; +import { Flex } from 'antd'; import classNames from 'classnames'; -import { memo } from 'react'; +import { memo, useCallback } from 'react'; +import { SwitchOperatorOptions } from '../../constant'; import { useGetComponentLabelByValue } from '../../hooks/use-get-begin-query'; import { RightHandleStyle } from './handle-icon'; import { useBuildSwitchHandlePositions } from './hooks'; @@ -29,29 +32,28 @@ const ConditionBlock = ({ }) => { const items = condition?.items ?? []; const getLabel = useGetComponentLabelByValue(nodeId); + + const renderOperatorIcon = useCallback((operator?: string) => { + const name = SwitchOperatorOptions.find((x) => x.value === operator)?.icon; + return ; + }, []); + return ( - - {items.map((x, idx) => ( -
- -
- {getLabel(x?.cpn_id)} -
- {x?.operator} - - {x?.value} - -
- {idx + 1 < items.length && ( - - {condition?.logical_operator} - - )} -
- ))} -
+ + + {items.map((x, idx) => ( +
+
+
+ {getLabel(x?.cpn_id)} +
+ {renderOperatorIcon(x?.operator)} +
{x?.value}
+
+
+ ))} +
+
); }; @@ -87,7 +89,10 @@ function InnerSwitchNode({ id, data, selected }: NodeProps) {
- {idx < positions.length - 1 && position.text} + + {idx < positions.length - 1 && + position.condition?.logical_operator?.toUpperCase()} + {getConditionKey(idx, positions.length)} {position.condition && ( diff --git a/web/src/pages/agent/constant.tsx b/web/src/pages/agent/constant.tsx index 4f02a8b36..e0709bf73 100644 --- a/web/src/pages/agent/constant.tsx +++ b/web/src/pages/agent/constant.tsx @@ -129,6 +129,8 @@ export enum Operator { Agent = 'Agent', } +export const SwitchLogicOperatorOptions = ['and', 'or']; + export const CommonOperatorList = Object.values(Operator).filter( (x) => x !== Operator.Note, ); @@ -445,6 +447,23 @@ export const componentMenuList = [ }, ]; +export const SwitchOperatorOptions = [ + { value: '=', label: 'equal', icon: 'equal' }, + { value: '≠', label: 'notEqual', icon: 'not-equals' }, + { value: '>', label: 'gt', icon: 'Less' }, + { value: '≥', label: 'ge', icon: 'Greater-or-equal' }, + { value: '<', label: 'lt', icon: 'Less' }, + { value: '≤', label: 'le', icon: 'less-or-equal' }, + { value: 'contains', label: 'contains', icon: 'Contains' }, + { value: 'not contains', label: 'notContains', icon: 'not-contains' }, + { value: 'start with', label: 'startWith', icon: 'list-start' }, + { value: 'end with', label: 'endWith', icon: 'list-end' }, + { value: 'empty', label: 'empty', icon: 'circle' }, + { value: 'not empty', label: 'notEmpty', icon: 'circle-slash-2' }, +]; + +export const SwitchElseTo = 'end_cpn_ids'; + const initialQueryBaseValues = { query: [], }; @@ -616,7 +635,20 @@ export const initialExeSqlValues = { ...initialQueryBaseValues, }; -export const initialSwitchValues = { conditions: [] }; +export const initialSwitchValues = { + conditions: [ + { + logical_operator: SwitchLogicOperatorOptions[0], + items: [ + { + operator: SwitchOperatorOptions[0].value, + }, + ], + to: [], + }, + ], + [SwitchElseTo]: [], +}; export const initialWenCaiValues = { top_n: 20, @@ -3000,25 +3032,6 @@ export const ExeSQLOptions = ['mysql', 'postgresql', 'mariadb', 'mssql'].map( }), ); -export const SwitchElseTo = 'end_cpn_id'; - -export const SwitchOperatorOptions = [ - { value: '=', label: 'equal', icon: 'equal' }, - { value: '≠', label: 'notEqual', icon: 'not-equals' }, - { value: '>', label: 'gt', icon: 'Less' }, - { value: '≥', label: 'ge', icon: 'Greater-or-equal' }, - { value: '<', label: 'lt', icon: 'Less' }, - { value: '≤', label: 'le', icon: 'less-or-equal' }, - { value: 'contains', label: 'contains', icon: 'Contains' }, - { value: 'not contains', label: 'notContains', icon: 'not-contains' }, - { value: 'start with', label: 'startWith', icon: 'list-start' }, - { value: 'end with', label: 'endWith', icon: 'list-end' }, - // { value: 'empty', label: 'empty', icon: '' }, - // { value: 'not empty', label: 'notEmpty', icon: '' }, -]; - -export const SwitchLogicOperatorOptions = ['and', 'or']; - export const WenCaiQueryTypeOptions = [ 'stock', 'zhishu', diff --git a/web/src/pages/agent/form/switch-form/index.tsx b/web/src/pages/agent/form/switch-form/index.tsx index 6d4394c25..ef300b877 100644 --- a/web/src/pages/agent/form/switch-form/index.tsx +++ b/web/src/pages/agent/form/switch-form/index.tsx @@ -12,7 +12,7 @@ import { import { RAGFlowSelect } from '@/components/ui/select'; import { Separator } from '@/components/ui/separator'; import { Textarea } from '@/components/ui/textarea'; -import { ISwitchForm } from '@/interfaces/database/flow'; +import { ISwitchForm } from '@/interfaces/database/agent'; import { cn } from '@/lib/utils'; import { zodResolver } from '@hookform/resolvers/zod'; import { X } from 'lucide-react'; @@ -27,6 +27,7 @@ import { } from '../../constant'; import { useBuildFormSelectOptions } from '../../form-hooks'; import { useBuildComponentIdAndBeginOptions } from '../../hooks/use-get-begin-query'; +import { useWatchFormChange } from '../../hooks/use-watch-form-change'; import { IOperatorForm } from '../../interface'; import { useValues } from './use-values'; @@ -40,20 +41,27 @@ type ConditionCardsProps = { parentLength: number; } & IOperatorForm; +const OperatorIcon = function OperatorIcon({ + icon, + value, +}: Omit<(typeof SwitchOperatorOptions)[0], 'label'>) { + return ( + ', + })} + > + ); +}; + function useBuildSwitchOperatorOptions() { const { t } = useTranslation(); const switchOperatorOptions = useMemo(() => { return SwitchOperatorOptions.map((x) => ({ value: x.value, - icon: ( - ', - })} - > - ), + icon: , label: t(`flow.switchOperatorOptions.${x.label}`), })); }, [t]); @@ -174,7 +182,7 @@ function ConditionCards({ className="mt-6" onClick={() => append({ operator: switchOperatorOptions[0].value })} > - add + Add
@@ -183,7 +191,7 @@ function ConditionCards({ const SwitchForm = ({ node }: IOperatorForm) => { const { t } = useTranslation(); - const values = useValues(); + const values = useValues(node); const switchOperatorOptions = useBuildSwitchOperatorOptions(); const FormSchema = z.object({ @@ -234,6 +242,8 @@ const SwitchForm = ({ node }: IOperatorForm) => { })); }, [t]); + useWatchFormChange(node?.id, form); + return (
{ }) } > - add + Add
diff --git a/web/src/pages/agent/form/switch-form/use-values.ts b/web/src/pages/agent/form/switch-form/use-values.ts index 3323c005a..990bd0124 100644 --- a/web/src/pages/agent/form/switch-form/use-values.ts +++ b/web/src/pages/agent/form/switch-form/use-values.ts @@ -1,16 +1,13 @@ import { RAGFlowNodeType } from '@/interfaces/database/flow'; import { isEmpty } from 'lodash'; import { useMemo } from 'react'; - -const defaultValues = { - conditions: [], -}; +import { initialSwitchValues } from '../../constant'; export function useValues(node?: RAGFlowNodeType) { const values = useMemo(() => { const formData = node?.data?.form; if (isEmpty(formData)) { - return defaultValues; + return initialSwitchValues; } return formData; diff --git a/web/src/pages/agent/hooks.tsx b/web/src/pages/agent/hooks.tsx index a868cfa14..7d194fef0 100644 --- a/web/src/pages/agent/hooks.tsx +++ b/web/src/pages/agent/hooks.tsx @@ -15,10 +15,10 @@ import React, { // import { shallow } from 'zustand/shallow'; import { settledModelVariableMap } from '@/constants/knowledge'; import { useFetchModelId } from '@/hooks/logic-hooks'; +import { ISwitchForm } from '@/interfaces/database/agent'; import { ICategorizeForm, IRelevantForm, - ISwitchForm, RAGFlowNodeType, } from '@/interfaces/database/flow'; import { message } from 'antd'; @@ -543,9 +543,9 @@ export const useWatchNodeFormDataChange = () => { case Operator.Categorize: buildCategorizeEdgesByFormData(node.id, form as ICategorizeForm); break; - case Operator.Switch: - buildSwitchEdgesByFormData(node.id, form as ISwitchForm); - break; + // case Operator.Switch: + // buildSwitchEdgesByFormData(node.id, form as ISwitchForm); + // break; default: break; } @@ -555,7 +555,6 @@ export const useWatchNodeFormDataChange = () => { buildCategorizeEdgesByFormData, getNode, buildRelevantEdgesByFormData, - buildSwitchEdgesByFormData, ]); }; diff --git a/web/src/pages/agent/hooks/use-add-node.ts b/web/src/pages/agent/hooks/use-add-node.ts index d2b8b7fba..b234805d8 100644 --- a/web/src/pages/agent/hooks/use-add-node.ts +++ b/web/src/pages/agent/hooks/use-add-node.ts @@ -224,6 +224,7 @@ export function useAddNode(reactFlowInstance?: ReactFlowInstance) { [ addEdge, addNode, + edges, getNode, getNodeName, initializeOperatorParams, diff --git a/web/src/pages/agent/hooks/use-get-begin-query.tsx b/web/src/pages/agent/hooks/use-get-begin-query.tsx index 2e105b092..c53e23f48 100644 --- a/web/src/pages/agent/hooks/use-get-begin-query.tsx +++ b/web/src/pages/agent/hooks/use-get-begin-query.tsx @@ -135,24 +135,6 @@ export const useBuildVariableOptions = (nodeId?: string) => { return options; }; -export const useGetComponentLabelByValue = (nodeId: string) => { - const options = useBuildVariableOptions(nodeId); - - const flattenOptions = useMemo(() => { - return options.reduce((pre, cur) => { - return [...pre, ...cur.options]; - }, []); - }, [options]); - - const getLabel = useCallback( - (val?: string) => { - return flattenOptions.find((x) => x.value === val)?.label; - }, - [flattenOptions], - ); - return getLabel; -}; - export function useBuildQueryVariableOptions() { const { data } = useFetchAgent(); const node = useContext(AgentFormContext); @@ -220,3 +202,21 @@ export function useBuildComponentIdAndBeginOptions( return [...beginOptions, ...componentIdOptions]; } + +export const useGetComponentLabelByValue = (nodeId: string) => { + const options = useBuildComponentIdAndBeginOptions(nodeId); + + const flattenOptions = useMemo(() => { + return options.reduce((pre, cur) => { + return [...pre, ...cur.options]; + }, []); + }, [options]); + + const getLabel = useCallback( + (val?: string) => { + return flattenOptions.find((x) => x.value === val)?.label; + }, + [flattenOptions], + ); + return getLabel; +}; diff --git a/web/src/pages/agent/store.ts b/web/src/pages/agent/store.ts index c6e8c69b6..8b6d3842e 100644 --- a/web/src/pages/agent/store.ts +++ b/web/src/pages/agent/store.ts @@ -56,6 +56,7 @@ export type RFState = { source: string, sourceHandle?: string | null, target?: string | null, + isConnecting?: boolean, ) => void; deletePreviousEdgeOfClassificationNode: (connection: Connection) => void; duplicateNode: (id: string, name: string) => void; @@ -204,7 +205,7 @@ const useGraphStore = create()( ]); break; case Operator.Switch: { - updateSwitchFormData(source, sourceHandle, target); + updateSwitchFormData(source, sourceHandle, target, true); break; } default: @@ -219,7 +220,7 @@ const useGraphStore = create()( const anchoredNodes = [ Operator.Categorize, Operator.Relevant, - Operator.Switch, + // Operator.Switch, ]; if ( anchoredNodes.some( @@ -303,7 +304,7 @@ const useGraphStore = create()( const currentEdge = edges.find((x) => x.id === id); if (currentEdge) { - const { source, sourceHandle } = currentEdge; + const { source, sourceHandle, target } = currentEdge; const operatorType = getOperatorTypeFromId(source); // After deleting the edge, set the corresponding field in the node's form field to undefined switch (operatorType) { @@ -321,7 +322,7 @@ const useGraphStore = create()( ]); break; case Operator.Switch: { - updateSwitchFormData(source, sourceHandle, undefined); + updateSwitchFormData(source, sourceHandle, target, false); break; } default: @@ -402,15 +403,32 @@ const useGraphStore = create()( return nextNodes; }, - updateSwitchFormData: (source, sourceHandle, target) => { - const { updateNodeForm } = get(); + updateSwitchFormData: (source, sourceHandle, target, isConnecting) => { + const { updateNodeForm, edges } = get(); if (sourceHandle) { + // A handle will connect to multiple downstream nodes + let currentHandleTargets = edges + .filter( + (x) => + x.source === source && + x.sourceHandle === sourceHandle && + typeof x.target === 'string', + ) + .map((x) => x.target); + + let targets: string[] = currentHandleTargets; + if (target) { + if (!isConnecting) { + targets = currentHandleTargets.filter((x) => x !== target); + } + } + if (sourceHandle === SwitchElseTo) { - updateNodeForm(source, target, [SwitchElseTo]); + updateNodeForm(source, targets, [SwitchElseTo]); } else { const operatorIndex = getOperatorIndex(sourceHandle); if (operatorIndex) { - updateNodeForm(source, target, [ + updateNodeForm(source, targets, [ 'conditions', Number(operatorIndex) - 1, // The index is the conditions form index 'to', @@ -448,7 +466,7 @@ const useGraphStore = create()( return generateNodeNamesWithIncreasingIndex(name, nodes); }, })), - { name: 'graph' }, + { name: 'graph', trace: true }, ), );