Refactor add_llm and add speech to text (#12089)

### What problem does this PR solve?

1. Refactor implementation of add_llm
2. Add speech to text model.

### Type of change

- [x] Refactoring

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
This commit is contained in:
Jin Hai
2025-12-22 19:27:26 +08:00
committed by GitHub
parent 4cbc91f2fa
commit e5f3d5ae26

View File

@ -25,7 +25,7 @@ from api.utils.api_utils import get_allowed_llm_factories, get_data_error_result
from common.constants import StatusEnum, LLMType from common.constants import StatusEnum, LLMType
from api.db.db_models import TenantLLM from api.db.db_models import TenantLLM
from rag.utils.base64_image import test_image from rag.utils.base64_image import test_image
from rag.llm import EmbeddingModel, ChatModel, RerankModel, CvModel, TTSModel, OcrModel from rag.llm import EmbeddingModel, ChatModel, RerankModel, CvModel, TTSModel, OcrModel, Seq2txtModel
@manager.route("/factories", methods=["GET"]) # noqa: F821 @manager.route("/factories", methods=["GET"]) # noqa: F821
@ -208,70 +208,83 @@ async def add_llm():
msg = "" msg = ""
mdl_nm = llm["llm_name"].split("___")[0] mdl_nm = llm["llm_name"].split("___")[0]
extra = {"provider": factory} extra = {"provider": factory}
if llm["model_type"] == LLMType.EMBEDDING.value: model_type = llm["model_type"]
assert factory in EmbeddingModel, f"Embedding model from {factory} is not supported yet." model_api_key = llm["api_key"]
mdl = EmbeddingModel[factory](key=llm["api_key"], model_name=mdl_nm, base_url=llm["api_base"]) model_base_url = llm.get("api_base", "")
try: match model_type:
arr, tc = mdl.encode(["Test if the api key is available"]) case LLMType.EMBEDDING.value:
if len(arr[0]) == 0: assert factory in EmbeddingModel, f"Embedding model from {factory} is not supported yet."
raise Exception("Fail") mdl = EmbeddingModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
except Exception as e: try:
msg += f"\nFail to access embedding model({mdl_nm})." + str(e) arr, tc = mdl.encode(["Test if the api key is available"])
elif llm["model_type"] == LLMType.CHAT.value: if len(arr[0]) == 0:
assert factory in ChatModel, f"Chat model from {factory} is not supported yet." raise Exception("Fail")
mdl = ChatModel[factory]( except Exception as e:
key=llm["api_key"], msg += f"\nFail to access embedding model({mdl_nm})." + str(e)
model_name=mdl_nm, case LLMType.CHAT.value:
base_url=llm["api_base"], assert factory in ChatModel, f"Chat model from {factory} is not supported yet."
**extra, mdl = ChatModel[factory](
) key=model_api_key,
try: model_name=mdl_nm,
m, tc = await mdl.async_chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9}) base_url=model_base_url,
if not tc and m.find("**ERROR**:") >= 0: **extra,
raise Exception(m) )
except Exception as e: try:
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e) m, tc = await mdl.async_chat(None, [{"role": "user", "content": "Hello! How are you doing!"}],
elif llm["model_type"] == LLMType.RERANK: {"temperature": 0.9})
assert factory in RerankModel, f"RE-rank model from {factory} is not supported yet." if not tc and m.find("**ERROR**:") >= 0:
try: raise Exception(m)
mdl = RerankModel[factory](key=llm["api_key"], model_name=mdl_nm, base_url=llm["api_base"]) except Exception as e:
arr, tc = mdl.similarity("Hello~ RAGFlower!", ["Hi, there!", "Ohh, my friend!"]) msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
if len(arr) == 0:
raise Exception("Not known.") case LLMType.RERANK.value:
except KeyError: assert factory in RerankModel, f"RE-rank model from {factory} is not supported yet."
msg += f"{factory} dose not support this model({factory}/{mdl_nm})" try:
except Exception as e: mdl = RerankModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e) arr, tc = mdl.similarity("Hello~ RAGFlower!", ["Hi, there!", "Ohh, my friend!"])
elif llm["model_type"] == LLMType.IMAGE2TEXT.value: if len(arr) == 0:
assert factory in CvModel, f"Image to text model from {factory} is not supported yet." raise Exception("Not known.")
mdl = CvModel[factory](key=llm["api_key"], model_name=mdl_nm, base_url=llm["api_base"]) except KeyError:
try: msg += f"{factory} dose not support this model({factory}/{mdl_nm})"
image_data = test_image except Exception as e:
m, tc = mdl.describe(image_data) msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
if not tc and m.find("**ERROR**:") >= 0:
raise Exception(m) case LLMType.IMAGE2TEXT.value:
except Exception as e: assert factory in CvModel, f"Image to text model from {factory} is not supported yet."
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e) mdl = CvModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
elif llm["model_type"] == LLMType.TTS: try:
assert factory in TTSModel, f"TTS model from {factory} is not supported yet." image_data = test_image
mdl = TTSModel[factory](key=llm["api_key"], model_name=mdl_nm, base_url=llm["api_base"]) m, tc = mdl.describe(image_data)
try: if not tc and m.find("**ERROR**:") >= 0:
for resp in mdl.tts("Hello~ RAGFlower!"): raise Exception(m)
pass except Exception as e:
except RuntimeError as e: msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e) case LLMType.TTS.value:
elif llm["model_type"] == LLMType.OCR.value: assert factory in TTSModel, f"TTS model from {factory} is not supported yet."
assert factory in OcrModel, f"OCR model from {factory} is not supported yet." mdl = TTSModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
try: try:
mdl = OcrModel[factory](key=llm["api_key"], model_name=mdl_nm, base_url=llm.get("api_base", "")) for resp in mdl.tts("Hello~ RAGFlower!"):
ok, reason = mdl.check_available() pass
if not ok: except RuntimeError as e:
raise RuntimeError(reason or "Model not available") msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
except Exception as e: case LLMType.OCR.value:
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e) assert factory in OcrModel, f"OCR model from {factory} is not supported yet."
else: try:
# TODO: check other type of models mdl = OcrModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
pass ok, reason = mdl.check_available()
if not ok:
raise RuntimeError(reason or "Model not available")
except Exception as e:
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
case LLMType.SPEECH2TEXT:
assert factory in Seq2txtModel, f"Speech model from {factory} is not supported yet."
try:
mdl = Seq2txtModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
# TODO: check the availability
except Exception as e:
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
case _:
raise RuntimeError(f"Unknown model type: {model_type}")
if msg: if msg:
return get_data_error_result(message=msg) return get_data_error_result(message=msg)