Fix:improve multi-column document detection (#11415)

### What problem does this PR solve?

change:
improve multi-column document detection

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
buua436
2025-11-20 19:00:38 +08:00
committed by GitHub
parent 0d5589bfda
commit c8ab9079b3

View File

@ -33,6 +33,8 @@ import xgboost as xgb
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from PIL import Image from PIL import Image
from pypdf import PdfReader as pdf2_read from pypdf import PdfReader as pdf2_read
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
from common.file_utils import get_project_base_directory from common.file_utils import get_project_base_directory
from common.misc_utils import pip_install_torch from common.misc_utils import pip_install_torch
@ -353,7 +355,6 @@ class RAGFlowPdfParser:
def _assign_column(self, boxes, zoomin=3): def _assign_column(self, boxes, zoomin=3):
if not boxes: if not boxes:
return boxes return boxes
if all("col_id" in b for b in boxes): if all("col_id" in b for b in boxes):
return boxes return boxes
@ -361,61 +362,80 @@ class RAGFlowPdfParser:
for b in boxes: for b in boxes:
by_page[b["page_number"]].append(b) by_page[b["page_number"]].append(b)
page_info = {} # pg -> dict(page_w, left_edge, cand_cols) page_cols = {}
counter = Counter()
for pg, bxs in by_page.items(): for pg, bxs in by_page.items():
if not bxs: if not bxs:
page_info[pg] = {"page_w": 1.0, "left_edge": 0.0, "cand": 1} page_cols[pg] = 1
counter[1] += 1
continue continue
if hasattr(self, "page_images") and self.page_images and len(self.page_images) >= pg: x0s_raw = np.array([b["x0"] for b in bxs], dtype=float)
page_w = self.page_images[pg - 1].size[0] / max(1, zoomin)
left_edge = 0.0
else:
xs0 = [box["x0"] for box in bxs]
xs1 = [box["x1"] for box in bxs]
left_edge = float(min(xs0))
page_w = max(1.0, float(max(xs1) - left_edge))
widths = [max(1.0, (box["x1"] - box["x0"])) for box in bxs] min_x0 = np.min(x0s_raw)
median_w = float(np.median(widths)) if widths else 1.0 max_x1 = np.max([b["x1"] for b in bxs])
width = max_x1 - min_x0
raw_cols = int(page_w / max(1.0, median_w)) INDENT_TOL = width * 0.12
x0s = []
for x in x0s_raw:
if abs(x - min_x0) < INDENT_TOL:
x0s.append([min_x0])
else:
x0s.append([x])
x0s = np.array(x0s, dtype=float)
# cand = raw_cols if (raw_cols >= 2 and median_w < page_w / raw_cols * 0.8) else 1 max_try = min(4, len(bxs))
cand = raw_cols if max_try < 2:
max_try = 1
best_k = 1
best_score = -1
page_info[pg] = {"page_w": page_w, "left_edge": left_edge, "cand": cand} for k in range(1, max_try + 1):
counter[cand] += 1 km = KMeans(n_clusters=k, n_init="auto")
labels = km.fit_predict(x0s)
logging.info(f"[Page {pg}] median_w={median_w:.2f}, page_w={page_w:.2f}, raw_cols={raw_cols}, cand={cand}") centers = np.sort(km.cluster_centers_.flatten())
if len(centers) > 1:
try:
score = silhouette_score(x0s, labels)
except ValueError:
continue
else:
score = 0
print(f"{k=},{score=}",flush=True)
if score > best_score:
best_score = score
best_k = k
global_cols = counter.most_common(1)[0][0] page_cols[pg] = best_k
logging.info(f"[Page {pg}] best_score={best_score:.2f}, best_k={best_k}")
global_cols = Counter(page_cols.values()).most_common(1)[0][0]
logging.info(f"Global column_num decided by majority: {global_cols}") logging.info(f"Global column_num decided by majority: {global_cols}")
for pg, bxs in by_page.items(): for pg, bxs in by_page.items():
if not bxs: if not bxs:
continue continue
k = page_cols[pg]
if len(bxs) < k:
k = 1
x0s = np.array([[b["x0"]] for b in bxs], dtype=float)
km = KMeans(n_clusters=k, n_init="auto")
labels = km.fit_predict(x0s)
page_w = page_info[pg]["page_w"] centers = km.cluster_centers_.flatten()
left_edge = page_info[pg]["left_edge"] order = np.argsort(centers)
if global_cols == 1: remap = {orig: new for new, orig in enumerate(order)}
for box in bxs:
box["col_id"] = 0
continue
for box in bxs: for b, lb in zip(bxs, labels):
w = box["x1"] - box["x0"] b["col_id"] = remap[lb]
if w >= 0.8 * page_w:
box["col_id"] = 0 grouped = defaultdict(list)
continue for b in bxs:
cx = 0.5 * (box["x0"] + box["x1"]) grouped[b["col_id"]].append(b)
norm_cx = (cx - left_edge) / page_w
norm_cx = max(0.0, min(norm_cx, 0.999999))
box["col_id"] = int(min(global_cols - 1, norm_cx * global_cols))
return boxes return boxes
@ -1303,7 +1323,10 @@ class RAGFlowPdfParser:
positions = [] positions = []
for ii, (pns, left, right, top, bottom) in enumerate(poss): for ii, (pns, left, right, top, bottom) in enumerate(poss):
right = left + max_width if 0 < ii < len(poss) - 1:
right = max(left + 10, right)
else:
right = left + max_width
bottom *= ZM bottom *= ZM
for pn in pns[1:]: for pn in pns[1:]:
if 0 <= pn - 1 < page_count: if 0 <= pn - 1 < page_count: