diff --git a/agent/tools/retrieval.py b/agent/tools/retrieval.py index fd9096cb6..52e295fd0 100644 --- a/agent/tools/retrieval.py +++ b/agent/tools/retrieval.py @@ -198,6 +198,7 @@ class Retrieval(ToolBase, ABC): return if cks: kbinfos["chunks"] = cks + kbinfos["chunks"] = settings.retriever.retrieval_by_children(kbinfos["chunks"], [kb.tenant_id for kb in kbs]) if self._param.use_kg: ck = settings.kg_retriever.retrieval(query, [kb.tenant_id for kb in kbs], diff --git a/rag/flow/extractor/extractor.py b/rag/flow/extractor/extractor.py index 70464148a..45698b204 100644 --- a/rag/flow/extractor/extractor.py +++ b/rag/flow/extractor/extractor.py @@ -12,10 +12,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import json +import logging import random -from copy import deepcopy +from copy import deepcopy, copy + +import trio +import xxhash + from agent.component.llm import LLMParam, LLM from rag.flow.base import ProcessBase, ProcessParamBase +from rag.prompts.generator import run_toc_from_text class ExtractorParam(ProcessParamBase, LLMParam): @@ -31,6 +38,38 @@ class ExtractorParam(ProcessParamBase, LLMParam): class Extractor(ProcessBase, LLM): component_name = "Extractor" + def _build_TOC(self, docs): + self.callback(message="Start to generate table of content ...") + docs = sorted(docs, key=lambda d:( + d.get("page_num_int", 0)[0] if isinstance(d.get("page_num_int", 0), list) else d.get("page_num_int", 0), + d.get("top_int", 0)[0] if isinstance(d.get("top_int", 0), list) else d.get("top_int", 0) + )) + toc: list[dict] = trio.run(run_toc_from_text, [d["text"] for d in docs], self.chat_mdl) + logging.info("------------ T O C -------------\n"+json.dumps(toc, ensure_ascii=False, indent=' ')) + ii = 0 + while ii < len(toc): + try: + idx = int(toc[ii]["chunk_id"]) + del toc[ii]["chunk_id"] + toc[ii]["ids"] = [docs[idx]["id"]] + if ii == len(toc) -1: + break + for jj in range(idx+1, int(toc[ii+1]["chunk_id"])+1): + toc[ii]["ids"].append(docs[jj]["id"]) + except Exception as e: + logging.exception(e) + ii += 1 + + if toc: + d = copy.deepcopy(docs[-1]) + d["content_with_weight"] = json.dumps(toc, ensure_ascii=False) + d["toc_kwd"] = "toc" + d["available_int"] = 0 + d["page_num_int"] = [100000000] + d["id"] = xxhash.xxh64((d["content_with_weight"] + str(d["doc_id"])).encode("utf-8", "surrogatepass")).hexdigest() + return d + return None + async def _invoke(self, **kwargs): self.set_output("output_format", "chunks") self.callback(random.randint(1, 5) / 100.0, "Start to generate.") @@ -45,6 +84,12 @@ class Extractor(ProcessBase, LLM): chunks_key = k if chunks: + if self._param.field_name == "toc": + toc = self._build_TOC(chunks) + chunks.append(toc) + self.set_output("chunks", chunks) + return + prog = 0 for i, ck in enumerate(chunks): args[chunks_key] = ck["text"] diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index a91421007..25dea2b24 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -944,7 +944,7 @@ async def do_handle_task(task): logging.info(progress_message) progress_callback(msg=progress_message) if task["parser_id"].lower() == "naive" and task["parser_config"].get("toc_extraction", False): - toc_thread = executor.submit(build_TOC,task, chunks, progress_callback) + toc_thread = executor.submit(build_TOC, task, chunks, progress_callback) chunk_count = len(set([chunk["id"] for chunk in chunks])) start_ts = timer()