Feat: add context for figure and table (#11547)

### What problem does this PR solve?

Add context for figure table.



![demo_figure_table_context](https://github.com/user-attachments/assets/61b37fac-e22e-40a4-9665-9396c7b4103e)


`==================()` for demonstrating purpose. 
### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Yongteng Lei
2025-11-27 10:21:44 +08:00
committed by GitHub
parent 7c3c185038
commit 9d8b96c1d0
11 changed files with 373 additions and 74 deletions

View File

@ -318,6 +318,7 @@ def tokenize_table(tbls, doc, eng, batch_size=10):
d = copy.deepcopy(doc)
tokenize(d, rows, eng)
d["content_with_weight"] = rows
d["doc_type_kwd"] = "table"
if img:
d["image"] = img
d["doc_type_kwd"] = "image"
@ -330,6 +331,7 @@ def tokenize_table(tbls, doc, eng, batch_size=10):
d = copy.deepcopy(doc)
r = de.join(rows[i:i + batch_size])
tokenize(d, r, eng)
d["doc_type_kwd"] = "table"
if img:
d["image"] = img
d["doc_type_kwd"] = "image"
@ -338,6 +340,194 @@ def tokenize_table(tbls, doc, eng, batch_size=10):
return res
def attach_media_context(chunks, table_context_size=0, image_context_size=0):
"""
Attach surrounding text chunk content to media chunks (table/image).
Best-effort ordering: if positional info exists on any chunk, use it to
order chunks before collecting context; otherwise keep original order.
"""
if not chunks or (table_context_size <= 0 and image_context_size <= 0):
return chunks
def is_image_chunk(ck):
if ck.get("doc_type_kwd") == "image":
return True
text_val = ck.get("content_with_weight") if isinstance(ck.get("content_with_weight"), str) else ck.get("text")
has_text = isinstance(text_val, str) and text_val.strip()
return bool(ck.get("image")) and not has_text
def is_table_chunk(ck):
return ck.get("doc_type_kwd") == "table"
def is_text_chunk(ck):
return not is_image_chunk(ck) and not is_table_chunk(ck)
def get_text(ck):
if isinstance(ck.get("content_with_weight"), str):
return ck["content_with_weight"]
if isinstance(ck.get("text"), str):
return ck["text"]
return ""
def split_sentences(text):
pattern = r"([.。!?!?;:\n])"
parts = re.split(pattern, text)
sentences = []
buf = ""
for p in parts:
if not p:
continue
if re.fullmatch(pattern, p):
buf += p
sentences.append(buf)
buf = ""
else:
buf += p
if buf:
sentences.append(buf)
return sentences
def trim_to_tokens(text, token_budget, from_tail=False):
if token_budget <= 0 or not text:
return ""
sentences = split_sentences(text)
if not sentences:
return ""
collected = []
remaining = token_budget
seq = reversed(sentences) if from_tail else sentences
for s in seq:
tks = num_tokens_from_string(s)
if tks <= 0:
continue
if tks > remaining:
collected.append(s)
break
collected.append(s)
remaining -= tks
if from_tail:
collected = list(reversed(collected))
return "".join(collected)
def extract_position(ck):
pn = None
top = None
left = None
try:
if ck.get("page_num_int"):
pn = ck["page_num_int"][0]
elif ck.get("page_number") is not None:
pn = ck.get("page_number")
if ck.get("top_int"):
top = ck["top_int"][0]
elif ck.get("top") is not None:
top = ck.get("top")
if ck.get("position_int"):
left = ck["position_int"][0][1]
elif ck.get("x0") is not None:
left = ck.get("x0")
except Exception:
pn = top = left = None
return pn, top, left
indexed = list(enumerate(chunks))
positioned_indices = []
unpositioned_indices = []
for idx, ck in indexed:
pn, top, left = extract_position(ck)
if pn is not None and top is not None:
positioned_indices.append((idx, pn, top, left if left is not None else 0))
else:
unpositioned_indices.append(idx)
if positioned_indices:
positioned_indices.sort(key=lambda x: (int(x[1]), int(x[2]), int(x[3]), x[0]))
ordered_indices = [i for i, _, _, _ in positioned_indices] + unpositioned_indices
else:
ordered_indices = [idx for idx, _ in indexed]
total = len(ordered_indices)
for sorted_pos, idx in enumerate(ordered_indices):
ck = chunks[idx]
token_budget = image_context_size if is_image_chunk(ck) else table_context_size if is_table_chunk(ck) else 0
if token_budget <= 0:
continue
prev_ctx = []
remaining_prev = token_budget
for prev_idx in range(sorted_pos - 1, -1, -1):
if remaining_prev <= 0:
break
neighbor_idx = ordered_indices[prev_idx]
if not is_text_chunk(chunks[neighbor_idx]):
break
txt = get_text(chunks[neighbor_idx])
if not txt:
continue
tks = num_tokens_from_string(txt)
if tks <= 0:
continue
if tks > remaining_prev:
txt = trim_to_tokens(txt, remaining_prev, from_tail=True)
tks = num_tokens_from_string(txt)
prev_ctx.append(txt)
remaining_prev -= tks
prev_ctx.reverse()
next_ctx = []
remaining_next = token_budget
for next_idx in range(sorted_pos + 1, total):
if remaining_next <= 0:
break
neighbor_idx = ordered_indices[next_idx]
if not is_text_chunk(chunks[neighbor_idx]):
break
txt = get_text(chunks[neighbor_idx])
if not txt:
continue
tks = num_tokens_from_string(txt)
if tks <= 0:
continue
if tks > remaining_next:
txt = trim_to_tokens(txt, remaining_next, from_tail=False)
tks = num_tokens_from_string(txt)
next_ctx.append(txt)
remaining_next -= tks
if not prev_ctx and not next_ctx:
continue
self_text = get_text(ck)
pieces = [*prev_ctx]
if self_text:
pieces.append(self_text)
pieces.extend(next_ctx)
combined = "\n".join(pieces)
original = ck.get("content_with_weight")
if "content_with_weight" in ck:
ck["content_with_weight"] = combined
elif "text" in ck:
original = ck.get("text")
ck["text"] = combined
if combined != original:
if "content_ltks" in ck:
ck["content_ltks"] = rag_tokenizer.tokenize(combined)
if "content_sm_ltks" in ck:
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck.get("content_ltks", rag_tokenizer.tokenize(combined)))
if positioned_indices:
chunks[:] = [chunks[i] for i in ordered_indices]
return chunks
def add_positions(d, poss):
if not poss:
return