# # Copyright 2025 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import json import logging import random from copy import deepcopy import xxhash from agent.component.llm import LLMParam, LLM from rag.flow.base import ProcessBase, ProcessParamBase from rag.prompts.generator import run_toc_from_text class ExtractorParam(ProcessParamBase, LLMParam): def __init__(self): super().__init__() self.field_name = "" def check(self): super().check() self.check_empty(self.field_name, "Result Destination") class Extractor(ProcessBase, LLM): component_name = "Extractor" async def _build_TOC(self, docs): self.callback(0.2,message="Start to generate table of content ...") docs = sorted(docs, key=lambda d:( d.get("page_num_int", 0)[0] if isinstance(d.get("page_num_int", 0), list) else d.get("page_num_int", 0), d.get("top_int", 0)[0] if isinstance(d.get("top_int", 0), list) else d.get("top_int", 0) )) toc = await run_toc_from_text([d["text"] for d in docs], self.chat_mdl) logging.info("------------ T O C -------------\n"+json.dumps(toc, ensure_ascii=False, indent=' ')) ii = 0 while ii < len(toc): try: idx = int(toc[ii]["chunk_id"]) del toc[ii]["chunk_id"] toc[ii]["ids"] = [docs[idx]["id"]] if ii == len(toc) -1: break for jj in range(idx+1, int(toc[ii+1]["chunk_id"])+1): toc[ii]["ids"].append(docs[jj]["id"]) except Exception as e: logging.exception(e) ii += 1 if toc: d = deepcopy(docs[-1]) d["doc_id"] = self._canvas._doc_id d["content_with_weight"] = json.dumps(toc, ensure_ascii=False) d["toc_kwd"] = "toc" d["available_int"] = 0 d["page_num_int"] = [100000000] d["id"] = xxhash.xxh64((d["content_with_weight"] + str(d["doc_id"])).encode("utf-8", "surrogatepass")).hexdigest() return d return None async def _invoke(self, **kwargs): self.set_output("output_format", "chunks") self.callback(random.randint(1, 5) / 100.0, "Start to generate.") inputs = self.get_input_elements() chunks = [] chunks_key = "" args = {} for k, v in inputs.items(): args[k] = v["value"] if isinstance(args[k], list): chunks = deepcopy(args[k]) chunks_key = k if chunks: if self._param.field_name == "toc": for ck in chunks: ck["doc_id"] = self._canvas._doc_id ck["id"] = xxhash.xxh64((ck["text"] + str(ck["doc_id"])).encode("utf-8")).hexdigest() toc =await self._build_TOC(chunks) chunks.append(toc) self.set_output("chunks", chunks) return prog = 0 for i, ck in enumerate(chunks): args[chunks_key] = ck["text"] msg, sys_prompt = self._sys_prompt_and_msg([], args) msg.insert(0, {"role": "system", "content": sys_prompt}) ck[self._param.field_name] = self._generate(msg) prog += 1./len(chunks) if i % (len(chunks)//100+1) == 1: self.callback(prog, f"{i+1} / {len(chunks)}") self.set_output("chunks", chunks) else: msg, sys_prompt = self._sys_prompt_and_msg([], args) msg.insert(0, {"role": "system", "content": sys_prompt}) self.set_output("chunks", [{self._param.field_name: self._generate(msg)}])