From e5f3d5ae263b5f6ddaa341a8025b100ec3475943 Mon Sep 17 00:00:00 2001 From: Jin Hai Date: Mon, 22 Dec 2025 19:27:26 +0800 Subject: [PATCH] Refactor add_llm and add speech to text (#12089) ### What problem does this PR solve? 1. Refactor implementation of add_llm 2. Add speech to text model. ### Type of change - [x] Refactoring Signed-off-by: Jin Hai --- api/apps/llm_app.py | 143 ++++++++++++++++++++++++-------------------- 1 file changed, 78 insertions(+), 65 deletions(-) diff --git a/api/apps/llm_app.py b/api/apps/llm_app.py index e77b90506..9a68e8256 100644 --- a/api/apps/llm_app.py +++ b/api/apps/llm_app.py @@ -25,7 +25,7 @@ from api.utils.api_utils import get_allowed_llm_factories, get_data_error_result from common.constants import StatusEnum, LLMType from api.db.db_models import TenantLLM from rag.utils.base64_image import test_image -from rag.llm import EmbeddingModel, ChatModel, RerankModel, CvModel, TTSModel, OcrModel +from rag.llm import EmbeddingModel, ChatModel, RerankModel, CvModel, TTSModel, OcrModel, Seq2txtModel @manager.route("/factories", methods=["GET"]) # noqa: F821 @@ -208,70 +208,83 @@ async def add_llm(): msg = "" mdl_nm = llm["llm_name"].split("___")[0] extra = {"provider": factory} - if llm["model_type"] == LLMType.EMBEDDING.value: - assert factory in EmbeddingModel, f"Embedding model from {factory} is not supported yet." - mdl = EmbeddingModel[factory](key=llm["api_key"], model_name=mdl_nm, base_url=llm["api_base"]) - try: - arr, tc = mdl.encode(["Test if the api key is available"]) - if len(arr[0]) == 0: - raise Exception("Fail") - except Exception as e: - msg += f"\nFail to access embedding model({mdl_nm})." + str(e) - elif llm["model_type"] == LLMType.CHAT.value: - assert factory in ChatModel, f"Chat model from {factory} is not supported yet." - mdl = ChatModel[factory]( - key=llm["api_key"], - model_name=mdl_nm, - base_url=llm["api_base"], - **extra, - ) - try: - m, tc = await mdl.async_chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9}) - if not tc and m.find("**ERROR**:") >= 0: - raise Exception(m) - except Exception as e: - msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e) - elif llm["model_type"] == LLMType.RERANK: - assert factory in RerankModel, f"RE-rank model from {factory} is not supported yet." - try: - mdl = RerankModel[factory](key=llm["api_key"], model_name=mdl_nm, base_url=llm["api_base"]) - arr, tc = mdl.similarity("Hello~ RAGFlower!", ["Hi, there!", "Ohh, my friend!"]) - if len(arr) == 0: - raise Exception("Not known.") - except KeyError: - msg += f"{factory} dose not support this model({factory}/{mdl_nm})" - except Exception as e: - msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e) - elif llm["model_type"] == LLMType.IMAGE2TEXT.value: - assert factory in CvModel, f"Image to text model from {factory} is not supported yet." - mdl = CvModel[factory](key=llm["api_key"], model_name=mdl_nm, base_url=llm["api_base"]) - try: - image_data = test_image - m, tc = mdl.describe(image_data) - if not tc and m.find("**ERROR**:") >= 0: - raise Exception(m) - except Exception as e: - msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e) - elif llm["model_type"] == LLMType.TTS: - assert factory in TTSModel, f"TTS model from {factory} is not supported yet." - mdl = TTSModel[factory](key=llm["api_key"], model_name=mdl_nm, base_url=llm["api_base"]) - try: - for resp in mdl.tts("Hello~ RAGFlower!"): - pass - except RuntimeError as e: - msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e) - elif llm["model_type"] == LLMType.OCR.value: - assert factory in OcrModel, f"OCR model from {factory} is not supported yet." - try: - mdl = OcrModel[factory](key=llm["api_key"], model_name=mdl_nm, base_url=llm.get("api_base", "")) - ok, reason = mdl.check_available() - if not ok: - raise RuntimeError(reason or "Model not available") - except Exception as e: - msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e) - else: - # TODO: check other type of models - pass + model_type = llm["model_type"] + model_api_key = llm["api_key"] + model_base_url = llm.get("api_base", "") + match model_type: + case LLMType.EMBEDDING.value: + assert factory in EmbeddingModel, f"Embedding model from {factory} is not supported yet." + mdl = EmbeddingModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url) + try: + arr, tc = mdl.encode(["Test if the api key is available"]) + if len(arr[0]) == 0: + raise Exception("Fail") + except Exception as e: + msg += f"\nFail to access embedding model({mdl_nm})." + str(e) + case LLMType.CHAT.value: + assert factory in ChatModel, f"Chat model from {factory} is not supported yet." + mdl = ChatModel[factory]( + key=model_api_key, + model_name=mdl_nm, + base_url=model_base_url, + **extra, + ) + try: + m, tc = await mdl.async_chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], + {"temperature": 0.9}) + if not tc and m.find("**ERROR**:") >= 0: + raise Exception(m) + except Exception as e: + msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e) + + case LLMType.RERANK.value: + assert factory in RerankModel, f"RE-rank model from {factory} is not supported yet." + try: + mdl = RerankModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url) + arr, tc = mdl.similarity("Hello~ RAGFlower!", ["Hi, there!", "Ohh, my friend!"]) + if len(arr) == 0: + raise Exception("Not known.") + except KeyError: + msg += f"{factory} dose not support this model({factory}/{mdl_nm})" + except Exception as e: + msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e) + + case LLMType.IMAGE2TEXT.value: + assert factory in CvModel, f"Image to text model from {factory} is not supported yet." + mdl = CvModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url) + try: + image_data = test_image + m, tc = mdl.describe(image_data) + if not tc and m.find("**ERROR**:") >= 0: + raise Exception(m) + except Exception as e: + msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e) + case LLMType.TTS.value: + assert factory in TTSModel, f"TTS model from {factory} is not supported yet." + mdl = TTSModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url) + try: + for resp in mdl.tts("Hello~ RAGFlower!"): + pass + except RuntimeError as e: + msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e) + case LLMType.OCR.value: + assert factory in OcrModel, f"OCR model from {factory} is not supported yet." + try: + mdl = OcrModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url) + ok, reason = mdl.check_available() + if not ok: + raise RuntimeError(reason or "Model not available") + except Exception as e: + msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e) + case LLMType.SPEECH2TEXT: + assert factory in Seq2txtModel, f"Speech model from {factory} is not supported yet." + try: + mdl = Seq2txtModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url) + # TODO: check the availability + except Exception as e: + msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e) + case _: + raise RuntimeError(f"Unknown model type: {model_type}") if msg: return get_data_error_result(message=msg)