add rerank model (#969)

### What problem does this PR solve?

feat: add rerank models to the project #724 #162

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
KevinHuSh
2024-05-29 16:50:02 +08:00
committed by GitHub
parent e1f0644deb
commit 614defec21
17 changed files with 437 additions and 64 deletions

View File

@ -115,11 +115,14 @@ def chat(dialog, messages, stream=True, **kwargs):
if "knowledge" not in [p["key"] for p in prompt_config["parameters"]]:
kbinfos = {"total": 0, "chunks": [], "doc_aggs": []}
else:
rerank_mdl = None
if dialog.rerank_id:
rerank_mdl = LLMBundle(dialog.tenant_id, LLMType.RERANK, dialog.rerank_id)
kbinfos = retrievaler.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n,
dialog.similarity_threshold,
dialog.vector_similarity_weight,
doc_ids=kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None,
top=1024, aggs=False)
top=1024, aggs=False, rerank_mdl=rerank_mdl)
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
chat_logger.info(
"{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
@ -130,7 +133,7 @@ def chat(dialog, messages, stream=True, **kwargs):
kwargs["knowledge"] = "\n".join(knowledges)
gen_conf = dialog.llm_setting
msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs)}]
msg.extend([{"role": m["role"], "content": m["content"]}
for m in messages if m["role"] != "system"])

View File

@ -15,7 +15,7 @@
#
from api.db.services.user_service import TenantService
from api.settings import database_logger
from rag.llm import EmbeddingModel, CvModel, ChatModel
from rag.llm import EmbeddingModel, CvModel, ChatModel, RerankModel
from api.db import LLMType
from api.db.db_models import DB, UserTenant
from api.db.db_models import LLMFactories, LLM, TenantLLM
@ -73,21 +73,25 @@ class TenantLLMService(CommonService):
mdlnm = tenant.img2txt_id
elif llm_type == LLMType.CHAT.value:
mdlnm = tenant.llm_id if not llm_name else llm_name
elif llm_type == LLMType.RERANK:
mdlnm = tenant.rerank_id if not llm_name else llm_name
else:
assert False, "LLM type error"
model_config = cls.get_api_key(tenant_id, mdlnm)
if model_config: model_config = model_config.to_dict()
if not model_config:
if llm_type == LLMType.EMBEDDING.value:
if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]:
llm = LLMService.query(llm_name=llm_name)
if llm and llm[0].fid in ["Youdao", "FastEmbed", "DeepSeek"]:
if llm and llm[0].fid in ["Youdao", "FastEmbed", "BAAI"]:
model_config = {"llm_factory": llm[0].fid, "api_key":"", "llm_name": llm_name, "api_base": ""}
if not model_config:
if llm_name == "flag-embedding":
model_config = {"llm_factory": "Tongyi-Qianwen", "api_key": "",
"llm_name": llm_name, "api_base": ""}
else:
if not mdlnm:
raise LookupError(f"Type of {llm_type} model is not set.")
raise LookupError("Model({}) not authorized".format(mdlnm))
if llm_type == LLMType.EMBEDDING.value:
@ -96,6 +100,12 @@ class TenantLLMService(CommonService):
return EmbeddingModel[model_config["llm_factory"]](
model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
if llm_type == LLMType.RERANK:
if model_config["llm_factory"] not in RerankModel:
return
return RerankModel[model_config["llm_factory"]](
model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
if llm_type == LLMType.IMAGE2TEXT.value:
if model_config["llm_factory"] not in CvModel:
return
@ -125,14 +135,20 @@ class TenantLLMService(CommonService):
mdlnm = tenant.img2txt_id
elif llm_type == LLMType.CHAT.value:
mdlnm = tenant.llm_id if not llm_name else llm_name
elif llm_type == LLMType.RERANK:
mdlnm = tenant.llm_id if not llm_name else llm_name
else:
assert False, "LLM type error"
num = 0
for u in cls.query(tenant_id = tenant_id, llm_name=mdlnm):
num += cls.model.update(used_tokens = u.used_tokens + used_tokens)\
.where(cls.model.tenant_id == tenant_id, cls.model.llm_name == mdlnm)\
.execute()
try:
for u in cls.query(tenant_id = tenant_id, llm_name=mdlnm):
num += cls.model.update(used_tokens = u.used_tokens + used_tokens)\
.where(cls.model.tenant_id == tenant_id, cls.model.llm_name == mdlnm)\
.execute()
except Exception as e:
print(e)
pass
return num
@classmethod
@ -176,6 +192,14 @@ class LLMBundle(object):
"Can't update token usage for {}/EMBEDDING".format(self.tenant_id))
return emd, used_tokens
def similarity(self, query: str, texts: list):
sim, used_tokens = self.mdl.similarity(query, texts)
if not TenantLLMService.increase_usage(
self.tenant_id, self.llm_type, used_tokens):
database_logger.error(
"Can't update token usage for {}/RERANK".format(self.tenant_id))
return sim, used_tokens
def describe(self, image, max_tokens=300):
txt, used_tokens = self.mdl.describe(image, max_tokens)
if not TenantLLMService.increase_usage(

View File

@ -93,6 +93,7 @@ class TenantService(CommonService):
cls.model.name,
cls.model.llm_id,
cls.model.embd_id,
cls.model.rerank_id,
cls.model.asr_id,
cls.model.img2txt_id,
cls.model.parser_ids,