diff --git a/api/apps/llm_app.py b/api/apps/llm_app.py
index 2e878af9e..b95601706 100644
--- a/api/apps/llm_app.py
+++ b/api/apps/llm_app.py
@@ -96,16 +96,29 @@ def set_api_key():
@validate_request("llm_factory", "llm_name", "model_type")
def add_llm():
req = request.json
+ factory = req["llm_factory"]
+ # For VolcEngine, due to its special authentication method
+ # Assemble volc_ak, volc_sk, endpoint_id into api_key
+ if factory == "VolcEngine":
+ temp = list(eval(req["llm_name"]).items())[0]
+ llm_name = temp[0]
+ endpoint_id = temp[1]
+ api_key = '{' + f'"volc_ak": "{req.get("volc_ak", "")}", ' \
+ f'"volc_sk": "{req.get("volc_sk", "")}", ' \
+ f'"ep_id": "{endpoint_id}", ' + '}'
+ else:
+ llm_name = req["llm_name"]
+ api_key = "xxxxxxxxxxxxxxx"
+
llm = {
"tenant_id": current_user.id,
- "llm_factory": req["llm_factory"],
+ "llm_factory": factory,
"model_type": req["model_type"],
- "llm_name": req["llm_name"],
+ "llm_name": llm_name,
"api_base": req.get("api_base", ""),
- "api_key": "xxxxxxxxxxxxxxx"
+ "api_key": api_key
}
- factory = req["llm_factory"]
msg = ""
if llm["model_type"] == LLMType.EMBEDDING.value:
mdl = EmbeddingModel[factory](
@@ -118,7 +131,10 @@ def add_llm():
msg += f"\nFail to access embedding model({llm['llm_name']})." + str(e)
elif llm["model_type"] == LLMType.CHAT.value:
mdl = ChatModel[factory](
- key=None, model_name=llm["llm_name"], base_url=llm["api_base"])
+ key=llm['api_key'] if factory == "VolcEngine" else None,
+ model_name=llm["llm_name"],
+ base_url=llm["api_base"]
+ )
try:
m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {
"temperature": 0.9})
@@ -134,7 +150,6 @@ def add_llm():
if msg:
return get_data_error_result(retmsg=msg)
-
if not TenantLLMService.filter_update(
[TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == factory, TenantLLM.llm_name == llm["llm_name"]], llm):
TenantLLMService.save(**llm)
diff --git a/api/db/init_data.py b/api/db/init_data.py
index 1bf6f01a6..3044ea976 100644
--- a/api/db/init_data.py
+++ b/api/db/init_data.py
@@ -132,7 +132,12 @@ factory_infos = [{
"logo": "",
"tags": "LLM",
"status": "1",
-},
+},{
+ "name": "VolcEngine",
+ "logo": "",
+ "tags": "LLM, TEXT EMBEDDING",
+ "status": "1",
+}
# {
# "name": "文心一言",
# "logo": "",
@@ -372,6 +377,21 @@ def init_llm_factory():
"max_tokens": 16385,
"model_type": LLMType.CHAT.value
},
+ # ------------------------ VolcEngine -----------------------
+ {
+ "fid": factory_infos[9]["name"],
+ "llm_name": "Skylark2-pro-32k",
+ "tags": "LLM,CHAT,32k",
+ "max_tokens": 32768,
+ "model_type": LLMType.CHAT.value
+ },
+ {
+ "fid": factory_infos[9]["name"],
+ "llm_name": "Skylark2-pro-4k",
+ "tags": "LLM,CHAT,4k",
+ "max_tokens": 4096,
+ "model_type": LLMType.CHAT.value
+ },
]
for info in factory_infos:
try:
diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py
index 9a2eec5d5..e9eb470c2 100644
--- a/rag/llm/chat_model.py
+++ b/rag/llm/chat_model.py
@@ -19,6 +19,7 @@ from abc import ABC
from openai import OpenAI
import openai
from ollama import Client
+from volcengine.maas.v2 import MaasService
from rag.nlp import is_english
from rag.utils import num_tokens_from_string
@@ -315,3 +316,71 @@ class LocalLLM(Base):
yield answer + "\n**ERROR**: " + str(e)
yield token_count
+
+
+class VolcEngineChat(Base):
+ def __init__(self, key, model_name, base_url):
+ """
+ Since do not want to modify the original database fields, and the VolcEngine authentication method is quite special,
+ Assemble ak, sk, ep_id into api_key, store it as a dictionary type, and parse it for use
+ model_name is for display only
+ """
+ self.client = MaasService('maas-api.ml-platform-cn-beijing.volces.com', 'cn-beijing')
+ self.volc_ak = eval(key).get('volc_ak', '')
+ self.volc_sk = eval(key).get('volc_sk', '')
+ self.client.set_ak(self.volc_ak)
+ self.client.set_sk(self.volc_sk)
+ self.model_name = eval(key).get('ep_id', '')
+
+ def chat(self, system, history, gen_conf):
+ if system:
+ history.insert(0, {"role": "system", "content": system})
+ try:
+ req = {
+ "parameters": {
+ "min_new_tokens": gen_conf.get("min_new_tokens", 1),
+ "top_k": gen_conf.get("top_k", 0),
+ "max_prompt_tokens": gen_conf.get("max_prompt_tokens", 30000),
+ "temperature": gen_conf.get("temperature", 0.1),
+ "max_new_tokens": gen_conf.get("max_tokens", 1000),
+ "top_p": gen_conf.get("top_p", 0.3),
+ },
+ "messages": history
+ }
+ response = self.client.chat(self.model_name, req)
+ ans = response.choices[0].message.content.strip()
+ if response.choices[0].finish_reason == "length":
+ ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
+ [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
+ return ans, response.usage.total_tokens
+ except Exception as e:
+ return "**ERROR**: " + str(e), 0
+
+ def chat_streamly(self, system, history, gen_conf):
+ if system:
+ history.insert(0, {"role": "system", "content": system})
+ ans = ""
+ try:
+ req = {
+ "parameters": {
+ "min_new_tokens": gen_conf.get("min_new_tokens", 1),
+ "top_k": gen_conf.get("top_k", 0),
+ "max_prompt_tokens": gen_conf.get("max_prompt_tokens", 30000),
+ "temperature": gen_conf.get("temperature", 0.1),
+ "max_new_tokens": gen_conf.get("max_tokens", 1000),
+ "top_p": gen_conf.get("top_p", 0.3),
+ },
+ "messages": history
+ }
+ stream = self.client.stream_chat(self.model_name, req)
+ for resp in stream:
+ if not resp.choices[0].message.content:
+ continue
+ ans += resp.choices[0].message.content
+ yield ans
+ if resp.choices[0].finish_reason == "stop":
+ return resp.usage.total_tokens
+
+ except Exception as e:
+ yield ans + "\n**ERROR**: " + str(e)
+ yield 0
diff --git a/web/src/assets/svg/llm/volc_engine.svg b/web/src/assets/svg/llm/volc_engine.svg
new file mode 100644
index 000000000..2c56cb00b
--- /dev/null
+++ b/web/src/assets/svg/llm/volc_engine.svg
@@ -0,0 +1,14 @@
+
\ No newline at end of file
diff --git a/web/src/locales/en.ts b/web/src/locales/en.ts
index c5affca35..0e345ea63 100644
--- a/web/src/locales/en.ts
+++ b/web/src/locales/en.ts
@@ -477,6 +477,11 @@ The above is the content you need to summarize.`,
baseUrlNameMessage: 'Please input your base url!',
vision: 'Does it support Vision?',
ollamaLink: 'How to integrate {{name}}',
+ volcModelNameMessage: 'Please input your model name! Format: {"ModelName":"EndpointID"}',
+ addVolcEngineAK: 'VOLC ACCESS_KEY',
+ volcAKMessage: 'Please input your VOLC_ACCESS_KEY',
+ addVolcEngineSK: 'VOLC SECRET_KEY',
+ volcSKMessage: 'Please input your SECRET_KEY',
},
message: {
registered: 'Registered!',
diff --git a/web/src/locales/zh-traditional.ts b/web/src/locales/zh-traditional.ts
index 58123c0d0..cff078267 100644
--- a/web/src/locales/zh-traditional.ts
+++ b/web/src/locales/zh-traditional.ts
@@ -440,7 +440,12 @@ export default {
modelNameMessage: '請輸入模型名稱!',
modelTypeMessage: '請輸入模型類型!',
baseUrlNameMessage: '請輸入基礎 Url!',
- ollamaLink: '如何集成Ollama',
+ ollamaLink: '如何集成 {{name}}',
+ volcModelNameMessage: '請輸入模型名稱!格式:{"模型名稱":"EndpointID"}',
+ addVolcEngineAK: '火山 ACCESS_KEY',
+ volcAKMessage: '請輸入VOLC_ACCESS_KEY',
+ addVolcEngineSK: '火山 SECRET_KEY',
+ volcSKMessage: '請輸入VOLC_SECRET_KEY',
},
message: {
registered: '註冊成功',
diff --git a/web/src/locales/zh.ts b/web/src/locales/zh.ts
index 9442898f0..effb3bf20 100644
--- a/web/src/locales/zh.ts
+++ b/web/src/locales/zh.ts
@@ -458,6 +458,11 @@ export default {
modelTypeMessage: '请输入模型类型!',
baseUrlNameMessage: '请输入基础 Url!',
ollamaLink: '如何集成 {{name}}',
+ volcModelNameMessage: '请输入模型名称!格式:{"模型名称":"EndpointID"}',
+ addVolcEngineAK: '火山 ACCESS_KEY',
+ volcAKMessage: '请输入VOLC_ACCESS_KEY',
+ addVolcEngineSK: '火山 SECRET_KEY',
+ volcSKMessage: '请输入VOLC_SECRET_KEY',
},
message: {
registered: '注册成功',
diff --git a/web/src/pages/user-setting/setting-model/hooks.ts b/web/src/pages/user-setting/setting-model/hooks.ts
index 6a5649225..dfdce7e25 100644
--- a/web/src/pages/user-setting/setting-model/hooks.ts
+++ b/web/src/pages/user-setting/setting-model/hooks.ts
@@ -166,6 +166,41 @@ export const useSubmitOllama = () => {
};
};
+export const useSubmitVolcEngine = () => {
+ const loading = useOneNamespaceEffectsLoading('settingModel', ['add_llm']);
+ const [selectedVolcFactory, setSelectedVolcFactory] = useState('');
+ const addLlm = useAddLlm();
+ const {
+ visible: volcAddingVisible,
+ hideModal: hideVolcAddingModal,
+ showModal: showVolcAddingModal,
+ } = useSetModalState();
+
+ const onVolcAddingOk = useCallback(
+ async (payload: IAddLlmRequestBody) => {
+ const ret = await addLlm(payload);
+ if (ret === 0) {
+ hideVolcAddingModal();
+ }
+ },
+ [hideVolcAddingModal, addLlm],
+ );
+
+ const handleShowVolcAddingModal = (llmFactory: string) => {
+ setSelectedVolcFactory(llmFactory);
+ showVolcAddingModal();
+ };
+
+ return {
+ volcAddingLoading: loading,
+ onVolcAddingOk,
+ volcAddingVisible,
+ hideVolcAddingModal,
+ showVolcAddingModal: handleShowVolcAddingModal,
+ selectedVolcFactory,
+ };
+};
+
export const useHandleDeleteLlm = (llmFactory: string) => {
const deleteLlm = useDeleteLlm();
const showDeleteConfirm = useShowDeleteConfirm();
diff --git a/web/src/pages/user-setting/setting-model/index.tsx b/web/src/pages/user-setting/setting-model/index.tsx
index 3d39d55d7..69a770108 100644
--- a/web/src/pages/user-setting/setting-model/index.tsx
+++ b/web/src/pages/user-setting/setting-model/index.tsx
@@ -37,10 +37,12 @@ import {
useSelectModelProvidersLoading,
useSubmitApiKey,
useSubmitOllama,
+ useSubmitVolcEngine,
useSubmitSystemModelSetting,
} from './hooks';
import styles from './index.less';
import OllamaModal from './ollama-modal';
+import VolcEngineModal from "./volcengine-model";
import SystemModelSettingModal from './system-model-setting-modal';
const IconMap = {
@@ -52,6 +54,7 @@ const IconMap = {
Ollama: 'ollama',
Xinference: 'xinference',
DeepSeek: 'deepseek',
+ VolcEngine: 'volc_engine',
};
const LlmIcon = ({ name }: { name: string }) => {
@@ -165,6 +168,15 @@ const UserSettingModel = () => {
selectedLlmFactory,
} = useSubmitOllama();
+ const {
+ volcAddingVisible,
+ hideVolcAddingModal,
+ showVolcAddingModal,
+ onVolcAddingOk,
+ volcAddingLoading,
+ selectedVolcFactory,
+ } = useSubmitVolcEngine();
+
const handleApiKeyClick = useCallback(
(llmFactory: string) => {
if (isLocalLlmFactory(llmFactory)) {
@@ -179,6 +191,8 @@ const UserSettingModel = () => {
const handleAddModel = (llmFactory: string) => () => {
if (isLocalLlmFactory(llmFactory)) {
showLlmAddingModal(llmFactory);
+ } else if (llmFactory === 'VolcEngine') {
+ showVolcAddingModal('VolcEngine');
} else {
handleApiKeyClick(llmFactory);
}
@@ -270,6 +284,13 @@ const UserSettingModel = () => {
loading={llmAddingLoading}
llmFactory={selectedLlmFactory}
>
+
);
};
diff --git a/web/src/pages/user-setting/setting-model/volcengine-model/index.tsx b/web/src/pages/user-setting/setting-model/volcengine-model/index.tsx
new file mode 100644
index 000000000..65872067a
--- /dev/null
+++ b/web/src/pages/user-setting/setting-model/volcengine-model/index.tsx
@@ -0,0 +1,118 @@
+import { useTranslate } from '@/hooks/commonHooks';
+import { IModalProps } from '@/interfaces/common';
+import { IAddLlmRequestBody } from '@/interfaces/request/llm';
+import { Flex, Form, Input, Modal, Select, Space, Switch } from 'antd';
+import omit from 'lodash/omit';
+
+type FieldType = IAddLlmRequestBody & { vision: boolean };
+
+const { Option } = Select;
+
+const VolcEngineModal = ({
+ visible,
+ hideModal,
+ onOk,
+ loading,
+ llmFactory
+}: IModalProps & { llmFactory: string }) => {
+ const [form] = Form.useForm();
+
+ const { t } = useTranslate('setting');
+
+ const handleOk = async () => {
+ const values = await form.validateFields();
+ const modelType =
+ values.model_type === 'chat' && values.vision
+ ? 'image2text'
+ : values.model_type;
+
+ const data = {
+ ...omit(values, ['vision']),
+ model_type: modelType,
+ llm_factory: llmFactory,
+ };
+ console.info(data);
+
+ onOk?.(data);
+ };
+
+ return (
+ {
+ return (
+
+
+ {t('ollamaLink', { name: llmFactory })}
+
+ {originNode}
+
+ );
+ }}
+ >
+
+ label={t('modelType')}
+ name="model_type"
+ initialValue={'chat'}
+ rules={[{ required: true, message: t('modelTypeMessage') }]}
+ >
+
+
+
+ label={t('modelName')}
+ name="llm_name"
+ rules={[{ required: true, message: t('volcModelNameMessage') }]}
+ >
+
+
+
+ label={t('addVolcEngineAK')}
+ name="volc_ak"
+ rules={[{ required: true, message: t('volcAKMessage') }]}
+ >
+
+
+
+ label={t('addVolcEngineSK')}
+ name="volc_sk"
+ rules={[{ required: true, message: t('volcAKMessage') }]}
+ >
+
+
+
+ {({ getFieldValue }) =>
+ getFieldValue('model_type') === 'chat' && (
+
+
+
+ )
+ }
+
+
+
+ );
+};
+
+export default VolcEngineModal;