mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Feat: TOC retrieval (#10456)
### What problem does this PR solve? #10436 ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -13,12 +13,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import math
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
|
||||
from rag.prompts.generator import relevant_chunks_with_toc
|
||||
from rag.settings import TAG_FLD, PAGERANK_FLD
|
||||
from rag.utils import rmSpace, get_float
|
||||
from rag.nlp import rag_tokenizer, query
|
||||
@ -514,3 +516,63 @@ class Dealer:
|
||||
tag_fea = sorted([(a, round(0.1*(c + 1) / (cnt + S) / max(1e-6, all_tags.get(a, 0.0001)))) for a, c in aggs],
|
||||
key=lambda x: x[1] * -1)[:topn_tags]
|
||||
return {a.replace(".", "_"): max(1, c) for a, c in tag_fea}
|
||||
|
||||
def retrieval_by_toc(self, query:str, chunks:list[dict], tenant_ids:list[str], chat_mdl, topn: int=6):
|
||||
if not chunks:
|
||||
return []
|
||||
idx_nms = [index_name(tid) for tid in tenant_ids]
|
||||
ranks, doc_id2kb_id = {}, {}
|
||||
for ck in chunks:
|
||||
if ck["doc_id"] not in ranks:
|
||||
ranks[ck["doc_id"]] = 0
|
||||
ranks[ck["doc_id"]] += ck["similarity"]
|
||||
doc_id2kb_id[ck["doc_id"]] = ck["kb_id"]
|
||||
doc_id = sorted(ranks.items(), key=lambda x: x[1]*-1.)[0][0]
|
||||
kb_ids = [doc_id2kb_id[doc_id]]
|
||||
es_res = self.dataStore.search(["content_with_weight"], [], {"doc_id": doc_id, "toc_kwd": "toc"}, [], OrderByExpr(), 0, 128, idx_nms,
|
||||
kb_ids)
|
||||
toc = []
|
||||
dict_chunks = self.dataStore.getFields(es_res, ["content_with_weight"])
|
||||
for _, doc in dict_chunks.items():
|
||||
try:
|
||||
toc.extend(json.loads(doc["content_with_weight"]))
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
if not toc:
|
||||
return chunks
|
||||
|
||||
ids = relevant_chunks_with_toc(query, toc, chat_mdl, topn*2)
|
||||
if not ids:
|
||||
return chunks
|
||||
|
||||
vector_size = 1024
|
||||
id2idx = {ck["chunk_id"]: i for i, ck in enumerate(chunks)}
|
||||
for cid, sim in ids:
|
||||
if cid in id2idx:
|
||||
chunks[id2idx[cid]]["similarity"] += sim
|
||||
continue
|
||||
chunk = self.dataStore.get(cid, idx_nms, kb_ids)
|
||||
d = {
|
||||
"chunk_id": cid,
|
||||
"content_ltks": chunk["content_ltks"],
|
||||
"content_with_weight": chunk["content_with_weight"],
|
||||
"doc_id": doc_id,
|
||||
"docnm_kwd": chunk.get("docnm_kwd", ""),
|
||||
"kb_id": chunk["kb_id"],
|
||||
"important_kwd": chunk.get("important_kwd", []),
|
||||
"image_id": chunk.get("img_id", ""),
|
||||
"similarity": sim,
|
||||
"vector_similarity": sim,
|
||||
"term_similarity": sim,
|
||||
"vector": [0.0] * vector_size,
|
||||
"positions": chunk.get("position_int", []),
|
||||
"doc_type_kwd": chunk.get("doc_type_kwd", "")
|
||||
}
|
||||
for k in chunk.keys():
|
||||
if k[-4:] == "_vec":
|
||||
d["vector"] = chunk[k]
|
||||
vector_size = len(chunk[k])
|
||||
break
|
||||
chunks.append(d)
|
||||
|
||||
return sorted(chunks, key=lambda x:x["similarity"]*-1)[:topn]
|
||||
|
||||
Reference in New Issue
Block a user