diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index d5d0e1664..24e46ad83 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -482,6 +482,7 @@ def chat(dialog, messages, stream=True, **kwargs): cks = retriever.retrieval_by_toc(" ".join(questions), kbinfos["chunks"], tenant_ids, chat_mdl, dialog.top_n) if cks: kbinfos["chunks"] = cks + kbinfos["chunks"] = retriever.retrieval_by_children(kbinfos["chunks"], tenant_ids) if prompt_config.get("tavily_api_key"): tav = Tavily(prompt_config["tavily_api_key"]) tav_res = tav.retrieve_chunks(" ".join(questions)) diff --git a/rag/nlp/search.py b/rag/nlp/search.py index 084f7b48f..4f64c1f8f 100644 --- a/rag/nlp/search.py +++ b/rag/nlp/search.py @@ -17,7 +17,7 @@ import json import logging import re import math -from collections import OrderedDict +from collections import OrderedDict, defaultdict from dataclasses import dataclass from rag.prompts.generator import relevant_chunks_with_toc @@ -640,3 +640,50 @@ class Dealer: chunks.append(d) return sorted(chunks, key=lambda x:x["similarity"]*-1)[:topn] + + def retrieval_by_children(self, chunks:list[dict], tenant_ids:list[str]): + if not chunks: + return [] + idx_nms = [index_name(tid) for tid in tenant_ids] + mom_chunks = defaultdict([]) + i = 0 + while i < len(chunks): + ck = chunks[i] + if not ck.get("mom_id"): + i += 1 + continue + mom_chunks[ck["mom_id"]].append(chunks.pop(i)) + + if not mom_chunks: + return chunks + + if not chunks: + chunks = [] + + vector_size = 1024 + for id, cks in mom_chunks.items(): + chunk = self.dataStore.get(id, idx_nms, [ck["kb_id"] for ck in cks]) + d = { + "chunk_id": id, + "content_ltks": " ".join([ck["content_ltks"] for ck in cks]), + "content_with_weight": chunk["content_with_weight"], + "doc_id": chunk["doc_id"], + "docnm_kwd": chunk.get("docnm_kwd", ""), + "kb_id": chunk["kb_id"], + "important_kwd": [kwd for ck in cks for kwd in ck.get("important_kwd", [])], + "image_id": chunk.get("img_id", ""), + "similarity": np.mean([ck["similarity"] for ck in cks]), + "vector_similarity": np.mean([ck["similarity"] for ck in cks]), + "term_similarity": np.mean([ck["similarity"] for ck in cks]), + "vector": [0.0] * vector_size, + "positions": chunk.get("position_int", []), + "doc_type_kwd": chunk.get("doc_type_kwd", "") + } + for k in cks[0].keys(): + if k[-4:] == "_vec": + d["vector"] = cks[0][k] + vector_size = len(cks[0][k]) + break + chunks.append(d) + + return sorted(chunks, key=lambda x:x["similarity"]*-1) diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index d7cbced0c..714b886eb 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -734,7 +734,7 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c mom_ck["available_int"] = 0 flds = list(mom_ck.keys()) for fld in flds: - if fld not in ["id", "content_with_weight", "doc_id", "kb_id", "available_int"]: + if fld not in ["id", "content_with_weight", "doc_id", "kb_id", "available_int", "position_int"]: del mom_ck[fld] mothers.append(mom_ck)