From 0878526ba86aba47b33b5e1596d3110dbe8218cd Mon Sep 17 00:00:00 2001 From: balibabu Date: Fri, 9 Jan 2026 13:42:28 +0800 Subject: [PATCH] Refactor: Refactoring OllamaModal using shadcn. #1036 (#12530) ### What problem does this PR solve? Refactor: Refactoring OllamaModal using shadcn. #1036 ### Type of change - [x] Refactoring --- .../user-setting/setting-model/hooks.tsx | 2 +- .../user-setting/setting-model/index.tsx | 20 +- .../modal/ollama-modal/index.tsx | 428 +++++++++--------- web/tsconfig.json | 4 +- 4 files changed, 232 insertions(+), 222 deletions(-) diff --git a/web/src/pages/user-setting/setting-model/hooks.tsx b/web/src/pages/user-setting/setting-model/hooks.tsx index ceb80f248..237999fef 100644 --- a/web/src/pages/user-setting/setting-model/hooks.tsx +++ b/web/src/pages/user-setting/setting-model/hooks.tsx @@ -117,7 +117,7 @@ export const useSubmitOllama = () => { const [selectedLlmFactory, setSelectedLlmFactory] = useState(''); const [editMode, setEditMode] = useState(false); const [initialValues, setInitialValues] = useState< - Partial | undefined + Partial & { provider_order?: string } >(); const { addLlm, loading } = useAddLlm(); const { diff --git a/web/src/pages/user-setting/setting-model/index.tsx b/web/src/pages/user-setting/setting-model/index.tsx index af7907bb0..1b549496c 100644 --- a/web/src/pages/user-setting/setting-model/index.tsx +++ b/web/src/pages/user-setting/setting-model/index.tsx @@ -229,15 +229,17 @@ const ModelProviders = () => { onOk={onApiKeySavingOk} llmFactory={llmFactory} > - + {llmAddingVisible && ( + + )} > = { [LLMFactory.Ollama]: 'https://github.com/infiniflow/ragflow/blob/main/docs/guides/models/deploy_local_llm.mdx', [LLMFactory.Xinference]: @@ -43,7 +32,49 @@ const llmFactoryToUrlMap = { [LLMFactory.VLLM]: 'https://docs.vllm.ai/en/latest/', [LLMFactory.TokenPony]: 'https://docs.tokenpony.cn/#/', }; -type LlmFactory = keyof typeof llmFactoryToUrlMap; + +const optionsMap: Partial< + Record +> & { + Default: { label: string; value: string }[]; +} = { + [LLMFactory.HuggingFace]: [ + { label: 'embedding', value: 'embedding' }, + { label: 'chat', value: 'chat' }, + { label: 'rerank', value: 'rerank' }, + ], + [LLMFactory.LMStudio]: [ + { label: 'chat', value: 'chat' }, + { label: 'embedding', value: 'embedding' }, + { label: 'image2text', value: 'image2text' }, + ], + [LLMFactory.Xinference]: [ + { label: 'chat', value: 'chat' }, + { label: 'embedding', value: 'embedding' }, + { label: 'rerank', value: 'rerank' }, + { label: 'image2text', value: 'image2text' }, + { label: 'sequence2text', value: 'speech2text' }, + { label: 'tts', value: 'tts' }, + ], + [LLMFactory.ModelScope]: [{ label: 'chat', value: 'chat' }], + [LLMFactory.GPUStack]: [ + { label: 'chat', value: 'chat' }, + { label: 'embedding', value: 'embedding' }, + { label: 'rerank', value: 'rerank' }, + { label: 'sequence2text', value: 'speech2text' }, + { label: 'tts', value: 'tts' }, + ], + [LLMFactory.OpenRouter]: [ + { label: 'chat', value: 'chat' }, + { label: 'image2text', value: 'image2text' }, + ], + Default: [ + { label: 'chat', value: 'chat' }, + { label: 'embedding', value: 'embedding' }, + { label: 'rerank', value: 'rerank' }, + { label: 'image2text', value: 'image2text' }, + ], +}; const OllamaModal = ({ visible, @@ -53,215 +84,192 @@ const OllamaModal = ({ llmFactory, editMode = false, initialValues, -}: IModalProps & { +}: IModalProps & { provider_order?: string }> & { llmFactory: string; editMode?: boolean; - initialValues?: Partial; }) => { - const [form] = Form.useForm(); - const { t } = useTranslate('setting'); + const { t: tc } = useCommonTranslation(); + + const url = + llmFactoryToUrlMap[llmFactory as LLMFactory] || + 'https://github.com/infiniflow/ragflow/blob/main/docs/guides/models/deploy_local_llm.mdx'; + + const fields = useMemo(() => { + const getOptions = (factory: string) => { + return optionsMap[factory as LLMFactory] || optionsMap.Default; + }; + + const baseFields: FormFieldConfig[] = [ + { + name: 'model_type', + label: t('modelType'), + type: FormFieldType.Select, + required: true, + options: getOptions(llmFactory), + validation: { + message: t('modelTypeMessage'), + }, + }, + { + name: 'llm_name', + label: t(llmFactory === 'Xinference' ? 'modelUid' : 'modelName'), + type: FormFieldType.Text, + required: true, + placeholder: t('modelNameMessage'), + validation: { + message: t('modelNameMessage'), + }, + }, + { + name: 'api_base', + label: t('addLlmBaseUrl'), + type: FormFieldType.Text, + required: true, + placeholder: t('baseUrlNameMessage'), + validation: { + message: t('baseUrlNameMessage'), + }, + }, + { + name: 'api_key', + label: t('apiKey'), + type: FormFieldType.Text, + required: false, + placeholder: t('apiKeyMessage'), + }, + { + name: 'max_tokens', + label: t('maxTokens'), + type: FormFieldType.Number, + required: true, + placeholder: t('maxTokensTip'), + validation: { + message: t('maxTokensMessage'), + }, + customValidate: (value: any) => { + if (value !== undefined && value !== null && value !== '') { + if (typeof value !== 'number') { + return t('maxTokensInvalidMessage'); + } + if (value < 0) { + return t('maxTokensMinMessage'); + } + } + return true; + }, + }, + ]; + + // Add provider_order field only for OpenRouter + if (llmFactory === 'OpenRouter') { + baseFields.push({ + name: 'provider_order', + label: 'Provider Order', + type: FormFieldType.Text, + required: false, + tooltip: 'Comma-separated provider list, e.g. Groq,Fireworks', + placeholder: 'Groq,Fireworks', + }); + } + + // Add vision switch (conditional on model_type === 'chat') + baseFields.push({ + name: 'vision', + label: t('vision'), + type: FormFieldType.Switch, + required: false, + dependencies: ['model_type'], + shouldRender: (formValues: any) => { + return formValues?.model_type === 'chat'; + }, + }); + + return baseFields; + }, [llmFactory, t]); + + const defaultValues: FieldValues = useMemo(() => { + if (editMode && initialValues) { + return { + llm_name: initialValues.llm_name || '', + model_type: initialValues.model_type || 'chat', + api_base: initialValues.api_base || '', + max_tokens: initialValues.max_tokens || 8192, + api_key: '', + vision: initialValues.model_type === 'image2text', + provider_order: initialValues.provider_order || '', + }; + } + return { + model_type: + llmFactory in optionsMap + ? optionsMap[llmFactory as LLMFactory]?.at(0)?.value + : 'embedding', + vision: false, + }; + }, [editMode, initialValues, llmFactory]); + + const handleOk = async (values?: FieldValues) => { + if (!values) return; - const handleOk = async () => { - const values = await form.validateFields(); const modelType = values.model_type === 'chat' && values.vision ? 'image2text' : values.model_type; - const data = { - ...omit(values, ['vision']), - model_type: modelType, + const data: IAddLlmRequestBody & { provider_order?: string } = { llm_factory: llmFactory, - max_tokens: values.max_tokens, + llm_name: values.llm_name as string, + model_type: modelType, + api_base: values.api_base as string, + api_key: values.api_key as string, + max_tokens: values.max_tokens as number, }; - console.info(data); - onOk?.(data); - }; - - const handleKeyDown = async (e: React.KeyboardEvent) => { - if (e.key === 'Enter') { - await handleOk(); + // Add provider_order only if it exists (for OpenRouter) + if (values.provider_order) { + data.provider_order = values.provider_order as string; } + + await onOk?.(data); }; - useEffect(() => { - if (visible && editMode && initialValues) { - const formValues = { - llm_name: initialValues.llm_name, - model_type: initialValues.model_type, - api_base: initialValues.api_base, - max_tokens: initialValues.max_tokens || 8192, - api_key: '', - ...initialValues, - }; - form.setFieldsValue(formValues); - } else if (visible && !editMode) { - form.resetFields(); - } - }, [visible, editMode, initialValues, form]); - - const url = - llmFactoryToUrlMap[llmFactory as LlmFactory] || - 'https://github.com/infiniflow/ragflow/blob/main/docs/guides/models/deploy_local_llm.mdx'; - const optionsMap = { - [LLMFactory.HuggingFace]: [ - { value: 'embedding', label: 'embedding' }, - { value: 'chat', label: 'chat' }, - { value: 'rerank', label: 'rerank' }, - ], - [LLMFactory.LMStudio]: [ - { value: 'chat', label: 'chat' }, - { value: 'embedding', label: 'embedding' }, - { value: 'image2text', label: 'image2text' }, - ], - [LLMFactory.Xinference]: [ - { value: 'chat', label: 'chat' }, - { value: 'embedding', label: 'embedding' }, - { value: 'rerank', label: 'rerank' }, - { value: 'image2text', label: 'image2text' }, - { value: 'speech2text', label: 'sequence2text' }, - { value: 'tts', label: 'tts' }, - ], - [LLMFactory.ModelScope]: [{ value: 'chat', label: 'chat' }], - [LLMFactory.GPUStack]: [ - { value: 'chat', label: 'chat' }, - { value: 'embedding', label: 'embedding' }, - { value: 'rerank', label: 'rerank' }, - { value: 'speech2text', label: 'sequence2text' }, - { value: 'tts', label: 'tts' }, - ], - [LLMFactory.OpenRouter]: [ - { value: 'chat', label: 'chat' }, - { value: 'image2text', label: 'image2text' }, - ], - Default: [ - { value: 'chat', label: 'chat' }, - { value: 'embedding', label: 'embedding' }, - { value: 'rerank', label: 'rerank' }, - { value: 'image2text', label: 'image2text' }, - ], - }; - const getOptions = (factory: string) => { - return optionsMap[factory as keyof typeof optionsMap] || optionsMap.Default; - }; return ( } - open={visible} - onOk={handleOk} - onCancel={hideModal} - okButtonProps={{ loading }} - footer={(originNode: React.ReactNode) => { - return ( - - - {t('ollamaLink', { name: llmFactory })} - - {originNode} - - ); - }} + open={visible || false} + onOpenChange={(open) => !open && hideModal?.()} + maskClosable={false} + footer={<>} + footerClassName="py-1" > -
{}} + defaultValues={defaultValues} + labelClassName="font-normal" > - - label={t('modelType')} - name="model_type" - initialValue={'embedding'} - rules={[{ required: true, message: t('modelTypeMessage') }]} - > - - - - label={t(llmFactory === 'Xinference' ? 'modelUid' : 'modelName')} - name="llm_name" - rules={[{ required: true, message: t('modelNameMessage') }]} - > - - - - label={t('addLlmBaseUrl')} - name="api_base" - rules={[{ required: true, message: t('baseUrlNameMessage') }]} - > - - - - label={t('apiKey')} - name="api_key" - rules={[{ required: false, message: t('apiKeyMessage') }]} - > - - - - label={t('maxTokens')} - name="max_tokens" - rules={[ - { required: true, message: t('maxTokensMessage') }, - { - type: 'number', - message: t('maxTokensInvalidMessage'), - }, - ({}) => ({ - validator(_, value) { - if (value < 0) { - return Promise.reject(new Error(t('maxTokensMinMessage'))); - } - return Promise.resolve(); - }, - }), - ]} - > - - - {llmFactory === LLMFactory.OpenRouter && ( - - label="Provider Order" - name="provider_order" - tooltip="Comma-separated provider list, e.g. Groq,Fireworks" - rules={[]} - > - - - )} - - - {({ getFieldValue }) => - getFieldValue('model_type') === 'chat' && ( - - - - ) - } - - +
+ + {t('ollamaLink', { name: llmFactory })} + +
+ { + hideModal?.(); + }} + /> + { + handleOk(values); + }} + /> +
+
+
); }; diff --git a/web/tsconfig.json b/web/tsconfig.json index 513f1cca1..93c63b7e2 100644 --- a/web/tsconfig.json +++ b/web/tsconfig.json @@ -1,8 +1,8 @@ { "compilerOptions": { - "target": "ES2020", + "target": "ES2022", "useDefineForClassFields": true, - "lib": ["ES2020", "DOM", "DOM.Iterable"], + "lib": ["ES2022", "DOM", "DOM.Iterable"], "module": "ESNext", "skipLibCheck": true,