Feat: Connect conditional operators to other operators #3221 (#8231)

### What problem does this PR solve?

Feat: Connect conditional operators to other operators #3221

### Type of change


- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
balibabu
2025-06-13 09:30:34 +08:00
committed by GitHub
parent 99725444f1
commit a9d9215547
11 changed files with 156 additions and 98 deletions

View File

@ -9,3 +9,21 @@ export type ICategorizeItemResult = Record<
string, string,
Omit<ICategorizeItem, 'name' | 'examples'> & { examples: string[] } Omit<ICategorizeItem, 'name' | 'examples'> & { examples: string[] }
>; >;
export interface ISwitchCondition {
items: ISwitchItem[];
logical_operator: string;
to: string[];
}
export interface ISwitchItem {
cpn_id: string;
operator: string;
value: string;
}
export interface ISwitchForm {
conditions: ISwitchCondition[];
end_cpn_ids: string[];
no: string;
}

View File

@ -92,7 +92,7 @@ export interface IRelevantForm extends IGenerateForm {
export interface ISwitchCondition { export interface ISwitchCondition {
items: ISwitchItem[]; items: ISwitchItem[];
logical_operator: string; logical_operator: string;
to: string; to: string[] | string;
} }
export interface ISwitchItem { export interface ISwitchItem {

View File

@ -17,7 +17,6 @@ import {
useHandleDrop, useHandleDrop,
useSelectCanvasData, useSelectCanvasData,
useValidateConnection, useValidateConnection,
useWatchNodeFormDataChange,
} from '../hooks'; } from '../hooks';
import { useAddNode } from '../hooks/use-add-node'; import { useAddNode } from '../hooks/use-add-node';
import { useBeforeDelete } from '../hooks/use-before-delete'; import { useBeforeDelete } from '../hooks/use-before-delete';
@ -120,8 +119,6 @@ function AgentCanvas({ drawerVisible, hideDrawer }: IProps) {
const { handleBeforeDelete } = useBeforeDelete(); const { handleBeforeDelete } = useBeforeDelete();
useWatchNodeFormDataChange();
const { addCanvasNode } = useAddNode(reactFlowInstance); const { addCanvasNode } = useAddNode(reactFlowInstance);
useEffect(() => { useEffect(() => {

View File

@ -1,9 +1,12 @@
import { IconFont } from '@/components/icon-font';
import { useTheme } from '@/components/theme-provider'; import { useTheme } from '@/components/theme-provider';
import { Card, CardContent } from '@/components/ui/card';
import { ISwitchCondition, ISwitchNode } from '@/interfaces/database/flow'; import { ISwitchCondition, ISwitchNode } from '@/interfaces/database/flow';
import { Handle, NodeProps, Position } from '@xyflow/react'; import { Handle, NodeProps, Position } from '@xyflow/react';
import { Divider, Flex } from 'antd'; import { Flex } from 'antd';
import classNames from 'classnames'; import classNames from 'classnames';
import { memo } from 'react'; import { memo, useCallback } from 'react';
import { SwitchOperatorOptions } from '../../constant';
import { useGetComponentLabelByValue } from '../../hooks/use-get-begin-query'; import { useGetComponentLabelByValue } from '../../hooks/use-get-begin-query';
import { RightHandleStyle } from './handle-icon'; import { RightHandleStyle } from './handle-icon';
import { useBuildSwitchHandlePositions } from './hooks'; import { useBuildSwitchHandlePositions } from './hooks';
@ -29,29 +32,28 @@ const ConditionBlock = ({
}) => { }) => {
const items = condition?.items ?? []; const items = condition?.items ?? [];
const getLabel = useGetComponentLabelByValue(nodeId); const getLabel = useGetComponentLabelByValue(nodeId);
const renderOperatorIcon = useCallback((operator?: string) => {
const name = SwitchOperatorOptions.find((x) => x.value === operator)?.icon;
return <IconFont name={name!}></IconFont>;
}, []);
return ( return (
<Flex vertical className={styles.conditionBlock}> <Card>
<CardContent className="space-y-1 p-1">
{items.map((x, idx) => ( {items.map((x, idx) => (
<div key={idx}> <div key={idx}>
<Flex> <section className="flex justify-between gap-2 items-center text-xs">
<div <div className="flex-1 truncate text-background-checked">
className={classNames(styles.conditionLine, styles.conditionKey)}
>
{getLabel(x?.cpn_id)} {getLabel(x?.cpn_id)}
</div> </div>
<span className={styles.conditionOperator}>{x?.operator}</span> <span>{renderOperatorIcon(x?.operator)}</span>
<Flex flex={1} className={styles.conditionLine}> <div className="flex-1 truncate">{x?.value}</div>
{x?.value} </section>
</Flex>
</Flex>
{idx + 1 < items.length && (
<Divider orientationMargin="0" className={styles.zeroDivider}>
{condition?.logical_operator}
</Divider>
)}
</div> </div>
))} ))}
</Flex> </CardContent>
</Card>
); );
}; };
@ -87,7 +89,10 @@ function InnerSwitchNode({ id, data, selected }: NodeProps<ISwitchNode>) {
<div key={idx}> <div key={idx}>
<Flex vertical> <Flex vertical>
<Flex justify={'space-between'}> <Flex justify={'space-between'}>
<span>{idx < positions.length - 1 && position.text}</span> <span className="text-text-sub-title text-xs translate-y-2">
{idx < positions.length - 1 &&
position.condition?.logical_operator?.toUpperCase()}
</span>
<span>{getConditionKey(idx, positions.length)}</span> <span>{getConditionKey(idx, positions.length)}</span>
</Flex> </Flex>
{position.condition && ( {position.condition && (

View File

@ -129,6 +129,8 @@ export enum Operator {
Agent = 'Agent', Agent = 'Agent',
} }
export const SwitchLogicOperatorOptions = ['and', 'or'];
export const CommonOperatorList = Object.values(Operator).filter( export const CommonOperatorList = Object.values(Operator).filter(
(x) => x !== Operator.Note, (x) => x !== Operator.Note,
); );
@ -445,6 +447,23 @@ export const componentMenuList = [
}, },
]; ];
export const SwitchOperatorOptions = [
{ value: '=', label: 'equal', icon: 'equal' },
{ value: '≠', label: 'notEqual', icon: 'not-equals' },
{ value: '>', label: 'gt', icon: 'Less' },
{ value: '≥', label: 'ge', icon: 'Greater-or-equal' },
{ value: '<', label: 'lt', icon: 'Less' },
{ value: '≤', label: 'le', icon: 'less-or-equal' },
{ value: 'contains', label: 'contains', icon: 'Contains' },
{ value: 'not contains', label: 'notContains', icon: 'not-contains' },
{ value: 'start with', label: 'startWith', icon: 'list-start' },
{ value: 'end with', label: 'endWith', icon: 'list-end' },
{ value: 'empty', label: 'empty', icon: 'circle' },
{ value: 'not empty', label: 'notEmpty', icon: 'circle-slash-2' },
];
export const SwitchElseTo = 'end_cpn_ids';
const initialQueryBaseValues = { const initialQueryBaseValues = {
query: [], query: [],
}; };
@ -616,7 +635,20 @@ export const initialExeSqlValues = {
...initialQueryBaseValues, ...initialQueryBaseValues,
}; };
export const initialSwitchValues = { conditions: [] }; export const initialSwitchValues = {
conditions: [
{
logical_operator: SwitchLogicOperatorOptions[0],
items: [
{
operator: SwitchOperatorOptions[0].value,
},
],
to: [],
},
],
[SwitchElseTo]: [],
};
export const initialWenCaiValues = { export const initialWenCaiValues = {
top_n: 20, top_n: 20,
@ -3000,25 +3032,6 @@ export const ExeSQLOptions = ['mysql', 'postgresql', 'mariadb', 'mssql'].map(
}), }),
); );
export const SwitchElseTo = 'end_cpn_id';
export const SwitchOperatorOptions = [
{ value: '=', label: 'equal', icon: 'equal' },
{ value: '≠', label: 'notEqual', icon: 'not-equals' },
{ value: '>', label: 'gt', icon: 'Less' },
{ value: '≥', label: 'ge', icon: 'Greater-or-equal' },
{ value: '<', label: 'lt', icon: 'Less' },
{ value: '≤', label: 'le', icon: 'less-or-equal' },
{ value: 'contains', label: 'contains', icon: 'Contains' },
{ value: 'not contains', label: 'notContains', icon: 'not-contains' },
{ value: 'start with', label: 'startWith', icon: 'list-start' },
{ value: 'end with', label: 'endWith', icon: 'list-end' },
// { value: 'empty', label: 'empty', icon: '' },
// { value: 'not empty', label: 'notEmpty', icon: '' },
];
export const SwitchLogicOperatorOptions = ['and', 'or'];
export const WenCaiQueryTypeOptions = [ export const WenCaiQueryTypeOptions = [
'stock', 'stock',
'zhishu', 'zhishu',

View File

@ -12,7 +12,7 @@ import {
import { RAGFlowSelect } from '@/components/ui/select'; import { RAGFlowSelect } from '@/components/ui/select';
import { Separator } from '@/components/ui/separator'; import { Separator } from '@/components/ui/separator';
import { Textarea } from '@/components/ui/textarea'; import { Textarea } from '@/components/ui/textarea';
import { ISwitchForm } from '@/interfaces/database/flow'; import { ISwitchForm } from '@/interfaces/database/agent';
import { cn } from '@/lib/utils'; import { cn } from '@/lib/utils';
import { zodResolver } from '@hookform/resolvers/zod'; import { zodResolver } from '@hookform/resolvers/zod';
import { X } from 'lucide-react'; import { X } from 'lucide-react';
@ -27,6 +27,7 @@ import {
} from '../../constant'; } from '../../constant';
import { useBuildFormSelectOptions } from '../../form-hooks'; import { useBuildFormSelectOptions } from '../../form-hooks';
import { useBuildComponentIdAndBeginOptions } from '../../hooks/use-get-begin-query'; import { useBuildComponentIdAndBeginOptions } from '../../hooks/use-get-begin-query';
import { useWatchFormChange } from '../../hooks/use-watch-form-change';
import { IOperatorForm } from '../../interface'; import { IOperatorForm } from '../../interface';
import { useValues } from './use-values'; import { useValues } from './use-values';
@ -40,20 +41,27 @@ type ConditionCardsProps = {
parentLength: number; parentLength: number;
} & IOperatorForm; } & IOperatorForm;
const OperatorIcon = function OperatorIcon({
icon,
value,
}: Omit<(typeof SwitchOperatorOptions)[0], 'label'>) {
return (
<IconFont
name={icon}
className={cn('size-4', {
'rotate-180': value === '>',
})}
></IconFont>
);
};
function useBuildSwitchOperatorOptions() { function useBuildSwitchOperatorOptions() {
const { t } = useTranslation(); const { t } = useTranslation();
const switchOperatorOptions = useMemo(() => { const switchOperatorOptions = useMemo(() => {
return SwitchOperatorOptions.map((x) => ({ return SwitchOperatorOptions.map((x) => ({
value: x.value, value: x.value,
icon: ( icon: <OperatorIcon icon={x.icon} value={x.value}></OperatorIcon>,
<IconFont
name={x.icon}
className={cn('size-4', {
'rotate-180': x.value === '>',
})}
></IconFont>
),
label: t(`flow.switchOperatorOptions.${x.label}`), label: t(`flow.switchOperatorOptions.${x.label}`),
})); }));
}, [t]); }, [t]);
@ -174,7 +182,7 @@ function ConditionCards({
className="mt-6" className="mt-6"
onClick={() => append({ operator: switchOperatorOptions[0].value })} onClick={() => append({ operator: switchOperatorOptions[0].value })}
> >
add Add
</BlockButton> </BlockButton>
</div> </div>
</section> </section>
@ -183,7 +191,7 @@ function ConditionCards({
const SwitchForm = ({ node }: IOperatorForm) => { const SwitchForm = ({ node }: IOperatorForm) => {
const { t } = useTranslation(); const { t } = useTranslation();
const values = useValues(); const values = useValues(node);
const switchOperatorOptions = useBuildSwitchOperatorOptions(); const switchOperatorOptions = useBuildSwitchOperatorOptions();
const FormSchema = z.object({ const FormSchema = z.object({
@ -234,6 +242,8 @@ const SwitchForm = ({ node }: IOperatorForm) => {
})); }));
}, [t]); }, [t]);
useWatchFormChange(node?.id, form);
return ( return (
<Form {...form}> <Form {...form}>
<form <form
@ -289,7 +299,7 @@ const SwitchForm = ({ node }: IOperatorForm) => {
}) })
} }
> >
add Add
</BlockButton> </BlockButton>
</form> </form>
</Form> </Form>

View File

@ -1,16 +1,13 @@
import { RAGFlowNodeType } from '@/interfaces/database/flow'; import { RAGFlowNodeType } from '@/interfaces/database/flow';
import { isEmpty } from 'lodash'; import { isEmpty } from 'lodash';
import { useMemo } from 'react'; import { useMemo } from 'react';
import { initialSwitchValues } from '../../constant';
const defaultValues = {
conditions: [],
};
export function useValues(node?: RAGFlowNodeType) { export function useValues(node?: RAGFlowNodeType) {
const values = useMemo(() => { const values = useMemo(() => {
const formData = node?.data?.form; const formData = node?.data?.form;
if (isEmpty(formData)) { if (isEmpty(formData)) {
return defaultValues; return initialSwitchValues;
} }
return formData; return formData;

View File

@ -15,10 +15,10 @@ import React, {
// import { shallow } from 'zustand/shallow'; // import { shallow } from 'zustand/shallow';
import { settledModelVariableMap } from '@/constants/knowledge'; import { settledModelVariableMap } from '@/constants/knowledge';
import { useFetchModelId } from '@/hooks/logic-hooks'; import { useFetchModelId } from '@/hooks/logic-hooks';
import { ISwitchForm } from '@/interfaces/database/agent';
import { import {
ICategorizeForm, ICategorizeForm,
IRelevantForm, IRelevantForm,
ISwitchForm,
RAGFlowNodeType, RAGFlowNodeType,
} from '@/interfaces/database/flow'; } from '@/interfaces/database/flow';
import { message } from 'antd'; import { message } from 'antd';
@ -543,9 +543,9 @@ export const useWatchNodeFormDataChange = () => {
case Operator.Categorize: case Operator.Categorize:
buildCategorizeEdgesByFormData(node.id, form as ICategorizeForm); buildCategorizeEdgesByFormData(node.id, form as ICategorizeForm);
break; break;
case Operator.Switch: // case Operator.Switch:
buildSwitchEdgesByFormData(node.id, form as ISwitchForm); // buildSwitchEdgesByFormData(node.id, form as ISwitchForm);
break; // break;
default: default:
break; break;
} }
@ -555,7 +555,6 @@ export const useWatchNodeFormDataChange = () => {
buildCategorizeEdgesByFormData, buildCategorizeEdgesByFormData,
getNode, getNode,
buildRelevantEdgesByFormData, buildRelevantEdgesByFormData,
buildSwitchEdgesByFormData,
]); ]);
}; };

View File

@ -224,6 +224,7 @@ export function useAddNode(reactFlowInstance?: ReactFlowInstance<any, any>) {
[ [
addEdge, addEdge,
addNode, addNode,
edges,
getNode, getNode,
getNodeName, getNodeName,
initializeOperatorParams, initializeOperatorParams,

View File

@ -135,24 +135,6 @@ export const useBuildVariableOptions = (nodeId?: string) => {
return options; return options;
}; };
export const useGetComponentLabelByValue = (nodeId: string) => {
const options = useBuildVariableOptions(nodeId);
const flattenOptions = useMemo(() => {
return options.reduce<DefaultOptionType[]>((pre, cur) => {
return [...pre, ...cur.options];
}, []);
}, [options]);
const getLabel = useCallback(
(val?: string) => {
return flattenOptions.find((x) => x.value === val)?.label;
},
[flattenOptions],
);
return getLabel;
};
export function useBuildQueryVariableOptions() { export function useBuildQueryVariableOptions() {
const { data } = useFetchAgent(); const { data } = useFetchAgent();
const node = useContext(AgentFormContext); const node = useContext(AgentFormContext);
@ -220,3 +202,21 @@ export function useBuildComponentIdAndBeginOptions(
return [...beginOptions, ...componentIdOptions]; return [...beginOptions, ...componentIdOptions];
} }
export const useGetComponentLabelByValue = (nodeId: string) => {
const options = useBuildComponentIdAndBeginOptions(nodeId);
const flattenOptions = useMemo(() => {
return options.reduce<DefaultOptionType[]>((pre, cur) => {
return [...pre, ...cur.options];
}, []);
}, [options]);
const getLabel = useCallback(
(val?: string) => {
return flattenOptions.find((x) => x.value === val)?.label;
},
[flattenOptions],
);
return getLabel;
};

View File

@ -56,6 +56,7 @@ export type RFState = {
source: string, source: string,
sourceHandle?: string | null, sourceHandle?: string | null,
target?: string | null, target?: string | null,
isConnecting?: boolean,
) => void; ) => void;
deletePreviousEdgeOfClassificationNode: (connection: Connection) => void; deletePreviousEdgeOfClassificationNode: (connection: Connection) => void;
duplicateNode: (id: string, name: string) => void; duplicateNode: (id: string, name: string) => void;
@ -204,7 +205,7 @@ const useGraphStore = create<RFState>()(
]); ]);
break; break;
case Operator.Switch: { case Operator.Switch: {
updateSwitchFormData(source, sourceHandle, target); updateSwitchFormData(source, sourceHandle, target, true);
break; break;
} }
default: default:
@ -219,7 +220,7 @@ const useGraphStore = create<RFState>()(
const anchoredNodes = [ const anchoredNodes = [
Operator.Categorize, Operator.Categorize,
Operator.Relevant, Operator.Relevant,
Operator.Switch, // Operator.Switch,
]; ];
if ( if (
anchoredNodes.some( anchoredNodes.some(
@ -303,7 +304,7 @@ const useGraphStore = create<RFState>()(
const currentEdge = edges.find((x) => x.id === id); const currentEdge = edges.find((x) => x.id === id);
if (currentEdge) { if (currentEdge) {
const { source, sourceHandle } = currentEdge; const { source, sourceHandle, target } = currentEdge;
const operatorType = getOperatorTypeFromId(source); const operatorType = getOperatorTypeFromId(source);
// After deleting the edge, set the corresponding field in the node's form field to undefined // After deleting the edge, set the corresponding field in the node's form field to undefined
switch (operatorType) { switch (operatorType) {
@ -321,7 +322,7 @@ const useGraphStore = create<RFState>()(
]); ]);
break; break;
case Operator.Switch: { case Operator.Switch: {
updateSwitchFormData(source, sourceHandle, undefined); updateSwitchFormData(source, sourceHandle, target, false);
break; break;
} }
default: default:
@ -402,15 +403,32 @@ const useGraphStore = create<RFState>()(
return nextNodes; return nextNodes;
}, },
updateSwitchFormData: (source, sourceHandle, target) => { updateSwitchFormData: (source, sourceHandle, target, isConnecting) => {
const { updateNodeForm } = get(); const { updateNodeForm, edges } = get();
if (sourceHandle) { if (sourceHandle) {
// A handle will connect to multiple downstream nodes
let currentHandleTargets = edges
.filter(
(x) =>
x.source === source &&
x.sourceHandle === sourceHandle &&
typeof x.target === 'string',
)
.map((x) => x.target);
let targets: string[] = currentHandleTargets;
if (target) {
if (!isConnecting) {
targets = currentHandleTargets.filter((x) => x !== target);
}
}
if (sourceHandle === SwitchElseTo) { if (sourceHandle === SwitchElseTo) {
updateNodeForm(source, target, [SwitchElseTo]); updateNodeForm(source, targets, [SwitchElseTo]);
} else { } else {
const operatorIndex = getOperatorIndex(sourceHandle); const operatorIndex = getOperatorIndex(sourceHandle);
if (operatorIndex) { if (operatorIndex) {
updateNodeForm(source, target, [ updateNodeForm(source, targets, [
'conditions', 'conditions',
Number(operatorIndex) - 1, // The index is the conditions form index Number(operatorIndex) - 1, // The index is the conditions form index
'to', 'to',
@ -448,7 +466,7 @@ const useGraphStore = create<RFState>()(
return generateNodeNamesWithIncreasingIndex(name, nodes); return generateNodeNamesWithIncreasingIndex(name, nodes);
}, },
})), })),
{ name: 'graph' }, { name: 'graph', trace: true },
), ),
); );