mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Add support for VolcEngine - the current version supports SDK2 (#885)
- The main idea is to assemble **ak**, **sk**, and **ep_id** into a
dictionary and store it in the database **api_key** field
- I don’t know much about the front-end, so I learned from Ollama, which
may be redundant.
### Configuration method
- model name
- Format requirements: {"VolcEngine model name":"endpoint_id"}
- For example: {"Skylark-pro-32K":"ep-xxxxxxxxx"}
- Volcano ACCESS_KEY
- Format requirements: VOLC_ACCESSKEY of the volcano engine
corresponding to the model
- Volcano SECRET_KEY
- Format requirements: VOLC_SECRETKEY of the volcano engine
corresponding to the model
### What problem does this PR solve?
_Briefly describe what this PR aims to solve. Include background context
that will help reviewers understand the purpose of the PR._
### Type of change
- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -96,16 +96,29 @@ def set_api_key():
|
||||
@validate_request("llm_factory", "llm_name", "model_type")
|
||||
def add_llm():
|
||||
req = request.json
|
||||
factory = req["llm_factory"]
|
||||
# For VolcEngine, due to its special authentication method
|
||||
# Assemble volc_ak, volc_sk, endpoint_id into api_key
|
||||
if factory == "VolcEngine":
|
||||
temp = list(eval(req["llm_name"]).items())[0]
|
||||
llm_name = temp[0]
|
||||
endpoint_id = temp[1]
|
||||
api_key = '{' + f'"volc_ak": "{req.get("volc_ak", "")}", ' \
|
||||
f'"volc_sk": "{req.get("volc_sk", "")}", ' \
|
||||
f'"ep_id": "{endpoint_id}", ' + '}'
|
||||
else:
|
||||
llm_name = req["llm_name"]
|
||||
api_key = "xxxxxxxxxxxxxxx"
|
||||
|
||||
llm = {
|
||||
"tenant_id": current_user.id,
|
||||
"llm_factory": req["llm_factory"],
|
||||
"llm_factory": factory,
|
||||
"model_type": req["model_type"],
|
||||
"llm_name": req["llm_name"],
|
||||
"llm_name": llm_name,
|
||||
"api_base": req.get("api_base", ""),
|
||||
"api_key": "xxxxxxxxxxxxxxx"
|
||||
"api_key": api_key
|
||||
}
|
||||
|
||||
factory = req["llm_factory"]
|
||||
msg = ""
|
||||
if llm["model_type"] == LLMType.EMBEDDING.value:
|
||||
mdl = EmbeddingModel[factory](
|
||||
@ -118,7 +131,10 @@ def add_llm():
|
||||
msg += f"\nFail to access embedding model({llm['llm_name']})." + str(e)
|
||||
elif llm["model_type"] == LLMType.CHAT.value:
|
||||
mdl = ChatModel[factory](
|
||||
key=None, model_name=llm["llm_name"], base_url=llm["api_base"])
|
||||
key=llm['api_key'] if factory == "VolcEngine" else None,
|
||||
model_name=llm["llm_name"],
|
||||
base_url=llm["api_base"]
|
||||
)
|
||||
try:
|
||||
m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {
|
||||
"temperature": 0.9})
|
||||
@ -134,7 +150,6 @@ def add_llm():
|
||||
if msg:
|
||||
return get_data_error_result(retmsg=msg)
|
||||
|
||||
|
||||
if not TenantLLMService.filter_update(
|
||||
[TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == factory, TenantLLM.llm_name == llm["llm_name"]], llm):
|
||||
TenantLLMService.save(**llm)
|
||||
|
||||
Reference in New Issue
Block a user