mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Fix:improve multi-column document detection (#11415)
### What problem does this PR solve? change: improve multi-column document detection ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
@ -33,6 +33,8 @@ import xgboost as xgb
|
||||
from huggingface_hub import snapshot_download
|
||||
from PIL import Image
|
||||
from pypdf import PdfReader as pdf2_read
|
||||
from sklearn.cluster import KMeans
|
||||
from sklearn.metrics import silhouette_score
|
||||
|
||||
from common.file_utils import get_project_base_directory
|
||||
from common.misc_utils import pip_install_torch
|
||||
@ -353,7 +355,6 @@ class RAGFlowPdfParser:
|
||||
def _assign_column(self, boxes, zoomin=3):
|
||||
if not boxes:
|
||||
return boxes
|
||||
|
||||
if all("col_id" in b for b in boxes):
|
||||
return boxes
|
||||
|
||||
@ -361,61 +362,80 @@ class RAGFlowPdfParser:
|
||||
for b in boxes:
|
||||
by_page[b["page_number"]].append(b)
|
||||
|
||||
page_info = {} # pg -> dict(page_w, left_edge, cand_cols)
|
||||
counter = Counter()
|
||||
page_cols = {}
|
||||
|
||||
for pg, bxs in by_page.items():
|
||||
if not bxs:
|
||||
page_info[pg] = {"page_w": 1.0, "left_edge": 0.0, "cand": 1}
|
||||
counter[1] += 1
|
||||
page_cols[pg] = 1
|
||||
continue
|
||||
|
||||
if hasattr(self, "page_images") and self.page_images and len(self.page_images) >= pg:
|
||||
page_w = self.page_images[pg - 1].size[0] / max(1, zoomin)
|
||||
left_edge = 0.0
|
||||
else:
|
||||
xs0 = [box["x0"] for box in bxs]
|
||||
xs1 = [box["x1"] for box in bxs]
|
||||
left_edge = float(min(xs0))
|
||||
page_w = max(1.0, float(max(xs1) - left_edge))
|
||||
x0s_raw = np.array([b["x0"] for b in bxs], dtype=float)
|
||||
|
||||
widths = [max(1.0, (box["x1"] - box["x0"])) for box in bxs]
|
||||
median_w = float(np.median(widths)) if widths else 1.0
|
||||
min_x0 = np.min(x0s_raw)
|
||||
max_x1 = np.max([b["x1"] for b in bxs])
|
||||
width = max_x1 - min_x0
|
||||
|
||||
raw_cols = int(page_w / max(1.0, median_w))
|
||||
INDENT_TOL = width * 0.12
|
||||
x0s = []
|
||||
for x in x0s_raw:
|
||||
if abs(x - min_x0) < INDENT_TOL:
|
||||
x0s.append([min_x0])
|
||||
else:
|
||||
x0s.append([x])
|
||||
x0s = np.array(x0s, dtype=float)
|
||||
|
||||
max_try = min(4, len(bxs))
|
||||
if max_try < 2:
|
||||
max_try = 1
|
||||
best_k = 1
|
||||
best_score = -1
|
||||
|
||||
# cand = raw_cols if (raw_cols >= 2 and median_w < page_w / raw_cols * 0.8) else 1
|
||||
cand = raw_cols
|
||||
for k in range(1, max_try + 1):
|
||||
km = KMeans(n_clusters=k, n_init="auto")
|
||||
labels = km.fit_predict(x0s)
|
||||
|
||||
page_info[pg] = {"page_w": page_w, "left_edge": left_edge, "cand": cand}
|
||||
counter[cand] += 1
|
||||
centers = np.sort(km.cluster_centers_.flatten())
|
||||
if len(centers) > 1:
|
||||
try:
|
||||
score = silhouette_score(x0s, labels)
|
||||
except ValueError:
|
||||
continue
|
||||
else:
|
||||
score = 0
|
||||
print(f"{k=},{score=}",flush=True)
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
best_k = k
|
||||
|
||||
logging.info(f"[Page {pg}] median_w={median_w:.2f}, page_w={page_w:.2f}, raw_cols={raw_cols}, cand={cand}")
|
||||
page_cols[pg] = best_k
|
||||
logging.info(f"[Page {pg}] best_score={best_score:.2f}, best_k={best_k}")
|
||||
|
||||
global_cols = counter.most_common(1)[0][0]
|
||||
|
||||
global_cols = Counter(page_cols.values()).most_common(1)[0][0]
|
||||
logging.info(f"Global column_num decided by majority: {global_cols}")
|
||||
|
||||
|
||||
for pg, bxs in by_page.items():
|
||||
if not bxs:
|
||||
continue
|
||||
k = page_cols[pg]
|
||||
if len(bxs) < k:
|
||||
k = 1
|
||||
x0s = np.array([[b["x0"]] for b in bxs], dtype=float)
|
||||
km = KMeans(n_clusters=k, n_init="auto")
|
||||
labels = km.fit_predict(x0s)
|
||||
|
||||
page_w = page_info[pg]["page_w"]
|
||||
left_edge = page_info[pg]["left_edge"]
|
||||
centers = km.cluster_centers_.flatten()
|
||||
order = np.argsort(centers)
|
||||
|
||||
if global_cols == 1:
|
||||
for box in bxs:
|
||||
box["col_id"] = 0
|
||||
continue
|
||||
remap = {orig: new for new, orig in enumerate(order)}
|
||||
|
||||
for box in bxs:
|
||||
w = box["x1"] - box["x0"]
|
||||
if w >= 0.8 * page_w:
|
||||
box["col_id"] = 0
|
||||
continue
|
||||
cx = 0.5 * (box["x0"] + box["x1"])
|
||||
norm_cx = (cx - left_edge) / page_w
|
||||
norm_cx = max(0.0, min(norm_cx, 0.999999))
|
||||
box["col_id"] = int(min(global_cols - 1, norm_cx * global_cols))
|
||||
for b, lb in zip(bxs, labels):
|
||||
b["col_id"] = remap[lb]
|
||||
|
||||
grouped = defaultdict(list)
|
||||
for b in bxs:
|
||||
grouped[b["col_id"]].append(b)
|
||||
|
||||
return boxes
|
||||
|
||||
@ -1303,7 +1323,10 @@ class RAGFlowPdfParser:
|
||||
|
||||
positions = []
|
||||
for ii, (pns, left, right, top, bottom) in enumerate(poss):
|
||||
right = left + max_width
|
||||
if 0 < ii < len(poss) - 1:
|
||||
right = max(left + 10, right)
|
||||
else:
|
||||
right = left + max_width
|
||||
bottom *= ZM
|
||||
for pn in pns[1:]:
|
||||
if 0 <= pn - 1 < page_count:
|
||||
|
||||
Reference in New Issue
Block a user