support api-version and change default-model in adding azure-openai and openai (#2799)

### What problem does this PR solve?
#2701 #2712 #2749

### Type of change
-[x] Bug Fix (non-breaking change which fixes an issue)
- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
This commit is contained in:
JobSmithManipulation
2024-10-11 11:26:42 +08:00
committed by GitHub
parent bfaef2cca6
commit 18f80743eb
11 changed files with 203 additions and 10 deletions

View File

@ -58,7 +58,7 @@ def set_api_key():
chat_passed, embd_passed, rerank_passed = False, False, False
factory = req["llm_factory"]
msg = ""
for llm in LLMService.query(fid=factory)[:3]:
for llm in LLMService.query(fid=factory):
if not embd_passed and llm.model_type == LLMType.EMBEDDING.value:
mdl = EmbeddingModel[factory](
req["api_key"], llm.llm_name, base_url=req.get("base_url"))
@ -77,10 +77,10 @@ def set_api_key():
{"temperature": 0.9,'max_tokens':50})
if m.find("**ERROR**") >=0:
raise Exception(m)
chat_passed = True
except Exception as e:
msg += f"\nFail to access model({llm.llm_name}) using this api key." + str(
e)
chat_passed = True
elif not rerank_passed and llm.model_type == LLMType.RERANK:
mdl = RerankModel[factory](
req["api_key"], llm.llm_name, base_url=req.get("base_url"))
@ -88,10 +88,14 @@ def set_api_key():
arr, tc = mdl.similarity("What's the weather?", ["Is it sunny today?"])
if len(arr) == 0 or tc == 0:
raise Exception("Fail")
rerank_passed = True
print(f'passed model rerank{llm.llm_name}',flush=True)
except Exception as e:
msg += f"\nFail to access model({llm.llm_name}) using this api key." + str(
e)
rerank_passed = True
if any([embd_passed, chat_passed, rerank_passed]):
msg = ''
break
if msg:
return get_data_error_result(retmsg=msg)
@ -183,6 +187,10 @@ def add_llm():
llm_name = req["llm_name"]
api_key = apikey_json(["google_project_id", "google_region", "google_service_account_key"])
elif factory == "Azure-OpenAI":
llm_name = req["llm_name"]
api_key = apikey_json(["api_key", "api_version"])
else:
llm_name = req["llm_name"]
api_key = req.get("api_key", "xxxxxxxxxxxxxxx")