mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-02-06 10:35:06 +08:00
Feat: support verify to set llm key and boost bigrams. (#12980)
#12863 ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
import logging
|
||||
import json
|
||||
import os
|
||||
@ -64,13 +65,17 @@ async def set_api_key():
|
||||
chat_passed, embd_passed, rerank_passed = False, False, False
|
||||
factory = req["llm_factory"]
|
||||
extra = {"provider": factory}
|
||||
timeout_seconds = int(os.environ.get("LLM_TIMEOUT_SECONDS", 10))
|
||||
msg = ""
|
||||
for llm in LLMService.query(fid=factory):
|
||||
if not embd_passed and llm.model_type == LLMType.EMBEDDING.value:
|
||||
assert factory in EmbeddingModel, f"Embedding model from {factory} is not supported yet."
|
||||
mdl = EmbeddingModel[factory](req["api_key"], llm.llm_name, base_url=req.get("base_url"))
|
||||
try:
|
||||
arr, tc = mdl.encode(["Test if the api key is available"])
|
||||
arr, tc = await asyncio.wait_for(
|
||||
asyncio.to_thread(mdl.encode, ["Test if the api key is available"]),
|
||||
timeout=timeout_seconds,
|
||||
)
|
||||
if len(arr[0]) == 0:
|
||||
raise Exception("Fail")
|
||||
embd_passed = True
|
||||
@ -80,17 +85,27 @@ async def set_api_key():
|
||||
assert factory in ChatModel, f"Chat model from {factory} is not supported yet."
|
||||
mdl = ChatModel[factory](req["api_key"], llm.llm_name, base_url=req.get("base_url"), **extra)
|
||||
try:
|
||||
m, tc = await mdl.async_chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9, "max_tokens": 50})
|
||||
m, tc = await asyncio.wait_for(
|
||||
mdl.async_chat(
|
||||
None,
|
||||
[{"role": "user", "content": "Hello! How are you doing!"}],
|
||||
{"temperature": 0.9, "max_tokens": 50},
|
||||
),
|
||||
timeout=timeout_seconds,
|
||||
)
|
||||
if m.find("**ERROR**") >= 0:
|
||||
raise Exception(m)
|
||||
chat_passed = True
|
||||
except Exception as e:
|
||||
msg += f"\nFail to access model({llm.fid}/{llm.llm_name}) using this api key." + str(e)
|
||||
elif not rerank_passed and llm.model_type == LLMType.RERANK:
|
||||
elif not rerank_passed and llm.model_type == LLMType.RERANK.value:
|
||||
assert factory in RerankModel, f"Re-rank model from {factory} is not supported yet."
|
||||
mdl = RerankModel[factory](req["api_key"], llm.llm_name, base_url=req.get("base_url"))
|
||||
try:
|
||||
arr, tc = mdl.similarity("What's the weather?", ["Is it sunny today?"])
|
||||
arr, tc = await asyncio.wait_for(
|
||||
asyncio.to_thread(mdl.similarity, "What's the weather?", ["Is it sunny today?"]),
|
||||
timeout=timeout_seconds,
|
||||
)
|
||||
if len(arr) == 0 or tc == 0:
|
||||
raise Exception("Fail")
|
||||
rerank_passed = True
|
||||
@ -101,6 +116,9 @@ async def set_api_key():
|
||||
msg = ""
|
||||
break
|
||||
|
||||
if req.get("verify", False):
|
||||
return get_json_result(data={"message": msg, "success": len(msg.strip())==0})
|
||||
|
||||
if msg:
|
||||
return get_data_error_result(message=msg)
|
||||
|
||||
@ -133,6 +151,7 @@ async def add_llm():
|
||||
factory = req["llm_factory"]
|
||||
api_key = req.get("api_key", "x")
|
||||
llm_name = req.get("llm_name")
|
||||
timeout_seconds = int(os.environ.get("LLM_TIMEOUT_SECONDS", 10))
|
||||
|
||||
if factory not in [f.name for f in get_allowed_llm_factories()]:
|
||||
return get_data_error_result(message=f"LLM factory {factory} is not allowed")
|
||||
@ -215,7 +234,10 @@ async def add_llm():
|
||||
assert factory in EmbeddingModel, f"Embedding model from {factory} is not supported yet."
|
||||
mdl = EmbeddingModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
|
||||
try:
|
||||
arr, tc = mdl.encode(["Test if the api key is available"])
|
||||
arr, tc = await asyncio.wait_for(
|
||||
asyncio.to_thread(mdl.encode, ["Test if the api key is available"]),
|
||||
timeout=timeout_seconds,
|
||||
)
|
||||
if len(arr[0]) == 0:
|
||||
raise Exception("Fail")
|
||||
except Exception as e:
|
||||
@ -229,7 +251,14 @@ async def add_llm():
|
||||
**extra,
|
||||
)
|
||||
try:
|
||||
m, tc = await mdl.async_chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9})
|
||||
m, tc = await asyncio.wait_for(
|
||||
mdl.async_chat(
|
||||
None,
|
||||
[{"role": "user", "content": "Hello! How are you doing!"}],
|
||||
{"temperature": 0.9},
|
||||
),
|
||||
timeout=timeout_seconds,
|
||||
)
|
||||
if not tc and m.find("**ERROR**:") >= 0:
|
||||
raise Exception(m)
|
||||
except Exception as e:
|
||||
@ -239,7 +268,10 @@ async def add_llm():
|
||||
assert factory in RerankModel, f"RE-rank model from {factory} is not supported yet."
|
||||
try:
|
||||
mdl = RerankModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
|
||||
arr, tc = mdl.similarity("Hello~ RAGFlower!", ["Hi, there!", "Ohh, my friend!"])
|
||||
arr, tc = await asyncio.wait_for(
|
||||
asyncio.to_thread(mdl.similarity, "Hello~ RAGFlower!", ["Hi, there!", "Ohh, my friend!"]),
|
||||
timeout=timeout_seconds,
|
||||
)
|
||||
if len(arr) == 0:
|
||||
raise Exception("Not known.")
|
||||
except KeyError:
|
||||
@ -252,7 +284,10 @@ async def add_llm():
|
||||
mdl = CvModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
|
||||
try:
|
||||
image_data = test_image
|
||||
m, tc = mdl.describe(image_data)
|
||||
m, tc = await asyncio.wait_for(
|
||||
asyncio.to_thread(mdl.describe, image_data),
|
||||
timeout=timeout_seconds,
|
||||
)
|
||||
if not tc and m.find("**ERROR**:") >= 0:
|
||||
raise Exception(m)
|
||||
except Exception as e:
|
||||
@ -261,20 +296,29 @@ async def add_llm():
|
||||
assert factory in TTSModel, f"TTS model from {factory} is not supported yet."
|
||||
mdl = TTSModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
|
||||
try:
|
||||
for resp in mdl.tts("Hello~ RAGFlower!"):
|
||||
pass
|
||||
def drain_tts():
|
||||
for _ in mdl.tts("Hello~ RAGFlower!"):
|
||||
pass
|
||||
|
||||
await asyncio.wait_for(
|
||||
asyncio.to_thread(drain_tts),
|
||||
timeout=timeout_seconds,
|
||||
)
|
||||
except RuntimeError as e:
|
||||
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
|
||||
case LLMType.OCR.value:
|
||||
assert factory in OcrModel, f"OCR model from {factory} is not supported yet."
|
||||
try:
|
||||
mdl = OcrModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
|
||||
ok, reason = mdl.check_available()
|
||||
ok, reason = await asyncio.wait_for(
|
||||
asyncio.to_thread(mdl.check_available),
|
||||
timeout=timeout_seconds,
|
||||
)
|
||||
if not ok:
|
||||
raise RuntimeError(reason or "Model not available")
|
||||
except Exception as e:
|
||||
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
|
||||
case LLMType.SPEECH2TEXT:
|
||||
case LLMType.SPEECH2TEXT.value:
|
||||
assert factory in Seq2txtModel, f"Speech model from {factory} is not supported yet."
|
||||
try:
|
||||
mdl = Seq2txtModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
|
||||
@ -284,6 +328,9 @@ async def add_llm():
|
||||
case _:
|
||||
raise RuntimeError(f"Unknown model type: {model_type}")
|
||||
|
||||
if req.get("verify", False):
|
||||
return get_json_result(data={"message": msg, "success": len(msg.strip()) == 0})
|
||||
|
||||
if msg:
|
||||
return get_data_error_result(message=msg)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user