Add pagerank to KB. (#3809)

### What problem does this PR solve?

#3794

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Kevin Hu
2024-12-03 14:30:35 +08:00
committed by GitHub
parent 7543047de3
commit 74b28ef1b0
11 changed files with 67 additions and 26 deletions

View File

@ -75,7 +75,7 @@ class Dealer:
src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd",
"doc_id", "position_list", "knowledge_graph_kwd",
"available_int", "content_with_weight"])
"available_int", "content_with_weight", "pagerank_fea"])
kwds = set([])
qst = req.get("question", "")
@ -234,11 +234,13 @@ class Dealer:
vector_column = f"q_{vector_size}_vec"
zero_vector = [0.0] * vector_size
ins_embd = []
pageranks = []
for chunk_id in sres.ids:
vector = sres.field[chunk_id].get(vector_column, zero_vector)
if isinstance(vector, str):
vector = [float(v) for v in vector.split("\t")]
ins_embd.append(vector)
pageranks.append(sres.field[chunk_id].get("pagerank_fea", 0))
if not ins_embd:
return [], [], []
@ -257,7 +259,8 @@ class Dealer:
ins_embd,
keywords,
ins_tw, tkweight, vtweight)
return sim, tksim, vtsim
return sim+np.array(pageranks, dtype=float), tksim, vtsim
def rerank_by_model(self, rerank_mdl, sres, query, tkweight=0.3,
vtweight=0.7, cfield="content_ltks"):
@ -351,7 +354,7 @@ class Dealer:
"vector": chunk.get(vector_column, zero_vector),
"positions": json.loads(position_list)
}
if highlight:
if highlight and sres.highlight:
if id in sres.highlight:
d["highlight"] = rmSpace(sres.highlight[id])
else:

View File

@ -201,6 +201,7 @@ def build_chunks(task, progress_callback):
"doc_id": task["doc_id"],
"kb_id": str(task["kb_id"])
}
if task["pagerank"]: doc["pagerank_fea"] = int(task["pagerank"])
el = 0
for ck in cks:
d = copy.deepcopy(doc)
@ -339,6 +340,7 @@ def run_raptor(row, chat_mdl, embd_mdl, callback=None):
"docnm_kwd": row["name"],
"title_tks": rag_tokenizer.tokenize(row["name"])
}
if row["pagerank"]: doc["pagerank_fea"] = int(row["pagerank"])
res = []
tk_count = 0
for content, vctr in chunks[original_length:]:
@ -431,7 +433,7 @@ def do_handle_task(task):
progress_callback(prog=0.8 + 0.1 * (b + 1) / len(chunks), msg="")
logging.info("Indexing {} elapsed: {:.2f}".format(task_document_name, timer() - start_ts))
if doc_store_result:
error_message = "Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
progress_callback(-1, msg=error_message)
settings.docStoreConn.delete({"doc_id": task_doc_id}, search.index_name(task_tenant_id), task_dataset_id)
logging.error(error_message)

View File

@ -175,6 +175,7 @@ class ESConnection(DocStoreConnection):
)
if bqry:
bqry.should.append(Q("rank_feature", field="pagerank_fea", linear={}, boost=10))
s = s.query(bqry)
for field in highlightFields:
s = s.highlight(field)
@ -283,12 +284,16 @@ class ESConnection(DocStoreConnection):
f"ESConnection.update(index={indexName}, id={id}, doc={json.dumps(condition, ensure_ascii=False)}) got exception")
if str(e).find("Timeout") > 0:
continue
return False
else:
# update unspecific maybe-multiple documents
bqry = Q("bool")
for k, v in condition.items():
if not isinstance(k, str) or not v:
continue
if k == "exist":
bqry.filter.append(Q("exists", field=v))
continue
if isinstance(v, list):
bqry.filter.append(Q("terms", **{k: v}))
elif isinstance(v, str) or isinstance(v, int):
@ -298,6 +303,9 @@ class ESConnection(DocStoreConnection):
f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
scripts = []
for k, v in newValue.items():
if k == "remove":
scripts.append(f"ctx._source.remove('{v}');")
continue
if (not isinstance(k, str) or not v) and k != "available_int":
continue
if isinstance(v, str):
@ -307,21 +315,21 @@ class ESConnection(DocStoreConnection):
else:
raise Exception(
f"newValue `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str.")
ubq = UpdateByQuery(
index=indexName).using(
self.es).query(bqry)
ubq = ubq.script(source="; ".join(scripts))
ubq = ubq.params(refresh=True)
ubq = ubq.params(slices=5)
ubq = ubq.params(conflicts="proceed")
for i in range(3):
try:
_ = ubq.execute()
return True
except Exception as e:
logger.error("ESConnection.update got exception: " + str(e))
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
continue
ubq = UpdateByQuery(
index=indexName).using(
self.es).query(bqry)
ubq = ubq.script(source="; ".join(scripts))
ubq = ubq.params(refresh=True)
ubq = ubq.params(slices=5)
ubq = ubq.params(conflicts="proceed")
for i in range(3):
try:
_ = ubq.execute()
return True
except Exception as e:
logger.error("ESConnection.update got exception: " + str(e))
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
continue
return False
def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int: