mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Feat: debugging toc part. (#10486)
### What problem does this PR solve? #10436 ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -25,7 +25,7 @@ class SplitterFromUpstream(BaseModel):
|
||||
file: dict | None = Field(default=None)
|
||||
chunks: list[dict[str, Any]] | None = Field(default=None)
|
||||
|
||||
output_format: Literal["json", "markdown", "text", "html"] | None = Field(default=None)
|
||||
output_format: Literal["json", "markdown", "text", "html", "chunks"] | None = Field(default=None)
|
||||
|
||||
json_result: list[dict[str, Any]] | None = Field(default=None, alias="json")
|
||||
markdown_result: str | None = Field(default=None, alias="markdown")
|
||||
|
||||
@ -126,7 +126,7 @@ class Tokenizer(ProcessBase):
|
||||
if ck.get("summary"):
|
||||
ck["content_ltks"] = rag_tokenizer.tokenize(str(ck["summary"]))
|
||||
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
|
||||
else:
|
||||
elif ck.get("text"):
|
||||
ck["content_ltks"] = rag_tokenizer.tokenize(ck["text"])
|
||||
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
|
||||
if i % 100 == 99:
|
||||
@ -155,6 +155,8 @@ class Tokenizer(ProcessBase):
|
||||
for i, ck in enumerate(chunks):
|
||||
ck["title_tks"] = rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", from_upstream.name))
|
||||
ck["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(ck["title_tks"])
|
||||
if not ck.get("text"):
|
||||
continue
|
||||
ck["content_ltks"] = rag_tokenizer.tokenize(ck["text"])
|
||||
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
|
||||
if i % 100 == 99:
|
||||
|
||||
@ -613,13 +613,13 @@ def naive_merge(sections: str | list, chunk_token_num=128, delimiter="\n。;
|
||||
dels = get_delimiters(delimiter)
|
||||
for sec, pos in sections:
|
||||
if num_tokens_from_string(sec) < chunk_token_num:
|
||||
add_chunk(sec, pos)
|
||||
add_chunk("\n"+sec, pos)
|
||||
continue
|
||||
split_sec = re.split(r"(%s)" % dels, sec, flags=re.DOTALL)
|
||||
for sub_sec in split_sec:
|
||||
if re.match(f"^{dels}$", sub_sec):
|
||||
continue
|
||||
add_chunk(sub_sec, pos)
|
||||
add_chunk("\n"+sub_sec, pos)
|
||||
|
||||
return cks
|
||||
|
||||
@ -669,13 +669,13 @@ def naive_merge_with_images(texts, images, chunk_token_num=128, delimiter="\n。
|
||||
for sub_sec in split_sec:
|
||||
if re.match(f"^{dels}$", sub_sec):
|
||||
continue
|
||||
add_chunk(sub_sec, image, text_pos)
|
||||
add_chunk("\n"+sub_sec, image, text_pos)
|
||||
else:
|
||||
split_sec = re.split(r"(%s)" % dels, text)
|
||||
for sub_sec in split_sec:
|
||||
if re.match(f"^{dels}$", sub_sec):
|
||||
continue
|
||||
add_chunk(sub_sec, image)
|
||||
add_chunk("\n"+sub_sec, image)
|
||||
|
||||
return cks, result_images
|
||||
|
||||
@ -757,7 +757,7 @@ def naive_merge_docx(sections, chunk_token_num=128, delimiter="\n。;!?"):
|
||||
for sub_sec in split_sec:
|
||||
if re.match(f"^{dels}$", sub_sec):
|
||||
continue
|
||||
add_chunk(sub_sec, image,"")
|
||||
add_chunk("\n"+sub_sec, image,"")
|
||||
line = ""
|
||||
|
||||
if line:
|
||||
@ -765,7 +765,7 @@ def naive_merge_docx(sections, chunk_token_num=128, delimiter="\n。;!?"):
|
||||
for sub_sec in split_sec:
|
||||
if re.match(f"^{dels}$", sub_sec):
|
||||
continue
|
||||
add_chunk(sub_sec, image,"")
|
||||
add_chunk("\n"+sub_sec, image,"")
|
||||
|
||||
return cks, images
|
||||
|
||||
|
||||
@ -23,7 +23,7 @@ import jinja2
|
||||
import json_repair
|
||||
import trio
|
||||
from api.utils import hash_str2int
|
||||
from rag.nlp import is_chinese
|
||||
from rag.nlp import rag_tokenizer
|
||||
from rag.prompts.template import load_prompt
|
||||
from rag.settings import TAG_FLD
|
||||
from rag.utils import encoder, num_tokens_from_string
|
||||
@ -672,7 +672,7 @@ def assign_toc_levels(toc_secs, chat_mdl, gen_conf = {"temperature": 0.2}):
|
||||
TOC_FROM_TEXT_SYSTEM = load_prompt("toc_from_text_system")
|
||||
TOC_FROM_TEXT_USER = load_prompt("toc_from_text_user")
|
||||
# Generate TOC from text chunks with text llms
|
||||
async def gen_toc_from_text(txt_info: dict, chat_mdl):
|
||||
async def gen_toc_from_text(txt_info: dict, chat_mdl, callback=None):
|
||||
try:
|
||||
ans = gen_json(
|
||||
PROMPT_JINJA_ENV.from_string(TOC_FROM_TEXT_SYSTEM).render(),
|
||||
@ -682,6 +682,8 @@ async def gen_toc_from_text(txt_info: dict, chat_mdl):
|
||||
)
|
||||
print(ans, "::::::::::::::::::::::::::::::::::::", flush=True)
|
||||
txt_info["toc"] = ans if ans else []
|
||||
if callback:
|
||||
callback(msg="")
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
|
||||
@ -707,14 +709,14 @@ def split_chunks(chunks, max_length: int):
|
||||
return result
|
||||
|
||||
|
||||
async def run_toc_from_text(chunks, chat_mdl):
|
||||
async def run_toc_from_text(chunks, chat_mdl, callback=None):
|
||||
input_budget = int(chat_mdl.max_length * INPUT_UTILIZATION) - num_tokens_from_string(
|
||||
TOC_FROM_TEXT_USER + TOC_FROM_TEXT_SYSTEM
|
||||
)
|
||||
|
||||
input_budget = 1024 if input_budget > 1024 else input_budget
|
||||
chunk_sections = split_chunks(chunks, input_budget)
|
||||
res = []
|
||||
titles = []
|
||||
|
||||
chunks_res = []
|
||||
async with trio.open_nursery() as nursery:
|
||||
@ -722,21 +724,21 @@ async def run_toc_from_text(chunks, chat_mdl):
|
||||
if not chunk:
|
||||
continue
|
||||
chunks_res.append({"chunks": chunk})
|
||||
nursery.start_soon(gen_toc_from_text, chunks_res[-1], chat_mdl)
|
||||
nursery.start_soon(gen_toc_from_text, chunks_res[-1], chat_mdl, callback)
|
||||
|
||||
for chunk in chunks_res:
|
||||
res.extend(chunk.get("toc", []))
|
||||
titles.extend(chunk.get("toc", []))
|
||||
|
||||
print(res, ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>")
|
||||
print(titles, ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>")
|
||||
|
||||
# Filter out entries with title == -1
|
||||
prune = len(titles) > 512
|
||||
max_len = 12 if prune else 22
|
||||
filtered = []
|
||||
for x in res:
|
||||
for x in titles:
|
||||
if not x.get("title") or x["title"] == "-1":
|
||||
continue
|
||||
if is_chinese(x["title"]) and len(x["title"]) > 12:
|
||||
continue
|
||||
if len(x["title"].split(" ")) > 12:
|
||||
if len(rag_tokenizer.tokenize(x["title"]).split(" ")) > max_len:
|
||||
continue
|
||||
if re.match(r"[0-9,.()/ -]+$", x["title"]):
|
||||
continue
|
||||
@ -751,8 +753,12 @@ async def run_toc_from_text(chunks, chat_mdl):
|
||||
toc_with_levels = assign_toc_levels(raw_structure, chat_mdl, {"temperature": 0.0, "top_p": 0.9})
|
||||
|
||||
# Merge structure and content (by index)
|
||||
prune = len(toc_with_levels) > 512
|
||||
max_lvl = sorted([t.get("level", "0") for t in toc_with_levels])[-1]
|
||||
merged = []
|
||||
for _ , (toc_item, src_item) in enumerate(zip(toc_with_levels, filtered)):
|
||||
if prune and toc_item.get("level", "0") >= max_lvl:
|
||||
continue
|
||||
merged.append({
|
||||
"level": toc_item.get("level", "0"),
|
||||
"title": toc_item.get("title", ""),
|
||||
@ -776,7 +782,7 @@ def relevant_chunks_with_toc(query: str, toc:list[dict], chat_mdl, topn: int=6):
|
||||
print(ans, "::::::::::::::::::::::::::::::::::::", flush=True)
|
||||
id2score = {}
|
||||
for ti, sc in zip(toc, ans):
|
||||
if sc.get("score", -1) < 1:
|
||||
if not isinstance(sc, dict) or sc.get("score", -1) < 1:
|
||||
continue
|
||||
for id in ti.get("ids", []):
|
||||
if id not in id2score:
|
||||
|
||||
@ -370,14 +370,14 @@ async def build_chunks(task, progress_callback):
|
||||
nursery.start_soon(doc_question_proposal, chat_mdl, d, task["parser_config"]["auto_questions"])
|
||||
progress_callback(msg="Question generation {} chunks completed in {:.2f}s".format(len(docs), timer() - st))
|
||||
|
||||
if task["parser_config"].get("toc_extraction", True):
|
||||
if task["parser_id"].lower() == "naive" and task["parser_config"].get("toc_extraction", False):
|
||||
progress_callback(msg="Start to generate table of content ...")
|
||||
chat_mdl = LLMBundle(task["tenant_id"], LLMType.CHAT, llm_name=task["llm_id"], lang=task["language"])
|
||||
docs = sorted(docs, key=lambda d:(
|
||||
d.get("page_num_int", 0)[0] if isinstance(d.get("page_num_int", 0), list) else d.get("page_num_int", 0),
|
||||
d.get("top_int", 0)[0] if isinstance(d.get("top_int", 0), list) else d.get("top_int", 0)
|
||||
))
|
||||
toc: list[dict] = await run_toc_from_text([d["content_with_weight"] for d in docs], chat_mdl)
|
||||
toc: list[dict] = await run_toc_from_text([d["content_with_weight"] for d in docs], chat_mdl, progress_callback)
|
||||
logging.info("------------ T O C -------------\n"+json.dumps(toc, ensure_ascii=False, indent=' '))
|
||||
ii = 0
|
||||
while ii < len(toc):
|
||||
@ -387,7 +387,7 @@ async def build_chunks(task, progress_callback):
|
||||
toc[ii]["ids"] = [docs[idx]["id"]]
|
||||
if ii == len(toc) -1:
|
||||
break
|
||||
for jj in range(idx+1, int(toc[ii+1]["chunk_id"])):
|
||||
for jj in range(idx+1, int(toc[ii+1]["chunk_id"])+1):
|
||||
toc[ii]["ids"].append(docs[jj]["id"])
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
|
||||
Reference in New Issue
Block a user