refine code (#595)

### What problem does this PR solve?

### Type of change

- [x] Refactoring
This commit is contained in:
KevinHuSh
2024-04-28 19:13:33 +08:00
committed by GitHub
parent aee8b48d2f
commit 8c07992b6c
24 changed files with 538 additions and 116 deletions

View File

@ -9,7 +9,7 @@ from dataclasses import dataclass
from rag.settings import es_logger
from rag.utils import rmSpace
from rag.nlp import huqie, query
from rag.nlp import rag_tokenizer, query
import numpy as np
@ -128,7 +128,7 @@ class Dealer:
kwds = set([])
for k in keywords:
kwds.add(k)
for kk in huqie.qieqie(k).split(" "):
for kk in rag_tokenizer.fine_grained_tokenize(k).split(" "):
if len(kk) < 2:
continue
if kk in kwds:
@ -243,7 +243,7 @@ class Dealer:
assert len(ans_v[0]) == len(chunk_v[0]), "The dimension of query and chunk do not match: {} vs. {}".format(
len(ans_v[0]), len(chunk_v[0]))
chunks_tks = [huqie.qie(self.qryr.rmWWW(ck)).split(" ")
chunks_tks = [rag_tokenizer.tokenize(self.qryr.rmWWW(ck)).split(" ")
for ck in chunks]
cites = {}
thr = 0.63
@ -251,7 +251,7 @@ class Dealer:
for i, a in enumerate(pieces_):
sim, tksim, vtsim = self.qryr.hybrid_similarity(ans_v[i],
chunk_v,
huqie.qie(
rag_tokenizer.tokenize(
self.qryr.rmWWW(pieces_[i])).split(" "),
chunks_tks,
tkweight, vtweight)
@ -310,8 +310,8 @@ class Dealer:
def hybrid_similarity(self, ans_embd, ins_embd, ans, inst):
return self.qryr.hybrid_similarity(ans_embd,
ins_embd,
huqie.qie(ans).split(" "),
huqie.qie(inst).split(" "))
rag_tokenizer.tokenize(ans).split(" "),
rag_tokenizer.tokenize(inst).split(" "))
def retrieval(self, question, embd_mdl, tenant_id, kb_ids, page, page_size, similarity_threshold=0.2,
vector_similarity_weight=0.3, top=1024, doc_ids=None, aggs=True):
@ -385,7 +385,7 @@ class Dealer:
for r in re.finditer(r" ([a-z_]+_l?tks)( like | ?= ?)'([^']+)'", sql):
fld, v = r.group(1), r.group(3)
match = " MATCH({}, '{}', 'operator=OR;minimum_should_match=30%') ".format(
fld, huqie.qieqie(huqie.qie(v)))
fld, rag_tokenizer.fine_grained_tokenize(rag_tokenizer.tokenize(v)))
replaces.append(
("{}{}'{}'".format(
r.group(1),