diff --git a/rag/nlp/__init__.py b/rag/nlp/__init__.py index f61019377..add454ade 100644 --- a/rag/nlp/__init__.py +++ b/rag/nlp/__init__.py @@ -437,16 +437,16 @@ def not_title(txt): return re.search(r"[,;,。;!!]", txt) def tree_merge(bull, sections, depth): - + if not sections or bull < 0: return sections if isinstance(sections[0], type("")): sections = [(s, "") for s in sections] - + # filter out position information in pdf sections sections = [(t, o) for t, o in sections if t and len(t.split("@")[0].strip()) > 1 and not re.match(r"[0-9]+$", t.split("@")[0].strip())] - + def get_level(bull, section): text, layout = section text = re.sub(r"\u3000", " ", text).strip() @@ -465,7 +465,7 @@ def tree_merge(bull, sections, depth): level, text = get_level(bull, section) if not text.strip("\n"): continue - + lines.append((level, text)) level_set.add(level) @@ -608,6 +608,26 @@ def naive_merge(sections: str | list, chunk_token_num=128, delimiter="\n。; cks[-1] += t tk_nums[-1] += tnum + custom_delimiters = [m.group(1) for m in re.finditer(r"`([^`]+)`", delimiter)] + has_custom = bool(custom_delimiters) + if has_custom: + custom_pattern = "|".join(re.escape(t) for t in sorted(set(custom_delimiters), key=len, reverse=True)) + cks, tk_nums = [], [] + for sec, pos in sections: + split_sec = re.split(r"(%s)" % custom_pattern, sec, flags=re.DOTALL) + for sub_sec in split_sec: + if re.fullmatch(custom_pattern, sub_sec or ""): + continue + text = "\n" + sub_sec + local_pos = pos + if num_tokens_from_string(text) < 8: + local_pos = "" + if local_pos and text.find(local_pos) < 0: + text += local_pos + cks.append(text) + tk_nums.append(num_tokens_from_string(text)) + return cks + dels = get_delimiters(delimiter) for sec, pos in sections: if num_tokens_from_string(sec) < chunk_token_num: @@ -657,6 +677,29 @@ def naive_merge_with_images(texts, images, chunk_token_num=128, delimiter="\n。 result_images[-1] = concat_img(result_images[-1], image) tk_nums[-1] += tnum + custom_delimiters = [m.group(1) for m in re.finditer(r"`([^`]+)`", delimiter)] + has_custom = bool(custom_delimiters) + if has_custom: + custom_pattern = "|".join(re.escape(t) for t in sorted(set(custom_delimiters), key=len, reverse=True)) + cks, result_images, tk_nums = [], [], [] + for text, image in zip(texts, images): + text_str = text[0] if isinstance(text, tuple) else text + text_pos = text[1] if isinstance(text, tuple) and len(text) > 1 else "" + split_sec = re.split(r"(%s)" % custom_pattern, text_str) + for sub_sec in split_sec: + if re.fullmatch(custom_pattern, sub_sec or ""): + continue + text_seg = "\n" + sub_sec + local_pos = text_pos + if num_tokens_from_string(text_seg) < 8: + local_pos = "" + if local_pos and text_seg.find(local_pos) < 0: + text_seg += local_pos + cks.append(text_seg) + result_images.append(image) + tk_nums.append(num_tokens_from_string(text_seg)) + return cks, result_images + dels = get_delimiters(delimiter) for text, image in zip(texts, images): # if text is tuple, unpack it @@ -748,6 +791,23 @@ def naive_merge_docx(sections, chunk_token_num=128, delimiter="\n。;!?"): images[-1] = concat_img(images[-1], image) tk_nums[-1] += tnum + custom_delimiters = [m.group(1) for m in re.finditer(r"`([^`]+)`", delimiter)] + has_custom = bool(custom_delimiters) + if has_custom: + custom_pattern = "|".join(re.escape(t) for t in sorted(set(custom_delimiters), key=len, reverse=True)) + cks, images, tk_nums = [], [], [] + pattern = r"(%s)" % custom_pattern + for sec, image in sections: + split_sec = re.split(pattern, sec) + for sub_sec in split_sec: + if not sub_sec or re.fullmatch(custom_pattern, sub_sec): + continue + text_seg = "\n" + sub_sec + cks.append(text_seg) + images.append(image) + tk_nums.append(num_tokens_from_string(text_seg)) + return cks, images + dels = get_delimiters(delimiter) pattern = r"(%s)" % dels @@ -789,7 +849,7 @@ class Node: self.level = level self.depth = depth self.texts = texts or [] - self.children = [] + self.children = [] def add_child(self, child_node): self.children.append(child_node) @@ -835,7 +895,7 @@ class Node: return self def get_tree(self): - tree_list = [] + tree_list = [] self._dfs(self, tree_list, []) return tree_list @@ -860,7 +920,7 @@ class Node: # A leaf title within depth emits its title path as a chunk (header-only section) elif not child and (1 <= level <= self.depth): tree_list.append("\n".join(path_titles)) - + # Recurse into children with the updated title path for c in child: - self._dfs(c, tree_list, path_titles) \ No newline at end of file + self._dfs(c, tree_list, path_titles)