diff --git a/web/src/interfaces/database/agent.ts b/web/src/interfaces/database/agent.ts index 8fd95d83d..aac61a503 100644 --- a/web/src/interfaces/database/agent.ts +++ b/web/src/interfaces/database/agent.ts @@ -3,6 +3,7 @@ export interface ICategorizeItem { description?: string; examples?: { value: string }[]; index: number; + to: string[]; } export type ICategorizeItemResult = Record< diff --git a/web/src/pages/agent/store.ts b/web/src/pages/agent/store.ts index a78768e01..e8cc2b241 100644 --- a/web/src/pages/agent/store.ts +++ b/web/src/pages/agent/store.ts @@ -66,7 +66,6 @@ export type RFState = { target?: string | null, isConnecting?: boolean, ) => void; - deletePreviousEdgeOfClassificationNode: (connection: Connection) => void; duplicateNode: (id: string, name: string) => void; duplicateIterationNode: (id: string, name: string) => void; deleteEdge: () => void; @@ -122,14 +121,10 @@ const useGraphStore = create()( setEdges(mapEdgeMouseEvent(edges, edgeId, false)); }, onConnect: (connection: Connection) => { - const { - deletePreviousEdgeOfClassificationNode, - updateFormDataOnConnect, - } = get(); + const { updateFormDataOnConnect } = get(); set({ edges: addEdge(connection, get().edges), }); - deletePreviousEdgeOfClassificationNode(connection); updateFormDataOnConnect(connection); }, onSelectionChange: ({ nodes, edges }: OnSelectionChangeParams) => { @@ -216,7 +211,6 @@ const useGraphStore = create()( set({ edges: addEdge(connection, get().edges), }); - get().deletePreviousEdgeOfClassificationNode(connection); // TODO: This may not be reasonable. You need to choose between listening to changes in the form. get().updateFormDataOnConnect(connection); }, @@ -224,23 +218,11 @@ const useGraphStore = create()( return get().edges.find((x) => x.id === id); }, updateFormDataOnConnect: (connection: Connection) => { - const { getOperatorTypeFromId, updateNodeForm, updateSwitchFormData } = - get(); + const { getOperatorTypeFromId, updateSwitchFormData } = get(); const { source, target, sourceHandle } = connection; const operatorType = getOperatorTypeFromId(source); if (source) { switch (operatorType) { - case Operator.Relevant: - updateNodeForm(source, { [sourceHandle as string]: target }); - break; - case Operator.Categorize: - if (sourceHandle) - updateNodeForm(source, target, [ - 'category_description', - sourceHandle, - 'to', - ]); - break; case Operator.Switch: { updateSwitchFormData(source, sourceHandle, target, true); break; @@ -250,31 +232,6 @@ const useGraphStore = create()( } } }, - deletePreviousEdgeOfClassificationNode: (connection: Connection) => { - // Delete the edge on the classification node or relevant node anchor when the anchor is connected to other nodes - const { edges, getOperatorTypeFromId, deleteEdgeById } = get(); - // the node containing the anchor - const anchoredNodes = [ - Operator.Categorize, - Operator.Relevant, - // Operator.Switch, - ]; - if ( - anchoredNodes.some( - (x) => x === getOperatorTypeFromId(connection.source), - ) - ) { - const previousEdge = edges.find( - (x) => - x.source === connection.source && - x.sourceHandle === connection.sourceHandle && - x.target !== connection.target, - ); - if (previousEdge) { - deleteEdgeById(previousEdge.id); - } - } - }, duplicateNode: (id: string, name: string) => { const { getNode, addNode, generateNodeName, duplicateIterationNode } = get(); diff --git a/web/src/pages/agent/utils.ts b/web/src/pages/agent/utils.ts index 06fcb5230..f1167b791 100644 --- a/web/src/pages/agent/utils.ts +++ b/web/src/pages/agent/utils.ts @@ -1,5 +1,6 @@ import { IAgentForm, + ICategorizeForm, ICategorizeItem, ICategorizeItemResult, } from '@/interfaces/database/agent'; @@ -158,6 +159,33 @@ function buildAgentTools(edges: Edge[], nodes: Node[], nodeId: string) { return params; } +function filterTargetsBySourceHandleId(edges: Edge[], handleId: string) { + return edges.filter((x) => x.sourceHandle === handleId).map((x) => x.target); +} + +function buildCategorizeTos(edges: Edge[], nodes: Node[], nodeId: string) { + const node = nodes.find((x) => x.id === nodeId); + const params = { ...(node?.data.form ?? {}) } as ICategorizeForm; + if (node && node.data.label === Operator.Categorize) { + const subEdges = edges.filter((x) => x.source === nodeId); + + const categoryDescription = params.category_description || {}; + + const nextCategoryDescription = Object.entries(categoryDescription).reduce< + ICategorizeForm['category_description'] + >((pre, [key, val]) => { + pre[key] = { + ...val, + to: filterTargetsBySourceHandleId(subEdges, key), + }; + return pre; + }, {}); + + params.category_description = nextCategoryDescription; + } + return params; +} + const buildOperatorParams = (operatorName: string) => pipe( removeUselessDataInTheOperator(operatorName), @@ -190,8 +218,20 @@ export const buildDslComponentsByGraph = ( .forEach((x) => { const id = x.id; const operatorName = x.data.label; + let params = x?.data.form ?? {}; + + switch (operatorName) { + case Operator.Agent: + params = buildAgentTools(edges, nodes, id); + break; + case Operator.Categorize: + params = buildCategorizeTos(edges, nodes, id); + break; + + default: + break; + } - const params = buildAgentTools(edges, nodes, id); components[id] = { obj: { ...(oldDslComponents[id]?.obj ?? {}),