mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-23 23:16:58 +08:00
Refactor add_llm and add speech to text (#12089)
### What problem does this PR solve? 1. Refactor implementation of add_llm 2. Add speech to text model. ### Type of change - [x] Refactoring Signed-off-by: Jin Hai <haijin.chn@gmail.com>
This commit is contained in:
@ -25,7 +25,7 @@ from api.utils.api_utils import get_allowed_llm_factories, get_data_error_result
|
|||||||
from common.constants import StatusEnum, LLMType
|
from common.constants import StatusEnum, LLMType
|
||||||
from api.db.db_models import TenantLLM
|
from api.db.db_models import TenantLLM
|
||||||
from rag.utils.base64_image import test_image
|
from rag.utils.base64_image import test_image
|
||||||
from rag.llm import EmbeddingModel, ChatModel, RerankModel, CvModel, TTSModel, OcrModel
|
from rag.llm import EmbeddingModel, ChatModel, RerankModel, CvModel, TTSModel, OcrModel, Seq2txtModel
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/factories", methods=["GET"]) # noqa: F821
|
@manager.route("/factories", methods=["GET"]) # noqa: F821
|
||||||
@ -208,70 +208,83 @@ async def add_llm():
|
|||||||
msg = ""
|
msg = ""
|
||||||
mdl_nm = llm["llm_name"].split("___")[0]
|
mdl_nm = llm["llm_name"].split("___")[0]
|
||||||
extra = {"provider": factory}
|
extra = {"provider": factory}
|
||||||
if llm["model_type"] == LLMType.EMBEDDING.value:
|
model_type = llm["model_type"]
|
||||||
assert factory in EmbeddingModel, f"Embedding model from {factory} is not supported yet."
|
model_api_key = llm["api_key"]
|
||||||
mdl = EmbeddingModel[factory](key=llm["api_key"], model_name=mdl_nm, base_url=llm["api_base"])
|
model_base_url = llm.get("api_base", "")
|
||||||
try:
|
match model_type:
|
||||||
arr, tc = mdl.encode(["Test if the api key is available"])
|
case LLMType.EMBEDDING.value:
|
||||||
if len(arr[0]) == 0:
|
assert factory in EmbeddingModel, f"Embedding model from {factory} is not supported yet."
|
||||||
raise Exception("Fail")
|
mdl = EmbeddingModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
|
||||||
except Exception as e:
|
try:
|
||||||
msg += f"\nFail to access embedding model({mdl_nm})." + str(e)
|
arr, tc = mdl.encode(["Test if the api key is available"])
|
||||||
elif llm["model_type"] == LLMType.CHAT.value:
|
if len(arr[0]) == 0:
|
||||||
assert factory in ChatModel, f"Chat model from {factory} is not supported yet."
|
raise Exception("Fail")
|
||||||
mdl = ChatModel[factory](
|
except Exception as e:
|
||||||
key=llm["api_key"],
|
msg += f"\nFail to access embedding model({mdl_nm})." + str(e)
|
||||||
model_name=mdl_nm,
|
case LLMType.CHAT.value:
|
||||||
base_url=llm["api_base"],
|
assert factory in ChatModel, f"Chat model from {factory} is not supported yet."
|
||||||
**extra,
|
mdl = ChatModel[factory](
|
||||||
)
|
key=model_api_key,
|
||||||
try:
|
model_name=mdl_nm,
|
||||||
m, tc = await mdl.async_chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9})
|
base_url=model_base_url,
|
||||||
if not tc and m.find("**ERROR**:") >= 0:
|
**extra,
|
||||||
raise Exception(m)
|
)
|
||||||
except Exception as e:
|
try:
|
||||||
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
|
m, tc = await mdl.async_chat(None, [{"role": "user", "content": "Hello! How are you doing!"}],
|
||||||
elif llm["model_type"] == LLMType.RERANK:
|
{"temperature": 0.9})
|
||||||
assert factory in RerankModel, f"RE-rank model from {factory} is not supported yet."
|
if not tc and m.find("**ERROR**:") >= 0:
|
||||||
try:
|
raise Exception(m)
|
||||||
mdl = RerankModel[factory](key=llm["api_key"], model_name=mdl_nm, base_url=llm["api_base"])
|
except Exception as e:
|
||||||
arr, tc = mdl.similarity("Hello~ RAGFlower!", ["Hi, there!", "Ohh, my friend!"])
|
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
|
||||||
if len(arr) == 0:
|
|
||||||
raise Exception("Not known.")
|
case LLMType.RERANK.value:
|
||||||
except KeyError:
|
assert factory in RerankModel, f"RE-rank model from {factory} is not supported yet."
|
||||||
msg += f"{factory} dose not support this model({factory}/{mdl_nm})"
|
try:
|
||||||
except Exception as e:
|
mdl = RerankModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
|
||||||
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
|
arr, tc = mdl.similarity("Hello~ RAGFlower!", ["Hi, there!", "Ohh, my friend!"])
|
||||||
elif llm["model_type"] == LLMType.IMAGE2TEXT.value:
|
if len(arr) == 0:
|
||||||
assert factory in CvModel, f"Image to text model from {factory} is not supported yet."
|
raise Exception("Not known.")
|
||||||
mdl = CvModel[factory](key=llm["api_key"], model_name=mdl_nm, base_url=llm["api_base"])
|
except KeyError:
|
||||||
try:
|
msg += f"{factory} dose not support this model({factory}/{mdl_nm})"
|
||||||
image_data = test_image
|
except Exception as e:
|
||||||
m, tc = mdl.describe(image_data)
|
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
|
||||||
if not tc and m.find("**ERROR**:") >= 0:
|
|
||||||
raise Exception(m)
|
case LLMType.IMAGE2TEXT.value:
|
||||||
except Exception as e:
|
assert factory in CvModel, f"Image to text model from {factory} is not supported yet."
|
||||||
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
|
mdl = CvModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
|
||||||
elif llm["model_type"] == LLMType.TTS:
|
try:
|
||||||
assert factory in TTSModel, f"TTS model from {factory} is not supported yet."
|
image_data = test_image
|
||||||
mdl = TTSModel[factory](key=llm["api_key"], model_name=mdl_nm, base_url=llm["api_base"])
|
m, tc = mdl.describe(image_data)
|
||||||
try:
|
if not tc and m.find("**ERROR**:") >= 0:
|
||||||
for resp in mdl.tts("Hello~ RAGFlower!"):
|
raise Exception(m)
|
||||||
pass
|
except Exception as e:
|
||||||
except RuntimeError as e:
|
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
|
||||||
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
|
case LLMType.TTS.value:
|
||||||
elif llm["model_type"] == LLMType.OCR.value:
|
assert factory in TTSModel, f"TTS model from {factory} is not supported yet."
|
||||||
assert factory in OcrModel, f"OCR model from {factory} is not supported yet."
|
mdl = TTSModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
|
||||||
try:
|
try:
|
||||||
mdl = OcrModel[factory](key=llm["api_key"], model_name=mdl_nm, base_url=llm.get("api_base", ""))
|
for resp in mdl.tts("Hello~ RAGFlower!"):
|
||||||
ok, reason = mdl.check_available()
|
pass
|
||||||
if not ok:
|
except RuntimeError as e:
|
||||||
raise RuntimeError(reason or "Model not available")
|
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
|
||||||
except Exception as e:
|
case LLMType.OCR.value:
|
||||||
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
|
assert factory in OcrModel, f"OCR model from {factory} is not supported yet."
|
||||||
else:
|
try:
|
||||||
# TODO: check other type of models
|
mdl = OcrModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
|
||||||
pass
|
ok, reason = mdl.check_available()
|
||||||
|
if not ok:
|
||||||
|
raise RuntimeError(reason or "Model not available")
|
||||||
|
except Exception as e:
|
||||||
|
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
|
||||||
|
case LLMType.SPEECH2TEXT:
|
||||||
|
assert factory in Seq2txtModel, f"Speech model from {factory} is not supported yet."
|
||||||
|
try:
|
||||||
|
mdl = Seq2txtModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
|
||||||
|
# TODO: check the availability
|
||||||
|
except Exception as e:
|
||||||
|
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
|
||||||
|
case _:
|
||||||
|
raise RuntimeError(f"Unknown model type: {model_type}")
|
||||||
|
|
||||||
if msg:
|
if msg:
|
||||||
return get_data_error_result(message=msg)
|
return get_data_error_result(message=msg)
|
||||||
|
|||||||
Reference in New Issue
Block a user