mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 04:22:28 +08:00
Feat: support TOC transformer. (#11685)
### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -198,6 +198,7 @@ class Retrieval(ToolBase, ABC):
|
||||
return
|
||||
if cks:
|
||||
kbinfos["chunks"] = cks
|
||||
kbinfos["chunks"] = settings.retriever.retrieval_by_children(kbinfos["chunks"], [kb.tenant_id for kb in kbs])
|
||||
if self._param.use_kg:
|
||||
ck = settings.kg_retriever.retrieval(query,
|
||||
[kb.tenant_id for kb in kbs],
|
||||
|
||||
@ -12,10 +12,17 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
from copy import deepcopy
|
||||
from copy import deepcopy, copy
|
||||
|
||||
import trio
|
||||
import xxhash
|
||||
|
||||
from agent.component.llm import LLMParam, LLM
|
||||
from rag.flow.base import ProcessBase, ProcessParamBase
|
||||
from rag.prompts.generator import run_toc_from_text
|
||||
|
||||
|
||||
class ExtractorParam(ProcessParamBase, LLMParam):
|
||||
@ -31,6 +38,38 @@ class ExtractorParam(ProcessParamBase, LLMParam):
|
||||
class Extractor(ProcessBase, LLM):
|
||||
component_name = "Extractor"
|
||||
|
||||
def _build_TOC(self, docs):
|
||||
self.callback(message="Start to generate table of content ...")
|
||||
docs = sorted(docs, key=lambda d:(
|
||||
d.get("page_num_int", 0)[0] if isinstance(d.get("page_num_int", 0), list) else d.get("page_num_int", 0),
|
||||
d.get("top_int", 0)[0] if isinstance(d.get("top_int", 0), list) else d.get("top_int", 0)
|
||||
))
|
||||
toc: list[dict] = trio.run(run_toc_from_text, [d["text"] for d in docs], self.chat_mdl)
|
||||
logging.info("------------ T O C -------------\n"+json.dumps(toc, ensure_ascii=False, indent=' '))
|
||||
ii = 0
|
||||
while ii < len(toc):
|
||||
try:
|
||||
idx = int(toc[ii]["chunk_id"])
|
||||
del toc[ii]["chunk_id"]
|
||||
toc[ii]["ids"] = [docs[idx]["id"]]
|
||||
if ii == len(toc) -1:
|
||||
break
|
||||
for jj in range(idx+1, int(toc[ii+1]["chunk_id"])+1):
|
||||
toc[ii]["ids"].append(docs[jj]["id"])
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
ii += 1
|
||||
|
||||
if toc:
|
||||
d = copy.deepcopy(docs[-1])
|
||||
d["content_with_weight"] = json.dumps(toc, ensure_ascii=False)
|
||||
d["toc_kwd"] = "toc"
|
||||
d["available_int"] = 0
|
||||
d["page_num_int"] = [100000000]
|
||||
d["id"] = xxhash.xxh64((d["content_with_weight"] + str(d["doc_id"])).encode("utf-8", "surrogatepass")).hexdigest()
|
||||
return d
|
||||
return None
|
||||
|
||||
async def _invoke(self, **kwargs):
|
||||
self.set_output("output_format", "chunks")
|
||||
self.callback(random.randint(1, 5) / 100.0, "Start to generate.")
|
||||
@ -45,6 +84,12 @@ class Extractor(ProcessBase, LLM):
|
||||
chunks_key = k
|
||||
|
||||
if chunks:
|
||||
if self._param.field_name == "toc":
|
||||
toc = self._build_TOC(chunks)
|
||||
chunks.append(toc)
|
||||
self.set_output("chunks", chunks)
|
||||
return
|
||||
|
||||
prog = 0
|
||||
for i, ck in enumerate(chunks):
|
||||
args[chunks_key] = ck["text"]
|
||||
|
||||
@ -944,7 +944,7 @@ async def do_handle_task(task):
|
||||
logging.info(progress_message)
|
||||
progress_callback(msg=progress_message)
|
||||
if task["parser_id"].lower() == "naive" and task["parser_config"].get("toc_extraction", False):
|
||||
toc_thread = executor.submit(build_TOC,task, chunks, progress_callback)
|
||||
toc_thread = executor.submit(build_TOC, task, chunks, progress_callback)
|
||||
|
||||
chunk_count = len(set([chunk["id"] for chunk in chunks]))
|
||||
start_ts = timer()
|
||||
|
||||
Reference in New Issue
Block a user