mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Integration with Infinity (#2894)
### What problem does this PR solve? Integration with Infinity - Replaced ELASTICSEARCH with dataStoreConn - Renamed deleteByQuery with delete - Renamed bulk to upsertBulk - getHighlight, getAggregation - Fix KGSearch.search - Moved Dealer.sql_retrieval to es_conn.py ### Type of change - [x] Refactoring
This commit is contained in:
@ -14,34 +14,25 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import json
|
||||
import re
|
||||
from copy import deepcopy
|
||||
|
||||
from elasticsearch_dsl import Q, Search
|
||||
import json
|
||||
from typing import List, Optional, Dict, Union
|
||||
from dataclasses import dataclass
|
||||
|
||||
from rag.settings import es_logger
|
||||
from rag.settings import doc_store_logger
|
||||
from rag.utils import rmSpace
|
||||
from rag.nlp import rag_tokenizer, query, is_english
|
||||
from rag.nlp import rag_tokenizer, query
|
||||
import numpy as np
|
||||
from rag.utils.doc_store_conn import DocStoreConnection, MatchDenseExpr, FusionExpr, OrderByExpr
|
||||
|
||||
|
||||
def index_name(uid): return f"ragflow_{uid}"
|
||||
|
||||
|
||||
class Dealer:
|
||||
def __init__(self, es):
|
||||
self.qryr = query.EsQueryer(es)
|
||||
self.qryr.flds = [
|
||||
"title_tks^10",
|
||||
"title_sm_tks^5",
|
||||
"important_kwd^30",
|
||||
"important_tks^20",
|
||||
"content_ltks^2",
|
||||
"content_sm_ltks"]
|
||||
self.es = es
|
||||
def __init__(self, dataStore: DocStoreConnection):
|
||||
self.qryr = query.FulltextQueryer()
|
||||
self.dataStore = dataStore
|
||||
|
||||
@dataclass
|
||||
class SearchResult:
|
||||
@ -54,170 +45,99 @@ class Dealer:
|
||||
keywords: Optional[List[str]] = None
|
||||
group_docs: List[List] = None
|
||||
|
||||
def _vector(self, txt, emb_mdl, sim=0.8, topk=10):
|
||||
qv, c = emb_mdl.encode_queries(txt)
|
||||
return {
|
||||
"field": "q_%d_vec" % len(qv),
|
||||
"k": topk,
|
||||
"similarity": sim,
|
||||
"num_candidates": topk * 2,
|
||||
"query_vector": [float(v) for v in qv]
|
||||
}
|
||||
def get_vector(self, txt, emb_mdl, topk=10, similarity=0.1):
|
||||
qv, _ = emb_mdl.encode_queries(txt)
|
||||
embedding_data = [float(v) for v in qv]
|
||||
vector_column_name = f"q_{len(embedding_data)}_vec"
|
||||
return MatchDenseExpr(vector_column_name, embedding_data, 'float', 'cosine', topk, {"similarity": similarity})
|
||||
|
||||
def _add_filters(self, bqry, req):
|
||||
if req.get("kb_ids"):
|
||||
bqry.filter.append(Q("terms", kb_id=req["kb_ids"]))
|
||||
if req.get("doc_ids"):
|
||||
bqry.filter.append(Q("terms", doc_id=req["doc_ids"]))
|
||||
if req.get("knowledge_graph_kwd"):
|
||||
bqry.filter.append(Q("terms", knowledge_graph_kwd=req["knowledge_graph_kwd"]))
|
||||
if "available_int" in req:
|
||||
if req["available_int"] == 0:
|
||||
bqry.filter.append(Q("range", available_int={"lt": 1}))
|
||||
else:
|
||||
bqry.filter.append(
|
||||
Q("bool", must_not=Q("range", available_int={"lt": 1})))
|
||||
return bqry
|
||||
def get_filters(self, req):
|
||||
condition = dict()
|
||||
for key, field in {"kb_ids": "kb_id", "doc_ids": "doc_id"}.items():
|
||||
if key in req and req[key] is not None:
|
||||
condition[field] = req[key]
|
||||
# TODO(yzc): `available_int` is nullable however infinity doesn't support nullable columns.
|
||||
for key in ["knowledge_graph_kwd"]:
|
||||
if key in req and req[key] is not None:
|
||||
condition[key] = req[key]
|
||||
return condition
|
||||
|
||||
def search(self, req, idxnms, emb_mdl=None, highlight=False):
|
||||
qst = req.get("question", "")
|
||||
bqry, keywords = self.qryr.question(qst, min_match="30%")
|
||||
bqry = self._add_filters(bqry, req)
|
||||
bqry.boost = 0.05
|
||||
def search(self, req, idx_names: list[str], kb_ids: list[str], emb_mdl=None, highlight = False):
|
||||
filters = self.get_filters(req)
|
||||
orderBy = OrderByExpr()
|
||||
|
||||
s = Search()
|
||||
pg = int(req.get("page", 1)) - 1
|
||||
topk = int(req.get("topk", 1024))
|
||||
ps = int(req.get("size", topk))
|
||||
offset, limit = pg * ps, (pg + 1) * ps
|
||||
|
||||
src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd",
|
||||
"image_id", "doc_id", "q_512_vec", "q_768_vec", "position_int", "knowledge_graph_kwd",
|
||||
"q_1024_vec", "q_1536_vec", "available_int", "content_with_weight"])
|
||||
|
||||
s = s.query(bqry)[pg * ps:(pg + 1) * ps]
|
||||
s = s.highlight("content_ltks")
|
||||
s = s.highlight("title_ltks")
|
||||
if not qst:
|
||||
if not req.get("sort"):
|
||||
s = s.sort(
|
||||
#{"create_time": {"order": "desc", "unmapped_type": "date"}},
|
||||
{"create_timestamp_flt": {
|
||||
"order": "desc", "unmapped_type": "float"}}
|
||||
)
|
||||
else:
|
||||
s = s.sort(
|
||||
{"page_num_int": {"order": "asc", "unmapped_type": "float",
|
||||
"mode": "avg", "numeric_type": "double"}},
|
||||
{"top_int": {"order": "asc", "unmapped_type": "float",
|
||||
"mode": "avg", "numeric_type": "double"}},
|
||||
#{"create_time": {"order": "desc", "unmapped_type": "date"}},
|
||||
{"create_timestamp_flt": {
|
||||
"order": "desc", "unmapped_type": "float"}}
|
||||
)
|
||||
|
||||
if qst:
|
||||
s = s.highlight_options(
|
||||
fragment_size=120,
|
||||
number_of_fragments=5,
|
||||
boundary_scanner_locale="zh-CN",
|
||||
boundary_scanner="SENTENCE",
|
||||
boundary_chars=",./;:\\!(),。?:!……()——、"
|
||||
)
|
||||
s = s.to_dict()
|
||||
q_vec = []
|
||||
if req.get("vector"):
|
||||
assert emb_mdl, "No embedding model selected"
|
||||
s["knn"] = self._vector(
|
||||
qst, emb_mdl, req.get(
|
||||
"similarity", 0.1), topk)
|
||||
s["knn"]["filter"] = bqry.to_dict()
|
||||
if not highlight and "highlight" in s:
|
||||
del s["highlight"]
|
||||
q_vec = s["knn"]["query_vector"]
|
||||
es_logger.info("【Q】: {}".format(json.dumps(s)))
|
||||
res = self.es.search(deepcopy(s), idxnms=idxnms, timeout="600s", src=src)
|
||||
es_logger.info("TOTAL: {}".format(self.es.getTotal(res)))
|
||||
if self.es.getTotal(res) == 0 and "knn" in s:
|
||||
bqry, _ = self.qryr.question(qst, min_match="10%")
|
||||
if req.get("doc_ids"):
|
||||
bqry = Q("bool", must=[])
|
||||
bqry = self._add_filters(bqry, req)
|
||||
s["query"] = bqry.to_dict()
|
||||
s["knn"]["filter"] = bqry.to_dict()
|
||||
s["knn"]["similarity"] = 0.17
|
||||
res = self.es.search(s, idxnms=idxnms, timeout="600s", src=src)
|
||||
es_logger.info("【Q】: {}".format(json.dumps(s)))
|
||||
|
||||
"doc_id", "position_list", "knowledge_graph_kwd",
|
||||
"available_int", "content_with_weight"])
|
||||
kwds = set([])
|
||||
for k in keywords:
|
||||
kwds.add(k)
|
||||
for kk in rag_tokenizer.fine_grained_tokenize(k).split(" "):
|
||||
if len(kk) < 2:
|
||||
continue
|
||||
if kk in kwds:
|
||||
continue
|
||||
kwds.add(kk)
|
||||
|
||||
aggs = self.getAggregation(res, "docnm_kwd")
|
||||
qst = req.get("question", "")
|
||||
q_vec = []
|
||||
if not qst:
|
||||
if req.get("sort"):
|
||||
orderBy.desc("create_timestamp_flt")
|
||||
res = self.dataStore.search(src, [], filters, [], orderBy, offset, limit, idx_names, kb_ids)
|
||||
total=self.dataStore.getTotal(res)
|
||||
doc_store_logger.info("Dealer.search TOTAL: {}".format(total))
|
||||
else:
|
||||
highlightFields = ["content_ltks", "title_tks"] if highlight else []
|
||||
matchText, keywords = self.qryr.question(qst, min_match=0.3)
|
||||
if emb_mdl is None:
|
||||
matchExprs = [matchText]
|
||||
res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit, idx_names, kb_ids)
|
||||
total=self.dataStore.getTotal(res)
|
||||
doc_store_logger.info("Dealer.search TOTAL: {}".format(total))
|
||||
else:
|
||||
matchDense = self.get_vector(qst, emb_mdl, topk, req.get("similarity", 0.1))
|
||||
q_vec = matchDense.embedding_data
|
||||
src.append(f"q_{len(q_vec)}_vec")
|
||||
|
||||
fusionExpr = FusionExpr("weighted_sum", topk, {"weights": "0.05, 0.95"})
|
||||
matchExprs = [matchText, matchDense, fusionExpr]
|
||||
|
||||
res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit, idx_names, kb_ids)
|
||||
total=self.dataStore.getTotal(res)
|
||||
doc_store_logger.info("Dealer.search TOTAL: {}".format(total))
|
||||
|
||||
# If result is empty, try again with lower min_match
|
||||
if total == 0:
|
||||
matchText, _ = self.qryr.question(qst, min_match=0.1)
|
||||
if "doc_ids" in filters:
|
||||
del filters["doc_ids"]
|
||||
matchDense.extra_options["similarity"] = 0.17
|
||||
res = self.dataStore.search(src, highlightFields, filters, [matchText, matchDense, fusionExpr], orderBy, offset, limit, idx_names, kb_ids)
|
||||
total=self.dataStore.getTotal(res)
|
||||
doc_store_logger.info("Dealer.search 2 TOTAL: {}".format(total))
|
||||
|
||||
for k in keywords:
|
||||
kwds.add(k)
|
||||
for kk in rag_tokenizer.fine_grained_tokenize(k).split(" "):
|
||||
if len(kk) < 2:
|
||||
continue
|
||||
if kk in kwds:
|
||||
continue
|
||||
kwds.add(kk)
|
||||
|
||||
doc_store_logger.info(f"TOTAL: {total}")
|
||||
ids=self.dataStore.getChunkIds(res)
|
||||
keywords=list(kwds)
|
||||
highlight = self.dataStore.getHighlight(res, keywords, "content_with_weight")
|
||||
aggs = self.dataStore.getAggregation(res, "docnm_kwd")
|
||||
return self.SearchResult(
|
||||
total=self.es.getTotal(res),
|
||||
ids=self.es.getDocIds(res),
|
||||
total=total,
|
||||
ids=ids,
|
||||
query_vector=q_vec,
|
||||
aggregation=aggs,
|
||||
highlight=self.getHighlight(res, keywords, "content_with_weight"),
|
||||
field=self.getFields(res, src),
|
||||
keywords=list(kwds)
|
||||
highlight=highlight,
|
||||
field=self.dataStore.getFields(res, src),
|
||||
keywords=keywords
|
||||
)
|
||||
|
||||
def getAggregation(self, res, g):
|
||||
if not "aggregations" in res or "aggs_" + g not in res["aggregations"]:
|
||||
return
|
||||
bkts = res["aggregations"]["aggs_" + g]["buckets"]
|
||||
return [(b["key"], b["doc_count"]) for b in bkts]
|
||||
|
||||
def getHighlight(self, res, keywords, fieldnm):
|
||||
ans = {}
|
||||
for d in res["hits"]["hits"]:
|
||||
hlts = d.get("highlight")
|
||||
if not hlts:
|
||||
continue
|
||||
txt = "...".join([a for a in list(hlts.items())[0][1]])
|
||||
if not is_english(txt.split(" ")):
|
||||
ans[d["_id"]] = txt
|
||||
continue
|
||||
|
||||
txt = d["_source"][fieldnm]
|
||||
txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE|re.MULTILINE)
|
||||
txts = []
|
||||
for t in re.split(r"[.?!;\n]", txt):
|
||||
for w in keywords:
|
||||
t = re.sub(r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])"%re.escape(w), r"\1<em>\2</em>\3", t, flags=re.IGNORECASE|re.MULTILINE)
|
||||
if not re.search(r"<em>[^<>]+</em>", t, flags=re.IGNORECASE|re.MULTILINE): continue
|
||||
txts.append(t)
|
||||
ans[d["_id"]] = "...".join(txts) if txts else "...".join([a for a in list(hlts.items())[0][1]])
|
||||
|
||||
return ans
|
||||
|
||||
def getFields(self, sres, flds):
|
||||
res = {}
|
||||
if not flds:
|
||||
return {}
|
||||
for d in self.es.getSource(sres):
|
||||
m = {n: d.get(n) for n in flds if d.get(n) is not None}
|
||||
for n, v in m.items():
|
||||
if isinstance(v, type([])):
|
||||
m[n] = "\t".join([str(vv) if not isinstance(
|
||||
vv, list) else "\t".join([str(vvv) for vvv in vv]) for vv in v])
|
||||
continue
|
||||
if not isinstance(v, type("")):
|
||||
m[n] = str(m[n])
|
||||
#if n.find("tks") > 0:
|
||||
# m[n] = rmSpace(m[n])
|
||||
|
||||
if m:
|
||||
res[d["id"]] = m
|
||||
return res
|
||||
|
||||
@staticmethod
|
||||
def trans2floats(txt):
|
||||
return [float(t) for t in txt.split("\t")]
|
||||
@ -260,7 +180,7 @@ class Dealer:
|
||||
continue
|
||||
idx.append(i)
|
||||
pieces_.append(t)
|
||||
es_logger.info("{} => {}".format(answer, pieces_))
|
||||
doc_store_logger.info("{} => {}".format(answer, pieces_))
|
||||
if not pieces_:
|
||||
return answer, set([])
|
||||
|
||||
@ -281,7 +201,7 @@ class Dealer:
|
||||
chunks_tks,
|
||||
tkweight, vtweight)
|
||||
mx = np.max(sim) * 0.99
|
||||
es_logger.info("{} SIM: {}".format(pieces_[i], mx))
|
||||
doc_store_logger.info("{} SIM: {}".format(pieces_[i], mx))
|
||||
if mx < thr:
|
||||
continue
|
||||
cites[idx[i]] = list(
|
||||
@ -309,9 +229,15 @@ class Dealer:
|
||||
def rerank(self, sres, query, tkweight=0.3,
|
||||
vtweight=0.7, cfield="content_ltks"):
|
||||
_, keywords = self.qryr.question(query)
|
||||
ins_embd = [
|
||||
Dealer.trans2floats(
|
||||
sres.field[i].get("q_%d_vec" % len(sres.query_vector), "\t".join(["0"] * len(sres.query_vector)))) for i in sres.ids]
|
||||
vector_size = len(sres.query_vector)
|
||||
vector_column = f"q_{vector_size}_vec"
|
||||
zero_vector = [0.0] * vector_size
|
||||
ins_embd = []
|
||||
for chunk_id in sres.ids:
|
||||
vector = sres.field[chunk_id].get(vector_column, zero_vector)
|
||||
if isinstance(vector, str):
|
||||
vector = [float(v) for v in vector.split("\t")]
|
||||
ins_embd.append(vector)
|
||||
if not ins_embd:
|
||||
return [], [], []
|
||||
|
||||
@ -377,7 +303,7 @@ class Dealer:
|
||||
if isinstance(tenant_ids, str):
|
||||
tenant_ids = tenant_ids.split(",")
|
||||
|
||||
sres = self.search(req, [index_name(tid) for tid in tenant_ids], embd_mdl, highlight)
|
||||
sres = self.search(req, [index_name(tid) for tid in tenant_ids], kb_ids, embd_mdl, highlight)
|
||||
ranks["total"] = sres.total
|
||||
|
||||
if page <= RERANK_PAGE_LIMIT:
|
||||
@ -393,6 +319,8 @@ class Dealer:
|
||||
idx = list(range(len(sres.ids)))
|
||||
|
||||
dim = len(sres.query_vector)
|
||||
vector_column = f"q_{dim}_vec"
|
||||
zero_vector = [0.0] * dim
|
||||
for i in idx:
|
||||
if sim[i] < similarity_threshold:
|
||||
break
|
||||
@ -401,34 +329,32 @@ class Dealer:
|
||||
continue
|
||||
break
|
||||
id = sres.ids[i]
|
||||
dnm = sres.field[id]["docnm_kwd"]
|
||||
did = sres.field[id]["doc_id"]
|
||||
chunk = sres.field[id]
|
||||
dnm = chunk["docnm_kwd"]
|
||||
did = chunk["doc_id"]
|
||||
position_list = chunk.get("position_list", "[]")
|
||||
if not position_list:
|
||||
position_list = "[]"
|
||||
d = {
|
||||
"chunk_id": id,
|
||||
"content_ltks": sres.field[id]["content_ltks"],
|
||||
"content_with_weight": sres.field[id]["content_with_weight"],
|
||||
"doc_id": sres.field[id]["doc_id"],
|
||||
"content_ltks": chunk["content_ltks"],
|
||||
"content_with_weight": chunk["content_with_weight"],
|
||||
"doc_id": chunk["doc_id"],
|
||||
"docnm_kwd": dnm,
|
||||
"kb_id": sres.field[id]["kb_id"],
|
||||
"important_kwd": sres.field[id].get("important_kwd", []),
|
||||
"img_id": sres.field[id].get("img_id", ""),
|
||||
"kb_id": chunk["kb_id"],
|
||||
"important_kwd": chunk.get("important_kwd", []),
|
||||
"image_id": chunk.get("img_id", ""),
|
||||
"similarity": sim[i],
|
||||
"vector_similarity": vsim[i],
|
||||
"term_similarity": tsim[i],
|
||||
"vector": self.trans2floats(sres.field[id].get("q_%d_vec" % dim, "\t".join(["0"] * dim))),
|
||||
"positions": sres.field[id].get("position_int", "").split("\t")
|
||||
"vector": chunk.get(vector_column, zero_vector),
|
||||
"positions": json.loads(position_list)
|
||||
}
|
||||
if highlight:
|
||||
if id in sres.highlight:
|
||||
d["highlight"] = rmSpace(sres.highlight[id])
|
||||
else:
|
||||
d["highlight"] = d["content_with_weight"]
|
||||
if len(d["positions"]) % 5 == 0:
|
||||
poss = []
|
||||
for i in range(0, len(d["positions"]), 5):
|
||||
poss.append([float(d["positions"][i]), float(d["positions"][i + 1]), float(d["positions"][i + 2]),
|
||||
float(d["positions"][i + 3]), float(d["positions"][i + 4])])
|
||||
d["positions"] = poss
|
||||
ranks["chunks"].append(d)
|
||||
if dnm not in ranks["doc_aggs"]:
|
||||
ranks["doc_aggs"][dnm] = {"doc_id": did, "count": 0}
|
||||
@ -442,39 +368,11 @@ class Dealer:
|
||||
return ranks
|
||||
|
||||
def sql_retrieval(self, sql, fetch_size=128, format="json"):
|
||||
from api.settings import chat_logger
|
||||
sql = re.sub(r"[ `]+", " ", sql)
|
||||
sql = sql.replace("%", "")
|
||||
es_logger.info(f"Get es sql: {sql}")
|
||||
replaces = []
|
||||
for r in re.finditer(r" ([a-z_]+_l?tks)( like | ?= ?)'([^']+)'", sql):
|
||||
fld, v = r.group(1), r.group(3)
|
||||
match = " MATCH({}, '{}', 'operator=OR;minimum_should_match=30%') ".format(
|
||||
fld, rag_tokenizer.fine_grained_tokenize(rag_tokenizer.tokenize(v)))
|
||||
replaces.append(
|
||||
("{}{}'{}'".format(
|
||||
r.group(1),
|
||||
r.group(2),
|
||||
r.group(3)),
|
||||
match))
|
||||
tbl = self.dataStore.sql(sql, fetch_size, format)
|
||||
return tbl
|
||||
|
||||
for p, r in replaces:
|
||||
sql = sql.replace(p, r, 1)
|
||||
chat_logger.info(f"To es: {sql}")
|
||||
|
||||
try:
|
||||
tbl = self.es.sql(sql, fetch_size, format)
|
||||
return tbl
|
||||
except Exception as e:
|
||||
chat_logger.error(f"SQL failure: {sql} =>" + str(e))
|
||||
return {"error": str(e)}
|
||||
|
||||
def chunk_list(self, doc_id, tenant_id, max_count=1024, fields=["docnm_kwd", "content_with_weight", "img_id"]):
|
||||
s = Search()
|
||||
s = s.query(Q("match", doc_id=doc_id))[0:max_count]
|
||||
s = s.to_dict()
|
||||
es_res = self.es.search(s, idxnms=index_name(tenant_id), timeout="600s", src=fields)
|
||||
res = []
|
||||
for index, chunk in enumerate(es_res['hits']['hits']):
|
||||
res.append({fld: chunk['_source'].get(fld) for fld in fields})
|
||||
return res
|
||||
def chunk_list(self, doc_id: str, tenant_id: str, kb_ids: list[str], max_count=1024, fields=["docnm_kwd", "content_with_weight", "img_id"]):
|
||||
condition = {"doc_id": doc_id}
|
||||
res = self.dataStore.search(fields, [], condition, [], OrderByExpr(), 0, max_count, index_name(tenant_id), kb_ids)
|
||||
dict_chunks = self.dataStore.getFields(res, fields)
|
||||
return dict_chunks.values()
|
||||
|
||||
Reference in New Issue
Block a user