Refactor add_llm and add speech to text (#12089)

### What problem does this PR solve?

1. Refactor implementation of add_llm
2. Add speech to text model.

### Type of change

- [x] Refactoring

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
This commit is contained in:
Jin Hai
2025-12-22 19:27:26 +08:00
committed by GitHub
parent 4cbc91f2fa
commit e5f3d5ae26

View File

@ -25,7 +25,7 @@ from api.utils.api_utils import get_allowed_llm_factories, get_data_error_result
from common.constants import StatusEnum, LLMType from common.constants import StatusEnum, LLMType
from api.db.db_models import TenantLLM from api.db.db_models import TenantLLM
from rag.utils.base64_image import test_image from rag.utils.base64_image import test_image
from rag.llm import EmbeddingModel, ChatModel, RerankModel, CvModel, TTSModel, OcrModel from rag.llm import EmbeddingModel, ChatModel, RerankModel, CvModel, TTSModel, OcrModel, Seq2txtModel
@manager.route("/factories", methods=["GET"]) # noqa: F821 @manager.route("/factories", methods=["GET"]) # noqa: F821
@ -208,33 +208,39 @@ async def add_llm():
msg = "" msg = ""
mdl_nm = llm["llm_name"].split("___")[0] mdl_nm = llm["llm_name"].split("___")[0]
extra = {"provider": factory} extra = {"provider": factory}
if llm["model_type"] == LLMType.EMBEDDING.value: model_type = llm["model_type"]
model_api_key = llm["api_key"]
model_base_url = llm.get("api_base", "")
match model_type:
case LLMType.EMBEDDING.value:
assert factory in EmbeddingModel, f"Embedding model from {factory} is not supported yet." assert factory in EmbeddingModel, f"Embedding model from {factory} is not supported yet."
mdl = EmbeddingModel[factory](key=llm["api_key"], model_name=mdl_nm, base_url=llm["api_base"]) mdl = EmbeddingModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
try: try:
arr, tc = mdl.encode(["Test if the api key is available"]) arr, tc = mdl.encode(["Test if the api key is available"])
if len(arr[0]) == 0: if len(arr[0]) == 0:
raise Exception("Fail") raise Exception("Fail")
except Exception as e: except Exception as e:
msg += f"\nFail to access embedding model({mdl_nm})." + str(e) msg += f"\nFail to access embedding model({mdl_nm})." + str(e)
elif llm["model_type"] == LLMType.CHAT.value: case LLMType.CHAT.value:
assert factory in ChatModel, f"Chat model from {factory} is not supported yet." assert factory in ChatModel, f"Chat model from {factory} is not supported yet."
mdl = ChatModel[factory]( mdl = ChatModel[factory](
key=llm["api_key"], key=model_api_key,
model_name=mdl_nm, model_name=mdl_nm,
base_url=llm["api_base"], base_url=model_base_url,
**extra, **extra,
) )
try: try:
m, tc = await mdl.async_chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9}) m, tc = await mdl.async_chat(None, [{"role": "user", "content": "Hello! How are you doing!"}],
{"temperature": 0.9})
if not tc and m.find("**ERROR**:") >= 0: if not tc and m.find("**ERROR**:") >= 0:
raise Exception(m) raise Exception(m)
except Exception as e: except Exception as e:
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e) msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
elif llm["model_type"] == LLMType.RERANK:
case LLMType.RERANK.value:
assert factory in RerankModel, f"RE-rank model from {factory} is not supported yet." assert factory in RerankModel, f"RE-rank model from {factory} is not supported yet."
try: try:
mdl = RerankModel[factory](key=llm["api_key"], model_name=mdl_nm, base_url=llm["api_base"]) mdl = RerankModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
arr, tc = mdl.similarity("Hello~ RAGFlower!", ["Hi, there!", "Ohh, my friend!"]) arr, tc = mdl.similarity("Hello~ RAGFlower!", ["Hi, there!", "Ohh, my friend!"])
if len(arr) == 0: if len(arr) == 0:
raise Exception("Not known.") raise Exception("Not known.")
@ -242,9 +248,10 @@ async def add_llm():
msg += f"{factory} dose not support this model({factory}/{mdl_nm})" msg += f"{factory} dose not support this model({factory}/{mdl_nm})"
except Exception as e: except Exception as e:
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e) msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
elif llm["model_type"] == LLMType.IMAGE2TEXT.value:
case LLMType.IMAGE2TEXT.value:
assert factory in CvModel, f"Image to text model from {factory} is not supported yet." assert factory in CvModel, f"Image to text model from {factory} is not supported yet."
mdl = CvModel[factory](key=llm["api_key"], model_name=mdl_nm, base_url=llm["api_base"]) mdl = CvModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
try: try:
image_data = test_image image_data = test_image
m, tc = mdl.describe(image_data) m, tc = mdl.describe(image_data)
@ -252,26 +259,32 @@ async def add_llm():
raise Exception(m) raise Exception(m)
except Exception as e: except Exception as e:
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e) msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
elif llm["model_type"] == LLMType.TTS: case LLMType.TTS.value:
assert factory in TTSModel, f"TTS model from {factory} is not supported yet." assert factory in TTSModel, f"TTS model from {factory} is not supported yet."
mdl = TTSModel[factory](key=llm["api_key"], model_name=mdl_nm, base_url=llm["api_base"]) mdl = TTSModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
try: try:
for resp in mdl.tts("Hello~ RAGFlower!"): for resp in mdl.tts("Hello~ RAGFlower!"):
pass pass
except RuntimeError as e: except RuntimeError as e:
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e) msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
elif llm["model_type"] == LLMType.OCR.value: case LLMType.OCR.value:
assert factory in OcrModel, f"OCR model from {factory} is not supported yet." assert factory in OcrModel, f"OCR model from {factory} is not supported yet."
try: try:
mdl = OcrModel[factory](key=llm["api_key"], model_name=mdl_nm, base_url=llm.get("api_base", "")) mdl = OcrModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
ok, reason = mdl.check_available() ok, reason = mdl.check_available()
if not ok: if not ok:
raise RuntimeError(reason or "Model not available") raise RuntimeError(reason or "Model not available")
except Exception as e: except Exception as e:
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e) msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
else: case LLMType.SPEECH2TEXT:
# TODO: check other type of models assert factory in Seq2txtModel, f"Speech model from {factory} is not supported yet."
pass try:
mdl = Seq2txtModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
# TODO: check the availability
except Exception as e:
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
case _:
raise RuntimeError(f"Unknown model type: {model_type}")
if msg: if msg:
return get_data_error_result(message=msg) return get_data_error_result(message=msg)