add rerank model (#969)

### What problem does this PR solve?

feat: add rerank models to the project #724 #162

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
KevinHuSh
2024-05-29 16:50:02 +08:00
committed by GitHub
parent e1f0644deb
commit 614defec21
17 changed files with 437 additions and 64 deletions

View File

@ -54,7 +54,8 @@ class EsQueryer:
if not self.isChinese(txt):
tks = rag_tokenizer.tokenize(txt).split(" ")
tks_w = self.tw.weights(tks)
q = [re.sub(r"[ \\\"']+", "", tk)+"^{:.4f}".format(w) for tk, w in tks_w]
tks_w = [(re.sub(r"[ \\\"']+", "", tk), w) for tk, w in tks_w]
q = ["{}^{:.4f}".format(tk, w) for tk, w in tks_w if tk]
for i in range(1, len(tks_w)):
q.append("\"%s %s\"^%.4f" % (tks_w[i - 1][0], tks_w[i][0], max(tks_w[i - 1][1], tks_w[i][1])*2))
if not q:
@ -136,7 +137,11 @@ class EsQueryer:
from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity
import numpy as np
sims = CosineSimilarity([avec], bvecs)
tksim = self.token_similarity(atks, btkss)
return np.array(sims[0]) * vtweight + \
np.array(tksim) * tkweight, tksim, sims[0]
def token_similarity(self, atks, btkss):
def toDict(tks):
d = {}
if isinstance(tks, str):
@ -149,9 +154,7 @@ class EsQueryer:
atks = toDict(atks)
btkss = [toDict(tks) for tks in btkss]
tksim = [self.similarity(atks, btks) for btks in btkss]
return np.array(sims[0]) * vtweight + \
np.array(tksim) * tkweight, tksim, sims[0]
return [self.similarity(atks, btks) for btks in btkss]
def similarity(self, qtwt, dtwt):
if isinstance(dtwt, type("")):

View File

@ -241,11 +241,14 @@ class RagTokenizer:
return self.score_(res[::-1])
def english_normalize_(self, tks):
return [self.stemmer.stem(self.lemmatizer.lemmatize(t)) if re.match(r"[a-zA-Z_-]+$", t) else t for t in tks]
def tokenize(self, line):
line = self._strQ2B(line).lower()
line = self._tradi2simp(line)
zh_num = len([1 for c in line if is_chinese(c)])
if zh_num < len(line) * 0.2:
if zh_num == 0:
return " ".join([self.stemmer.stem(self.lemmatizer.lemmatize(t)) for t in word_tokenize(line)])
arr = re.split(self.SPLIT_CHAR, line)
@ -293,7 +296,7 @@ class RagTokenizer:
i = e + 1
res = " ".join(res)
res = " ".join(self.english_normalize_(res))
if self.DEBUG:
print("[TKS]", self.merge_(res))
return self.merge_(res)
@ -336,7 +339,7 @@ class RagTokenizer:
res.append(stk)
return " ".join(res)
return " ".join(self.english_normalize_(res))
def is_chinese(s):

View File

@ -71,8 +71,8 @@ class Dealer:
s = Search()
pg = int(req.get("page", 1)) - 1
ps = int(req.get("size", 1000))
topk = int(req.get("topk", 1024))
ps = int(req.get("size", topk))
src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd",
"image_id", "doc_id", "q_512_vec", "q_768_vec", "position_int",
"q_1024_vec", "q_1536_vec", "available_int", "content_with_weight"])
@ -311,6 +311,26 @@ class Dealer:
ins_tw, tkweight, vtweight)
return sim, tksim, vtsim
def rerank_by_model(self, rerank_mdl, sres, query, tkweight=0.3,
vtweight=0.7, cfield="content_ltks"):
_, keywords = self.qryr.question(query)
for i in sres.ids:
if isinstance(sres.field[i].get("important_kwd", []), str):
sres.field[i]["important_kwd"] = [sres.field[i]["important_kwd"]]
ins_tw = []
for i in sres.ids:
content_ltks = sres.field[i][cfield].split(" ")
title_tks = [t for t in sres.field[i].get("title_tks", "").split(" ") if t]
important_kwd = sres.field[i].get("important_kwd", [])
tks = content_ltks + title_tks + important_kwd
ins_tw.append(tks)
tksim = self.qryr.token_similarity(keywords, ins_tw)
vtsim,_ = rerank_mdl.similarity(" ".join(keywords), [rmSpace(" ".join(tks)) for tks in ins_tw])
return tkweight*np.array(tksim) + vtweight*vtsim, tksim, vtsim
def hybrid_similarity(self, ans_embd, ins_embd, ans, inst):
return self.qryr.hybrid_similarity(ans_embd,
ins_embd,
@ -318,17 +338,22 @@ class Dealer:
rag_tokenizer.tokenize(inst).split(" "))
def retrieval(self, question, embd_mdl, tenant_id, kb_ids, page, page_size, similarity_threshold=0.2,
vector_similarity_weight=0.3, top=1024, doc_ids=None, aggs=True):
vector_similarity_weight=0.3, top=1024, doc_ids=None, aggs=True, rerank_mdl=None):
ranks = {"total": 0, "chunks": [], "doc_aggs": {}}
if not question:
return ranks
req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "size": page_size,
"question": question, "vector": True, "topk": top,
"similarity": similarity_threshold}
"similarity": similarity_threshold,
"available_int": 1}
sres = self.search(req, index_name(tenant_id), embd_mdl)
sim, tsim, vsim = self.rerank(
sres, question, 1 - vector_similarity_weight, vector_similarity_weight)
if rerank_mdl:
sim, tsim, vsim = self.rerank_by_model(rerank_mdl,
sres, question, 1 - vector_similarity_weight, vector_similarity_weight)
else:
sim, tsim, vsim = self.rerank(
sres, question, 1 - vector_similarity_weight, vector_similarity_weight)
idx = np.argsort(sim * -1)
dim = len(sres.query_vector)