add rerank model (#969)

### What problem does this PR solve?

feat: add rerank models to the project #724 #162

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
KevinHuSh
2024-05-29 16:50:02 +08:00
committed by GitHub
parent e1f0644deb
commit 614defec21
17 changed files with 437 additions and 64 deletions

View File

@ -20,7 +20,7 @@ from api.utils.api_utils import server_error_response, get_data_error_result, va
from api.db import StatusEnum, LLMType
from api.db.db_models import TenantLLM
from api.utils.api_utils import get_json_result
from rag.llm import EmbeddingModel, ChatModel
from rag.llm import EmbeddingModel, ChatModel, RerankModel
@manager.route('/factories', methods=['GET'])
@ -28,7 +28,7 @@ from rag.llm import EmbeddingModel, ChatModel
def factories():
try:
fac = LLMFactoriesService.get_all()
return get_json_result(data=[f.to_dict() for f in fac if f.name not in ["Youdao", "FastEmbed"]])
return get_json_result(data=[f.to_dict() for f in fac if f.name not in ["Youdao", "FastEmbed", "BAAI"]])
except Exception as e:
return server_error_response(e)
@ -64,6 +64,16 @@ def set_api_key():
except Exception as e:
msg += f"\nFail to access model({llm.llm_name}) using this api key." + str(
e)
elif llm.model_type == LLMType.RERANK:
mdl = RerankModel[factory](
req["api_key"], llm.llm_name, base_url=req.get("base_url"))
try:
m, tc = mdl.similarity("What's the weather?", ["Is it sunny today?"])
if len(arr[0]) == 0 or tc == 0:
raise Exception("Fail")
except Exception as e:
msg += f"\nFail to access model({llm.llm_name}) using this api key." + str(
e)
if msg:
return get_data_error_result(retmsg=msg)
@ -199,7 +209,7 @@ def list_app():
llms = [m.to_dict()
for m in llms if m.status == StatusEnum.VALID.value]
for m in llms:
m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" or m["fid"] in ["Youdao","FastEmbed"]
m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" or m["fid"] in ["Youdao","FastEmbed", "BAAI"]
llm_set = set([m["llm_name"] for m in llms])
for o in objs: