mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
support api-version and change default-model in adding azure-openai and openai (#2799)
### What problem does this PR solve? #2701 #2712 #2749 ### Type of change -[x] Bug Fix (non-breaking change which fixes an issue) - [x] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
This commit is contained in:
committed by
GitHub
parent
bfaef2cca6
commit
18f80743eb
@ -58,7 +58,7 @@ def set_api_key():
|
||||
chat_passed, embd_passed, rerank_passed = False, False, False
|
||||
factory = req["llm_factory"]
|
||||
msg = ""
|
||||
for llm in LLMService.query(fid=factory)[:3]:
|
||||
for llm in LLMService.query(fid=factory):
|
||||
if not embd_passed and llm.model_type == LLMType.EMBEDDING.value:
|
||||
mdl = EmbeddingModel[factory](
|
||||
req["api_key"], llm.llm_name, base_url=req.get("base_url"))
|
||||
@ -77,10 +77,10 @@ def set_api_key():
|
||||
{"temperature": 0.9,'max_tokens':50})
|
||||
if m.find("**ERROR**") >=0:
|
||||
raise Exception(m)
|
||||
chat_passed = True
|
||||
except Exception as e:
|
||||
msg += f"\nFail to access model({llm.llm_name}) using this api key." + str(
|
||||
e)
|
||||
chat_passed = True
|
||||
elif not rerank_passed and llm.model_type == LLMType.RERANK:
|
||||
mdl = RerankModel[factory](
|
||||
req["api_key"], llm.llm_name, base_url=req.get("base_url"))
|
||||
@ -88,10 +88,14 @@ def set_api_key():
|
||||
arr, tc = mdl.similarity("What's the weather?", ["Is it sunny today?"])
|
||||
if len(arr) == 0 or tc == 0:
|
||||
raise Exception("Fail")
|
||||
rerank_passed = True
|
||||
print(f'passed model rerank{llm.llm_name}',flush=True)
|
||||
except Exception as e:
|
||||
msg += f"\nFail to access model({llm.llm_name}) using this api key." + str(
|
||||
e)
|
||||
rerank_passed = True
|
||||
if any([embd_passed, chat_passed, rerank_passed]):
|
||||
msg = ''
|
||||
break
|
||||
|
||||
if msg:
|
||||
return get_data_error_result(retmsg=msg)
|
||||
@ -183,6 +187,10 @@ def add_llm():
|
||||
llm_name = req["llm_name"]
|
||||
api_key = apikey_json(["google_project_id", "google_region", "google_service_account_key"])
|
||||
|
||||
elif factory == "Azure-OpenAI":
|
||||
llm_name = req["llm_name"]
|
||||
api_key = apikey_json(["api_key", "api_version"])
|
||||
|
||||
else:
|
||||
llm_name = req["llm_name"]
|
||||
api_key = req.get("api_key", "xxxxxxxxxxxxxxx")
|
||||
|
||||
Reference in New Issue
Block a user