Refa: improve image table context (#12244)

### What problem does this PR solve?

Improve image table context.

Current strategy in attach_media_context:

- Order by position when possible: if any chunk has page/position info,
sort by (page, top, left), otherwise keep original order.
- Apply only to media chunks: images use image_context_size, tables use
table_context_size.
- Primary matching: on the same page, choose a text chunk whose vertical
span overlaps the media, then pick the one with the closest vertical
midpoint.
- Fallback matching: if no overlap on that page, choose the nearest text
chunk on the same page (page-head uses the next text; page-tail uses the
previous text).
- Context extraction: inside the chosen text chunk, find a mid-sentence
boundary near the text midpoint, then take context_size tokens split
before/after (total budget).
- No multi-chunk stitching: context comes from a single text chunk to
avoid mixing unrelated segments.

### Type of change

- [x] Refactoring

---------

Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
This commit is contained in:
Yongteng Lei
2025-12-26 17:55:32 +08:00
committed by GitHub
parent 9de3ecc4a8
commit 51bc41b2e8
4 changed files with 165 additions and 43 deletions

View File

@ -376,6 +376,7 @@ def attach_media_context(chunks, table_context_size=0, image_context_size=0):
order chunks before collecting context; otherwise keep original order.
"""
from . import rag_tokenizer
if not chunks or (table_context_size <= 0 and image_context_size <= 0):
return chunks
@ -418,6 +419,51 @@ def attach_media_context(chunks, table_context_size=0, image_context_size=0):
sentences.append(buf)
return sentences
def get_bounds_by_page(ck):
bounds = {}
try:
if ck.get("position_int"):
for pos in ck["position_int"]:
if not pos or len(pos) < 5:
continue
pn, _, _, top, bottom = pos
if pn is None or top is None:
continue
top_val = float(top)
bottom_val = float(bottom) if bottom is not None else top_val
if bottom_val < top_val:
top_val, bottom_val = bottom_val, top_val
pn = int(pn)
if pn in bounds:
bounds[pn] = (min(bounds[pn][0], top_val), max(bounds[pn][1], bottom_val))
else:
bounds[pn] = (top_val, bottom_val)
else:
pn = None
if ck.get("page_num_int"):
pn = ck["page_num_int"][0]
elif ck.get("page_number") is not None:
pn = ck.get("page_number")
if pn is None:
return bounds
top = None
if ck.get("top_int"):
top = ck["top_int"][0]
elif ck.get("top") is not None:
top = ck.get("top")
if top is None:
return bounds
bottom = ck.get("bottom")
pn = int(pn)
top_val = float(top)
bottom_val = float(bottom) if bottom is not None else top_val
if bottom_val < top_val:
top_val, bottom_val = bottom_val, top_val
bounds[pn] = (top_val, bottom_val)
except Exception:
return {}
return bounds
def trim_to_tokens(text, token_budget, from_tail=False):
if token_budget <= 0 or not text:
return ""
@ -442,6 +488,55 @@ def attach_media_context(chunks, table_context_size=0, image_context_size=0):
collected = list(reversed(collected))
return "".join(collected)
def find_mid_sentence_index(sentences):
if not sentences:
return 0
total = sum(max(0, num_tokens_from_string(s)) for s in sentences)
if total <= 0:
return max(0, len(sentences) // 2)
target = total / 2.0
best_idx = 0
best_diff = None
cum = 0
for i, s in enumerate(sentences):
cum += max(0, num_tokens_from_string(s))
diff = abs(cum - target)
if best_diff is None or diff < best_diff:
best_diff = diff
best_idx = i
return best_idx
def collect_context_from_sentences(sentences, boundary_idx, token_budget):
prev_ctx = []
remaining_prev = token_budget
for s in reversed(sentences[:boundary_idx + 1]):
if remaining_prev <= 0:
break
tks = num_tokens_from_string(s)
if tks <= 0:
continue
if tks > remaining_prev:
s = trim_to_tokens(s, remaining_prev, from_tail=True)
tks = num_tokens_from_string(s)
prev_ctx.append(s)
remaining_prev -= tks
prev_ctx.reverse()
next_ctx = []
remaining_next = token_budget
for s in sentences[boundary_idx + 1:]:
if remaining_next <= 0:
break
tks = num_tokens_from_string(s)
if tks <= 0:
continue
if tks > remaining_next:
s = trim_to_tokens(s, remaining_next, from_tail=False)
tks = num_tokens_from_string(s)
next_ctx.append(s)
remaining_next -= tks
return prev_ctx, next_ctx
def extract_position(ck):
pn = None
top = None
@ -481,7 +576,14 @@ def attach_media_context(chunks, table_context_size=0, image_context_size=0):
else:
ordered_indices = [idx for idx, _ in indexed]
total = len(ordered_indices)
text_bounds = []
for idx, ck in indexed:
if not is_text_chunk(ck):
continue
bounds = get_bounds_by_page(ck)
if bounds:
text_bounds.append((idx, bounds))
for sorted_pos, idx in enumerate(ordered_indices):
ck = chunks[idx]
token_budget = image_context_size if is_image_chunk(ck) else table_context_size if is_table_chunk(ck) else 0
@ -489,45 +591,51 @@ def attach_media_context(chunks, table_context_size=0, image_context_size=0):
continue
prev_ctx = []
remaining_prev = token_budget
for prev_idx in range(sorted_pos - 1, -1, -1):
if remaining_prev <= 0:
break
neighbor_idx = ordered_indices[prev_idx]
if not is_text_chunk(chunks[neighbor_idx]):
break
txt = get_text(chunks[neighbor_idx])
if not txt:
continue
tks = num_tokens_from_string(txt)
if tks <= 0:
continue
if tks > remaining_prev:
txt = trim_to_tokens(txt, remaining_prev, from_tail=True)
tks = num_tokens_from_string(txt)
prev_ctx.append(txt)
remaining_prev -= tks
prev_ctx.reverse()
next_ctx = []
remaining_next = token_budget
for next_idx in range(sorted_pos + 1, total):
if remaining_next <= 0:
break
neighbor_idx = ordered_indices[next_idx]
if not is_text_chunk(chunks[neighbor_idx]):
break
txt = get_text(chunks[neighbor_idx])
if not txt:
continue
tks = num_tokens_from_string(txt)
if tks <= 0:
continue
if tks > remaining_next:
txt = trim_to_tokens(txt, remaining_next, from_tail=False)
tks = num_tokens_from_string(txt)
next_ctx.append(txt)
remaining_next -= tks
media_bounds = get_bounds_by_page(ck)
best_idx = None
best_dist = None
candidate_count = 0
if media_bounds and text_bounds:
for text_idx, bounds in text_bounds:
for pn, (t_top, t_bottom) in bounds.items():
if pn not in media_bounds:
continue
m_top, m_bottom = media_bounds[pn]
if m_bottom < t_top or m_top > t_bottom:
continue
candidate_count += 1
m_mid = (m_top + m_bottom) / 2.0
t_mid = (t_top + t_bottom) / 2.0
dist = abs(m_mid - t_mid)
if best_dist is None or dist < best_dist:
best_dist = dist
best_idx = text_idx
if best_idx is None and media_bounds:
media_page = min(media_bounds.keys())
page_order = []
for ordered_idx in ordered_indices:
pn, _, _ = extract_position(chunks[ordered_idx])
if pn == media_page:
page_order.append(ordered_idx)
if page_order and idx in page_order:
pos_in_page = page_order.index(idx)
if pos_in_page == 0:
for neighbor in page_order[pos_in_page + 1:]:
if is_text_chunk(chunks[neighbor]):
best_idx = neighbor
break
elif pos_in_page == len(page_order) - 1:
for neighbor in reversed(page_order[:pos_in_page]):
if is_text_chunk(chunks[neighbor]):
best_idx = neighbor
break
if best_idx is not None:
base_text = get_text(chunks[best_idx])
sentences = split_sentences(base_text)
if sentences:
boundary_idx = find_mid_sentence_index(sentences)
prev_ctx, next_ctx = collect_context_from_sentences(sentences, boundary_idx, token_budget)
if not prev_ctx and not next_ctx:
continue