mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
add rerank model (#969)
### What problem does this PR solve? feat: add rerank models to the project #724 #162 ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -20,7 +20,7 @@ from api.utils.api_utils import server_error_response, get_data_error_result, va
|
||||
from api.db import StatusEnum, LLMType
|
||||
from api.db.db_models import TenantLLM
|
||||
from api.utils.api_utils import get_json_result
|
||||
from rag.llm import EmbeddingModel, ChatModel
|
||||
from rag.llm import EmbeddingModel, ChatModel, RerankModel
|
||||
|
||||
|
||||
@manager.route('/factories', methods=['GET'])
|
||||
@ -28,7 +28,7 @@ from rag.llm import EmbeddingModel, ChatModel
|
||||
def factories():
|
||||
try:
|
||||
fac = LLMFactoriesService.get_all()
|
||||
return get_json_result(data=[f.to_dict() for f in fac if f.name not in ["Youdao", "FastEmbed"]])
|
||||
return get_json_result(data=[f.to_dict() for f in fac if f.name not in ["Youdao", "FastEmbed", "BAAI"]])
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@ -64,6 +64,16 @@ def set_api_key():
|
||||
except Exception as e:
|
||||
msg += f"\nFail to access model({llm.llm_name}) using this api key." + str(
|
||||
e)
|
||||
elif llm.model_type == LLMType.RERANK:
|
||||
mdl = RerankModel[factory](
|
||||
req["api_key"], llm.llm_name, base_url=req.get("base_url"))
|
||||
try:
|
||||
m, tc = mdl.similarity("What's the weather?", ["Is it sunny today?"])
|
||||
if len(arr[0]) == 0 or tc == 0:
|
||||
raise Exception("Fail")
|
||||
except Exception as e:
|
||||
msg += f"\nFail to access model({llm.llm_name}) using this api key." + str(
|
||||
e)
|
||||
|
||||
if msg:
|
||||
return get_data_error_result(retmsg=msg)
|
||||
@ -199,7 +209,7 @@ def list_app():
|
||||
llms = [m.to_dict()
|
||||
for m in llms if m.status == StatusEnum.VALID.value]
|
||||
for m in llms:
|
||||
m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" or m["fid"] in ["Youdao","FastEmbed"]
|
||||
m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" or m["fid"] in ["Youdao","FastEmbed", "BAAI"]
|
||||
|
||||
llm_set = set([m["llm_name"] for m in llms])
|
||||
for o in objs:
|
||||
|
||||
Reference in New Issue
Block a user