From 0d8791936ee19b884de6a727c17ab28da7040ce1 Mon Sep 17 00:00:00 2001 From: Kevin Hu Date: Fri, 10 Oct 2025 17:07:55 +0800 Subject: [PATCH] Feat: TOC retrieval (#10456) ### What problem does this PR solve? #10436 ### Type of change - [x] New Feature (non-breaking change which adds functionality) --- agent/tools/retrieval.py | 6 ++ api/db/services/dialog_service.py | 4 + api/utils/api_utils.py | 6 +- deepdoc/parser/pdf_parser.py | 2 +- graphrag/utils.py | 11 +-- rag/flow/parser/parser.py | 6 +- rag/llm/chat_model.py | 3 +- rag/nlp/search.py | 62 +++++++++++++++ rag/prompts/assign_toc_levels.md | 32 ++++---- rag/prompts/generator.py | 113 +++++++++++++++++++++------- rag/prompts/toc_from_text_system.md | 62 ++++++++------- rag/svr/task_executor.py | 34 ++++++++- 12 files changed, 251 insertions(+), 90 deletions(-) diff --git a/agent/tools/retrieval.py b/agent/tools/retrieval.py index 07c16d97d..e2e24ea35 100644 --- a/agent/tools/retrieval.py +++ b/agent/tools/retrieval.py @@ -57,6 +57,7 @@ class RetrievalParam(ToolParamBase): self.empty_response = "" self.use_kg = False self.cross_languages = [] + self.toc_enhance = False def check(self): self.check_decimal_float(self.similarity_threshold, "[Retrieval] Similarity threshold") @@ -134,6 +135,11 @@ class Retrieval(ToolBase, ABC): rerank_mdl=rerank_mdl, rank_feature=label_question(query, kbs), ) + if self._param.toc_enhance: + chat_mdl = LLMBundle(self._canvas._tenant_id, LLMType.CHAT) + cks = settings.retriever.retrieval_by_toc(query, kbinfos["chunks"], [kb.tenant_id for kb in kbs], chat_mdl, self._param.top_n) + if cks: + kbinfos["chunks"] = cks if self._param.use_kg: ck = settings.kg_retriever.retrieval(query, [kb.tenant_id for kb in kbs], diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index a8ddf178d..9bde6238d 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -466,6 +466,10 @@ def chat(dialog, messages, stream=True, **kwargs): rerank_mdl=rerank_mdl, rank_feature=label_question(" ".join(questions), kbs), ) + if prompt_config.get("toc_enhance"): + cks = retriever.retrieval_by_toc(" ".join(questions), kbinfos["chunks"], tenant_ids, chat_mdl, dialog.top_n) + if cks: + kbinfos["chunks"] = cks if prompt_config.get("tavily_api_key"): tav = Tavily(prompt_config["tavily_api_key"]) tav_res = tav.retrieve_chunks(" ".join(questions)) diff --git a/api/utils/api_utils.py b/api/utils/api_utils.py index 8730dff4d..1579e852b 100644 --- a/api/utils/api_utils.py +++ b/api/utils/api_utils.py @@ -51,9 +51,6 @@ from api import settings from api.constants import REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC from api.db import ActiveEnum from api.db.db_models import APIToken -from api.db.services import UserService -from api.db.services.llm_service import LLMService -from api.db.services.tenant_llm_service import TenantLLMService from api.utils.json import CustomJSONEncoder, json_dumps from api.utils import get_uuid from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions @@ -239,6 +236,7 @@ def not_allowed_parameters(*params): def active_required(f): @wraps(f) def wrapper(*args, **kwargs): + from api.db.services import UserService user_id = current_user.id usr = UserService.filter_by_id(user_id) # check is_active @@ -544,6 +542,8 @@ def check_duplicate_ids(ids, id_type="item"): def verify_embedding_availability(embd_id: str, tenant_id: str) -> tuple[bool, Response | None]: + from api.db.services.llm_service import LLMService + from api.db.services.tenant_llm_service import TenantLLMService """ Verifies availability of an embedding model for a specific tenant. diff --git a/deepdoc/parser/pdf_parser.py b/deepdoc/parser/pdf_parser.py index c73b610ad..2cf14b88a 100644 --- a/deepdoc/parser/pdf_parser.py +++ b/deepdoc/parser/pdf_parser.py @@ -1048,7 +1048,7 @@ class RAGFlowPdfParser: def parse_into_bboxes(self, fnm, callback=None, zoomin=3): start = timer() - self.__images__(fnm, zoomin) + self.__images__(fnm, zoomin, callback=callback) if callback: callback(0.40, "OCR finished ({:.2f}s)".format(timer() - start)) diff --git a/graphrag/utils.py b/graphrag/utils.py index 5e64cdb11..877380d6a 100644 --- a/graphrag/utils.py +++ b/graphrag/utils.py @@ -92,10 +92,7 @@ def dict_has_keys_with_types(data: dict, expected_fields: list[tuple[str, type]] def get_llm_cache(llmnm, txt, history, genconf): hasher = xxhash.xxh64() - hasher.update(str(llmnm).encode("utf-8")) - hasher.update(str(txt).encode("utf-8")) - hasher.update(str(history).encode("utf-8")) - hasher.update(str(genconf).encode("utf-8")) + hasher.update((str(llmnm)+str(txt)+str(history)+str(genconf)).encode("utf-8")) k = hasher.hexdigest() bin = REDIS_CONN.get(k) @@ -106,11 +103,7 @@ def get_llm_cache(llmnm, txt, history, genconf): def set_llm_cache(llmnm, txt, v, history, genconf): hasher = xxhash.xxh64() - hasher.update(str(llmnm).encode("utf-8")) - hasher.update(str(txt).encode("utf-8")) - hasher.update(str(history).encode("utf-8")) - hasher.update(str(genconf).encode("utf-8")) - + hasher.update((str(llmnm)+str(txt)+str(history)+str(genconf)).encode("utf-8")) k = hasher.hexdigest() REDIS_CONN.set(k, v.encode("utf-8"), 24 * 3600) diff --git a/rag/flow/parser/parser.py b/rag/flow/parser/parser.py index 424b2d6fa..e08629f3d 100644 --- a/rag/flow/parser/parser.py +++ b/rag/flow/parser/parser.py @@ -366,6 +366,7 @@ class Parser(ProcessBase): email_content = {} conf = self._param.setups["email"] + self.set_output("output_format", conf["output_format"]) target_fields = conf["fields"] _, ext = os.path.splitext(name) @@ -442,8 +443,9 @@ class Parser(ProcessBase): } # get body if "body" in target_fields: - email_content["text"] = msg.body # usually empty. try text_html instead - email_content["text_html"] = msg.htmlBody + email_content["text"] = msg.body[0] if isinstance(msg.body, list) and msg.body else msg.body + if not email_content["text"] and msg.htmlBody: + email_content["text"] = msg.htmlBody[0] if isinstance(msg.htmlBody, list) and msg.htmlBody else msg.htmlBody # get attachments if "attachments" in target_fields: attachments = [] diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index dd088c83b..2a77a45eb 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -132,8 +132,7 @@ class Base(ABC): "tool_choice", "logprobs", "top_logprobs", - "extra_headers", - "enable_thinking" + "extra_headers" } gen_conf = {k: v for k, v in gen_conf.items() if k in allowed_conf} diff --git a/rag/nlp/search.py b/rag/nlp/search.py index db1423095..89c9d5bfc 100644 --- a/rag/nlp/search.py +++ b/rag/nlp/search.py @@ -13,12 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import json import logging import re import math from collections import OrderedDict from dataclasses import dataclass +from rag.prompts.generator import relevant_chunks_with_toc from rag.settings import TAG_FLD, PAGERANK_FLD from rag.utils import rmSpace, get_float from rag.nlp import rag_tokenizer, query @@ -514,3 +516,63 @@ class Dealer: tag_fea = sorted([(a, round(0.1*(c + 1) / (cnt + S) / max(1e-6, all_tags.get(a, 0.0001)))) for a, c in aggs], key=lambda x: x[1] * -1)[:topn_tags] return {a.replace(".", "_"): max(1, c) for a, c in tag_fea} + + def retrieval_by_toc(self, query:str, chunks:list[dict], tenant_ids:list[str], chat_mdl, topn: int=6): + if not chunks: + return [] + idx_nms = [index_name(tid) for tid in tenant_ids] + ranks, doc_id2kb_id = {}, {} + for ck in chunks: + if ck["doc_id"] not in ranks: + ranks[ck["doc_id"]] = 0 + ranks[ck["doc_id"]] += ck["similarity"] + doc_id2kb_id[ck["doc_id"]] = ck["kb_id"] + doc_id = sorted(ranks.items(), key=lambda x: x[1]*-1.)[0][0] + kb_ids = [doc_id2kb_id[doc_id]] + es_res = self.dataStore.search(["content_with_weight"], [], {"doc_id": doc_id, "toc_kwd": "toc"}, [], OrderByExpr(), 0, 128, idx_nms, + kb_ids) + toc = [] + dict_chunks = self.dataStore.getFields(es_res, ["content_with_weight"]) + for _, doc in dict_chunks.items(): + try: + toc.extend(json.loads(doc["content_with_weight"])) + except Exception as e: + logging.exception(e) + if not toc: + return chunks + + ids = relevant_chunks_with_toc(query, toc, chat_mdl, topn*2) + if not ids: + return chunks + + vector_size = 1024 + id2idx = {ck["chunk_id"]: i for i, ck in enumerate(chunks)} + for cid, sim in ids: + if cid in id2idx: + chunks[id2idx[cid]]["similarity"] += sim + continue + chunk = self.dataStore.get(cid, idx_nms, kb_ids) + d = { + "chunk_id": cid, + "content_ltks": chunk["content_ltks"], + "content_with_weight": chunk["content_with_weight"], + "doc_id": doc_id, + "docnm_kwd": chunk.get("docnm_kwd", ""), + "kb_id": chunk["kb_id"], + "important_kwd": chunk.get("important_kwd", []), + "image_id": chunk.get("img_id", ""), + "similarity": sim, + "vector_similarity": sim, + "term_similarity": sim, + "vector": [0.0] * vector_size, + "positions": chunk.get("position_int", []), + "doc_type_kwd": chunk.get("doc_type_kwd", "") + } + for k in chunk.keys(): + if k[-4:] == "_vec": + d["vector"] = chunk[k] + vector_size = len(chunk[k]) + break + chunks.append(d) + + return sorted(chunks, key=lambda x:x["similarity"]*-1)[:topn] diff --git a/rag/prompts/assign_toc_levels.md b/rag/prompts/assign_toc_levels.md index fff0cd8b3..d35dee779 100644 --- a/rag/prompts/assign_toc_levels.md +++ b/rag/prompts/assign_toc_levels.md @@ -1,4 +1,4 @@ -You are given a JSON array of TOC items. Each item has at least {"title": string} and may include an existing structure. +You are given a JSON array of TOC(tabel of content) items. Each item has at least {"title": string} and may include an existing title hierarchical level. Task - For each item, assign a depth label using Arabic numerals only: top-level = 1, second-level = 2, third-level = 3, etc. @@ -9,7 +9,7 @@ Task Output - Return a valid JSON array only (no extra text). -- Each element must be {"structure": "1|2|3", "title": }. +- Each element must be {"level": "1|2|3", "title": }. - title must be the original title string. Examples @@ -20,10 +20,10 @@ Input: Output: [ - {"structure":"1","title":"Chapter 1 Methods"}, - {"structure":"2","title":"Section 1 Definition"}, - {"structure":"2","title":"Section 2 Process"}, - {"structure":"1","title":"Chapter 2 Experiment"} + {"level":"1","title":"Chapter 1 Methods"}, + {"level":"2","title":"Section 1 Definition"}, + {"level":"2","title":"Section 2 Process"}, + {"level":"1","title":"Chapter 2 Experiment"} ] Example B (parts with chapters) @@ -32,11 +32,11 @@ Input: Output: [ - {"structure":"1","title":"Part I Theory"}, - {"structure":"2","title":"Chapter 1 Basics"}, - {"structure":"2","title":"Chapter 2 Methods"}, - {"structure":"1","title":"Part II Applications"}, - {"structure":"2","title":"Chapter 3 Case Studies"} + {"level":"1","title":"Part I Theory"}, + {"level":"2","title":"Chapter 1 Basics"}, + {"level":"2","title":"Chapter 2 Methods"}, + {"level":"1","title":"Part II Applications"}, + {"level":"2","title":"Chapter 3 Case Studies"} ] Example C (plain headings) @@ -45,9 +45,9 @@ Input: Output: [ - {"structure":"1","title":"Introduction"}, - {"structure":"2","title":"Background and Motivation"}, - {"structure":"2","title":"Related Work"}, - {"structure":"1","title":"Methodology"}, - {"structure":"1","title":"Evaluation"} + {"level":"1","title":"Introduction"}, + {"level":"2","title":"Background and Motivation"}, + {"level":"2","title":"Related Work"}, + {"level":"1","title":"Methodology"}, + {"level":"1","title":"Evaluation"} ] \ No newline at end of file diff --git a/rag/prompts/generator.py b/rag/prompts/generator.py index fa812c1ff..73e1201f5 100644 --- a/rag/prompts/generator.py +++ b/rag/prompts/generator.py @@ -21,7 +21,9 @@ from copy import deepcopy from typing import Tuple import jinja2 import json_repair +import trio from api.utils import hash_str2int +from rag.nlp import is_chinese from rag.prompts.template import load_prompt from rag.settings import TAG_FLD from rag.utils import encoder, num_tokens_from_string @@ -440,11 +442,17 @@ def gen_meta_filter(chat_mdl, meta_data:dict, query: str) -> list: def gen_json(system_prompt:str, user_prompt:str, chat_mdl, gen_conf = None): + from graphrag.utils import get_llm_cache, set_llm_cache + cached = get_llm_cache(chat_mdl.llm_name, system_prompt, user_prompt, gen_conf) + if cached: + return json_repair.loads(cached) _, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length) ans = chat_mdl.chat(msg[0]["content"], msg[1:],gen_conf=gen_conf) ans = re.sub(r"(^.*|```json\n|```\n*$)", "", ans, flags=re.DOTALL) try: - return json_repair.loads(ans) + res = json_repair.loads(ans) + set_llm_cache(chat_mdl.llm_name, system_prompt, ans, user_prompt, gen_conf) + return res except Exception: logging.exception(f"Loading json failure: {ans}") @@ -651,29 +659,31 @@ def toc_transformer(toc_pages, chat_mdl): TOC_LEVELS = load_prompt("assign_toc_levels") def assign_toc_levels(toc_secs, chat_mdl, gen_conf = {"temperature": 0.2}): - print("\nBegin TOC level assignment...\n") - - ans = gen_json( + if not toc_secs: + return [] + return gen_json( PROMPT_JINJA_ENV.from_string(TOC_LEVELS).render(), str(toc_secs), chat_mdl, gen_conf ) - - return ans TOC_FROM_TEXT_SYSTEM = load_prompt("toc_from_text_system") TOC_FROM_TEXT_USER = load_prompt("toc_from_text_user") # Generate TOC from text chunks with text llms -def gen_toc_from_text(text, chat_mdl): - ans = gen_json( - PROMPT_JINJA_ENV.from_string(TOC_FROM_TEXT_SYSTEM).render(), - PROMPT_JINJA_ENV.from_string(TOC_FROM_TEXT_USER).render(text=text), - chat_mdl, - gen_conf={"temperature": 0.0, "top_p": 0.9, "enable_thinking": False, } - ) - return ans +async def gen_toc_from_text(txt_info: dict, chat_mdl): + try: + ans = gen_json( + PROMPT_JINJA_ENV.from_string(TOC_FROM_TEXT_SYSTEM).render(), + PROMPT_JINJA_ENV.from_string(TOC_FROM_TEXT_USER).render(text="\n".join([json.dumps(d, ensure_ascii=False) for d in txt_info["chunks"]])), + chat_mdl, + gen_conf={"temperature": 0.0, "top_p": 0.9} + ) + print(ans, "::::::::::::::::::::::::::::::::::::", flush=True) + txt_info["toc"] = ans if ans else [] + except Exception as e: + logging.exception(e) def split_chunks(chunks, max_length: int): @@ -690,44 +700,91 @@ def split_chunks(chunks, max_length: int): if batch_tokens + t > max_length: result.append(batch) batch, batch_tokens = [], 0 - batch.append({"id": idx, "text": chunk}) + batch.append({idx: chunk}) batch_tokens += t if batch: result.append(batch) return result -def run_toc_from_text(chunks, chat_mdl): +async def run_toc_from_text(chunks, chat_mdl): input_budget = int(chat_mdl.max_length * INPUT_UTILIZATION) - num_tokens_from_string( TOC_FROM_TEXT_USER + TOC_FROM_TEXT_SYSTEM ) - input_budget = 2000 if input_budget > 2000 else input_budget + input_budget = 1024 if input_budget > 1024 else input_budget chunk_sections = split_chunks(chunks, input_budget) res = [] - for chunk in chunk_sections: - ans = gen_toc_from_text(chunk, chat_mdl) - res.extend(ans) + chunks_res = [] + async with trio.open_nursery() as nursery: + for i, chunk in enumerate(chunk_sections): + if not chunk: + continue + chunks_res.append({"chunks": chunk}) + nursery.start_soon(gen_toc_from_text, chunks_res[-1], chat_mdl) + + for chunk in chunks_res: + res.extend(chunk.get("toc", [])) + + print(res, ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>") # Filter out entries with title == -1 - filtered = [x for x in res if x.get("title") and x.get("title") != "-1"] + filtered = [] + for x in res: + if not x.get("title") or x["title"] == "-1": + continue + if is_chinese(x["title"]) and len(x["title"]) > 12: + continue + if len(x["title"].split(" ")) > 12: + continue + if re.match(r"[0-9,.()/ -]+$", x["title"]): + continue + filtered.append(x) - print("\n\nFiltered TOC sections:\n", filtered) + logging.info(f"\n\nFiltered TOC sections:\n{filtered}") - # Generate initial structure (structure/title) - raw_structure = [{"structure": "0", "title": x.get("title", "")} for x in filtered] + # Generate initial level (level/title) + raw_structure = [x.get("title", "") for x in filtered] # Assign hierarchy levels using LLM - toc_with_levels = assign_toc_levels(raw_structure, chat_mdl, {"temperature": 0.0, "top_p": 0.9, "enable_thinking": False}) + toc_with_levels = assign_toc_levels(raw_structure, chat_mdl, {"temperature": 0.0, "top_p": 0.9}) # Merge structure and content (by index) merged = [] for _ , (toc_item, src_item) in enumerate(zip(toc_with_levels, filtered)): merged.append({ - "structure": toc_item.get("structure", "0"), + "level": toc_item.get("level", "0"), "title": toc_item.get("title", ""), - "content": src_item.get("content", ""), + "chunk_id": src_item.get("chunk_id", ""), }) - return merged \ No newline at end of file + return merged + + +TOC_RELEVANCE_SYSTEM = load_prompt("toc_relevance_system") +TOC_RELEVANCE_USER = load_prompt("toc_relevance_user") +def relevant_chunks_with_toc(query: str, toc:list[dict], chat_mdl, topn: int=6): + import numpy as np + try: + ans = gen_json( + PROMPT_JINJA_ENV.from_string(TOC_RELEVANCE_SYSTEM).render(), + PROMPT_JINJA_ENV.from_string(TOC_RELEVANCE_USER).render(query=query, toc_json="[\n%s\n]\n"%"\n".join([json.dumps({"level": d["level"], "title":d["title"]}, ensure_ascii=False) for d in toc])), + chat_mdl, + gen_conf={"temperature": 0.0, "top_p": 0.9} + ) + print(ans, "::::::::::::::::::::::::::::::::::::", flush=True) + id2score = {} + for ti, sc in zip(toc, ans): + if sc.get("score", -1) < 1: + continue + for id in ti.get("ids", []): + if id not in id2score: + id2score[id] = [] + id2score[id].append(sc["score"]/5.) + for id in id2score.keys(): + id2score[id] = np.mean(id2score[id]) + return [(id, sc) for id, sc in list(id2score.items()) if sc>=0.3][:topn] + except Exception as e: + logging.exception(e) + return [] diff --git a/rag/prompts/toc_from_text_system.md b/rag/prompts/toc_from_text_system.md index f982df47d..7090f3058 100644 --- a/rag/prompts/toc_from_text_system.md +++ b/rag/prompts/toc_from_text_system.md @@ -1,25 +1,25 @@ You are a robust Table-of-Contents (TOC) extractor. GOAL -Given a dictionary of chunks {chunk_id: chunk_text}, extract TOC-like headings and return a strict JSON array of objects: +Given a dictionary of chunks {"": chunk_text}, extract TOC-like headings and return a strict JSON array of objects: [ - {"title": , "content": ""}, + {"title": "", "chunk_id": ""}, ... ] FIELDS - "title": the heading text (clean, no page numbers or leader dots). - If any part of a chunk has no valid heading, output that part as {"title":"-1", ...}. -- "content": the chunk_id (string). +- "chunk_id": the chunk ID (string). - One chunk can yield multiple JSON objects in order (unmatched text + one or more headings). RULES 1) Preserve input chunk order strictly. 2) If a chunk contains multiple headings, expand them in order: - - Pre-heading narrative → {"title":"-1","content":chunk_id} - - Then each heading → {"title":"...","content":chunk_id} -3) Do not merge outputs across chunks; each object refers to exactly one chunk_id. -4) "title" must be non-empty (or exactly "-1"). "content" must be a string (chunk_id). + - Pre-heading narrative → {"title":"-1","chunk_id":""} + - Then each heading → {"title":"...","chunk_id":""} +3) Do not merge outputs across chunks; each object refers to exactly one chunk ID. +4) "title" must be non-empty (or exactly "-1"). "chunk_id" must be a string (chunk ID). 5) When ambiguous, prefer "-1" unless the text strongly looks like a heading. HEADING DETECTION (cues, not hard rules) @@ -51,63 +51,69 @@ EXAMPLES Example 1 — No heading Input: -{0: "Copyright page · Publication info (ISBN 123-456). All rights reserved."} +[{"0": "Copyright page · Publication info (ISBN 123-456). All rights reserved."}, ...] Output: [ - {"title":"-1","content":"0"} + {"title":"-1","chunk_id":"0"}, + ... ] Example 2 — One heading Input: -{1: "Chapter 1: General Provisions This chapter defines the overall rules…"} +[{"1": "Chapter 1: General Provisions This chapter defines the overall rules…"}, ...] Output: [ - {"title":"Chapter 1: General Provisions","content":"1"} + {"title":"Chapter 1: General Provisions","chunk_id":"1"}, + ... ] Example 3 — Narrative + heading Input: -{2: "This paragraph introduces the background and goals. Section 2: Definitions Key terms are explained…"} +[{"2": "This paragraph introduces the background and goals. Section 2: Definitions Key terms are explained…"}, ...] Output: [ - {"title":"-1","content":"2"}, - {"title":"Section 2: Definitions","content":"2"} + {"title":"Section 2: Definitions","chunk_id":"2"}, + ... ] Example 4 — Multiple headings in one chunk Input: -{3: "Declarations and Commitments (I) Party B commits… (II) Party C commits… Appendix A Data Specification"} +[{"3": "Declarations and Commitments (I) Party B commits… (II) Party C commits… Appendix A Data Specification"}, ...] Output: [ - {"title":"Declarations and Commitments (I)","content":"3"}, - {"title":"(II)","content":"3"}, - {"title":"Appendix A","content":"3"} + {"title":"Declarations and Commitments","chunk_id":"3"}, + {"title":"(I) Party B commits","chunk_id":"3"}, + {"title":"(II) Party C commits","chunk_id":"3"}, + {"title":"Appendix A Data Specification","chunk_id":"3"}, + ... ] Example 5 — Numbering styles Input: -{4: "1. Scope: Defines boundaries. 2) Definitions: Terms used. III) Methods Overview."} +[{"4": "1. Scope: Defines boundaries. 2) Definitions: Terms used. III) Methods Overview."}, ...] Output: [ - {"title":"1. Scope","content":"4"}, - {"title":"2) Definitions","content":"4"}, - {"title":"III) Methods","content":"4"} + {"title":"1. Scope","chunk_id":"4"}, + {"title":"2) Definitions","chunk_id":"4"}, + {"title":"III) Methods Overview","chunk_id":"4"}, + ... ] Example 6 — Long list (NOT headings) Input: -{5: "Item list: apples, bananas, strawberries, blueberries, mangos, peaches"} +{"5": "Item list: apples, bananas, strawberries, blueberries, mangos, peaches"}, ...] Output: [ - {"title":"-1","content":"5"} + {"title":"-1","chunk_id":"5"}, + ... ] Example 7 — Mixed Chinese/English Input: -{6: "(出版信息略)This standard follows industry practices. Chapter 1: Overview 摘要… 第2节:术语与缩略语"} +{"6": "(出版信息略)This standard follows industry practices. Chapter 1: Overview 摘要… 第2节:术语与缩略语"}, ...] Output: [ - {"title":"-1","content":"6"}, - {"title":"Chapter 1: Overview","content":"6"}, - {"title":"第2节:术语与缩略语","content":"6"} + {"title":"Chapter 1: Overview","chunk_id":"6"}, + {"title":"第2节:术语与缩略语","chunk_id":"6"}, + ... ] diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index ae1021486..08b232e49 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -32,7 +32,7 @@ from api.utils.log_utils import init_root_logger, get_project_base_directory from graphrag.general.index import run_graphrag_for_kb from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache from rag.flow.pipeline import Pipeline -from rag.prompts.generator import keyword_extraction, question_proposal, content_tagging +from rag.prompts.generator import keyword_extraction, question_proposal, content_tagging, run_toc_from_text import logging import os from datetime import datetime @@ -370,6 +370,38 @@ async def build_chunks(task, progress_callback): nursery.start_soon(doc_question_proposal, chat_mdl, d, task["parser_config"]["auto_questions"]) progress_callback(msg="Question generation {} chunks completed in {:.2f}s".format(len(docs), timer() - st)) + if task["parser_config"].get("toc_extraction", True): + progress_callback(msg="Start to generate table of content ...") + chat_mdl = LLMBundle(task["tenant_id"], LLMType.CHAT, llm_name=task["llm_id"], lang=task["language"]) + docs = sorted(docs, key=lambda d:( + d.get("page_num_int", 0)[0] if isinstance(d.get("page_num_int", 0), list) else d.get("page_num_int", 0), + d.get("top_int", 0)[0] if isinstance(d.get("top_int", 0), list) else d.get("top_int", 0) + )) + toc: list[dict] = await run_toc_from_text([d["content_with_weight"] for d in docs], chat_mdl) + logging.info("------------ T O C -------------\n"+json.dumps(toc, ensure_ascii=False, indent=' ')) + ii = 0 + while ii < len(toc): + try: + idx = int(toc[ii]["chunk_id"]) + del toc[ii]["chunk_id"] + toc[ii]["ids"] = [docs[idx]["id"]] + if ii == len(toc) -1: + break + for jj in range(idx+1, int(toc[ii+1]["chunk_id"])): + toc[ii]["ids"].append(docs[jj]["id"]) + except Exception as e: + logging.exception(e) + ii += 1 + + if toc: + d = copy.deepcopy(docs[-1]) + d["content_with_weight"] = json.dumps(toc, ensure_ascii=False) + d["toc_kwd"] = "toc" + d["available_int"] = 0 + d["page_num_int"] = 100000000 + d["id"] = xxhash.xxh64((d["content_with_weight"] + str(d["doc_id"])).encode("utf-8", "surrogatepass")).hexdigest() + docs.append(d) + if task["kb_parser_config"].get("tag_kb_ids", []): progress_callback(msg="Start to tag for every chunk ...") kb_ids = task["kb_parser_config"]["tag_kb_ids"]