From 8e4d011b15b26065b053b4cc9c3152f28f1b255f Mon Sep 17 00:00:00 2001 From: Kevin Hu Date: Wed, 17 Dec 2025 16:50:36 +0800 Subject: [PATCH] Fix: parent-children chunking method. (#11997) ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) - [x] New Feature (non-breaking change which adds functionality) --- api/apps/canvas_app.py | 2 +- api/apps/chunk_app.py | 1 + common/metadata_utils.py | 65 +++++++++++++++++++++++++++- graphrag/general/extractor.py | 2 +- rag/app/naive.py | 11 +++-- rag/flow/splitter/splitter.py | 9 +--- rag/nlp/__init__.py | 25 +++++++---- rag/prompts/generator.py | 10 +++++ rag/prompts/meta_data.md | 13 ++++++ rag/svr/task_executor.py | 79 +++++++++++++++++++---------------- 10 files changed, 160 insertions(+), 57 deletions(-) create mode 100644 rag/prompts/meta_data.md diff --git a/api/apps/canvas_app.py b/api/apps/canvas_app.py index 64b0d0f55..ab6367eb5 100644 --- a/api/apps/canvas_app.py +++ b/api/apps/canvas_app.py @@ -147,7 +147,7 @@ async def run(): if cvs.canvas_category == CanvasCategory.DataFlow: task_id = get_uuid() Pipeline(cvs.dsl, tenant_id=current_user.id, doc_id=CANVAS_DEBUG_DOC_ID, task_id=task_id, flow_id=req["id"]) - ok, error_message = await asyncio.to_thread(queue_dataflow, user_id, req["id"], task_id, files[0], 0) + ok, error_message = await asyncio.to_thread(queue_dataflow, user_id, req["id"], task_id, CANVAS_DEBUG_DOC_ID, files[0], 0) if not ok: return get_data_error_result(message=error_message) return get_json_result(data={"message_id": task_id}) diff --git a/api/apps/chunk_app.py b/api/apps/chunk_app.py index 26df98148..a5df41983 100644 --- a/api/apps/chunk_app.py +++ b/api/apps/chunk_app.py @@ -386,6 +386,7 @@ async def retrieval_test(): LLMBundle(kb.tenant_id, LLMType.CHAT)) if ck["content_with_weight"]: ranks["chunks"].insert(0, ck) + ranks["chunks"] = settings.retriever.retrieval_by_children(ranks["chunks"], tenant_ids) for c in ranks["chunks"]: c.pop("vector", None) diff --git a/common/metadata_utils.py b/common/metadata_utils.py index 957ed3ece..ca6d36598 100644 --- a/common/metadata_utils.py +++ b/common/metadata_utils.py @@ -13,7 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from typing import Any, Callable +import logging +from typing import Any, Callable, Dict + +import json_repair from rag.prompts.generator import gen_meta_filter @@ -140,3 +143,63 @@ async def apply_meta_data_filter( doc_ids = ["-999"] return doc_ids + + +def update_metadata_to(metadata, meta): + if not meta: + return metadata + if isinstance(meta, str): + try: + meta = json_repair.loads(meta) + except Exception: + logging.error("Meta data format error.") + return metadata + if not isinstance(meta, dict): + return metadata + for k, v in meta.items(): + if isinstance(v, list): + v = [vv for vv in v if isinstance(vv, str)] + if not v: + continue + if not isinstance(v, list) and not isinstance(v, str): + continue + if k not in metadata: + metadata[k] = v + continue + if isinstance(metadata[k], list): + if isinstance(v, list): + metadata[k].extend(v) + else: + metadata[k].append(v) + else: + metadata[k] = v + + return metadata + + +def metadata_schema(metadata: list|None) -> Dict[str, Any]: + if not metadata: + return {} + properties = {} + + for item in metadata: + key = item.get("key") + if not key: + continue + + prop_schema = { + "description": item.get("description", "") + } + if "enum" in item and item["enum"]: + prop_schema["enum"] = item["enum"] + prop_schema["type"] = "string" + + properties[key] = prop_schema + + json_schema = { + "type": "object", + "properties": properties, + } + + json_schema["additionalProperties"] = False + return json_schema \ No newline at end of file diff --git a/graphrag/general/extractor.py b/graphrag/general/extractor.py index 86c971c4c..a965a30c4 100644 --- a/graphrag/general/extractor.py +++ b/graphrag/general/extractor.py @@ -78,7 +78,7 @@ class Extractor: raise TaskCanceledException(f"Task {task_id} was cancelled") try: - response = self._llm.chat(system_msg[0]["content"], hist, conf) + response = asyncio.run(self._llm.async_chat(system_msg[0]["content"], hist, conf)) response = re.sub(r"^.*", "", response, flags=re.DOTALL) if response.find("**ERROR**") >= 0: raise Exception(response) diff --git a/rag/app/naive.py b/rag/app/naive.py index 4d07d0983..579ed8380 100644 --- a/rag/app/naive.py +++ b/rag/app/naive.py @@ -635,9 +635,14 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca "parser_config", { "chunk_token_num": 512, "delimiter": "\n!?。;!?", "layout_recognize": "DeepDOC", "analyze_hyperlink": True}) - child_deli = re.findall(r"`([^`]+)`", parser_config.get("children_delimiter", "")) - child_deli = sorted(set(child_deli), key=lambda x: -len(x)) - child_deli = "|".join(re.escape(t) for t in child_deli if t) + child_deli = parser_config.get("children_delimiter", "").encode('utf-8').decode('unicode_escape').encode('latin1').decode('utf-8') + cust_child_deli = re.findall(r"`([^`]+)`", child_deli) + child_deli = "|".join(re.sub(r"`([^`]+)`", "", child_deli)) + if cust_child_deli: + cust_child_deli = sorted(set(cust_child_deli), key=lambda x: -len(x)) + cust_child_deli = "|".join(re.escape(t) for t in cust_child_deli if t) + child_deli += cust_child_deli + is_markdown = False table_context_size = max(0, int(parser_config.get("table_context_size", 0) or 0)) image_context_size = max(0, int(parser_config.get("image_context_size", 0) or 0)) diff --git a/rag/flow/splitter/splitter.py b/rag/flow/splitter/splitter.py index e0174800f..45abb547a 100644 --- a/rag/flow/splitter/splitter.py +++ b/rag/flow/splitter/splitter.py @@ -60,14 +60,7 @@ class Splitter(ProcessBase): deli += f"`{d}`" else: deli += d - child_deli = "" - for d in self._param.children_delimiters: - if len(d) > 1: - child_deli += f"`{d}`" - else: - child_deli += d - child_deli = [m.group(1) for m in re.finditer(r"`([^`]+)`", child_deli)] - custom_pattern = "|".join(re.escape(t) for t in sorted(set(child_deli), key=len, reverse=True)) + custom_pattern = "|".join(re.escape(t) for t in sorted(set(self._param.children_delimiters), key=len, reverse=True)) self.set_output("output_format", "chunks") self.callback(random.randint(1, 5) / 100.0, "Start to split into chunks.") diff --git a/rag/nlp/__init__.py b/rag/nlp/__init__.py index d6c23217b..1619eadbe 100644 --- a/rag/nlp/__init__.py +++ b/rag/nlp/__init__.py @@ -273,6 +273,21 @@ def tokenize(d, txt, eng): d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"]) +def split_with_pattern(d, pattern:str, content:str, eng) -> list: + docs = [] + txts = [txt for txt in re.split(r"(%s)" % pattern, content, flags=re.DOTALL)] + for j in range(0, len(txts), 2): + txt = txts[j] + if not txt: + continue + if j + 1 < len(txts): + txt += txts[j+1] + dd = copy.deepcopy(d) + tokenize(dd, txt, eng) + docs.append(dd) + return docs + + def tokenize_chunks(chunks, doc, eng, pdf_parser=None, child_delimiters_pattern=None): res = [] # wrap up as es documents @@ -293,10 +308,7 @@ def tokenize_chunks(chunks, doc, eng, pdf_parser=None, child_delimiters_pattern= if child_delimiters_pattern: d["mom_with_weight"] = ck - for txt in re.split(r"(%s)" % child_delimiters_pattern, ck, flags=re.DOTALL): - dd = copy.deepcopy(d) - tokenize(dd, txt, eng) - res.append(dd) + res.extend(split_with_pattern(d, child_delimiters_pattern, ck, eng)) continue tokenize(d, ck, eng) @@ -316,10 +328,7 @@ def tokenize_chunks_with_images(chunks, doc, eng, images, child_delimiters_patte add_positions(d, [[ii]*5]) if child_delimiters_pattern: d["mom_with_weight"] = ck - for txt in re.split(r"(%s)" % child_delimiters_pattern, ck, flags=re.DOTALL): - dd = copy.deepcopy(d) - tokenize(dd, txt, eng) - res.append(dd) + res.extend(split_with_pattern(d, child_delimiters_pattern, ck, eng)) continue tokenize(d, ck, eng) res.append(d) diff --git a/rag/prompts/generator.py b/rag/prompts/generator.py index 621a460ad..82f8ca556 100644 --- a/rag/prompts/generator.py +++ b/rag/prompts/generator.py @@ -821,3 +821,13 @@ async def relevant_chunks_with_toc(query: str, toc:list[dict], chat_mdl, topn: i except Exception as e: logging.exception(e) return [] + + +META_DATA = load_prompt("meta_data") +async def gen_metadata(chat_mdl, schema:dict, content:str): + template = PROMPT_JINJA_ENV.from_string(META_DATA) + system_prompt = template.render(content=content, schema=schema) + user_prompt = "Output: " + _, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length) + ans = await chat_mdl.async_chat(msg[0]["content"], msg[1:]) + return re.sub(r"^.*", "", ans, flags=re.DOTALL) \ No newline at end of file diff --git a/rag/prompts/meta_data.md b/rag/prompts/meta_data.md new file mode 100644 index 000000000..440396df4 --- /dev/null +++ b/rag/prompts/meta_data.md @@ -0,0 +1,13 @@ +Extract important structured information from the given content. +Output ONLY a valid JSON string with no additional text. +If no important structured information is found, output an empty JSON object: {}. + +Important structured information structure as following: + +{{ schema }} + +--------------------------- +The given content as following: + +{{ content }} + diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index 1de9e7049..8668d9178 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -23,19 +23,19 @@ import sys import threading import time -import json_repair - from api.db import PIPELINE_SPECIAL_PROGRESS_FREEZE_TASK_TYPES from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.pipeline_operation_log_service import PipelineOperationLogService from common.connection_utils import timeout +from common.metadata_utils import update_metadata_to, metadata_schema from rag.utils.base64_image import image2id from rag.utils.raptor_utils import should_skip_raptor, get_skip_reason from common.log_utils import init_root_logger from common.config_utils import show_configs from graphrag.general.index import run_graphrag_for_kb from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache -from rag.prompts.generator import keyword_extraction, question_proposal, content_tagging, run_toc_from_text +from rag.prompts.generator import keyword_extraction, question_proposal, content_tagging, run_toc_from_text, \ + gen_metadata import logging import os from datetime import datetime @@ -368,6 +368,45 @@ async def build_chunks(task, progress_callback): raise progress_callback(msg="Question generation {} chunks completed in {:.2f}s".format(len(docs), timer() - st)) + if task["parser_config"].get("enable_metadata", False) and task["parser_config"].get("metadata"): + st = timer() + progress_callback(msg="Start to generate meta-data for every chunk ...") + chat_mdl = LLMBundle(task["tenant_id"], LLMType.CHAT, llm_name=task["llm_id"], lang=task["language"]) + + async def gen_metadata_task(chat_mdl, d): + cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "metadata") + if not cached: + async with chat_limiter: + cached = await gen_metadata(chat_mdl, + metadata_schema(task["parser_config"]["metadata"]), + d["content_with_weight"]) + set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "metadata") + if cached: + d["metadata_obj"] = cached + tasks = [] + for d in docs: + tasks.append(asyncio.create_task(gen_metadata_task(chat_mdl, d))) + try: + await asyncio.gather(*tasks, return_exceptions=False) + except Exception as e: + logging.error("Error in doc_question_proposal", exc_info=e) + for t in tasks: + t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + raise + metadata = {} + for ck in cks: + metadata = update_metadata_to(metadata, ck["metadata_obj"]) + del ck["metadata_obj"] + if metadata: + e, doc = DocumentService.get_by_id(task["doc_id"]) + if e: + if isinstance(doc.meta_fields, str): + doc.meta_fields = json.loads(doc.meta_fields) + metadata = update_metadata_to(metadata, doc.meta_fields) + DocumentService.update_by_id(task["doc_id"], {"meta_fields": metadata}) + progress_callback(msg="Question generation {} chunks completed in {:.2f}s".format(len(docs), timer() - st)) + if task["kb_parser_config"].get("tag_kb_ids", []): progress_callback(msg="Start to tag for every chunk ...") kb_ids = task["kb_parser_config"]["tag_kb_ids"] @@ -602,36 +641,6 @@ async def run_dataflow(task: dict): metadata = {} - def dict_update(meta): - nonlocal metadata - if not meta: - return - if isinstance(meta, str): - try: - meta = json_repair.loads(meta) - except Exception: - logging.error("Meta data format error.") - return - if not isinstance(meta, dict): - return - for k, v in meta.items(): - if isinstance(v, list): - v = [vv for vv in v if isinstance(vv, str)] - if not v: - continue - if not isinstance(v, list) and not isinstance(v, str): - continue - if k not in metadata: - metadata[k] = v - continue - if isinstance(metadata[k], list): - if isinstance(v, list): - metadata[k].extend(v) - else: - metadata[k].append(v) - else: - metadata[k] = v - for ck in chunks: ck["doc_id"] = doc_id ck["kb_id"] = [str(task["kb_id"])] @@ -656,7 +665,7 @@ async def run_dataflow(task: dict): ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"]) del ck["summary"] if "metadata" in ck: - dict_update(ck["metadata"]) + metadata = update_metadata_to(metadata, ck["metadata"]) del ck["metadata"] if "content_with_weight" not in ck: ck["content_with_weight"] = ck["text"] @@ -670,7 +679,7 @@ async def run_dataflow(task: dict): if e: if isinstance(doc.meta_fields, str): doc.meta_fields = json.loads(doc.meta_fields) - dict_update(doc.meta_fields) + metadata = update_metadata_to(metadata, doc.meta_fields) DocumentService.update_by_id(doc_id, {"meta_fields": metadata}) start_ts = timer()