Feat: support TOC transformer. (#11685)

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Kevin Hu
2025-12-03 12:27:50 +08:00
committed by GitHub
parent 6fc7def562
commit b5ad7b7062
3 changed files with 48 additions and 2 deletions

View File

@ -198,6 +198,7 @@ class Retrieval(ToolBase, ABC):
return return
if cks: if cks:
kbinfos["chunks"] = cks kbinfos["chunks"] = cks
kbinfos["chunks"] = settings.retriever.retrieval_by_children(kbinfos["chunks"], [kb.tenant_id for kb in kbs])
if self._param.use_kg: if self._param.use_kg:
ck = settings.kg_retriever.retrieval(query, ck = settings.kg_retriever.retrieval(query,
[kb.tenant_id for kb in kbs], [kb.tenant_id for kb in kbs],

View File

@ -12,10 +12,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import json
import logging
import random import random
from copy import deepcopy from copy import deepcopy, copy
import trio
import xxhash
from agent.component.llm import LLMParam, LLM from agent.component.llm import LLMParam, LLM
from rag.flow.base import ProcessBase, ProcessParamBase from rag.flow.base import ProcessBase, ProcessParamBase
from rag.prompts.generator import run_toc_from_text
class ExtractorParam(ProcessParamBase, LLMParam): class ExtractorParam(ProcessParamBase, LLMParam):
@ -31,6 +38,38 @@ class ExtractorParam(ProcessParamBase, LLMParam):
class Extractor(ProcessBase, LLM): class Extractor(ProcessBase, LLM):
component_name = "Extractor" component_name = "Extractor"
def _build_TOC(self, docs):
self.callback(message="Start to generate table of content ...")
docs = sorted(docs, key=lambda d:(
d.get("page_num_int", 0)[0] if isinstance(d.get("page_num_int", 0), list) else d.get("page_num_int", 0),
d.get("top_int", 0)[0] if isinstance(d.get("top_int", 0), list) else d.get("top_int", 0)
))
toc: list[dict] = trio.run(run_toc_from_text, [d["text"] for d in docs], self.chat_mdl)
logging.info("------------ T O C -------------\n"+json.dumps(toc, ensure_ascii=False, indent=' '))
ii = 0
while ii < len(toc):
try:
idx = int(toc[ii]["chunk_id"])
del toc[ii]["chunk_id"]
toc[ii]["ids"] = [docs[idx]["id"]]
if ii == len(toc) -1:
break
for jj in range(idx+1, int(toc[ii+1]["chunk_id"])+1):
toc[ii]["ids"].append(docs[jj]["id"])
except Exception as e:
logging.exception(e)
ii += 1
if toc:
d = copy.deepcopy(docs[-1])
d["content_with_weight"] = json.dumps(toc, ensure_ascii=False)
d["toc_kwd"] = "toc"
d["available_int"] = 0
d["page_num_int"] = [100000000]
d["id"] = xxhash.xxh64((d["content_with_weight"] + str(d["doc_id"])).encode("utf-8", "surrogatepass")).hexdigest()
return d
return None
async def _invoke(self, **kwargs): async def _invoke(self, **kwargs):
self.set_output("output_format", "chunks") self.set_output("output_format", "chunks")
self.callback(random.randint(1, 5) / 100.0, "Start to generate.") self.callback(random.randint(1, 5) / 100.0, "Start to generate.")
@ -45,6 +84,12 @@ class Extractor(ProcessBase, LLM):
chunks_key = k chunks_key = k
if chunks: if chunks:
if self._param.field_name == "toc":
toc = self._build_TOC(chunks)
chunks.append(toc)
self.set_output("chunks", chunks)
return
prog = 0 prog = 0
for i, ck in enumerate(chunks): for i, ck in enumerate(chunks):
args[chunks_key] = ck["text"] args[chunks_key] = ck["text"]

View File

@ -944,7 +944,7 @@ async def do_handle_task(task):
logging.info(progress_message) logging.info(progress_message)
progress_callback(msg=progress_message) progress_callback(msg=progress_message)
if task["parser_id"].lower() == "naive" and task["parser_config"].get("toc_extraction", False): if task["parser_id"].lower() == "naive" and task["parser_config"].get("toc_extraction", False):
toc_thread = executor.submit(build_TOC,task, chunks, progress_callback) toc_thread = executor.submit(build_TOC, task, chunks, progress_callback)
chunk_count = len(set([chunk["id"] for chunk in chunks])) chunk_count = len(set([chunk["id"] for chunk in chunks]))
start_ts = timer() start_ts = timer()