Feat: add vision LLM PDF parser (#6173)

### What problem does this PR solve?

Add vision LLM PDF parser

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
This commit is contained in:
Yongteng Lei
2025-03-18 14:52:20 +08:00
committed by GitHub
parent 897fe85b5c
commit 5cf610af40
7 changed files with 413 additions and 102 deletions

View File

@ -17,26 +17,27 @@
import logging
import os
import random
from timeit import default_timer as timer
import re
import sys
import threading
import trio
import xgboost as xgb
from copy import deepcopy
from io import BytesIO
import re
import pdfplumber
from PIL import Image
from timeit import default_timer as timer
import numpy as np
import pdfplumber
import trio
import xgboost as xgb
from huggingface_hub import snapshot_download
from PIL import Image
from pypdf import PdfReader as pdf2_read
from api import settings
from api.utils.file_utils import get_project_base_directory
from deepdoc.vision import OCR, Recognizer, LayoutRecognizer, TableStructureRecognizer
from deepdoc.vision import OCR, LayoutRecognizer, Recognizer, TableStructureRecognizer
from rag.app.picture import vision_llm_chunk as picture_vision_llm_chunk
from rag.nlp import rag_tokenizer
from copy import deepcopy
from huggingface_hub import snapshot_download
from rag.prompts import vision_llm_describe_prompt
from rag.settings import PARALLEL_DEVICES
LOCK_KEY_pdfplumber = "global_shared_lock_pdfplumber"
@ -45,7 +46,7 @@ if LOCK_KEY_pdfplumber not in sys.modules:
class RAGFlowPdfParser:
def __init__(self):
def __init__(self, **kwargs):
"""
If you have trouble downloading HuggingFace models, -_^ this might help!!
@ -57,12 +58,12 @@ class RAGFlowPdfParser:
^_-
"""
self.ocr = OCR()
self.parallel_limiter = None
if PARALLEL_DEVICES is not None and PARALLEL_DEVICES > 1:
self.parallel_limiter = [trio.CapacityLimiter(1) for _ in range(PARALLEL_DEVICES)]
if hasattr(self, "model_speciess"):
self.layouter = LayoutRecognizer("layout." + self.model_speciess)
else:
@ -106,7 +107,7 @@ class RAGFlowPdfParser:
def _y_dis(
self, a, b):
return (
b["top"] + b["bottom"] - a["top"] - a["bottom"]) / 2
b["top"] + b["bottom"] - a["top"] - a["bottom"]) / 2
def _match_proj(self, b):
proj_patt = [
@ -129,9 +130,9 @@ class RAGFlowPdfParser:
tks_down = rag_tokenizer.tokenize(down["text"][:LEN]).split()
tks_up = rag_tokenizer.tokenize(up["text"][-LEN:]).split()
tks_all = up["text"][-LEN:].strip() \
+ (" " if re.match(r"[a-zA-Z0-9]+",
up["text"][-1] + down["text"][0]) else "") \
+ down["text"][:LEN].strip()
+ (" " if re.match(r"[a-zA-Z0-9]+",
up["text"][-1] + down["text"][0]) else "") \
+ down["text"][:LEN].strip()
tks_all = rag_tokenizer.tokenize(tks_all).split()
fea = [
up.get("R", -1) == down.get("R", -1),
@ -153,7 +154,7 @@ class RAGFlowPdfParser:
True if re.search(r"[,][^。.]+$", up["text"]) else False,
True if re.search(r"[,][^。.]+$", up["text"]) else False,
True if re.search(r"[\(][^\)]+$", up["text"])
and re.search(r"[\)]", down["text"]) else False,
and re.search(r"[\)]", down["text"]) else False,
self._match_proj(down),
True if re.match(r"[A-Z]", down["text"]) else False,
True if re.match(r"[A-Z]", up["text"][-1]) else False,
@ -215,7 +216,7 @@ class RAGFlowPdfParser:
continue
for tb in tbls: # for table
left, top, right, bott = tb["x0"] - MARGIN, tb["top"] - MARGIN, \
tb["x1"] + MARGIN, tb["bottom"] + MARGIN
tb["x1"] + MARGIN, tb["bottom"] + MARGIN
left *= ZM
top *= ZM
right *= ZM
@ -309,7 +310,7 @@ class RAGFlowPdfParser:
"page_number": pagenum} for b, t in bxs if b[0][0] <= b[1][0] and b[0][1] <= b[-1][1]],
self.mean_height[-1] / 3
)
# merge chars in the same rect
for c in Recognizer.sort_Y_firstly(
chars, self.mean_height[pagenum - 1] // 4):
@ -457,7 +458,7 @@ class RAGFlowPdfParser:
b_["text"],
any(feats),
any(concatting_feats),
))
))
i += 1
continue
# merge up and down
@ -665,7 +666,7 @@ class RAGFlowPdfParser:
i += 1
continue
lout_no = str(self.boxes[i]["page_number"]) + \
"-" + str(self.boxes[i]["layoutno"])
"-" + str(self.boxes[i]["layoutno"])
if TableStructureRecognizer.is_caption(self.boxes[i]) or self.boxes[i]["layout_type"] in ["table caption",
"title",
"figure caption",
@ -968,7 +969,7 @@ class RAGFlowPdfParser:
fnm) if not binary else pdfplumber.open(BytesIO(binary))
total_page = len(pdf.pages)
pdf.close()
return total_page
return total_page
except Exception:
logging.exception("total_page_number")
@ -994,7 +995,7 @@ class RAGFlowPdfParser:
except Exception as e:
logging.warning(f"Failed to extract characters for pages {page_from}-{page_to}: {str(e)}")
self.page_chars = [[] for _ in range(page_to - page_from)] # If failed to extract, using empty list instead.
self.total_page = len(self.pdf.pages)
except Exception:
logging.exception("RAGFlowPdfParser __images__")
@ -1023,7 +1024,7 @@ class RAGFlowPdfParser:
logging.debug("Images converted.")
self.is_english = [re.search(r"[a-zA-Z0-9,/¸;:'\[\]\(\)!@#$%^&*\"?<>._-]{30,}", "".join(
random.choices([c["text"] for c in self.page_chars[i]], k=min(100, len(self.page_chars[i]))))) for i in
range(len(self.page_chars))]
range(len(self.page_chars))]
if sum([1 if e else 0 for e in self.is_english]) > len(
self.page_images) / 2:
self.is_english = True
@ -1036,7 +1037,7 @@ class RAGFlowPdfParser:
if chars[j]["text"] and chars[j + 1]["text"] \
and re.match(r"[0-9a-zA-Z,.:;!%]+", chars[j]["text"] + chars[j + 1]["text"]) \
and chars[j + 1]["x0"] - chars[j]["x1"] >= min(chars[j + 1]["width"],
chars[j]["width"]) / 2:
chars[j]["width"]) / 2:
chars[j]["text"] += " "
j += 1
@ -1045,7 +1046,7 @@ class RAGFlowPdfParser:
await trio.to_thread.run_sync(lambda: self.__ocr(i + 1, img, chars, zoomin, id))
else:
self.__ocr(i + 1, img, chars, zoomin, id)
if callback and i % 6 == 5:
callback(prog=(i + 1) * 0.6 / len(self.page_images), msg="")
@ -1060,14 +1061,14 @@ class RAGFlowPdfParser:
)
self.page_cum_height.append(img.size[1] / zoomin)
return chars
if self.parallel_limiter:
async with trio.open_nursery() as nursery:
for i, img in enumerate(self.page_images):
chars = __ocr_preprocess()
nursery.start_soon(__img_ocr, i, i % PARALLEL_DEVICES, img, chars,
self.parallel_limiter[i % PARALLEL_DEVICES])
self.parallel_limiter[i % PARALLEL_DEVICES])
await trio.sleep(0.1)
else:
for i, img in enumerate(self.page_images):
@ -1075,9 +1076,9 @@ class RAGFlowPdfParser:
await __img_ocr(i, 0, img, chars, None)
start = timer()
trio.run(__img_ocr_launcher)
logging.info(f"__images__ {len(self.page_images)} pages cost {timer() - start}s")
if not self.is_english and not any(
@ -1142,7 +1143,7 @@ class RAGFlowPdfParser:
self.page_images[pns[0]].crop((left * ZM, top * ZM,
right *
ZM, min(
bottom, self.page_images[pns[0]].size[1])
bottom, self.page_images[pns[0]].size[1])
))
)
if 0 < ii < len(poss) - 1:
@ -1240,5 +1241,52 @@ class PlainParser:
raise NotImplementedError
class VisionParser(RAGFlowPdfParser):
def __init__(self, vision_model, *args, **kwargs):
super().__init__(*args, **kwargs)
self.vision_model = vision_model
def __images__(self, fnm, zoomin=3, page_from=0, page_to=299, callback=None):
try:
with sys.modules[LOCK_KEY_pdfplumber]:
self.pdf = pdfplumber.open(fnm) if isinstance(
fnm, str) else pdfplumber.open(BytesIO(fnm))
self.page_images = [p.to_image(resolution=72 * zoomin).annotated for i, p in
enumerate(self.pdf.pages[page_from:page_to])]
self.total_page = len(self.pdf.pages)
except Exception:
self.page_images = None
self.total_page = 0
logging.exception("VisionParser __images__")
def __call__(self, filename, from_page=0, to_page=100000, **kwargs):
callback = kwargs.get("callback", lambda prog, msg: None)
self.__images__(fnm=filename, zoomin=3, page_from=from_page, page_to=to_page, **kwargs)
total_pdf_pages = self.total_page
start_page = max(0, from_page)
end_page = min(to_page, total_pdf_pages)
all_docs = []
for idx, img_binary in enumerate(self.page_images or []):
pdf_page_num = idx # 0-based
if pdf_page_num < start_page or pdf_page_num >= end_page:
continue
docs = picture_vision_llm_chunk(
binary=img_binary,
vision_model=self.vision_model,
prompt=vision_llm_describe_prompt(page=pdf_page_num+1),
callback=callback,
)
if docs:
all_docs.append(docs)
return [(doc, "") for doc in all_docs], []
if __name__ == "__main__":
pass