Feat: debugging toc part. (#10486)

### What problem does this PR solve?

#10436

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Kevin Hu
2025-10-11 18:45:21 +08:00
committed by GitHub
parent a0d5f81098
commit 7d2f65671f
6 changed files with 32 additions and 24 deletions

View File

@ -23,7 +23,7 @@ import jinja2
import json_repair
import trio
from api.utils import hash_str2int
from rag.nlp import is_chinese
from rag.nlp import rag_tokenizer
from rag.prompts.template import load_prompt
from rag.settings import TAG_FLD
from rag.utils import encoder, num_tokens_from_string
@ -672,7 +672,7 @@ def assign_toc_levels(toc_secs, chat_mdl, gen_conf = {"temperature": 0.2}):
TOC_FROM_TEXT_SYSTEM = load_prompt("toc_from_text_system")
TOC_FROM_TEXT_USER = load_prompt("toc_from_text_user")
# Generate TOC from text chunks with text llms
async def gen_toc_from_text(txt_info: dict, chat_mdl):
async def gen_toc_from_text(txt_info: dict, chat_mdl, callback=None):
try:
ans = gen_json(
PROMPT_JINJA_ENV.from_string(TOC_FROM_TEXT_SYSTEM).render(),
@ -682,6 +682,8 @@ async def gen_toc_from_text(txt_info: dict, chat_mdl):
)
print(ans, "::::::::::::::::::::::::::::::::::::", flush=True)
txt_info["toc"] = ans if ans else []
if callback:
callback(msg="")
except Exception as e:
logging.exception(e)
@ -707,14 +709,14 @@ def split_chunks(chunks, max_length: int):
return result
async def run_toc_from_text(chunks, chat_mdl):
async def run_toc_from_text(chunks, chat_mdl, callback=None):
input_budget = int(chat_mdl.max_length * INPUT_UTILIZATION) - num_tokens_from_string(
TOC_FROM_TEXT_USER + TOC_FROM_TEXT_SYSTEM
)
input_budget = 1024 if input_budget > 1024 else input_budget
chunk_sections = split_chunks(chunks, input_budget)
res = []
titles = []
chunks_res = []
async with trio.open_nursery() as nursery:
@ -722,21 +724,21 @@ async def run_toc_from_text(chunks, chat_mdl):
if not chunk:
continue
chunks_res.append({"chunks": chunk})
nursery.start_soon(gen_toc_from_text, chunks_res[-1], chat_mdl)
nursery.start_soon(gen_toc_from_text, chunks_res[-1], chat_mdl, callback)
for chunk in chunks_res:
res.extend(chunk.get("toc", []))
titles.extend(chunk.get("toc", []))
print(res, ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>")
print(titles, ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>")
# Filter out entries with title == -1
prune = len(titles) > 512
max_len = 12 if prune else 22
filtered = []
for x in res:
for x in titles:
if not x.get("title") or x["title"] == "-1":
continue
if is_chinese(x["title"]) and len(x["title"]) > 12:
continue
if len(x["title"].split(" ")) > 12:
if len(rag_tokenizer.tokenize(x["title"]).split(" ")) > max_len:
continue
if re.match(r"[0-9,.()/ -]+$", x["title"]):
continue
@ -751,8 +753,12 @@ async def run_toc_from_text(chunks, chat_mdl):
toc_with_levels = assign_toc_levels(raw_structure, chat_mdl, {"temperature": 0.0, "top_p": 0.9})
# Merge structure and content (by index)
prune = len(toc_with_levels) > 512
max_lvl = sorted([t.get("level", "0") for t in toc_with_levels])[-1]
merged = []
for _ , (toc_item, src_item) in enumerate(zip(toc_with_levels, filtered)):
if prune and toc_item.get("level", "0") >= max_lvl:
continue
merged.append({
"level": toc_item.get("level", "0"),
"title": toc_item.get("title", ""),
@ -776,7 +782,7 @@ def relevant_chunks_with_toc(query: str, toc:list[dict], chat_mdl, topn: int=6):
print(ans, "::::::::::::::::::::::::::::::::::::", flush=True)
id2score = {}
for ti, sc in zip(toc, ans):
if sc.get("score", -1) < 1:
if not isinstance(sc, dict) or sc.get("score", -1) < 1:
continue
for id in ti.get("ids", []):
if id not in id2score: