mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
### What problem does this PR solve? Feat: Handling abnormal anchor points of agent operators #3221 ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -1,7 +1,9 @@
|
|||||||
|
import LLMLabel from '@/components/llm-select/llm-label';
|
||||||
import { IAgentNode } from '@/interfaces/database/flow';
|
import { IAgentNode } from '@/interfaces/database/flow';
|
||||||
import { Handle, NodeProps, Position } from '@xyflow/react';
|
import { Handle, NodeProps, Position } from '@xyflow/react';
|
||||||
|
import { get } from 'lodash';
|
||||||
import { memo, useMemo } from 'react';
|
import { memo, useMemo } from 'react';
|
||||||
import { NodeHandleId } from '../../constant';
|
import { AgentExceptionMethod, NodeHandleId } from '../../constant';
|
||||||
import useGraphStore from '../../store';
|
import useGraphStore from '../../store';
|
||||||
import { isBottomSubAgent } from '../../utils';
|
import { isBottomSubAgent } from '../../utils';
|
||||||
import { CommonHandle } from './handle';
|
import { CommonHandle } from './handle';
|
||||||
@ -23,6 +25,14 @@ function InnerAgentNode({
|
|||||||
return !isBottomSubAgent(edges, id);
|
return !isBottomSubAgent(edges, id);
|
||||||
}, [edges, id]);
|
}, [edges, id]);
|
||||||
|
|
||||||
|
const exceptionMethod = useMemo(() => {
|
||||||
|
return get(data, 'form.exception_method');
|
||||||
|
}, [data]);
|
||||||
|
|
||||||
|
const isGotoMethod = useMemo(() => {
|
||||||
|
return exceptionMethod === AgentExceptionMethod.Goto;
|
||||||
|
}, [exceptionMethod]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<ToolBar selected={selected} id={id} label={data.label}>
|
<ToolBar selected={selected} id={id} label={data.label}>
|
||||||
<NodeWrapper selected={selected}>
|
<NodeWrapper selected={selected}>
|
||||||
@ -48,6 +58,7 @@ function InnerAgentNode({
|
|||||||
></CommonHandle>
|
></CommonHandle>
|
||||||
</>
|
</>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
<Handle
|
<Handle
|
||||||
type="target"
|
type="target"
|
||||||
position={Position.Top}
|
position={Position.Top}
|
||||||
@ -69,6 +80,32 @@ function InnerAgentNode({
|
|||||||
style={{ left: 20 }}
|
style={{ left: 20 }}
|
||||||
></Handle>
|
></Handle>
|
||||||
<NodeHeader id={id} name={data.name} label={data.label}></NodeHeader>
|
<NodeHeader id={id} name={data.name} label={data.label}></NodeHeader>
|
||||||
|
<section className="flex flex-col gap-2">
|
||||||
|
<div className={'bg-background-card rounded-sm p-1'}>
|
||||||
|
<LLMLabel value={get(data, 'form.llm_id')}></LLMLabel>
|
||||||
|
</div>
|
||||||
|
{(isGotoMethod ||
|
||||||
|
exceptionMethod === AgentExceptionMethod.Comment) && (
|
||||||
|
<div className="bg-background-card rounded-sm p-1 flex justify-between gap-2">
|
||||||
|
<span className="text-text-sub-title">Abnormal</span>
|
||||||
|
<span className="truncate flex-1">
|
||||||
|
{isGotoMethod ? 'Exception branch' : 'Output default value'}
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</section>
|
||||||
|
{isGotoMethod && (
|
||||||
|
<CommonHandle
|
||||||
|
type="source"
|
||||||
|
position={Position.Right}
|
||||||
|
isConnectable={isConnectable}
|
||||||
|
className="!bg-text-delete-red"
|
||||||
|
style={{ ...RightHandleStyle, top: 94 }}
|
||||||
|
nodeId={id}
|
||||||
|
id={NodeHandleId.AgentException}
|
||||||
|
isConnectableEnd={false}
|
||||||
|
></CommonHandle>
|
||||||
|
)}
|
||||||
</NodeWrapper>
|
</NodeWrapper>
|
||||||
</ToolBar>
|
</ToolBar>
|
||||||
);
|
);
|
||||||
|
|||||||
@ -644,19 +644,20 @@ export const initialAgentValues = {
|
|||||||
max_rounds: 5,
|
max_rounds: 5,
|
||||||
exception_method: null,
|
exception_method: null,
|
||||||
exception_comment: '',
|
exception_comment: '',
|
||||||
exception_goto: '',
|
exception_goto: [],
|
||||||
|
exception_default_value: '',
|
||||||
tools: [],
|
tools: [],
|
||||||
mcp: [],
|
mcp: [],
|
||||||
outputs: {
|
outputs: {
|
||||||
structured_output: {
|
// structured_output: {
|
||||||
// topic: {
|
// topic: {
|
||||||
// type: 'string',
|
// type: 'string',
|
||||||
// description:
|
// description:
|
||||||
// 'default:general. The category of the search.news is useful for retrieving real-time updates, particularly about politics, sports, and major current events covered by mainstream media sources. general is for broader, more general-purpose searches that may include a wide range of sources.',
|
// 'default:general. The category of the search.news is useful for retrieving real-time updates, particularly about politics, sports, and major current events covered by mainstream media sources. general is for broader, more general-purpose searches that may include a wide range of sources.',
|
||||||
// enum: ['general', 'news'],
|
// enum: ['general', 'news'],
|
||||||
// default: 'general',
|
// default: 'general',
|
||||||
// },
|
// },
|
||||||
},
|
// },
|
||||||
content: {
|
content: {
|
||||||
type: 'string',
|
type: 'string',
|
||||||
value: '',
|
value: '',
|
||||||
@ -932,6 +933,7 @@ export enum NodeHandleId {
|
|||||||
Tool = 'tool',
|
Tool = 'tool',
|
||||||
AgentTop = 'agentTop',
|
AgentTop = 'agentTop',
|
||||||
AgentBottom = 'agentBottom',
|
AgentBottom = 'agentBottom',
|
||||||
|
AgentException = 'agentException',
|
||||||
}
|
}
|
||||||
|
|
||||||
export enum VariableType {
|
export enum VariableType {
|
||||||
|
|||||||
@ -30,40 +30,6 @@ export const useBuildFormSelectOptions = (
|
|||||||
return buildCategorizeToOptions;
|
return buildCategorizeToOptions;
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
|
||||||
* dumped
|
|
||||||
* @param nodeId
|
|
||||||
* @returns
|
|
||||||
*/
|
|
||||||
export const useHandleFormSelectChange = (nodeId?: string) => {
|
|
||||||
const { addEdge, deleteEdgeBySourceAndSourceHandle } = useGraphStore(
|
|
||||||
(state) => state,
|
|
||||||
);
|
|
||||||
const handleSelectChange = useCallback(
|
|
||||||
(name?: string) => (value?: string) => {
|
|
||||||
if (nodeId && name) {
|
|
||||||
if (value) {
|
|
||||||
addEdge({
|
|
||||||
source: nodeId,
|
|
||||||
target: value,
|
|
||||||
sourceHandle: name,
|
|
||||||
targetHandle: null,
|
|
||||||
});
|
|
||||||
} else {
|
|
||||||
// clear selected value
|
|
||||||
deleteEdgeBySourceAndSourceHandle({
|
|
||||||
source: nodeId,
|
|
||||||
sourceHandle: name,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
[addEdge, nodeId, deleteEdgeBySourceAndSourceHandle],
|
|
||||||
);
|
|
||||||
|
|
||||||
return { handleSelectChange };
|
|
||||||
};
|
|
||||||
|
|
||||||
export const useBuildSortOptions = () => {
|
export const useBuildSortOptions = () => {
|
||||||
const { t } = useTranslate('flow');
|
const { t } = useTranslate('flow');
|
||||||
|
|
||||||
|
|||||||
@ -19,19 +19,22 @@ import { LlmModelType } from '@/constants/knowledge';
|
|||||||
import { useFindLlmByUuid } from '@/hooks/use-llm-request';
|
import { useFindLlmByUuid } from '@/hooks/use-llm-request';
|
||||||
import { buildOptions } from '@/utils/form';
|
import { buildOptions } from '@/utils/form';
|
||||||
import { zodResolver } from '@hookform/resolvers/zod';
|
import { zodResolver } from '@hookform/resolvers/zod';
|
||||||
import { memo, useMemo } from 'react';
|
import { memo, useEffect, useMemo } from 'react';
|
||||||
import { useForm, useWatch } from 'react-hook-form';
|
import { useForm, useWatch } from 'react-hook-form';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { z } from 'zod';
|
import { z } from 'zod';
|
||||||
import {
|
import {
|
||||||
AgentExceptionMethod,
|
AgentExceptionMethod,
|
||||||
|
NodeHandleId,
|
||||||
VariableType,
|
VariableType,
|
||||||
initialAgentValues,
|
initialAgentValues,
|
||||||
} from '../../constant';
|
} from '../../constant';
|
||||||
import { INextOperatorForm } from '../../interface';
|
import { INextOperatorForm } from '../../interface';
|
||||||
import useGraphStore from '../../store';
|
import useGraphStore from '../../store';
|
||||||
import { isBottomSubAgent } from '../../utils';
|
import { isBottomSubAgent } from '../../utils';
|
||||||
|
import { buildOutputList } from '../../utils/build-output-list';
|
||||||
import { DescriptionField } from '../components/description-field';
|
import { DescriptionField } from '../components/description-field';
|
||||||
|
import { FormWrapper } from '../components/form-wrapper';
|
||||||
import { Output } from '../components/output';
|
import { Output } from '../components/output';
|
||||||
import { PromptEditor } from '../components/prompt-editor';
|
import { PromptEditor } from '../components/prompt-editor';
|
||||||
import { QueryVariable } from '../components/query-variable';
|
import { QueryVariable } from '../components/query-variable';
|
||||||
@ -69,13 +72,18 @@ const FormSchema = z.object({
|
|||||||
max_rounds: z.coerce.number().optional(),
|
max_rounds: z.coerce.number().optional(),
|
||||||
exception_method: z.string().nullable(),
|
exception_method: z.string().nullable(),
|
||||||
exception_comment: z.string().optional(),
|
exception_comment: z.string().optional(),
|
||||||
exception_goto: z.string().optional(),
|
exception_goto: z.array(z.string()).optional(),
|
||||||
|
exception_default_value: z.string().optional(),
|
||||||
...LargeModelFilterFormSchema,
|
...LargeModelFilterFormSchema,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
const outputList = buildOutputList(initialAgentValues.outputs);
|
||||||
|
|
||||||
function AgentForm({ node }: INextOperatorForm) {
|
function AgentForm({ node }: INextOperatorForm) {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const { edges } = useGraphStore((state) => state);
|
const { edges, deleteEdgesBySourceAndSourceHandle } = useGraphStore(
|
||||||
|
(state) => state,
|
||||||
|
);
|
||||||
|
|
||||||
const defaultValues = useValues(node);
|
const defaultValues = useValues(node);
|
||||||
|
|
||||||
@ -83,12 +91,6 @@ function AgentForm({ node }: INextOperatorForm) {
|
|||||||
return isBottomSubAgent(edges, node?.id);
|
return isBottomSubAgent(edges, node?.id);
|
||||||
}, [edges, node?.id]);
|
}, [edges, node?.id]);
|
||||||
|
|
||||||
const outputList = useMemo(() => {
|
|
||||||
return [
|
|
||||||
{ title: 'content', type: initialAgentValues.outputs.content.type },
|
|
||||||
];
|
|
||||||
}, []);
|
|
||||||
|
|
||||||
const form = useForm<z.infer<typeof FormSchema>>({
|
const form = useForm<z.infer<typeof FormSchema>>({
|
||||||
defaultValues: defaultValues,
|
defaultValues: defaultValues,
|
||||||
resolver: zodResolver(FormSchema),
|
resolver: zodResolver(FormSchema),
|
||||||
@ -98,16 +100,27 @@ function AgentForm({ node }: INextOperatorForm) {
|
|||||||
|
|
||||||
const findLlmByUuid = useFindLlmByUuid();
|
const findLlmByUuid = useFindLlmByUuid();
|
||||||
|
|
||||||
|
const exceptionMethod = useWatch({
|
||||||
|
control: form.control,
|
||||||
|
name: 'exception_method',
|
||||||
|
});
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (exceptionMethod !== AgentExceptionMethod.Goto) {
|
||||||
|
if (node?.id) {
|
||||||
|
deleteEdgesBySourceAndSourceHandle(
|
||||||
|
node?.id,
|
||||||
|
NodeHandleId.AgentException,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}, [deleteEdgesBySourceAndSourceHandle, exceptionMethod, node?.id]);
|
||||||
|
|
||||||
useWatchFormChange(node?.id, form);
|
useWatchFormChange(node?.id, form);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Form {...form}>
|
<Form {...form}>
|
||||||
<form
|
<FormWrapper>
|
||||||
className="space-y-6 p-4"
|
|
||||||
onSubmit={(e) => {
|
|
||||||
e.preventDefault();
|
|
||||||
}}
|
|
||||||
>
|
|
||||||
<FormContainer>
|
<FormContainer>
|
||||||
{isSubAgent && <DescriptionField></DescriptionField>}
|
{isSubAgent && <DescriptionField></DescriptionField>}
|
||||||
<LargeModelFormField></LargeModelFormField>
|
<LargeModelFormField></LargeModelFormField>
|
||||||
@ -219,6 +232,18 @@ function AgentForm({ node }: INextOperatorForm) {
|
|||||||
</FormItem>
|
</FormItem>
|
||||||
)}
|
)}
|
||||||
/>
|
/>
|
||||||
|
<FormField
|
||||||
|
control={form.control}
|
||||||
|
name={`exception_default_value`}
|
||||||
|
render={({ field }) => (
|
||||||
|
<FormItem className="flex-1">
|
||||||
|
<FormLabel>Exception default value</FormLabel>
|
||||||
|
<FormControl>
|
||||||
|
<Input {...field} />
|
||||||
|
</FormControl>
|
||||||
|
</FormItem>
|
||||||
|
)}
|
||||||
|
/>
|
||||||
<FormField
|
<FormField
|
||||||
control={form.control}
|
control={form.control}
|
||||||
name={`exception_comment`}
|
name={`exception_comment`}
|
||||||
@ -231,15 +256,10 @@ function AgentForm({ node }: INextOperatorForm) {
|
|||||||
</FormItem>
|
</FormItem>
|
||||||
)}
|
)}
|
||||||
/>
|
/>
|
||||||
<QueryVariable
|
|
||||||
name="exception_goto"
|
|
||||||
label="Exception goto"
|
|
||||||
type={VariableType.File}
|
|
||||||
></QueryVariable>
|
|
||||||
</FormContainer>
|
</FormContainer>
|
||||||
</Collapse>
|
</Collapse>
|
||||||
<Output list={outputList}></Output>
|
<Output list={outputList}></Output>
|
||||||
</form>
|
</FormWrapper>
|
||||||
</Form>
|
</Form>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -177,7 +177,7 @@ const DynamicCategorize = ({ nodeId }: IProps) => {
|
|||||||
const FormSchema = useCreateCategorizeFormSchema();
|
const FormSchema = useCreateCategorizeFormSchema();
|
||||||
|
|
||||||
const deleteCategorizeCaseEdges = useGraphStore(
|
const deleteCategorizeCaseEdges = useGraphStore(
|
||||||
(state) => state.deleteCategorizeCaseEdges,
|
(state) => state.deleteEdgesBySourceAndSourceHandle,
|
||||||
);
|
);
|
||||||
const form = useFormContext<z.infer<typeof FormSchema>>();
|
const form = useFormContext<z.infer<typeof FormSchema>>();
|
||||||
const { t } = useTranslate('flow');
|
const { t } = useTranslate('flow');
|
||||||
|
|||||||
@ -7,7 +7,6 @@ import {
|
|||||||
FormLabel,
|
FormLabel,
|
||||||
FormMessage,
|
FormMessage,
|
||||||
} from '@/components/ui/form';
|
} from '@/components/ui/form';
|
||||||
import { Input } from '@/components/ui/input';
|
|
||||||
import { RAGFlowSelect } from '@/components/ui/select';
|
import { RAGFlowSelect } from '@/components/ui/select';
|
||||||
import { buildOptions } from '@/utils/form';
|
import { buildOptions } from '@/utils/form';
|
||||||
import { zodResolver } from '@hookform/resolvers/zod';
|
import { zodResolver } from '@hookform/resolvers/zod';
|
||||||
@ -26,6 +25,7 @@ import { buildOutputList } from '../../utils/build-output-list';
|
|||||||
import { ApiKeyField } from '../components/api-key-field';
|
import { ApiKeyField } from '../components/api-key-field';
|
||||||
import { FormWrapper } from '../components/form-wrapper';
|
import { FormWrapper } from '../components/form-wrapper';
|
||||||
import { Output } from '../components/output';
|
import { Output } from '../components/output';
|
||||||
|
import { PromptEditor } from '../components/prompt-editor';
|
||||||
import { TavilyFormSchema } from '../tavily-form';
|
import { TavilyFormSchema } from '../tavily-form';
|
||||||
|
|
||||||
const outputList = buildOutputList(initialTavilyExtractValues.outputs);
|
const outputList = buildOutputList(initialTavilyExtractValues.outputs);
|
||||||
@ -64,7 +64,11 @@ function TavilyExtractForm({ node }: INextOperatorForm) {
|
|||||||
<FormItem>
|
<FormItem>
|
||||||
<FormLabel>URL</FormLabel>
|
<FormLabel>URL</FormLabel>
|
||||||
<FormControl>
|
<FormControl>
|
||||||
<Input {...field} />
|
<PromptEditor
|
||||||
|
{...field}
|
||||||
|
multiLine={false}
|
||||||
|
showToolbar={false}
|
||||||
|
></PromptEditor>
|
||||||
</FormControl>
|
</FormControl>
|
||||||
<FormMessage />
|
<FormMessage />
|
||||||
</FormItem>
|
</FormItem>
|
||||||
|
|||||||
@ -12,7 +12,7 @@ import { RAGFlowSelect } from '@/components/ui/select';
|
|||||||
import { Switch } from '@/components/ui/switch';
|
import { Switch } from '@/components/ui/switch';
|
||||||
import { buildOptions } from '@/utils/form';
|
import { buildOptions } from '@/utils/form';
|
||||||
import { zodResolver } from '@hookform/resolvers/zod';
|
import { zodResolver } from '@hookform/resolvers/zod';
|
||||||
import { memo, useMemo } from 'react';
|
import { memo } from 'react';
|
||||||
import { useForm } from 'react-hook-form';
|
import { useForm } from 'react-hook-form';
|
||||||
import { z } from 'zod';
|
import { z } from 'zod';
|
||||||
import {
|
import {
|
||||||
@ -21,9 +21,10 @@ import {
|
|||||||
initialTavilyValues,
|
initialTavilyValues,
|
||||||
} from '../../constant';
|
} from '../../constant';
|
||||||
import { INextOperatorForm } from '../../interface';
|
import { INextOperatorForm } from '../../interface';
|
||||||
|
import { buildOutputList } from '../../utils/build-output-list';
|
||||||
import { ApiKeyField } from '../components/api-key-field';
|
import { ApiKeyField } from '../components/api-key-field';
|
||||||
import { FormWrapper } from '../components/form-wrapper';
|
import { FormWrapper } from '../components/form-wrapper';
|
||||||
import { Output, OutputType } from '../components/output';
|
import { Output } from '../components/output';
|
||||||
import { QueryVariable } from '../components/query-variable';
|
import { QueryVariable } from '../components/query-variable';
|
||||||
import { DynamicDomain } from './dynamic-domain';
|
import { DynamicDomain } from './dynamic-domain';
|
||||||
import { useValues } from './use-values';
|
import { useValues } from './use-values';
|
||||||
@ -33,6 +34,8 @@ export const TavilyFormSchema = {
|
|||||||
api_key: z.string(),
|
api_key: z.string(),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const outputList = buildOutputList(initialTavilyValues.outputs);
|
||||||
|
|
||||||
function TavilyForm({ node }: INextOperatorForm) {
|
function TavilyForm({ node }: INextOperatorForm) {
|
||||||
const values = useValues(node);
|
const values = useValues(node);
|
||||||
|
|
||||||
@ -56,16 +59,6 @@ function TavilyForm({ node }: INextOperatorForm) {
|
|||||||
resolver: zodResolver(FormSchema),
|
resolver: zodResolver(FormSchema),
|
||||||
});
|
});
|
||||||
|
|
||||||
const outputList = useMemo(() => {
|
|
||||||
return Object.entries(initialTavilyValues.outputs).reduce<OutputType[]>(
|
|
||||||
(pre, [key, val]) => {
|
|
||||||
pre.push({ title: key, type: val.type });
|
|
||||||
return pre;
|
|
||||||
},
|
|
||||||
[],
|
|
||||||
);
|
|
||||||
}, []);
|
|
||||||
|
|
||||||
useWatchFormChange(node?.id, form);
|
useWatchFormChange(node?.id, form);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
|
|||||||
@ -74,7 +74,6 @@ export type RFState = {
|
|||||||
deleteAgentDownstreamNodesById: (id: string) => void;
|
deleteAgentDownstreamNodesById: (id: string) => void;
|
||||||
deleteAgentToolNodeById: (id: string) => void;
|
deleteAgentToolNodeById: (id: string) => void;
|
||||||
deleteIterationNodeById: (id: string) => void;
|
deleteIterationNodeById: (id: string) => void;
|
||||||
deleteEdgeBySourceAndSourceHandle: (connection: Partial<Connection>) => void;
|
|
||||||
findNodeByName: (operatorName: Operator) => RAGFlowNodeType | undefined;
|
findNodeByName: (operatorName: Operator) => RAGFlowNodeType | undefined;
|
||||||
updateMutableNodeFormItem: (id: string, field: string, value: any) => void;
|
updateMutableNodeFormItem: (id: string, field: string, value: any) => void;
|
||||||
getOperatorTypeFromId: (id?: string | null) => string | undefined;
|
getOperatorTypeFromId: (id?: string | null) => string | undefined;
|
||||||
@ -84,7 +83,10 @@ export type RFState = {
|
|||||||
setClickedNodeId: (id?: string) => void;
|
setClickedNodeId: (id?: string) => void;
|
||||||
setClickedToolId: (id?: string) => void;
|
setClickedToolId: (id?: string) => void;
|
||||||
findUpstreamNodeById: (id?: string | null) => RAGFlowNodeType | undefined;
|
findUpstreamNodeById: (id?: string | null) => RAGFlowNodeType | undefined;
|
||||||
deleteCategorizeCaseEdges: (source: string, sourceHandle: string) => void; // Deleting a condition of a classification operator will delete the related edge
|
deleteEdgesBySourceAndSourceHandle: (
|
||||||
|
source: string,
|
||||||
|
sourceHandle: string,
|
||||||
|
) => void; // Deleting a condition of a classification operator will delete the related edge
|
||||||
findAgentToolNodeById: (id: string | null) => string | undefined;
|
findAgentToolNodeById: (id: string | null) => string | undefined;
|
||||||
selectNodeIds: (nodeIds: string[]) => void;
|
selectNodeIds: (nodeIds: string[]) => void;
|
||||||
};
|
};
|
||||||
@ -330,19 +332,6 @@ const useGraphStore = create<RFState>()(
|
|||||||
edges: edges.filter((edge) => edge.id !== id),
|
edges: edges.filter((edge) => edge.id !== id),
|
||||||
});
|
});
|
||||||
},
|
},
|
||||||
deleteEdgeBySourceAndSourceHandle: ({
|
|
||||||
source,
|
|
||||||
sourceHandle,
|
|
||||||
}: Partial<Connection>) => {
|
|
||||||
const { edges } = get();
|
|
||||||
const nextEdges = edges.filter(
|
|
||||||
(edge) =>
|
|
||||||
edge.source !== source || edge.sourceHandle !== sourceHandle,
|
|
||||||
);
|
|
||||||
set({
|
|
||||||
edges: nextEdges,
|
|
||||||
});
|
|
||||||
},
|
|
||||||
deleteNodeById: (id: string) => {
|
deleteNodeById: (id: string) => {
|
||||||
const {
|
const {
|
||||||
nodes,
|
nodes,
|
||||||
@ -511,7 +500,7 @@ const useGraphStore = create<RFState>()(
|
|||||||
const edge = edges.find((x) => x.target === id);
|
const edge = edges.find((x) => x.target === id);
|
||||||
return getNode(edge?.source);
|
return getNode(edge?.source);
|
||||||
},
|
},
|
||||||
deleteCategorizeCaseEdges: (source, sourceHandle) => {
|
deleteEdgesBySourceAndSourceHandle: (source, sourceHandle) => {
|
||||||
const { edges, setEdges } = get();
|
const { edges, setEdges } = get();
|
||||||
setEdges(
|
setEdges(
|
||||||
edges.filter(
|
edges.filter(
|
||||||
|
|||||||
@ -1,106 +0,0 @@
|
|||||||
import fs from 'fs';
|
|
||||||
import path from 'path';
|
|
||||||
import customer_service from '../../../../graph/test/dsl_examples/customer_service.json';
|
|
||||||
import headhunter_zh from '../../../../graph/test/dsl_examples/headhunter_zh.json';
|
|
||||||
import interpreter from '../../../../graph/test/dsl_examples/interpreter.json';
|
|
||||||
import retrievalRelevantRewriteAndGenerate from '../../../../graph/test/dsl_examples/retrieval_relevant_rewrite_and_generate.json';
|
|
||||||
import { dsl } from './mock';
|
|
||||||
import { buildNodesAndEdgesFromDSLComponents } from './utils';
|
|
||||||
|
|
||||||
test('buildNodesAndEdgesFromDSLComponents', () => {
|
|
||||||
const { edges, nodes } = buildNodesAndEdgesFromDSLComponents(dsl.components);
|
|
||||||
|
|
||||||
expect(nodes.length).toEqual(4);
|
|
||||||
expect(edges.length).toEqual(4);
|
|
||||||
|
|
||||||
expect(edges).toEqual(
|
|
||||||
expect.arrayContaining([
|
|
||||||
expect.objectContaining({
|
|
||||||
source: 'begin',
|
|
||||||
target: 'Answer:China',
|
|
||||||
}),
|
|
||||||
expect.objectContaining({
|
|
||||||
source: 'Answer:China',
|
|
||||||
target: 'Retrieval:China',
|
|
||||||
}),
|
|
||||||
expect.objectContaining({
|
|
||||||
source: 'Retrieval:China',
|
|
||||||
target: 'Generate:China',
|
|
||||||
}),
|
|
||||||
expect.objectContaining({
|
|
||||||
source: 'Generate:China',
|
|
||||||
target: 'Answer:China',
|
|
||||||
}),
|
|
||||||
]),
|
|
||||||
);
|
|
||||||
});
|
|
||||||
|
|
||||||
test('build nodes and edges from headhunter_zh dsl', () => {
|
|
||||||
const { edges, nodes } = buildNodesAndEdgesFromDSLComponents(
|
|
||||||
headhunter_zh.components,
|
|
||||||
);
|
|
||||||
console.info('node length', nodes.length);
|
|
||||||
console.info('edge length', edges.length);
|
|
||||||
try {
|
|
||||||
fs.writeFileSync(
|
|
||||||
path.join(__dirname, 'headhunter_zh.json'),
|
|
||||||
JSON.stringify({ edges, nodes }, null, 4),
|
|
||||||
);
|
|
||||||
console.log('JSON data is saved.');
|
|
||||||
} catch (error) {
|
|
||||||
console.warn(error);
|
|
||||||
}
|
|
||||||
expect(nodes.length).toEqual(12);
|
|
||||||
});
|
|
||||||
|
|
||||||
test('build nodes and edges from customer_service dsl', () => {
|
|
||||||
const { edges, nodes } = buildNodesAndEdgesFromDSLComponents(
|
|
||||||
customer_service.components,
|
|
||||||
);
|
|
||||||
console.info('node length', nodes.length);
|
|
||||||
console.info('edge length', edges.length);
|
|
||||||
try {
|
|
||||||
fs.writeFileSync(
|
|
||||||
path.join(__dirname, 'customer_service.json'),
|
|
||||||
JSON.stringify({ edges, nodes }, null, 4),
|
|
||||||
);
|
|
||||||
console.log('JSON data is saved.');
|
|
||||||
} catch (error) {
|
|
||||||
console.warn(error);
|
|
||||||
}
|
|
||||||
expect(nodes.length).toEqual(12);
|
|
||||||
});
|
|
||||||
|
|
||||||
test('build nodes and edges from interpreter dsl', () => {
|
|
||||||
const { edges, nodes } = buildNodesAndEdgesFromDSLComponents(
|
|
||||||
interpreter.components,
|
|
||||||
);
|
|
||||||
console.info('node length', nodes.length);
|
|
||||||
console.info('edge length', edges.length);
|
|
||||||
try {
|
|
||||||
fs.writeFileSync(
|
|
||||||
path.join(__dirname, 'interpreter.json'),
|
|
||||||
JSON.stringify({ edges, nodes }, null, 4),
|
|
||||||
);
|
|
||||||
console.log('JSON data is saved.');
|
|
||||||
} catch (error) {
|
|
||||||
console.warn(error);
|
|
||||||
}
|
|
||||||
expect(nodes.length).toEqual(12);
|
|
||||||
});
|
|
||||||
|
|
||||||
test('build nodes and edges from chat bot dsl', () => {
|
|
||||||
const { edges, nodes } = buildNodesAndEdgesFromDSLComponents(
|
|
||||||
retrievalRelevantRewriteAndGenerate.components,
|
|
||||||
);
|
|
||||||
try {
|
|
||||||
fs.writeFileSync(
|
|
||||||
path.join(__dirname, 'retrieval_relevant_rewrite_and_generate.json'),
|
|
||||||
JSON.stringify({ edges, nodes }, null, 4),
|
|
||||||
);
|
|
||||||
console.log('JSON data is saved.');
|
|
||||||
} catch (error) {
|
|
||||||
console.warn(error);
|
|
||||||
}
|
|
||||||
expect(nodes.length).toEqual(12);
|
|
||||||
});
|
|
||||||
@ -6,91 +6,28 @@ import {
|
|||||||
} from '@/interfaces/database/agent';
|
} from '@/interfaces/database/agent';
|
||||||
import { DSLComponents, RAGFlowNodeType } from '@/interfaces/database/flow';
|
import { DSLComponents, RAGFlowNodeType } from '@/interfaces/database/flow';
|
||||||
import { removeUselessFieldsFromValues } from '@/utils/form';
|
import { removeUselessFieldsFromValues } from '@/utils/form';
|
||||||
import { Edge, Node, Position, XYPosition } from '@xyflow/react';
|
import { Edge, Node, XYPosition } from '@xyflow/react';
|
||||||
import { FormInstance, FormListFieldData } from 'antd';
|
import { FormInstance, FormListFieldData } from 'antd';
|
||||||
import { humanId } from 'human-id';
|
import { humanId } from 'human-id';
|
||||||
import { curry, get, intersectionWith, isEqual, omit, sample } from 'lodash';
|
import { curry, get, intersectionWith, isEqual, omit, sample } from 'lodash';
|
||||||
import pipe from 'lodash/fp/pipe';
|
import pipe from 'lodash/fp/pipe';
|
||||||
import isObject from 'lodash/isObject';
|
import isObject from 'lodash/isObject';
|
||||||
import { v4 as uuidv4 } from 'uuid';
|
|
||||||
import {
|
import {
|
||||||
CategorizeAnchorPointPositions,
|
CategorizeAnchorPointPositions,
|
||||||
NoDebugOperatorsList,
|
NoDebugOperatorsList,
|
||||||
NodeHandleId,
|
NodeHandleId,
|
||||||
NodeMap,
|
|
||||||
Operator,
|
Operator,
|
||||||
} from './constant';
|
} from './constant';
|
||||||
import { BeginQuery, IPosition } from './interface';
|
import { BeginQuery, IPosition } from './interface';
|
||||||
|
|
||||||
const buildEdges = (
|
function buildAgentExceptionGoto(edges: Edge[], nodeId: string) {
|
||||||
operatorIds: string[],
|
const exceptionEdges = edges.filter(
|
||||||
currentId: string,
|
(x) =>
|
||||||
allEdges: Edge[],
|
x.source === nodeId && x.sourceHandle === NodeHandleId.AgentException,
|
||||||
isUpstream = false,
|
);
|
||||||
componentName: string,
|
|
||||||
nodeParams: Record<string, unknown>,
|
|
||||||
) => {
|
|
||||||
operatorIds.forEach((cur) => {
|
|
||||||
const source = isUpstream ? cur : currentId;
|
|
||||||
const target = isUpstream ? currentId : cur;
|
|
||||||
if (!allEdges.some((e) => e.source === source && e.target === target)) {
|
|
||||||
const edge: Edge = {
|
|
||||||
id: uuidv4(),
|
|
||||||
label: '',
|
|
||||||
// type: 'step',
|
|
||||||
source: source,
|
|
||||||
target: target,
|
|
||||||
// markerEnd: {
|
|
||||||
// type: MarkerType.ArrowClosed,
|
|
||||||
// color: 'rgb(157 149 225)',
|
|
||||||
// width: 20,
|
|
||||||
// height: 20,
|
|
||||||
// },
|
|
||||||
};
|
|
||||||
if (componentName === Operator.Categorize && !isUpstream) {
|
|
||||||
const categoryDescription =
|
|
||||||
nodeParams.category_description as ICategorizeItemResult;
|
|
||||||
|
|
||||||
const name = Object.keys(categoryDescription).find(
|
return exceptionEdges.map((x) => x.target);
|
||||||
(x) => categoryDescription[x].to === target,
|
}
|
||||||
);
|
|
||||||
|
|
||||||
if (name) {
|
|
||||||
edge.sourceHandle = name;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
allEdges.push(edge);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
};
|
|
||||||
|
|
||||||
export const buildNodesAndEdgesFromDSLComponents = (data: DSLComponents) => {
|
|
||||||
const nodes: Node[] = [];
|
|
||||||
let edges: Edge[] = [];
|
|
||||||
|
|
||||||
Object.entries(data).forEach(([key, value]) => {
|
|
||||||
const downstream = [...value.downstream];
|
|
||||||
const upstream = [...value.upstream];
|
|
||||||
const { component_name: componentName, params } = value.obj;
|
|
||||||
nodes.push({
|
|
||||||
id: key,
|
|
||||||
type: NodeMap[value.obj.component_name as Operator] || 'ragNode',
|
|
||||||
position: { x: 0, y: 0 },
|
|
||||||
data: {
|
|
||||||
label: componentName,
|
|
||||||
name: humanId(),
|
|
||||||
form: params,
|
|
||||||
},
|
|
||||||
sourcePosition: Position.Left,
|
|
||||||
targetPosition: Position.Right,
|
|
||||||
});
|
|
||||||
|
|
||||||
buildEdges(upstream, key, edges, true, componentName, params);
|
|
||||||
buildEdges(downstream, key, edges, false, componentName, params);
|
|
||||||
});
|
|
||||||
|
|
||||||
return { nodes, edges };
|
|
||||||
};
|
|
||||||
|
|
||||||
const buildComponentDownstreamOrUpstream = (
|
const buildComponentDownstreamOrUpstream = (
|
||||||
edges: Edge[],
|
edges: Edge[],
|
||||||
@ -103,7 +40,9 @@ const buildComponentDownstreamOrUpstream = (
|
|||||||
const node = nodes.find((x) => x.id === nodeId);
|
const node = nodes.find((x) => x.id === nodeId);
|
||||||
let isNotUpstreamTool = true;
|
let isNotUpstreamTool = true;
|
||||||
let isNotUpstreamAgent = true;
|
let isNotUpstreamAgent = true;
|
||||||
|
let isNotExceptionGoto = true;
|
||||||
if (isBuildDownstream && node?.data.label === Operator.Agent) {
|
if (isBuildDownstream && node?.data.label === Operator.Agent) {
|
||||||
|
isNotExceptionGoto = y.sourceHandle !== NodeHandleId.AgentException;
|
||||||
// Exclude the tool operator downstream of the agent operator
|
// Exclude the tool operator downstream of the agent operator
|
||||||
isNotUpstreamTool = !y.target.startsWith(Operator.Tool);
|
isNotUpstreamTool = !y.target.startsWith(Operator.Tool);
|
||||||
// Exclude the agent operator downstream of the agent operator
|
// Exclude the agent operator downstream of the agent operator
|
||||||
@ -115,7 +54,8 @@ const buildComponentDownstreamOrUpstream = (
|
|||||||
return (
|
return (
|
||||||
y[isBuildDownstream ? 'source' : 'target'] === nodeId &&
|
y[isBuildDownstream ? 'source' : 'target'] === nodeId &&
|
||||||
isNotUpstreamTool &&
|
isNotUpstreamTool &&
|
||||||
isNotUpstreamAgent
|
isNotUpstreamAgent &&
|
||||||
|
isNotExceptionGoto
|
||||||
);
|
);
|
||||||
})
|
})
|
||||||
.map((y) => y[isBuildDownstream ? 'target' : 'source']);
|
.map((y) => y[isBuildDownstream ? 'target' : 'source']);
|
||||||
@ -234,7 +174,10 @@ export const buildDslComponentsByGraph = (
|
|||||||
switch (operatorName) {
|
switch (operatorName) {
|
||||||
case Operator.Agent: {
|
case Operator.Agent: {
|
||||||
const { params: formData } = buildAgentTools(edges, nodes, id);
|
const { params: formData } = buildAgentTools(edges, nodes, id);
|
||||||
params = formData;
|
params = {
|
||||||
|
...formData,
|
||||||
|
exception_goto: buildAgentExceptionGoto(edges, id),
|
||||||
|
};
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case Operator.Categorize:
|
case Operator.Categorize:
|
||||||
@ -559,7 +502,7 @@ export const buildCategorizeObjectFromList = (list: Array<ICategorizeItem>) => {
|
|||||||
if (cur?.name) {
|
if (cur?.name) {
|
||||||
pre[cur.name] = {
|
pre[cur.name] = {
|
||||||
...omit(cur, 'name', 'examples'),
|
...omit(cur, 'name', 'examples'),
|
||||||
examples: convertToStringArray(cur.examples),
|
examples: convertToStringArray(cur.examples) as string[],
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
return pre;
|
return pre;
|
||||||
|
|||||||
Reference in New Issue
Block a user