diff --git a/web/src/pages/agent/canvas/node/retrieval-node.tsx b/web/src/pages/agent/canvas/node/retrieval-node.tsx index 6e9e214e5..05a2a974e 100644 --- a/web/src/pages/agent/canvas/node/retrieval-node.tsx +++ b/web/src/pages/agent/canvas/node/retrieval-node.tsx @@ -1,13 +1,16 @@ import { NodeCollapsible } from '@/components/collapse'; import { RAGFlowAvatar } from '@/components/ragflow-avatar'; import { useFetchKnowledgeList } from '@/hooks/use-knowledge-request'; -import { IRetrievalNode } from '@/interfaces/database/flow'; +import { useFetchAllMemoryList } from '@/hooks/use-memory-request'; +import { BaseNode } from '@/interfaces/database/flow'; import { NodeProps, Position } from '@xyflow/react'; import classNames from 'classnames'; import { get } from 'lodash'; import { memo } from 'react'; -import { NodeHandleId } from '../../constant'; +import { NodeHandleId, RetrievalFrom } from '../../constant'; +import { RetrievalFormSchemaType } from '../../form/retrieval-form/next'; import { useGetVariableLabelOrTypeByValue } from '../../hooks/use-get-begin-query'; +import { LabelCard } from './card'; import { CommonHandle, LeftEndHandle } from './handle'; import styles from './index.less'; import NodeHeader from './node-header'; @@ -19,12 +22,17 @@ function InnerRetrievalNode({ data, isConnectable = true, selected, -}: NodeProps) { +}: NodeProps>) { const knowledgeBaseIds: string[] = get(data, 'form.kb_ids', []); + const memoryIds: string[] = get(data, 'form.memory_ids', []); const { list: knowledgeList } = useFetchKnowledgeList(true); const { getLabel } = useGetVariableLabelOrTypeByValue({ nodeId: id }); + const isMemory = data.form?.retrieval_from === RetrievalFrom.Memory; + + const memoryList = useFetchAllMemoryList(); + return ( @@ -45,8 +53,22 @@ function InnerRetrievalNode({ [styles.nodeHeader]: knowledgeBaseIds.length > 0, })} > - + {(id) => { + if (isMemory) { + const item = memoryList.data?.find((y) => id === y.id); + return ( + + + {item?.name} + + ); + } + const item = knowledgeList.find((y) => id === y.id); const label = getLabel(id); diff --git a/web/src/pages/agent/form/retrieval-form/next.tsx b/web/src/pages/agent/form/retrieval-form/next.tsx index b4f60e5f8..345efe43a 100644 --- a/web/src/pages/agent/form/retrieval-form/next.tsx +++ b/web/src/pages/agent/form/retrieval-form/next.tsx @@ -54,6 +54,7 @@ export const RetrievalPartialSchema = { toc_enhance: z.boolean(), ...MetadataFilterSchema, memory_ids: z.array(z.string()).optional(), + retrieval_from: z.string(), }; export const FormSchema = z.object({ @@ -61,6 +62,8 @@ export const FormSchema = z.object({ ...RetrievalPartialSchema, }); +export type RetrievalFormSchemaType = z.infer; + export function MemoryDatasetForm() { const { t } = useTranslation(); const form = useFormContext();