Feat: support TOC transformer. (#11685)

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Kevin Hu
2025-12-03 12:27:50 +08:00
committed by GitHub
parent 6fc7def562
commit b5ad7b7062
3 changed files with 48 additions and 2 deletions

View File

@ -198,6 +198,7 @@ class Retrieval(ToolBase, ABC):
return
if cks:
kbinfos["chunks"] = cks
kbinfos["chunks"] = settings.retriever.retrieval_by_children(kbinfos["chunks"], [kb.tenant_id for kb in kbs])
if self._param.use_kg:
ck = settings.kg_retriever.retrieval(query,
[kb.tenant_id for kb in kbs],

View File

@ -12,10 +12,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
import random
from copy import deepcopy
from copy import deepcopy, copy
import trio
import xxhash
from agent.component.llm import LLMParam, LLM
from rag.flow.base import ProcessBase, ProcessParamBase
from rag.prompts.generator import run_toc_from_text
class ExtractorParam(ProcessParamBase, LLMParam):
@ -31,6 +38,38 @@ class ExtractorParam(ProcessParamBase, LLMParam):
class Extractor(ProcessBase, LLM):
component_name = "Extractor"
def _build_TOC(self, docs):
self.callback(message="Start to generate table of content ...")
docs = sorted(docs, key=lambda d:(
d.get("page_num_int", 0)[0] if isinstance(d.get("page_num_int", 0), list) else d.get("page_num_int", 0),
d.get("top_int", 0)[0] if isinstance(d.get("top_int", 0), list) else d.get("top_int", 0)
))
toc: list[dict] = trio.run(run_toc_from_text, [d["text"] for d in docs], self.chat_mdl)
logging.info("------------ T O C -------------\n"+json.dumps(toc, ensure_ascii=False, indent=' '))
ii = 0
while ii < len(toc):
try:
idx = int(toc[ii]["chunk_id"])
del toc[ii]["chunk_id"]
toc[ii]["ids"] = [docs[idx]["id"]]
if ii == len(toc) -1:
break
for jj in range(idx+1, int(toc[ii+1]["chunk_id"])+1):
toc[ii]["ids"].append(docs[jj]["id"])
except Exception as e:
logging.exception(e)
ii += 1
if toc:
d = copy.deepcopy(docs[-1])
d["content_with_weight"] = json.dumps(toc, ensure_ascii=False)
d["toc_kwd"] = "toc"
d["available_int"] = 0
d["page_num_int"] = [100000000]
d["id"] = xxhash.xxh64((d["content_with_weight"] + str(d["doc_id"])).encode("utf-8", "surrogatepass")).hexdigest()
return d
return None
async def _invoke(self, **kwargs):
self.set_output("output_format", "chunks")
self.callback(random.randint(1, 5) / 100.0, "Start to generate.")
@ -45,6 +84,12 @@ class Extractor(ProcessBase, LLM):
chunks_key = k
if chunks:
if self._param.field_name == "toc":
toc = self._build_TOC(chunks)
chunks.append(toc)
self.set_output("chunks", chunks)
return
prog = 0
for i, ck in enumerate(chunks):
args[chunks_key] = ck["text"]