Feat: Add the iteration Node #4242 (#4247)

### What problem does this PR solve?

Feat: Add the iteration Node #4242

### Type of change


- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
balibabu
2024-12-27 11:24:17 +08:00
committed by GitHub
parent a6f4153775
commit a1a825c830
72 changed files with 1330 additions and 560 deletions

View File

@ -1,7 +1,3 @@
import { useSetModalState } from '@/hooks/common-hooks';
import { useFetchFlow, useResetFlow, useSetFlow } from '@/hooks/flow-hooks';
import { IGraph } from '@/interfaces/database/flow';
import { useIsFetching } from '@tanstack/react-query';
import React, {
ChangeEvent,
useCallback,
@ -12,23 +8,17 @@ import React, {
import { Connection, Edge, Node, Position, ReactFlowInstance } from 'reactflow';
// import { shallow } from 'zustand/shallow';
import { variableEnabledFieldMap } from '@/constants/chat';
import { FileMimeType } from '@/constants/common';
import {
ModelVariableType,
settledModelVariableMap,
} from '@/constants/knowledge';
import { useFetchModelId } from '@/hooks/logic-hooks';
import { Variable } from '@/interfaces/database/chat';
import { downloadJsonFile } from '@/utils/file-util';
import { useDebounceEffect } from 'ahooks';
import { FormInstance, UploadFile, message } from 'antd';
import { DefaultOptionType } from 'antd/es/select';
import dayjs from 'dayjs';
import { FormInstance, message } from 'antd';
import { humanId } from 'human-id';
import { get, isEmpty, lowerFirst, pick } from 'lodash';
import trim from 'lodash/trim';
import { useTranslation } from 'react-i18next';
import { useParams } from 'umi';
import { v4 as uuid } from 'uuid';
import {
NodeMap,
@ -53,6 +43,7 @@ import {
initialGoogleScholarValues,
initialGoogleValues,
initialInvokeValues,
initialIterationValues,
initialJin10Values,
initialKeywordExtractValues,
initialMessageValues,
@ -69,18 +60,13 @@ import {
initialWikipediaValues,
initialYahooFinanceValues,
} from './constant';
import {
BeginQuery,
ICategorizeForm,
IRelevantForm,
ISwitchForm,
} from './interface';
import { ICategorizeForm, IRelevantForm, ISwitchForm } from './interface';
import useGraphStore, { RFState } from './store';
import {
buildDslComponentsByGraph,
generateNodeNamesWithIncreasingIndex,
generateSwitchHandleText,
getNodeDragHandle,
getRelativePositionToIterationNode,
replaceIdWithText,
} from './utils';
@ -145,6 +131,8 @@ export const useInitializeOperatorParams = () => {
[Operator.Invoke]: initialInvokeValues,
[Operator.Template]: initialTemplateValues,
[Operator.Email]: initialEmailValues,
[Operator.Iteration]: initialIterationValues,
[Operator.IterationStart]: initialIterationValues,
};
}, [llmId]);
@ -210,7 +198,7 @@ export const useHandleDrop = () => {
x: event.clientX,
y: event.clientY,
});
const newNode = {
const newNode: Node<any> = {
id: `${type}:${humanId()}`,
type: NodeMap[type as Operator] || 'ragNode',
position: position || {
@ -227,7 +215,38 @@ export const useHandleDrop = () => {
dragHandle: getNodeDragHandle(type),
};
addNode(newNode);
if (type === Operator.Iteration) {
newNode.style = {
width: 500,
height: 250,
};
const iterationStartNode: Node<any> = {
id: `${Operator.IterationStart}:${humanId()}`,
type: 'iterationStartNode',
position: { x: 50, y: 100 },
// draggable: false,
data: {
label: Operator.IterationStart,
name: Operator.IterationStart,
form: {},
},
parentId: newNode.id,
extent: 'parent',
};
addNode(newNode);
addNode(iterationStartNode);
} else {
const subNodeOfIteration = getRelativePositionToIterationNode(
nodes,
position,
);
if (subNodeOfIteration) {
newNode.parentId = subNodeOfIteration.parentId;
newNode.position = subNodeOfIteration.position;
newNode.extent = 'parent';
}
addNode(newNode);
}
},
[reactFlowInstance, getNodeName, nodes, initializeOperatorParams, addNode],
);
@ -235,78 +254,6 @@ export const useHandleDrop = () => {
return { onDrop, onDragOver, setReactFlowInstance };
};
export const useShowFormDrawer = () => {
const {
clickedNodeId: clickNodeId,
setClickedNodeId,
getNode,
} = useGraphStore((state) => state);
const {
visible: formDrawerVisible,
hideModal: hideFormDrawer,
showModal: showFormDrawer,
} = useSetModalState();
const handleShow = useCallback(
(node: Node) => {
setClickedNodeId(node.id);
showFormDrawer();
},
[showFormDrawer, setClickedNodeId],
);
return {
formDrawerVisible,
hideFormDrawer,
showFormDrawer: handleShow,
clickedNode: getNode(clickNodeId),
};
};
export const useBuildDslData = () => {
const { data } = useFetchFlow();
const { nodes, edges } = useGraphStore((state) => state);
const buildDslData = useCallback(
(currentNodes?: Node[]) => {
const dslComponents = buildDslComponentsByGraph(
currentNodes ?? nodes,
edges,
data.dsl.components,
);
return {
...data.dsl,
graph: { nodes: currentNodes ?? nodes, edges },
components: dslComponents,
};
},
[data.dsl, edges, nodes],
);
return { buildDslData };
};
export const useSaveGraph = () => {
const { data } = useFetchFlow();
const { setFlow, loading } = useSetFlow();
const { id } = useParams();
const { buildDslData } = useBuildDslData();
const saveGraph = useCallback(
async (currentNodes?: Node[]) => {
return setFlow({
id,
title: data.title,
dsl: buildDslData(currentNodes),
});
},
[setFlow, id, data.title, buildDslData],
);
return { saveGraph, loading };
};
export const useHandleFormValuesChange = (id?: string) => {
const updateNodeForm = useGraphStore((state) => state.updateNodeForm);
const handleValuesChange = useCallback(
@ -335,39 +282,6 @@ export const useHandleFormValuesChange = (id?: string) => {
return { handleValuesChange };
};
const useSetGraphInfo = () => {
const { setEdges, setNodes } = useGraphStore((state) => state);
const setGraphInfo = useCallback(
({ nodes = [], edges = [] }: IGraph) => {
if (nodes.length || edges.length) {
setNodes(nodes);
setEdges(edges);
}
},
[setEdges, setNodes],
);
return setGraphInfo;
};
export const useFetchDataOnMount = () => {
const { loading, data, refetch } = useFetchFlow();
const setGraphInfo = useSetGraphInfo();
useEffect(() => {
setGraphInfo(data?.dsl?.graph ?? ({} as IGraph));
}, [setGraphInfo, data]);
useEffect(() => {
refetch();
}, [refetch]);
return { loading, flowDetail: data };
};
export const useFlowIsFetching = () => {
return useIsFetching({ queryKey: ['flowDetail'] }) > 0;
};
export const useSetLlmSetting = (
form?: FormInstance,
formData?: Record<string, any>,
@ -401,7 +315,22 @@ export const useSetLlmSetting = (
};
export const useValidateConnection = () => {
const { edges, getOperatorTypeFromId } = useGraphStore((state) => state);
const { edges, getOperatorTypeFromId, getParentIdById } = useGraphStore(
(state) => state,
);
const isSameNodeChild = useCallback(
(connection: Connection) => {
const sourceParentId = getParentIdById(connection.source);
const targetParentId = getParentIdById(connection.target);
if (sourceParentId || targetParentId) {
return sourceParentId === targetParentId;
}
return true;
},
[getParentIdById],
);
// restricted lines cannot be connected successfully.
const isValidConnection = useCallback(
(connection: Connection) => {
@ -418,10 +347,11 @@ export const useValidateConnection = () => {
!hasLine &&
RestrictedUpstreamMap[
getOperatorTypeFromId(connection.source) as Operator
]?.every((x) => x !== getOperatorTypeFromId(connection.target));
]?.every((x) => x !== getOperatorTypeFromId(connection.target)) &&
isSameNodeChild(connection);
return ret;
},
[edges, getOperatorTypeFromId],
[edges, getOperatorTypeFromId, isSameNodeChild],
);
return isValidConnection;
@ -464,52 +394,6 @@ export const useHandleNodeNameChange = ({
return { name, handleNameBlur, handleNameChange };
};
export const useGetBeginNodeDataQuery = () => {
const getNode = useGraphStore((state) => state.getNode);
const getBeginNodeDataQuery = useCallback(() => {
return get(getNode('begin'), 'data.form.query', []);
}, [getNode]);
return getBeginNodeDataQuery;
};
export const useGetBeginNodeDataQueryIsEmpty = () => {
const [isBeginNodeDataQueryEmpty, setIsBeginNodeDataQueryEmpty] =
useState(false);
const getBeginNodeDataQuery = useGetBeginNodeDataQuery();
const nodes = useGraphStore((state) => state.nodes);
useEffect(() => {
const query: BeginQuery[] = getBeginNodeDataQuery();
setIsBeginNodeDataQueryEmpty(query.length === 0);
}, [getBeginNodeDataQuery, nodes]);
return isBeginNodeDataQueryEmpty;
};
export const useSaveGraphBeforeOpeningDebugDrawer = (show: () => void) => {
const { saveGraph, loading } = useSaveGraph();
const { resetFlow } = useResetFlow();
const handleRun = useCallback(
async (nextNodes?: Node[]) => {
const saveRet = await saveGraph(nextNodes);
if (saveRet?.code === 0) {
// Call the reset api before opening the run drawer each time
const resetRet = await resetFlow();
// After resetting, all previous messages will be cleared.
if (resetRet?.code === 0) {
show();
}
}
},
[saveGraph, resetFlow, show],
);
return { handleRun, loading };
};
export const useReplaceIdWithName = () => {
const getNode = useGraphStore((state) => state.getNode);
@ -647,66 +531,6 @@ export const useWatchNodeFormDataChange = () => {
]);
};
// exclude nodes with branches
const ExcludedNodes = [
Operator.Categorize,
Operator.Relevant,
Operator.Begin,
Operator.Note,
];
export const useBuildComponentIdSelectOptions = (nodeId?: string) => {
const nodes = useGraphStore((state) => state.nodes);
const getBeginNodeDataQuery = useGetBeginNodeDataQuery();
const query: BeginQuery[] = getBeginNodeDataQuery();
const componentIdOptions = useMemo(() => {
return nodes
.filter(
(x) =>
x.id !== nodeId && !ExcludedNodes.some((y) => y === x.data.label),
)
.map((x) => ({ label: x.data.name, value: x.id }));
}, [nodes, nodeId]);
const groupedOptions = [
{
label: <span>Component Output</span>,
title: 'Component Output',
options: componentIdOptions,
},
{
label: <span>Begin Input</span>,
title: 'Begin Input',
options: query.map((x) => ({
label: x.name,
value: `begin@${x.key}`,
})),
},
];
return groupedOptions;
};
export const useGetComponentLabelByValue = (nodeId: string) => {
const options = useBuildComponentIdSelectOptions(nodeId);
const flattenOptions = useMemo(
() =>
options.reduce<DefaultOptionType[]>((pre, cur) => {
return [...pre, ...cur.options];
}, []),
[options],
);
const getLabel = useCallback(
(val?: string) => {
return flattenOptions.find((x) => x.value === val)?.label;
},
[flattenOptions],
);
return getLabel;
};
export const useDuplicateNode = () => {
const duplicateNodeById = useGraphStore((store) => store.duplicateNode);
const getNodeName = useGetNodeName();
@ -769,107 +593,3 @@ export const useCopyPaste = () => {
};
}, [onPasteCapture]);
};
export const useWatchAgentChange = (chatDrawerVisible: boolean) => {
const [time, setTime] = useState<string>();
const nodes = useGraphStore((state) => state.nodes);
const edges = useGraphStore((state) => state.edges);
const { saveGraph } = useSaveGraph();
const { data: flowDetail } = useFetchFlow();
const setSaveTime = useCallback((updateTime: number) => {
setTime(dayjs(updateTime).format('YYYY-MM-DD HH:mm:ss'));
}, []);
useEffect(() => {
setSaveTime(flowDetail?.update_time);
}, [flowDetail, setSaveTime]);
const saveAgent = useCallback(async () => {
if (!chatDrawerVisible) {
const ret = await saveGraph();
setSaveTime(ret.data.update_time);
}
}, [chatDrawerVisible, saveGraph, setSaveTime]);
useDebounceEffect(
() => {
saveAgent();
},
[nodes, edges],
{
wait: 1000 * 20,
},
);
return time;
};
export const useHandleExportOrImportJsonFile = () => {
const { buildDslData } = useBuildDslData();
const {
visible: fileUploadVisible,
hideModal: hideFileUploadModal,
showModal: showFileUploadModal,
} = useSetModalState();
const setGraphInfo = useSetGraphInfo();
const { data } = useFetchFlow();
const { t } = useTranslation();
const onFileUploadOk = useCallback(
async (fileList: UploadFile[]) => {
if (fileList.length > 0) {
const file: File = fileList[0] as unknown as File;
if (file.type !== FileMimeType.Json) {
message.error(t('flow.jsonUploadTypeErrorMessage'));
return;
}
const graphStr = await file.text();
const errorMessage = t('flow.jsonUploadContentErrorMessage');
try {
const graph = JSON.parse(graphStr);
if (graphStr && !isEmpty(graph) && Array.isArray(graph?.nodes)) {
setGraphInfo(graph ?? ({} as IGraph));
hideFileUploadModal();
} else {
message.error(errorMessage);
}
} catch (error) {
message.error(errorMessage);
}
}
},
[hideFileUploadModal, setGraphInfo, t],
);
const handleExportJson = useCallback(() => {
downloadJsonFile(buildDslData().graph, `${data.title}.json`);
}, [buildDslData, data.title]);
return {
fileUploadVisible,
handleExportJson,
handleImportJson: showFileUploadModal,
hideFileUploadModal,
onFileUploadOk,
};
};
export const useShowSingleDebugDrawer = () => {
const { visible, showModal, hideModal } = useSetModalState();
const { saveGraph } = useSaveGraph();
const showSingleDebugDrawer = useCallback(async () => {
const saveRet = await saveGraph();
if (saveRet?.code === 0) {
showModal();
}
}, [saveGraph, showModal]);
return {
singleDebugDrawerVisible: visible,
hideSingleDebugDrawer: hideModal,
showSingleDebugDrawer,
};
};