mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-23 23:16:58 +08:00
Refactor add_llm and add speech to text (#12089)
### What problem does this PR solve? 1. Refactor implementation of add_llm 2. Add speech to text model. ### Type of change - [x] Refactoring Signed-off-by: Jin Hai <haijin.chn@gmail.com>
This commit is contained in:
@ -25,7 +25,7 @@ from api.utils.api_utils import get_allowed_llm_factories, get_data_error_result
|
||||
from common.constants import StatusEnum, LLMType
|
||||
from api.db.db_models import TenantLLM
|
||||
from rag.utils.base64_image import test_image
|
||||
from rag.llm import EmbeddingModel, ChatModel, RerankModel, CvModel, TTSModel, OcrModel
|
||||
from rag.llm import EmbeddingModel, ChatModel, RerankModel, CvModel, TTSModel, OcrModel, Seq2txtModel
|
||||
|
||||
|
||||
@manager.route("/factories", methods=["GET"]) # noqa: F821
|
||||
@ -208,33 +208,39 @@ async def add_llm():
|
||||
msg = ""
|
||||
mdl_nm = llm["llm_name"].split("___")[0]
|
||||
extra = {"provider": factory}
|
||||
if llm["model_type"] == LLMType.EMBEDDING.value:
|
||||
model_type = llm["model_type"]
|
||||
model_api_key = llm["api_key"]
|
||||
model_base_url = llm.get("api_base", "")
|
||||
match model_type:
|
||||
case LLMType.EMBEDDING.value:
|
||||
assert factory in EmbeddingModel, f"Embedding model from {factory} is not supported yet."
|
||||
mdl = EmbeddingModel[factory](key=llm["api_key"], model_name=mdl_nm, base_url=llm["api_base"])
|
||||
mdl = EmbeddingModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
|
||||
try:
|
||||
arr, tc = mdl.encode(["Test if the api key is available"])
|
||||
if len(arr[0]) == 0:
|
||||
raise Exception("Fail")
|
||||
except Exception as e:
|
||||
msg += f"\nFail to access embedding model({mdl_nm})." + str(e)
|
||||
elif llm["model_type"] == LLMType.CHAT.value:
|
||||
case LLMType.CHAT.value:
|
||||
assert factory in ChatModel, f"Chat model from {factory} is not supported yet."
|
||||
mdl = ChatModel[factory](
|
||||
key=llm["api_key"],
|
||||
key=model_api_key,
|
||||
model_name=mdl_nm,
|
||||
base_url=llm["api_base"],
|
||||
base_url=model_base_url,
|
||||
**extra,
|
||||
)
|
||||
try:
|
||||
m, tc = await mdl.async_chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9})
|
||||
m, tc = await mdl.async_chat(None, [{"role": "user", "content": "Hello! How are you doing!"}],
|
||||
{"temperature": 0.9})
|
||||
if not tc and m.find("**ERROR**:") >= 0:
|
||||
raise Exception(m)
|
||||
except Exception as e:
|
||||
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
|
||||
elif llm["model_type"] == LLMType.RERANK:
|
||||
|
||||
case LLMType.RERANK.value:
|
||||
assert factory in RerankModel, f"RE-rank model from {factory} is not supported yet."
|
||||
try:
|
||||
mdl = RerankModel[factory](key=llm["api_key"], model_name=mdl_nm, base_url=llm["api_base"])
|
||||
mdl = RerankModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
|
||||
arr, tc = mdl.similarity("Hello~ RAGFlower!", ["Hi, there!", "Ohh, my friend!"])
|
||||
if len(arr) == 0:
|
||||
raise Exception("Not known.")
|
||||
@ -242,9 +248,10 @@ async def add_llm():
|
||||
msg += f"{factory} dose not support this model({factory}/{mdl_nm})"
|
||||
except Exception as e:
|
||||
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
|
||||
elif llm["model_type"] == LLMType.IMAGE2TEXT.value:
|
||||
|
||||
case LLMType.IMAGE2TEXT.value:
|
||||
assert factory in CvModel, f"Image to text model from {factory} is not supported yet."
|
||||
mdl = CvModel[factory](key=llm["api_key"], model_name=mdl_nm, base_url=llm["api_base"])
|
||||
mdl = CvModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
|
||||
try:
|
||||
image_data = test_image
|
||||
m, tc = mdl.describe(image_data)
|
||||
@ -252,26 +259,32 @@ async def add_llm():
|
||||
raise Exception(m)
|
||||
except Exception as e:
|
||||
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
|
||||
elif llm["model_type"] == LLMType.TTS:
|
||||
case LLMType.TTS.value:
|
||||
assert factory in TTSModel, f"TTS model from {factory} is not supported yet."
|
||||
mdl = TTSModel[factory](key=llm["api_key"], model_name=mdl_nm, base_url=llm["api_base"])
|
||||
mdl = TTSModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
|
||||
try:
|
||||
for resp in mdl.tts("Hello~ RAGFlower!"):
|
||||
pass
|
||||
except RuntimeError as e:
|
||||
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
|
||||
elif llm["model_type"] == LLMType.OCR.value:
|
||||
case LLMType.OCR.value:
|
||||
assert factory in OcrModel, f"OCR model from {factory} is not supported yet."
|
||||
try:
|
||||
mdl = OcrModel[factory](key=llm["api_key"], model_name=mdl_nm, base_url=llm.get("api_base", ""))
|
||||
mdl = OcrModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
|
||||
ok, reason = mdl.check_available()
|
||||
if not ok:
|
||||
raise RuntimeError(reason or "Model not available")
|
||||
except Exception as e:
|
||||
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
|
||||
else:
|
||||
# TODO: check other type of models
|
||||
pass
|
||||
case LLMType.SPEECH2TEXT:
|
||||
assert factory in Seq2txtModel, f"Speech model from {factory} is not supported yet."
|
||||
try:
|
||||
mdl = Seq2txtModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
|
||||
# TODO: check the availability
|
||||
except Exception as e:
|
||||
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
|
||||
case _:
|
||||
raise RuntimeError(f"Unknown model type: {model_type}")
|
||||
|
||||
if msg:
|
||||
return get_data_error_result(message=msg)
|
||||
|
||||
Reference in New Issue
Block a user