mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-30 08:35:33 +08:00
### What problem does this PR solve? Feat: Add the iteration Node #4242 ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -1,7 +1,3 @@
|
||||
import { useSetModalState } from '@/hooks/common-hooks';
|
||||
import { useFetchFlow, useResetFlow, useSetFlow } from '@/hooks/flow-hooks';
|
||||
import { IGraph } from '@/interfaces/database/flow';
|
||||
import { useIsFetching } from '@tanstack/react-query';
|
||||
import React, {
|
||||
ChangeEvent,
|
||||
useCallback,
|
||||
@ -12,23 +8,17 @@ import React, {
|
||||
import { Connection, Edge, Node, Position, ReactFlowInstance } from 'reactflow';
|
||||
// import { shallow } from 'zustand/shallow';
|
||||
import { variableEnabledFieldMap } from '@/constants/chat';
|
||||
import { FileMimeType } from '@/constants/common';
|
||||
import {
|
||||
ModelVariableType,
|
||||
settledModelVariableMap,
|
||||
} from '@/constants/knowledge';
|
||||
import { useFetchModelId } from '@/hooks/logic-hooks';
|
||||
import { Variable } from '@/interfaces/database/chat';
|
||||
import { downloadJsonFile } from '@/utils/file-util';
|
||||
import { useDebounceEffect } from 'ahooks';
|
||||
import { FormInstance, UploadFile, message } from 'antd';
|
||||
import { DefaultOptionType } from 'antd/es/select';
|
||||
import dayjs from 'dayjs';
|
||||
import { FormInstance, message } from 'antd';
|
||||
import { humanId } from 'human-id';
|
||||
import { get, isEmpty, lowerFirst, pick } from 'lodash';
|
||||
import trim from 'lodash/trim';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useParams } from 'umi';
|
||||
import { v4 as uuid } from 'uuid';
|
||||
import {
|
||||
NodeMap,
|
||||
@ -53,6 +43,7 @@ import {
|
||||
initialGoogleScholarValues,
|
||||
initialGoogleValues,
|
||||
initialInvokeValues,
|
||||
initialIterationValues,
|
||||
initialJin10Values,
|
||||
initialKeywordExtractValues,
|
||||
initialMessageValues,
|
||||
@ -69,18 +60,13 @@ import {
|
||||
initialWikipediaValues,
|
||||
initialYahooFinanceValues,
|
||||
} from './constant';
|
||||
import {
|
||||
BeginQuery,
|
||||
ICategorizeForm,
|
||||
IRelevantForm,
|
||||
ISwitchForm,
|
||||
} from './interface';
|
||||
import { ICategorizeForm, IRelevantForm, ISwitchForm } from './interface';
|
||||
import useGraphStore, { RFState } from './store';
|
||||
import {
|
||||
buildDslComponentsByGraph,
|
||||
generateNodeNamesWithIncreasingIndex,
|
||||
generateSwitchHandleText,
|
||||
getNodeDragHandle,
|
||||
getRelativePositionToIterationNode,
|
||||
replaceIdWithText,
|
||||
} from './utils';
|
||||
|
||||
@ -145,6 +131,8 @@ export const useInitializeOperatorParams = () => {
|
||||
[Operator.Invoke]: initialInvokeValues,
|
||||
[Operator.Template]: initialTemplateValues,
|
||||
[Operator.Email]: initialEmailValues,
|
||||
[Operator.Iteration]: initialIterationValues,
|
||||
[Operator.IterationStart]: initialIterationValues,
|
||||
};
|
||||
}, [llmId]);
|
||||
|
||||
@ -210,7 +198,7 @@ export const useHandleDrop = () => {
|
||||
x: event.clientX,
|
||||
y: event.clientY,
|
||||
});
|
||||
const newNode = {
|
||||
const newNode: Node<any> = {
|
||||
id: `${type}:${humanId()}`,
|
||||
type: NodeMap[type as Operator] || 'ragNode',
|
||||
position: position || {
|
||||
@ -227,7 +215,38 @@ export const useHandleDrop = () => {
|
||||
dragHandle: getNodeDragHandle(type),
|
||||
};
|
||||
|
||||
addNode(newNode);
|
||||
if (type === Operator.Iteration) {
|
||||
newNode.style = {
|
||||
width: 500,
|
||||
height: 250,
|
||||
};
|
||||
const iterationStartNode: Node<any> = {
|
||||
id: `${Operator.IterationStart}:${humanId()}`,
|
||||
type: 'iterationStartNode',
|
||||
position: { x: 50, y: 100 },
|
||||
// draggable: false,
|
||||
data: {
|
||||
label: Operator.IterationStart,
|
||||
name: Operator.IterationStart,
|
||||
form: {},
|
||||
},
|
||||
parentId: newNode.id,
|
||||
extent: 'parent',
|
||||
};
|
||||
addNode(newNode);
|
||||
addNode(iterationStartNode);
|
||||
} else {
|
||||
const subNodeOfIteration = getRelativePositionToIterationNode(
|
||||
nodes,
|
||||
position,
|
||||
);
|
||||
if (subNodeOfIteration) {
|
||||
newNode.parentId = subNodeOfIteration.parentId;
|
||||
newNode.position = subNodeOfIteration.position;
|
||||
newNode.extent = 'parent';
|
||||
}
|
||||
addNode(newNode);
|
||||
}
|
||||
},
|
||||
[reactFlowInstance, getNodeName, nodes, initializeOperatorParams, addNode],
|
||||
);
|
||||
@ -235,78 +254,6 @@ export const useHandleDrop = () => {
|
||||
return { onDrop, onDragOver, setReactFlowInstance };
|
||||
};
|
||||
|
||||
export const useShowFormDrawer = () => {
|
||||
const {
|
||||
clickedNodeId: clickNodeId,
|
||||
setClickedNodeId,
|
||||
getNode,
|
||||
} = useGraphStore((state) => state);
|
||||
const {
|
||||
visible: formDrawerVisible,
|
||||
hideModal: hideFormDrawer,
|
||||
showModal: showFormDrawer,
|
||||
} = useSetModalState();
|
||||
|
||||
const handleShow = useCallback(
|
||||
(node: Node) => {
|
||||
setClickedNodeId(node.id);
|
||||
showFormDrawer();
|
||||
},
|
||||
[showFormDrawer, setClickedNodeId],
|
||||
);
|
||||
|
||||
return {
|
||||
formDrawerVisible,
|
||||
hideFormDrawer,
|
||||
showFormDrawer: handleShow,
|
||||
clickedNode: getNode(clickNodeId),
|
||||
};
|
||||
};
|
||||
|
||||
export const useBuildDslData = () => {
|
||||
const { data } = useFetchFlow();
|
||||
const { nodes, edges } = useGraphStore((state) => state);
|
||||
|
||||
const buildDslData = useCallback(
|
||||
(currentNodes?: Node[]) => {
|
||||
const dslComponents = buildDslComponentsByGraph(
|
||||
currentNodes ?? nodes,
|
||||
edges,
|
||||
data.dsl.components,
|
||||
);
|
||||
|
||||
return {
|
||||
...data.dsl,
|
||||
graph: { nodes: currentNodes ?? nodes, edges },
|
||||
components: dslComponents,
|
||||
};
|
||||
},
|
||||
[data.dsl, edges, nodes],
|
||||
);
|
||||
|
||||
return { buildDslData };
|
||||
};
|
||||
|
||||
export const useSaveGraph = () => {
|
||||
const { data } = useFetchFlow();
|
||||
const { setFlow, loading } = useSetFlow();
|
||||
const { id } = useParams();
|
||||
const { buildDslData } = useBuildDslData();
|
||||
|
||||
const saveGraph = useCallback(
|
||||
async (currentNodes?: Node[]) => {
|
||||
return setFlow({
|
||||
id,
|
||||
title: data.title,
|
||||
dsl: buildDslData(currentNodes),
|
||||
});
|
||||
},
|
||||
[setFlow, id, data.title, buildDslData],
|
||||
);
|
||||
|
||||
return { saveGraph, loading };
|
||||
};
|
||||
|
||||
export const useHandleFormValuesChange = (id?: string) => {
|
||||
const updateNodeForm = useGraphStore((state) => state.updateNodeForm);
|
||||
const handleValuesChange = useCallback(
|
||||
@ -335,39 +282,6 @@ export const useHandleFormValuesChange = (id?: string) => {
|
||||
return { handleValuesChange };
|
||||
};
|
||||
|
||||
const useSetGraphInfo = () => {
|
||||
const { setEdges, setNodes } = useGraphStore((state) => state);
|
||||
const setGraphInfo = useCallback(
|
||||
({ nodes = [], edges = [] }: IGraph) => {
|
||||
if (nodes.length || edges.length) {
|
||||
setNodes(nodes);
|
||||
setEdges(edges);
|
||||
}
|
||||
},
|
||||
[setEdges, setNodes],
|
||||
);
|
||||
return setGraphInfo;
|
||||
};
|
||||
|
||||
export const useFetchDataOnMount = () => {
|
||||
const { loading, data, refetch } = useFetchFlow();
|
||||
const setGraphInfo = useSetGraphInfo();
|
||||
|
||||
useEffect(() => {
|
||||
setGraphInfo(data?.dsl?.graph ?? ({} as IGraph));
|
||||
}, [setGraphInfo, data]);
|
||||
|
||||
useEffect(() => {
|
||||
refetch();
|
||||
}, [refetch]);
|
||||
|
||||
return { loading, flowDetail: data };
|
||||
};
|
||||
|
||||
export const useFlowIsFetching = () => {
|
||||
return useIsFetching({ queryKey: ['flowDetail'] }) > 0;
|
||||
};
|
||||
|
||||
export const useSetLlmSetting = (
|
||||
form?: FormInstance,
|
||||
formData?: Record<string, any>,
|
||||
@ -401,7 +315,22 @@ export const useSetLlmSetting = (
|
||||
};
|
||||
|
||||
export const useValidateConnection = () => {
|
||||
const { edges, getOperatorTypeFromId } = useGraphStore((state) => state);
|
||||
const { edges, getOperatorTypeFromId, getParentIdById } = useGraphStore(
|
||||
(state) => state,
|
||||
);
|
||||
|
||||
const isSameNodeChild = useCallback(
|
||||
(connection: Connection) => {
|
||||
const sourceParentId = getParentIdById(connection.source);
|
||||
const targetParentId = getParentIdById(connection.target);
|
||||
if (sourceParentId || targetParentId) {
|
||||
return sourceParentId === targetParentId;
|
||||
}
|
||||
return true;
|
||||
},
|
||||
[getParentIdById],
|
||||
);
|
||||
|
||||
// restricted lines cannot be connected successfully.
|
||||
const isValidConnection = useCallback(
|
||||
(connection: Connection) => {
|
||||
@ -418,10 +347,11 @@ export const useValidateConnection = () => {
|
||||
!hasLine &&
|
||||
RestrictedUpstreamMap[
|
||||
getOperatorTypeFromId(connection.source) as Operator
|
||||
]?.every((x) => x !== getOperatorTypeFromId(connection.target));
|
||||
]?.every((x) => x !== getOperatorTypeFromId(connection.target)) &&
|
||||
isSameNodeChild(connection);
|
||||
return ret;
|
||||
},
|
||||
[edges, getOperatorTypeFromId],
|
||||
[edges, getOperatorTypeFromId, isSameNodeChild],
|
||||
);
|
||||
|
||||
return isValidConnection;
|
||||
@ -464,52 +394,6 @@ export const useHandleNodeNameChange = ({
|
||||
return { name, handleNameBlur, handleNameChange };
|
||||
};
|
||||
|
||||
export const useGetBeginNodeDataQuery = () => {
|
||||
const getNode = useGraphStore((state) => state.getNode);
|
||||
|
||||
const getBeginNodeDataQuery = useCallback(() => {
|
||||
return get(getNode('begin'), 'data.form.query', []);
|
||||
}, [getNode]);
|
||||
|
||||
return getBeginNodeDataQuery;
|
||||
};
|
||||
|
||||
export const useGetBeginNodeDataQueryIsEmpty = () => {
|
||||
const [isBeginNodeDataQueryEmpty, setIsBeginNodeDataQueryEmpty] =
|
||||
useState(false);
|
||||
const getBeginNodeDataQuery = useGetBeginNodeDataQuery();
|
||||
const nodes = useGraphStore((state) => state.nodes);
|
||||
|
||||
useEffect(() => {
|
||||
const query: BeginQuery[] = getBeginNodeDataQuery();
|
||||
setIsBeginNodeDataQueryEmpty(query.length === 0);
|
||||
}, [getBeginNodeDataQuery, nodes]);
|
||||
|
||||
return isBeginNodeDataQueryEmpty;
|
||||
};
|
||||
|
||||
export const useSaveGraphBeforeOpeningDebugDrawer = (show: () => void) => {
|
||||
const { saveGraph, loading } = useSaveGraph();
|
||||
const { resetFlow } = useResetFlow();
|
||||
|
||||
const handleRun = useCallback(
|
||||
async (nextNodes?: Node[]) => {
|
||||
const saveRet = await saveGraph(nextNodes);
|
||||
if (saveRet?.code === 0) {
|
||||
// Call the reset api before opening the run drawer each time
|
||||
const resetRet = await resetFlow();
|
||||
// After resetting, all previous messages will be cleared.
|
||||
if (resetRet?.code === 0) {
|
||||
show();
|
||||
}
|
||||
}
|
||||
},
|
||||
[saveGraph, resetFlow, show],
|
||||
);
|
||||
|
||||
return { handleRun, loading };
|
||||
};
|
||||
|
||||
export const useReplaceIdWithName = () => {
|
||||
const getNode = useGraphStore((state) => state.getNode);
|
||||
|
||||
@ -647,66 +531,6 @@ export const useWatchNodeFormDataChange = () => {
|
||||
]);
|
||||
};
|
||||
|
||||
// exclude nodes with branches
|
||||
const ExcludedNodes = [
|
||||
Operator.Categorize,
|
||||
Operator.Relevant,
|
||||
Operator.Begin,
|
||||
Operator.Note,
|
||||
];
|
||||
|
||||
export const useBuildComponentIdSelectOptions = (nodeId?: string) => {
|
||||
const nodes = useGraphStore((state) => state.nodes);
|
||||
const getBeginNodeDataQuery = useGetBeginNodeDataQuery();
|
||||
const query: BeginQuery[] = getBeginNodeDataQuery();
|
||||
|
||||
const componentIdOptions = useMemo(() => {
|
||||
return nodes
|
||||
.filter(
|
||||
(x) =>
|
||||
x.id !== nodeId && !ExcludedNodes.some((y) => y === x.data.label),
|
||||
)
|
||||
.map((x) => ({ label: x.data.name, value: x.id }));
|
||||
}, [nodes, nodeId]);
|
||||
|
||||
const groupedOptions = [
|
||||
{
|
||||
label: <span>Component Output</span>,
|
||||
title: 'Component Output',
|
||||
options: componentIdOptions,
|
||||
},
|
||||
{
|
||||
label: <span>Begin Input</span>,
|
||||
title: 'Begin Input',
|
||||
options: query.map((x) => ({
|
||||
label: x.name,
|
||||
value: `begin@${x.key}`,
|
||||
})),
|
||||
},
|
||||
];
|
||||
|
||||
return groupedOptions;
|
||||
};
|
||||
|
||||
export const useGetComponentLabelByValue = (nodeId: string) => {
|
||||
const options = useBuildComponentIdSelectOptions(nodeId);
|
||||
const flattenOptions = useMemo(
|
||||
() =>
|
||||
options.reduce<DefaultOptionType[]>((pre, cur) => {
|
||||
return [...pre, ...cur.options];
|
||||
}, []),
|
||||
[options],
|
||||
);
|
||||
|
||||
const getLabel = useCallback(
|
||||
(val?: string) => {
|
||||
return flattenOptions.find((x) => x.value === val)?.label;
|
||||
},
|
||||
[flattenOptions],
|
||||
);
|
||||
return getLabel;
|
||||
};
|
||||
|
||||
export const useDuplicateNode = () => {
|
||||
const duplicateNodeById = useGraphStore((store) => store.duplicateNode);
|
||||
const getNodeName = useGetNodeName();
|
||||
@ -769,107 +593,3 @@ export const useCopyPaste = () => {
|
||||
};
|
||||
}, [onPasteCapture]);
|
||||
};
|
||||
|
||||
export const useWatchAgentChange = (chatDrawerVisible: boolean) => {
|
||||
const [time, setTime] = useState<string>();
|
||||
const nodes = useGraphStore((state) => state.nodes);
|
||||
const edges = useGraphStore((state) => state.edges);
|
||||
const { saveGraph } = useSaveGraph();
|
||||
const { data: flowDetail } = useFetchFlow();
|
||||
|
||||
const setSaveTime = useCallback((updateTime: number) => {
|
||||
setTime(dayjs(updateTime).format('YYYY-MM-DD HH:mm:ss'));
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
setSaveTime(flowDetail?.update_time);
|
||||
}, [flowDetail, setSaveTime]);
|
||||
|
||||
const saveAgent = useCallback(async () => {
|
||||
if (!chatDrawerVisible) {
|
||||
const ret = await saveGraph();
|
||||
setSaveTime(ret.data.update_time);
|
||||
}
|
||||
}, [chatDrawerVisible, saveGraph, setSaveTime]);
|
||||
|
||||
useDebounceEffect(
|
||||
() => {
|
||||
saveAgent();
|
||||
},
|
||||
[nodes, edges],
|
||||
{
|
||||
wait: 1000 * 20,
|
||||
},
|
||||
);
|
||||
|
||||
return time;
|
||||
};
|
||||
|
||||
export const useHandleExportOrImportJsonFile = () => {
|
||||
const { buildDslData } = useBuildDslData();
|
||||
const {
|
||||
visible: fileUploadVisible,
|
||||
hideModal: hideFileUploadModal,
|
||||
showModal: showFileUploadModal,
|
||||
} = useSetModalState();
|
||||
const setGraphInfo = useSetGraphInfo();
|
||||
const { data } = useFetchFlow();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const onFileUploadOk = useCallback(
|
||||
async (fileList: UploadFile[]) => {
|
||||
if (fileList.length > 0) {
|
||||
const file: File = fileList[0] as unknown as File;
|
||||
if (file.type !== FileMimeType.Json) {
|
||||
message.error(t('flow.jsonUploadTypeErrorMessage'));
|
||||
return;
|
||||
}
|
||||
|
||||
const graphStr = await file.text();
|
||||
const errorMessage = t('flow.jsonUploadContentErrorMessage');
|
||||
try {
|
||||
const graph = JSON.parse(graphStr);
|
||||
if (graphStr && !isEmpty(graph) && Array.isArray(graph?.nodes)) {
|
||||
setGraphInfo(graph ?? ({} as IGraph));
|
||||
hideFileUploadModal();
|
||||
} else {
|
||||
message.error(errorMessage);
|
||||
}
|
||||
} catch (error) {
|
||||
message.error(errorMessage);
|
||||
}
|
||||
}
|
||||
},
|
||||
[hideFileUploadModal, setGraphInfo, t],
|
||||
);
|
||||
|
||||
const handleExportJson = useCallback(() => {
|
||||
downloadJsonFile(buildDslData().graph, `${data.title}.json`);
|
||||
}, [buildDslData, data.title]);
|
||||
|
||||
return {
|
||||
fileUploadVisible,
|
||||
handleExportJson,
|
||||
handleImportJson: showFileUploadModal,
|
||||
hideFileUploadModal,
|
||||
onFileUploadOk,
|
||||
};
|
||||
};
|
||||
|
||||
export const useShowSingleDebugDrawer = () => {
|
||||
const { visible, showModal, hideModal } = useSetModalState();
|
||||
const { saveGraph } = useSaveGraph();
|
||||
|
||||
const showSingleDebugDrawer = useCallback(async () => {
|
||||
const saveRet = await saveGraph();
|
||||
if (saveRet?.code === 0) {
|
||||
showModal();
|
||||
}
|
||||
}, [saveGraph, showModal]);
|
||||
|
||||
return {
|
||||
singleDebugDrawerVisible: visible,
|
||||
hideSingleDebugDrawer: hideModal,
|
||||
showSingleDebugDrawer,
|
||||
};
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user