Add support for VolcEngine - the current version supports SDK2 (#885)

- The main idea is to assemble **ak**, **sk**, and **ep_id** into a
dictionary and store it in the database **api_key** field
- I don’t know much about the front-end, so I learned from Ollama, which
may be redundant.

### Configuration method

- model name

- Format requirements: {"VolcEngine model name":"endpoint_id"}
    - For example: {"Skylark-pro-32K":"ep-xxxxxxxxx"}
    
- Volcano ACCESS_KEY
- Format requirements: VOLC_ACCESSKEY of the volcano engine
corresponding to the model

- Volcano SECRET_KEY
- Format requirements: VOLC_SECRETKEY of the volcano engine
corresponding to the model
    
### What problem does this PR solve?

_Briefly describe what this PR aims to solve. Include background context
that will help reviewers understand the purpose of the PR._

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
yungongzi
2024-05-23 11:15:29 +08:00
committed by GitHub
parent fbd0d74053
commit eb51ad73d6
10 changed files with 315 additions and 8 deletions

View File

@ -96,16 +96,29 @@ def set_api_key():
@validate_request("llm_factory", "llm_name", "model_type")
def add_llm():
req = request.json
factory = req["llm_factory"]
# For VolcEngine, due to its special authentication method
# Assemble volc_ak, volc_sk, endpoint_id into api_key
if factory == "VolcEngine":
temp = list(eval(req["llm_name"]).items())[0]
llm_name = temp[0]
endpoint_id = temp[1]
api_key = '{' + f'"volc_ak": "{req.get("volc_ak", "")}", ' \
f'"volc_sk": "{req.get("volc_sk", "")}", ' \
f'"ep_id": "{endpoint_id}", ' + '}'
else:
llm_name = req["llm_name"]
api_key = "xxxxxxxxxxxxxxx"
llm = {
"tenant_id": current_user.id,
"llm_factory": req["llm_factory"],
"llm_factory": factory,
"model_type": req["model_type"],
"llm_name": req["llm_name"],
"llm_name": llm_name,
"api_base": req.get("api_base", ""),
"api_key": "xxxxxxxxxxxxxxx"
"api_key": api_key
}
factory = req["llm_factory"]
msg = ""
if llm["model_type"] == LLMType.EMBEDDING.value:
mdl = EmbeddingModel[factory](
@ -118,7 +131,10 @@ def add_llm():
msg += f"\nFail to access embedding model({llm['llm_name']})." + str(e)
elif llm["model_type"] == LLMType.CHAT.value:
mdl = ChatModel[factory](
key=None, model_name=llm["llm_name"], base_url=llm["api_base"])
key=llm['api_key'] if factory == "VolcEngine" else None,
model_name=llm["llm_name"],
base_url=llm["api_base"]
)
try:
m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {
"temperature": 0.9})
@ -134,7 +150,6 @@ def add_llm():
if msg:
return get_data_error_result(retmsg=msg)
if not TenantLLMService.filter_update(
[TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == factory, TenantLLM.llm_name == llm["llm_name"]], llm):
TenantLLMService.save(**llm)

View File

@ -132,7 +132,12 @@ factory_infos = [{
"logo": "",
"tags": "LLM",
"status": "1",
},
},{
"name": "VolcEngine",
"logo": "",
"tags": "LLM, TEXT EMBEDDING",
"status": "1",
}
# {
# "name": "文心一言",
# "logo": "",
@ -372,6 +377,21 @@ def init_llm_factory():
"max_tokens": 16385,
"model_type": LLMType.CHAT.value
},
# ------------------------ VolcEngine -----------------------
{
"fid": factory_infos[9]["name"],
"llm_name": "Skylark2-pro-32k",
"tags": "LLM,CHAT,32k",
"max_tokens": 32768,
"model_type": LLMType.CHAT.value
},
{
"fid": factory_infos[9]["name"],
"llm_name": "Skylark2-pro-4k",
"tags": "LLM,CHAT,4k",
"max_tokens": 4096,
"model_type": LLMType.CHAT.value
},
]
for info in factory_infos:
try: