Initial commit
34
DeepSeek-OCR-master/DeepSeek-OCR-hf/run_dpsk_ocr.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
from transformers import AutoModel, AutoTokenizer
|
||||||
|
import torch
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
|
||||||
|
|
||||||
|
|
||||||
|
model_name = 'deepseek-ai/DeepSeek-OCR'
|
||||||
|
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
||||||
|
model = AutoModel.from_pretrained(model_name, _attn_implementation='flash_attention_2', trust_remote_code=True, use_safetensors=True)
|
||||||
|
model = model.eval().cuda().to(torch.bfloat16)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# prompt = "<image>\nFree OCR. "
|
||||||
|
prompt = "<image>\n<|grounding|>Convert the document to markdown. "
|
||||||
|
image_file = 'your_image.jpg'
|
||||||
|
output_path = 'your/output/dir'
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# infer(self, tokenizer, prompt='', image_file='', output_path = ' ', base_size = 1024, image_size = 640, crop_mode = True, test_compress = False, save_results = False):
|
||||||
|
|
||||||
|
# Tiny: base_size = 512, image_size = 512, crop_mode = False
|
||||||
|
# Small: base_size = 640, image_size = 640, crop_mode = False
|
||||||
|
# Base: base_size = 1024, image_size = 1024, crop_mode = False
|
||||||
|
# Large: base_size = 1280, image_size = 1280, crop_mode = False
|
||||||
|
|
||||||
|
# Gundam: base_size = 1024, image_size = 640, crop_mode = True
|
||||||
|
|
||||||
|
res = model.infer(tokenizer, prompt=prompt, image_file=image_file, output_path = output_path, base_size = 1024, image_size = 640, crop_mode=True, save_results = True, test_compress = True)
|
||||||
42
DeepSeek-OCR-master/DeepSeek-OCR-vllm/config.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
# TODO: change modes
|
||||||
|
# Tiny: base_size = 512, image_size = 512, crop_mode = False
|
||||||
|
# Small: base_size = 640, image_size = 640, crop_mode = False
|
||||||
|
# Base: base_size = 1024, image_size = 1024, crop_mode = False
|
||||||
|
# Large: base_size = 1280, image_size = 1280, crop_mode = False
|
||||||
|
# Gundam: base_size = 1024, image_size = 640, crop_mode = True
|
||||||
|
|
||||||
|
BASE_SIZE = 1024
|
||||||
|
IMAGE_SIZE = 640
|
||||||
|
CROP_MODE = True
|
||||||
|
MIN_CROPS= 2
|
||||||
|
MAX_CROPS= 6 # max:9; If your GPU memory is small, it is recommended to set it to 6.
|
||||||
|
MAX_CONCURRENCY = 100 # If you have limited GPU memory, lower the concurrency count.
|
||||||
|
NUM_WORKERS = 64 # image pre-process (resize/padding) workers
|
||||||
|
PRINT_NUM_VIS_TOKENS = False
|
||||||
|
SKIP_REPEAT = True
|
||||||
|
MODEL_PATH = 'deepseek-ai/DeepSeek-OCR' # change to your model path
|
||||||
|
|
||||||
|
# TODO: change INPUT_PATH
|
||||||
|
# .pdf: run_dpsk_ocr_pdf.py;
|
||||||
|
# .jpg, .png, .jpeg: run_dpsk_ocr_image.py;
|
||||||
|
# Omnidocbench images path: run_dpsk_ocr_eval_batch.py
|
||||||
|
|
||||||
|
INPUT_PATH = ''
|
||||||
|
OUTPUT_PATH = ''
|
||||||
|
|
||||||
|
PROMPT = '<image>\n<|grounding|>Convert the document to markdown.'
|
||||||
|
# PROMPT = '<image>\nFree OCR.'
|
||||||
|
# TODO commonly used prompts
|
||||||
|
# document: <image>\n<|grounding|>Convert the document to markdown.
|
||||||
|
# other image: <image>\n<|grounding|>OCR this image.
|
||||||
|
# without layouts: <image>\nFree OCR.
|
||||||
|
# figures in document: <image>\nParse the figure.
|
||||||
|
# general: <image>\nDescribe this image in detail.
|
||||||
|
# rec: <image>\nLocate <|ref|>xxxx<|/ref|> in the image.
|
||||||
|
# '先天下之忧而忧'
|
||||||
|
# .......
|
||||||
|
|
||||||
|
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
TOKENIZER = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
|
||||||
@ -0,0 +1,174 @@
|
|||||||
|
import torch.nn as nn
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import copy
|
||||||
|
|
||||||
|
|
||||||
|
class MlpProjector(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, cfg):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.cfg = cfg
|
||||||
|
|
||||||
|
if cfg.projector_type == "identity":
|
||||||
|
modules = nn.Identity()
|
||||||
|
|
||||||
|
elif cfg.projector_type == "linear":
|
||||||
|
modules = nn.Linear(cfg.input_dim, cfg.n_embed)
|
||||||
|
|
||||||
|
elif cfg.projector_type == "mlp_gelu":
|
||||||
|
mlp_depth = cfg.get("depth", 1)
|
||||||
|
modules = [nn.Linear(cfg.input_dim, cfg.n_embed)]
|
||||||
|
for _ in range(1, mlp_depth):
|
||||||
|
modules.append(nn.GELU())
|
||||||
|
modules.append(nn.Linear(cfg.n_embed, cfg.n_embed))
|
||||||
|
modules = nn.Sequential(*modules)
|
||||||
|
|
||||||
|
elif cfg.projector_type == "normlayer_downsample_mlp_gelu":
|
||||||
|
mlp_depth = cfg.get("depth", 1)
|
||||||
|
mlp_ratio = cfg.get("mlp_ratio", 1)
|
||||||
|
modules = [
|
||||||
|
nn.LayerNorm(cfg.input_dim * cfg.downsample_ratio * cfg.downsample_ratio),
|
||||||
|
nn.Linear(cfg.input_dim * cfg.downsample_ratio * cfg.downsample_ratio, cfg.n_embed * mlp_ratio)
|
||||||
|
]
|
||||||
|
for _ in range(1, mlp_depth - 1):
|
||||||
|
modules.append(nn.GELU())
|
||||||
|
modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed * mlp_ratio))
|
||||||
|
modules.append(nn.GELU())
|
||||||
|
modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed))
|
||||||
|
modules = nn.Sequential(*modules)
|
||||||
|
|
||||||
|
elif cfg.projector_type == "downsample_mlp_gelu":
|
||||||
|
mlp_depth = cfg.get("depth", 1)
|
||||||
|
mlp_ratio = cfg.get("mlp_ratio", 1)
|
||||||
|
modules = [nn.Linear(cfg.input_dim * cfg.downsample_ratio * cfg.downsample_ratio, cfg.n_embed * mlp_ratio)]
|
||||||
|
for _ in range(1, mlp_depth - 1):
|
||||||
|
modules.append(nn.GELU())
|
||||||
|
modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed * mlp_ratio))
|
||||||
|
modules.append(nn.GELU())
|
||||||
|
modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed))
|
||||||
|
modules = nn.Sequential(*modules)
|
||||||
|
|
||||||
|
elif cfg.projector_type == "low_high_hybrid_split_mlp_gelu":
|
||||||
|
mlp_depth = cfg.get("depth", 1)
|
||||||
|
self.high_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2)
|
||||||
|
self.low_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2)
|
||||||
|
|
||||||
|
modules = []
|
||||||
|
for _ in range(1, mlp_depth):
|
||||||
|
modules.append(nn.GELU())
|
||||||
|
modules.append(nn.Linear(cfg.n_embed, cfg.n_embed))
|
||||||
|
modules = nn.Sequential(*modules)
|
||||||
|
|
||||||
|
elif cfg.projector_type == "hybrid_split_feature_mlp_gelu":
|
||||||
|
mlp_depth = cfg.get("depth", 1)
|
||||||
|
channel_div = cfg.get("channel_div", 0.5)
|
||||||
|
self.high_up_proj = nn.Linear(cfg.input_dim[0], int(cfg.n_embed * channel_div))
|
||||||
|
self.low_up_proj = nn.Linear(cfg.input_dim[1], cfg.n_embed - int(cfg.n_embed * channel_div))
|
||||||
|
|
||||||
|
modules = []
|
||||||
|
for _ in range(1, mlp_depth):
|
||||||
|
modules.append(nn.GELU())
|
||||||
|
modules.append(nn.Linear(cfg.n_embed, cfg.n_embed))
|
||||||
|
modules = nn.Sequential(*modules)
|
||||||
|
|
||||||
|
elif cfg.projector_type == "low_high_split_mlp_gelu":
|
||||||
|
mlp_depth = cfg.get("depth", 1)
|
||||||
|
modules = []
|
||||||
|
for _ in range(1, mlp_depth):
|
||||||
|
modules.append(nn.GELU())
|
||||||
|
modules.append(nn.Linear(cfg.n_embed // 2, cfg.n_embed // 2))
|
||||||
|
modules = nn.Sequential(*modules)
|
||||||
|
self.high_layers = nn.Sequential(*modules)
|
||||||
|
self.low_layers = copy.deepcopy(modules)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown projector type: {cfg.projector_type}")
|
||||||
|
|
||||||
|
if cfg.get("token_pooling", False):
|
||||||
|
self.token_pooling_layer = nn.Linear(cfg.input_dim * 4, cfg.input_dim)
|
||||||
|
|
||||||
|
if cfg.get("conv_fusion_high_low_features", False):
|
||||||
|
self.fusion_layer = nn.Linear(cfg.input_dim, cfg.input_dim)
|
||||||
|
self.layers = modules
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.cfg.get("token_pooling", False):
|
||||||
|
batch_size, wxh, channels = x.shape
|
||||||
|
w = h = int(wxh**0.5)
|
||||||
|
x = x.view(batch_size, w, h, channels)
|
||||||
|
x = x.permute(0, 3, 1, 2)
|
||||||
|
# import ipdb; ipdb.set_trace()
|
||||||
|
patches = x.unfold(2, 2, 2).unfold(3, 2, 2)
|
||||||
|
batch_size, channels, h_patches, w_patches, _, _ = patches.size()
|
||||||
|
# 在通道维度上拼接
|
||||||
|
patches = patches.contiguous().view(batch_size, channels, h_patches * w_patches, -1)
|
||||||
|
|
||||||
|
# 通过线性层
|
||||||
|
patches = patches.permute(0, 2, 1, 3).contiguous()
|
||||||
|
patches = patches.view(batch_size, h_patches * w_patches, channels * 4)
|
||||||
|
|
||||||
|
x = self.token_pooling_layer(patches)
|
||||||
|
|
||||||
|
if self.cfg.get("conv_fusion_high_low_features", False):
|
||||||
|
x = self.fusion_layer(x[:, 0]) + x[:, 1]
|
||||||
|
|
||||||
|
if self.cfg.projector_type == 'low_high_hybrid_split_mlp_gelu':
|
||||||
|
high_x, low_x = x[0], x[1]
|
||||||
|
high_x = self.high_up_proj(high_x)
|
||||||
|
low_x = self.low_up_proj(low_x)
|
||||||
|
x = torch.concat([high_x, low_x], dim=-1)
|
||||||
|
|
||||||
|
if self.cfg.projector_type == 'hybrid_split_feature_mlp_gelu':
|
||||||
|
high_x = x[...,:self.cfg.input_dim[0]]
|
||||||
|
low_x = x[...,self.cfg.input_dim[0]:]
|
||||||
|
high_x = self.high_up_proj(high_x)
|
||||||
|
low_x = self.low_up_proj(low_x)
|
||||||
|
x = torch.concat([high_x, low_x], dim=-1)
|
||||||
|
|
||||||
|
if self.cfg.projector_type == 'low_high_split_mlp_gelu':
|
||||||
|
high_x, low_x = x[0], x[1]
|
||||||
|
high_x = self.high_layers(high_x)
|
||||||
|
low_x = self.low_layers(low_x)
|
||||||
|
x = torch.concat([high_x, low_x], dim=-1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
if self.cfg.projector_type == 'downsample_mlp_gelu' or self.cfg.projector_type == 'normlayer_downsample_mlp_gelu':
|
||||||
|
bs, hw, input_dim = x.shape
|
||||||
|
h = w = int((hw) ** 0.5)
|
||||||
|
|
||||||
|
"""compute padding"""
|
||||||
|
if h % self.cfg.downsample_ratio:
|
||||||
|
pad = self.cfg.downsample_ratio - h % self.cfg.downsample_ratio
|
||||||
|
else:
|
||||||
|
pad = 0
|
||||||
|
x = x.reshape(bs, h, w, input_dim)
|
||||||
|
if pad > 0:
|
||||||
|
x = F.pad(x, (0, 0, 0, pad, 0, pad), "constant", 0)
|
||||||
|
|
||||||
|
"""4 to 1 concat"""
|
||||||
|
x = x.permute(0, 3, 1, 2) # B, C, H, W
|
||||||
|
x = F.unfold(x, kernel_size=self.cfg.downsample_ratio, stride=self.cfg.downsample_ratio, padding=0) # B, C*4, HW // 4
|
||||||
|
x = x.permute(0, 2, 1)
|
||||||
|
|
||||||
|
return self.layers(x)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_flops_per_sample(cfg):
|
||||||
|
if cfg.projector_type == "linear":
|
||||||
|
fwd = 2 * cfg.input_dim * cfg.n_embed
|
||||||
|
|
||||||
|
elif "mlp_gelu" in cfg.projector_type :
|
||||||
|
mlp_depth = cfg.get("depth", 1)
|
||||||
|
downsample_ratio = cfg.get("downsample_ratio", 1)
|
||||||
|
input_dim = sum(cfg.input_dim) if isinstance(cfg.input_dim, list) else cfg.input_dim
|
||||||
|
input_dim = input_dim * downsample_ratio * downsample_ratio
|
||||||
|
fwd = 2 * input_dim * cfg.n_embed + (mlp_depth - 1) * 2 * cfg.n_embed * cfg.n_embed
|
||||||
|
else:
|
||||||
|
fwd = 0
|
||||||
|
|
||||||
|
return fwd * 3
|
||||||
|
|
||||||
|
|
||||||
504
DeepSeek-OCR-master/DeepSeek-OCR-vllm/deepencoder/clip_sdpa.py
Normal file
@ -0,0 +1,504 @@
|
|||||||
|
from contextlib import nullcontext
|
||||||
|
import math
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
# from megatron.model import LayerNorm
|
||||||
|
from easydict import EasyDict as adict
|
||||||
|
import torch
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from torch import nn
|
||||||
|
from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
|
||||||
|
# from optimus import flash_attn_func
|
||||||
|
# from megatron.core import tensor_parallel
|
||||||
|
# from megatron.core import parallel_state as mpu
|
||||||
|
# from megatron.core.utils import make_viewless_tensor, divide
|
||||||
|
# from megatron.model.fused_rms_norm import RMSNorm
|
||||||
|
# from megatron.model.transformer import (
|
||||||
|
# FlashSelfAttention,
|
||||||
|
# NoopTransformerLayer,
|
||||||
|
# _cfg_to_kwargs,
|
||||||
|
# )
|
||||||
|
# from megatron.model.enums import AttnMaskType, AttnType
|
||||||
|
# from megatron.model.fused_softmax import FusedScaleMaskSoftmax
|
||||||
|
# from megatron.model.utils import attention_mask_func
|
||||||
|
|
||||||
|
# from megatron.model.module import MegatronModule
|
||||||
|
|
||||||
|
# try:
|
||||||
|
# from einops import rearrange
|
||||||
|
# except ImportError:
|
||||||
|
# rearrange = None
|
||||||
|
|
||||||
|
# from flash_attn import flash_attn_varlen_func as flash_attn_unpadded_func
|
||||||
|
|
||||||
|
# try:
|
||||||
|
# # flash attention 2.x
|
||||||
|
# from flash_attn import flash_attn_varlen_func as flash_attn_unpadded_func
|
||||||
|
# except ImportError:
|
||||||
|
# try:
|
||||||
|
# # flash attention 1.x
|
||||||
|
# from flash_attn.flash_attn_interface import flash_attn_unpadded_func
|
||||||
|
# except ImportError:
|
||||||
|
# flash_attn_unpadded_func = None
|
||||||
|
|
||||||
|
# try:
|
||||||
|
# from flash_attn.flash_attn_interface import flash_attn_unpadded_relative_attention_bias_func
|
||||||
|
# except ImportError:
|
||||||
|
# flash_attn_unpadded_relative_attention_bias_func = None
|
||||||
|
|
||||||
|
# try:
|
||||||
|
# from flash_attn.flash_attn_interface import mask_flash_attn_unpadded_func
|
||||||
|
# except ImportError:
|
||||||
|
# mask_flash_attn_unpadded_func = None
|
||||||
|
|
||||||
|
|
||||||
|
class LayerNormfp32(torch.nn.LayerNorm):
|
||||||
|
"""Subclass torch's LayerNorm to handle fp16."""
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
orig_type = x.dtype
|
||||||
|
ret = super().forward(x.type(torch.float32))
|
||||||
|
return ret.type(orig_type)
|
||||||
|
|
||||||
|
|
||||||
|
def get_abs_pos(abs_pos, tgt_size):
|
||||||
|
# abs_pos: L, C
|
||||||
|
# tgt_size: M
|
||||||
|
# return: M, C
|
||||||
|
|
||||||
|
# print(tgt_size)
|
||||||
|
# print(abs_pos.shape)
|
||||||
|
# exit()
|
||||||
|
dim = abs_pos.size(-1)
|
||||||
|
# print(dim)
|
||||||
|
abs_pos_new = abs_pos.squeeze(0)
|
||||||
|
cls_token, old_pos_embed = abs_pos_new[:1], abs_pos_new[1:]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
src_size = int(math.sqrt(abs_pos_new.shape[0] - 1))
|
||||||
|
tgt_size = int(math.sqrt(tgt_size))
|
||||||
|
dtype = abs_pos.dtype
|
||||||
|
|
||||||
|
if src_size != tgt_size:
|
||||||
|
old_pos_embed = old_pos_embed.view(1, src_size, src_size, dim).permute(0, 3, 1,
|
||||||
|
2).contiguous()
|
||||||
|
old_pos_embed = old_pos_embed.to(torch.float32)
|
||||||
|
new_pos_embed = F.interpolate(
|
||||||
|
old_pos_embed,
|
||||||
|
size=(tgt_size, tgt_size),
|
||||||
|
mode='bicubic',
|
||||||
|
antialias=True,
|
||||||
|
align_corners=False,
|
||||||
|
).to(dtype)
|
||||||
|
new_pos_embed = new_pos_embed.permute(0, 2, 3, 1)
|
||||||
|
new_pos_embed = new_pos_embed.view(tgt_size * tgt_size, dim)
|
||||||
|
vision_pos_embed = torch.cat([cls_token, new_pos_embed], dim=0)
|
||||||
|
vision_pos_embed = vision_pos_embed.view(1, tgt_size * tgt_size + 1, dim)
|
||||||
|
return vision_pos_embed
|
||||||
|
else:
|
||||||
|
return abs_pos
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def quick_gelu(x):
|
||||||
|
return x * torch.sigmoid(1.702 * x)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPVisionEmbeddings(nn.Module):
|
||||||
|
def __init__(self, hidden_size=1024, image_size=224, patch_size=14, num_channels=3):
|
||||||
|
super().__init__()
|
||||||
|
self.embed_dim = hidden_size
|
||||||
|
self.image_size = image_size
|
||||||
|
self.patch_size = patch_size
|
||||||
|
|
||||||
|
self.class_embedding = torch.nn.Parameter(torch.randn(self.embed_dim))
|
||||||
|
|
||||||
|
self.patch_embedding = torch.nn.Conv2d(
|
||||||
|
in_channels=num_channels,
|
||||||
|
out_channels=self.embed_dim,
|
||||||
|
kernel_size=self.patch_size,
|
||||||
|
stride=self.patch_size,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.num_patches = (self.image_size // self.patch_size) ** 2
|
||||||
|
self.num_positions = self.num_patches + 1
|
||||||
|
self.position_embedding = torch.nn.Embedding(self.num_positions, self.embed_dim)
|
||||||
|
self.register_buffer(
|
||||||
|
"position_ids", torch.arange(self.num_positions).expand((1, -1))
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, pixel_values, patch_embeds):
|
||||||
|
batch_size = pixel_values.shape[0]
|
||||||
|
# patch_embeds = self.patch_embedding(
|
||||||
|
# pixel_values
|
||||||
|
# ) # shape = [*, width, grid, grid]
|
||||||
|
|
||||||
|
|
||||||
|
if patch_embeds is not None:
|
||||||
|
patch_embeds = patch_embeds
|
||||||
|
# print(patch_embeds.shape)
|
||||||
|
else:
|
||||||
|
patch_embeds = self.patch_embedding(pixel_values)
|
||||||
|
# print(111111)
|
||||||
|
# shape = [*, width, grid, grid]
|
||||||
|
# patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
||||||
|
|
||||||
|
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
||||||
|
|
||||||
|
|
||||||
|
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
|
||||||
|
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
||||||
|
|
||||||
|
# x = torch.cat([cls_token, x], dim=1)
|
||||||
|
embeddings = embeddings + get_abs_pos(self.position_embedding(self.position_ids), embeddings.size(1))
|
||||||
|
# embeddings = embeddings + self.position_embedding(self.position_ids)
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
|
class NoTPFeedForward(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
cfg,
|
||||||
|
dim: int,
|
||||||
|
hidden_dim: int,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.fc1 = torch.nn.Linear(dim, hidden_dim, bias=True)
|
||||||
|
self.fc2 = torch.nn.Linear(hidden_dim, dim, bias=True)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
output = self.fc2(quick_gelu(self.fc1(x)))
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
# from optimus.flash_attn_interface import flash_attn_qkvpacked_func
|
||||||
|
|
||||||
|
|
||||||
|
# class NoTPAttention(nn.Module):
|
||||||
|
# def __init__(self, cfg):
|
||||||
|
# super().__init__()
|
||||||
|
# self.num_heads = cfg.num_attention_heads
|
||||||
|
# self.n_local_heads = cfg.num_attention_heads
|
||||||
|
# self.head_dim = cfg.hidden_size // cfg.num_attention_heads
|
||||||
|
# self.max_seq_len = cfg.seq_length
|
||||||
|
# self.use_flash_attention = cfg.use_flash_attn
|
||||||
|
|
||||||
|
# self.qkv_proj = torch.nn.Linear(cfg.hidden_size, cfg.hidden_size * 3, bias=True)
|
||||||
|
# self.out_proj = torch.nn.Linear(cfg.hidden_size, cfg.hidden_size, bias=True)
|
||||||
|
|
||||||
|
# # self.core_attention = CoreAttention(cfg, AttnType.self_attn)
|
||||||
|
|
||||||
|
# self.attn_drop = cfg.attention_dropout
|
||||||
|
|
||||||
|
# def forward(
|
||||||
|
# self,
|
||||||
|
# x: torch.Tensor,
|
||||||
|
# ):
|
||||||
|
# bsz, seqlen, _ = x.shape
|
||||||
|
# xqkv = self.qkv_proj(x)
|
||||||
|
# xqkv = xqkv.view(bsz, seqlen, 3, self.num_heads, self.head_dim)
|
||||||
|
|
||||||
|
# if self.use_flash_attention:
|
||||||
|
# output = flash_attn_qkvpacked_func(xqkv)
|
||||||
|
# output = output.view(bsz, seqlen, -1)
|
||||||
|
# else:
|
||||||
|
# xq, xk, xv = torch.split(xqkv, 1, dim=2)
|
||||||
|
# xq = xq.squeeze(2)
|
||||||
|
# xk = xk.squeeze(2)
|
||||||
|
# xv = xv.squeeze(2)
|
||||||
|
# # xq, xk, xv = xqkv[:, :, 0, ...], xqkv[:, :, 1, ...], xqkv[:, :, 2, ...]
|
||||||
|
|
||||||
|
# # (B, num_head, S, head_size)
|
||||||
|
# xq = xq.permute(0, 2, 1, 3)
|
||||||
|
# xk = xk.permute(0, 2, 1, 3)
|
||||||
|
# xv = xv.permute(0, 2, 1, 3)
|
||||||
|
|
||||||
|
# output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None)
|
||||||
|
# utput = output.permute(0, 2, 1, 3).view(bsz, seqlen, -1)
|
||||||
|
# output = self.out_proj(output)
|
||||||
|
# return output
|
||||||
|
|
||||||
|
|
||||||
|
# from optimus.flash_attn_interface import flash_attn_qkvpacked_func
|
||||||
|
|
||||||
|
|
||||||
|
class NoTPAttention(torch.nn.Module):
|
||||||
|
def __init__(self, cfg):
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = cfg.num_attention_heads
|
||||||
|
self.n_local_heads = cfg.num_attention_heads
|
||||||
|
self.head_dim = cfg.hidden_size // cfg.num_attention_heads
|
||||||
|
self.max_seq_len = cfg.seq_length
|
||||||
|
self.use_flash_attention = cfg.use_flash_attn
|
||||||
|
|
||||||
|
self.qkv_proj = torch.nn.Linear(cfg.hidden_size, cfg.hidden_size * 3, bias=True)
|
||||||
|
self.out_proj = torch.nn.Linear(cfg.hidden_size, cfg.hidden_size, bias=True)
|
||||||
|
|
||||||
|
# self.core_attention = CoreAttention(cfg, AttnType.self_attn)
|
||||||
|
|
||||||
|
self.attn_drop = cfg.attention_dropout
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
):
|
||||||
|
bsz, seqlen, _ = x.shape
|
||||||
|
xqkv = self.qkv_proj(x)
|
||||||
|
xqkv = xqkv.view(bsz, seqlen, 3, self.num_heads, self.head_dim)
|
||||||
|
|
||||||
|
if self.use_flash_attention:
|
||||||
|
output = flash_attn_qkvpacked_func(xqkv)
|
||||||
|
output = output.view(bsz, seqlen, -1)
|
||||||
|
# xq, xk, xv = torch.split(xqkv, 1, dim=2)
|
||||||
|
# xq = xq.squeeze(2)
|
||||||
|
# xk = xk.squeeze(2)
|
||||||
|
# xv = xv.squeeze(2)
|
||||||
|
# # xq, xk, xv = xqkv[:, :, 0, ...], xqkv[:, :, 1, ...], xqkv[:, :, 2, ...]
|
||||||
|
|
||||||
|
# # (B, num_head, S, head_size)
|
||||||
|
# xq = xq.permute(0, 2, 1, 3)
|
||||||
|
# xk = xk.permute(0, 2, 1, 3)
|
||||||
|
# xv = xv.permute(0, 2, 1, 3)
|
||||||
|
# # with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
|
||||||
|
# output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None)
|
||||||
|
# output = output.permute(0, 2, 1, 3).reshape(bsz, seqlen, -1)
|
||||||
|
# output = output.permute(0, 2, 1, 3).contiguous().view(bsz, seqlen, -1)
|
||||||
|
else:
|
||||||
|
# output = flash_attn_qkvpacked_func(xqkv)
|
||||||
|
xq, xk, xv = torch.split(xqkv, 1, dim=2)
|
||||||
|
xq = xq.squeeze(2)
|
||||||
|
xk = xk.squeeze(2)
|
||||||
|
xv = xv.squeeze(2)
|
||||||
|
# xq, xk, xv = xqkv[:, :, 0, ...], xqkv[:, :, 1, ...], xqkv[:, :, 2, ...]
|
||||||
|
|
||||||
|
# (B, num_head, S, head_size)
|
||||||
|
xq = xq.permute(0, 2, 1, 3)
|
||||||
|
xk = xk.permute(0, 2, 1, 3)
|
||||||
|
xv = xv.permute(0, 2, 1, 3)
|
||||||
|
# with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
|
||||||
|
output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None)
|
||||||
|
output = output.permute(0, 2, 1, 3).reshape(bsz, seqlen, -1)
|
||||||
|
output = self.out_proj(output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
class NoTPTransformerBlock(nn.Module):
|
||||||
|
def __init__(self, cfg, layer_id: int, multiple_of=256):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.n_heads = cfg.num_attention_heads
|
||||||
|
self.dim = cfg.hidden_size
|
||||||
|
self.head_dim = cfg.hidden_size // cfg.num_attention_heads
|
||||||
|
self.self_attn = NoTPAttention(cfg)
|
||||||
|
self.mlp = NoTPFeedForward(
|
||||||
|
cfg, dim=cfg.hidden_size, hidden_dim=cfg.ffn_hidden_size
|
||||||
|
)
|
||||||
|
self.layer_id = layer_id
|
||||||
|
self.layer_norm1 = torch.nn.LayerNorm(
|
||||||
|
cfg.hidden_size, eps=cfg.layernorm_epsilon
|
||||||
|
)
|
||||||
|
self.layer_norm2 = torch.nn.LayerNorm(
|
||||||
|
cfg.hidden_size, eps=cfg.layernorm_epsilon
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
residual = self.self_attn.forward(self.layer_norm1(x))
|
||||||
|
h = x + residual
|
||||||
|
out = h + self.mlp.forward(self.layer_norm2(h))
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class NoTPTransformer(nn.Module):
|
||||||
|
def __init__(self, cfg):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.cfg = cfg
|
||||||
|
# self.recompute_list = self.cfg.get("recompute_list", [])
|
||||||
|
self.num_layers = cfg.num_layers # _get_num_layers(cfg)
|
||||||
|
|
||||||
|
self.layers = torch.nn.ModuleList()
|
||||||
|
for layer_id in range(self.num_layers):
|
||||||
|
self.layers.append(
|
||||||
|
NoTPTransformerBlock(
|
||||||
|
cfg,
|
||||||
|
layer_id + 1,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
):
|
||||||
|
|
||||||
|
for lid, layer in enumerate(self.layers):
|
||||||
|
# if lid in self.recompute_list:
|
||||||
|
# def custom(layer_id):
|
||||||
|
# def custom_forward(*args, **kwargs):
|
||||||
|
# x_ = self.layers[layer_id](*args, **kwargs)
|
||||||
|
# return x_
|
||||||
|
|
||||||
|
# return custom_forward
|
||||||
|
|
||||||
|
# assert hidden_states.requires_grad == True, logger.warning(
|
||||||
|
# "When using recalculation, the input must have grad fn"
|
||||||
|
# )
|
||||||
|
# hidden_states = tensor_parallel.checkpoint(
|
||||||
|
# custom(lid),
|
||||||
|
# False,
|
||||||
|
# hidden_states.contiguous()
|
||||||
|
# )
|
||||||
|
# else:
|
||||||
|
hidden_states = layer(hidden_states)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
# from megatron.core.tensor_parallel.layers import non_tensor_paralleled, local_dp_reduce, local_dp_scatter
|
||||||
|
|
||||||
|
class VitModel(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
cfg,
|
||||||
|
freeze_embed=False,
|
||||||
|
freeze_pre_norm=False
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.embeddings = CLIPVisionEmbeddings(hidden_size=cfg.hidden_size, image_size=cfg.image_size, patch_size=cfg.patch_size)
|
||||||
|
|
||||||
|
if freeze_embed:
|
||||||
|
for name, param in self.embeddings.named_parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
self.transformer = NoTPTransformer(cfg=cfg)
|
||||||
|
|
||||||
|
if cfg.get("fp32norm", False):
|
||||||
|
logger.info("Load fp32 layernorm for ViT.")
|
||||||
|
self.pre_layrnorm = LayerNormfp32(
|
||||||
|
cfg.hidden_size,
|
||||||
|
eps=cfg.get("pre_layernorm_epsilon", 1e-5),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.pre_layrnorm = torch.nn.LayerNorm(
|
||||||
|
cfg.hidden_size,
|
||||||
|
eps=cfg.get("pre_layernorm_epsilon", 1e-5),
|
||||||
|
)
|
||||||
|
|
||||||
|
# self.pre_layrnorm = RMSNorm(
|
||||||
|
# cfg.hidden_size,
|
||||||
|
# eps=cfg.get("pre_layernorm_epsilon", 1e-5),
|
||||||
|
# sequence_parallel=False,
|
||||||
|
# use_fp32=True,
|
||||||
|
# use_optimus=True,
|
||||||
|
# )
|
||||||
|
|
||||||
|
if freeze_pre_norm:
|
||||||
|
for name, param in self.pre_layrnorm.named_parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
for p in self.parameters():
|
||||||
|
p.micro_dp = True
|
||||||
|
|
||||||
|
def set_input_tensor(self, input_tensor):
|
||||||
|
if not isinstance(input_tensor, list):
|
||||||
|
input_tensor = [input_tensor]
|
||||||
|
self.transformer.set_input_tensor(input_tensor[0])
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return "open_clip"
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
patch_embeds
|
||||||
|
):
|
||||||
|
x = self.embeddings(x, patch_embeds)
|
||||||
|
hidden_states = self.pre_layrnorm(x)
|
||||||
|
|
||||||
|
# hidden_states, dis = local_dp_scatter(hidden_states)
|
||||||
|
output = self.transformer(hidden_states)
|
||||||
|
|
||||||
|
# output = local_dp_reduce(output, dis)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
vit_model_cfg = adict(
|
||||||
|
num_layers=24,
|
||||||
|
hidden_size=1024,
|
||||||
|
num_heads = 16,
|
||||||
|
num_attention_heads=16,
|
||||||
|
ffn_hidden_size=4096,
|
||||||
|
seq_length=256,
|
||||||
|
max_position_embeddings=256,
|
||||||
|
use_flash_attn=False,
|
||||||
|
understand_projector_stride=2,
|
||||||
|
hidden_dropout = 0.0,
|
||||||
|
attention_dropout = 0.0,
|
||||||
|
no_persist_layer_norm = False,
|
||||||
|
layernorm_epsilon = 1e-5,
|
||||||
|
pre_layernorm_epsilon = 1e-5,
|
||||||
|
image_size = 224,
|
||||||
|
patch_size = 14,
|
||||||
|
recompute_list = []
|
||||||
|
)
|
||||||
|
|
||||||
|
def build_clip_l():
|
||||||
|
return VitModel(
|
||||||
|
cfg=vit_model_cfg,
|
||||||
|
freeze_embed=False,
|
||||||
|
freeze_pre_norm=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
|
||||||
|
from mmgpt.model.vision_encoder.sam_b import build_sam_vit_b
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
vit_model_cfg = adict(
|
||||||
|
num_layers=24,
|
||||||
|
hidden_size=1024,
|
||||||
|
num_attention_heads=16,
|
||||||
|
ffn_hidden_size=4096,
|
||||||
|
seq_length=256,
|
||||||
|
max_position_embeddings=256,
|
||||||
|
use_flash_attn=False,
|
||||||
|
understand_projector_stride=2,
|
||||||
|
hidden_dropout = 0.0,
|
||||||
|
attention_dropout = 0.0,
|
||||||
|
no_persist_layer_norm = False,
|
||||||
|
layernorm_epsilon = 1e-5,
|
||||||
|
pre_layernorm_epsilon = 1e-5,
|
||||||
|
image_size = 224,
|
||||||
|
patch_size = 14,
|
||||||
|
recompute_list = []
|
||||||
|
)
|
||||||
|
|
||||||
|
sam_model = build_sam_vit_b()
|
||||||
|
|
||||||
|
|
||||||
|
vision_model = VitModel(
|
||||||
|
cfg=vit_model_cfg,
|
||||||
|
freeze_embed=False,
|
||||||
|
freeze_pre_norm=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# model = VitModel(1344)
|
||||||
|
# x = torch.zeros(2, 3, 224, 224)
|
||||||
|
x = torch.zeros(2, 3, 1024, 1024)
|
||||||
|
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
# y = vision_model(x)
|
||||||
|
patch_embed = sam_model(x)
|
||||||
|
print(patch_embed.shape)
|
||||||
|
y = vision_model(x, patch_embed)
|
||||||
|
print(y.shape)
|
||||||
|
|
||||||
|
image_feature = torch.add(y[:, 1:], patch_embed.flatten(2).permute(0, 2, 1))
|
||||||
|
|
||||||
|
print(image_feature.shape)
|
||||||
@ -0,0 +1,528 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from typing import Optional, Tuple, Type
|
||||||
|
from functools import partial
|
||||||
|
from flash_attn import flash_attn_qkvpacked_func
|
||||||
|
# from .common import LayerNorm2d, MLPBlock
|
||||||
|
|
||||||
|
# from mmgpt.model.vision_encoder.flash_4 import _attention_rel_h_rel_w
|
||||||
|
|
||||||
|
|
||||||
|
def get_abs_pos(abs_pos, tgt_size):
|
||||||
|
|
||||||
|
dtype = abs_pos.dtype
|
||||||
|
|
||||||
|
src_size = abs_pos.size(1)
|
||||||
|
|
||||||
|
if src_size != tgt_size:
|
||||||
|
old_pos_embed = abs_pos.permute(0, 3, 1, 2)
|
||||||
|
old_pos_embed = old_pos_embed.to(torch.float32)
|
||||||
|
new_pos_embed = F.interpolate(
|
||||||
|
old_pos_embed,
|
||||||
|
size=(tgt_size, tgt_size),
|
||||||
|
mode='bicubic',
|
||||||
|
antialias=True,
|
||||||
|
align_corners=False,
|
||||||
|
).to(dtype)
|
||||||
|
new_pos_embed = new_pos_embed.permute(0, 2, 3, 1)
|
||||||
|
return new_pos_embed
|
||||||
|
else:
|
||||||
|
return abs_pos
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class MLPBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding_dim: int,
|
||||||
|
mlp_dim: int,
|
||||||
|
act: Type[nn.Module] = nn.GELU,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.lin1 = nn.Linear(embedding_dim, mlp_dim)
|
||||||
|
self.lin2 = nn.Linear(mlp_dim, embedding_dim)
|
||||||
|
self.act = act()
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.lin2(self.act(self.lin1(x)))
|
||||||
|
|
||||||
|
|
||||||
|
# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
|
||||||
|
# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
|
||||||
|
class LayerNorm2d(nn.Module):
|
||||||
|
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.weight = nn.Parameter(torch.ones(num_channels))
|
||||||
|
self.bias = nn.Parameter(torch.zeros(num_channels))
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
u = x.mean(1, keepdim=True)
|
||||||
|
s = (x - u).pow(2).mean(1, keepdim=True)
|
||||||
|
x = (x - u) / torch.sqrt(s + self.eps)
|
||||||
|
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
|
||||||
|
class ImageEncoderViT(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
img_size: int = 1024,
|
||||||
|
patch_size: int = 16,
|
||||||
|
in_chans: int = 3,
|
||||||
|
embed_dim: int = 768,
|
||||||
|
depth: int = 12,
|
||||||
|
num_heads: int = 12,
|
||||||
|
mlp_ratio: float = 4.0,
|
||||||
|
out_chans: int = 256,
|
||||||
|
qkv_bias: bool = True,
|
||||||
|
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
||||||
|
act_layer: Type[nn.Module] = nn.GELU,
|
||||||
|
use_abs_pos: bool = True,
|
||||||
|
use_rel_pos: bool = False,
|
||||||
|
rel_pos_zero_init: bool = True,
|
||||||
|
window_size: int = 0,
|
||||||
|
global_attn_indexes: Tuple[int, ...] = (),
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
img_size (int): Input image size.
|
||||||
|
patch_size (int): Patch size.
|
||||||
|
in_chans (int): Number of input image channels.
|
||||||
|
embed_dim (int): Patch embedding dimension.
|
||||||
|
depth (int): Depth of ViT.
|
||||||
|
num_heads (int): Number of attention heads in each ViT block.
|
||||||
|
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
||||||
|
qkv_bias (bool): If True, add a learnable bias to query, key, value.
|
||||||
|
norm_layer (nn.Module): Normalization layer.
|
||||||
|
act_layer (nn.Module): Activation layer.
|
||||||
|
use_abs_pos (bool): If True, use absolute positional embeddings.
|
||||||
|
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
|
||||||
|
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
|
||||||
|
window_size (int): Window size for window attention blocks.
|
||||||
|
global_attn_indexes (list): Indexes for blocks using global attention.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.img_size = img_size
|
||||||
|
|
||||||
|
self.patch_embed = PatchEmbed(
|
||||||
|
kernel_size=(patch_size, patch_size),
|
||||||
|
stride=(patch_size, patch_size),
|
||||||
|
in_chans=in_chans,
|
||||||
|
embed_dim=embed_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.pos_embed: Optional[nn.Parameter] = None
|
||||||
|
if use_abs_pos:
|
||||||
|
# Initialize absolute positional embedding with pretrain image size.
|
||||||
|
self.pos_embed = nn.Parameter(
|
||||||
|
torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.blocks = nn.ModuleList()
|
||||||
|
for i in range(depth):
|
||||||
|
block = Block(
|
||||||
|
dim=embed_dim,
|
||||||
|
num_heads=num_heads,
|
||||||
|
mlp_ratio=mlp_ratio,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
act_layer=act_layer,
|
||||||
|
use_rel_pos=use_rel_pos,
|
||||||
|
rel_pos_zero_init=rel_pos_zero_init,
|
||||||
|
window_size=window_size if i not in global_attn_indexes else 0,
|
||||||
|
input_size=(img_size // patch_size, img_size // patch_size),
|
||||||
|
)
|
||||||
|
self.blocks.append(block)
|
||||||
|
|
||||||
|
self.neck = nn.Sequential(
|
||||||
|
nn.Conv2d(
|
||||||
|
embed_dim,
|
||||||
|
out_chans,
|
||||||
|
kernel_size=1,
|
||||||
|
bias=False,
|
||||||
|
),
|
||||||
|
LayerNorm2d(out_chans),
|
||||||
|
nn.Conv2d(
|
||||||
|
out_chans,
|
||||||
|
out_chans,
|
||||||
|
kernel_size=3,
|
||||||
|
padding=1,
|
||||||
|
bias=False,
|
||||||
|
),
|
||||||
|
LayerNorm2d(out_chans),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.net_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False)
|
||||||
|
self.net_3 = nn.Conv2d(512, 1024, kernel_size=3, stride=2, padding=1, bias=False)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
x = self.patch_embed(x)
|
||||||
|
if self.pos_embed is not None:
|
||||||
|
# x = x + self.pos_embed
|
||||||
|
x = x + get_abs_pos(self.pos_embed, x.size(1))
|
||||||
|
|
||||||
|
for blk in self.blocks:
|
||||||
|
x = blk(x)
|
||||||
|
|
||||||
|
neck_output = self.neck(x.permute(0, 3, 1, 2))
|
||||||
|
conv2_output = self.net_2(neck_output)
|
||||||
|
# print(f"conv2_output shape: {conv2_output.shape}")
|
||||||
|
conv3_output = self.net_3(conv2_output)
|
||||||
|
|
||||||
|
return conv3_output
|
||||||
|
|
||||||
|
|
||||||
|
class Block(nn.Module):
|
||||||
|
"""Transformer blocks with support of window attention and residual propagation blocks"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
num_heads: int,
|
||||||
|
mlp_ratio: float = 4.0,
|
||||||
|
qkv_bias: bool = True,
|
||||||
|
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
||||||
|
act_layer: Type[nn.Module] = nn.GELU,
|
||||||
|
use_rel_pos: bool = False,
|
||||||
|
rel_pos_zero_init: bool = True,
|
||||||
|
window_size: int = 0,
|
||||||
|
input_size: Optional[Tuple[int, int]] = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
dim (int): Number of input channels.
|
||||||
|
num_heads (int): Number of attention heads in each ViT block.
|
||||||
|
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
||||||
|
qkv_bias (bool): If True, add a learnable bias to query, key, value.
|
||||||
|
norm_layer (nn.Module): Normalization layer.
|
||||||
|
act_layer (nn.Module): Activation layer.
|
||||||
|
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
|
||||||
|
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
|
||||||
|
window_size (int): Window size for window attention blocks. If it equals 0, then
|
||||||
|
use global attention.
|
||||||
|
input_size (tuple(int, int) or None): Input resolution for calculating the relative
|
||||||
|
positional parameter size.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.norm1 = norm_layer(dim)
|
||||||
|
self.attn = Attention(
|
||||||
|
dim,
|
||||||
|
num_heads=num_heads,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
use_rel_pos=use_rel_pos,
|
||||||
|
rel_pos_zero_init=rel_pos_zero_init,
|
||||||
|
input_size=input_size if window_size == 0 else (window_size, window_size),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.norm2 = norm_layer(dim)
|
||||||
|
self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
|
||||||
|
|
||||||
|
self.window_size = window_size
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
shortcut = x
|
||||||
|
x = self.norm1(x)
|
||||||
|
# Window partition
|
||||||
|
if self.window_size > 0:
|
||||||
|
H, W = x.shape[1], x.shape[2]
|
||||||
|
x, pad_hw = window_partition(x, self.window_size)
|
||||||
|
|
||||||
|
x = self.attn(x)
|
||||||
|
# Reverse window partition
|
||||||
|
if self.window_size > 0:
|
||||||
|
x = window_unpartition(x, self.window_size, pad_hw, (H, W))
|
||||||
|
|
||||||
|
x = shortcut + x
|
||||||
|
x = x + self.mlp(self.norm2(x))
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
"""Multi-head Attention block with relative position embeddings."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
num_heads: int = 8,
|
||||||
|
qkv_bias: bool = True,
|
||||||
|
use_rel_pos: bool = False,
|
||||||
|
rel_pos_zero_init: bool = True,
|
||||||
|
input_size: Optional[Tuple[int, int]] = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
dim (int): Number of input channels.
|
||||||
|
num_heads (int): Number of attention heads.
|
||||||
|
qkv_bias (bool): If True, add a learnable bias to query, key, value.
|
||||||
|
rel_pos (bool): If True, add relative positional embeddings to the attention map.
|
||||||
|
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
|
||||||
|
input_size (tuple(int, int) or None): Input resolution for calculating the relative
|
||||||
|
positional parameter size.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = num_heads
|
||||||
|
head_dim = dim // num_heads
|
||||||
|
self.scale = head_dim**-0.5
|
||||||
|
|
||||||
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||||
|
self.proj = nn.Linear(dim, dim)
|
||||||
|
|
||||||
|
self.use_rel_pos = use_rel_pos
|
||||||
|
if self.use_rel_pos:
|
||||||
|
assert (
|
||||||
|
input_size is not None
|
||||||
|
), "Input size must be provided if using relative positional encoding."
|
||||||
|
# initialize relative positional embeddings
|
||||||
|
self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
|
||||||
|
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
B, H, W, _ = x.shape
|
||||||
|
# qkv with shape (3, B, nHead, H * W, C)
|
||||||
|
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||||
|
# q, k, v with shape (B * nHead, H * W, C)
|
||||||
|
q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
|
||||||
|
|
||||||
|
rel_h, rel_w = None, None
|
||||||
|
if self.use_rel_pos:
|
||||||
|
rel_h, rel_w = add_decomposed_rel_pos(q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
|
||||||
|
|
||||||
|
q = q.view(B, self.num_heads, H * W, -1)
|
||||||
|
k = k.view(B, self.num_heads, H * W, -1)
|
||||||
|
v = v.view(B, self.num_heads, H * W, -1)
|
||||||
|
|
||||||
|
if self.use_rel_pos:
|
||||||
|
rel_h = rel_h.view(B, self.num_heads, rel_h.size(1), rel_h.size(2), rel_h.size(3))
|
||||||
|
rel_w = rel_w.view(B, self.num_heads, rel_w.size(1), rel_w.size(2), rel_w.size(3))
|
||||||
|
attn_bias = (rel_h + rel_w).view(B, self.num_heads, rel_h.size(2), rel_h.size(3) * rel_w.size(4))
|
||||||
|
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_bias)
|
||||||
|
# x = _attention_rel_h_rel_w(q, k, v, rel_h, rel_w)
|
||||||
|
else:
|
||||||
|
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
||||||
|
# qkv = torch.stack([q, k, v], dim=1).transpose(1, 3).reshape(B, H * W, 3, self.num_heads, -1)
|
||||||
|
# x = flash_attn_qkvpacked_func(qkv, dropout_p=0.0, causal=False).transpose(1, 2)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
x = x.view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
|
||||||
|
|
||||||
|
x = self.proj(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
|
||||||
|
"""
|
||||||
|
Partition into non-overlapping windows with padding if needed.
|
||||||
|
Args:
|
||||||
|
x (tensor): input tokens with [B, H, W, C].
|
||||||
|
window_size (int): window size.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
windows: windows after partition with [B * num_windows, window_size, window_size, C].
|
||||||
|
(Hp, Wp): padded height and width before partition
|
||||||
|
"""
|
||||||
|
B, H, W, C = x.shape
|
||||||
|
|
||||||
|
pad_h = (window_size - H % window_size) % window_size
|
||||||
|
pad_w = (window_size - W % window_size) % window_size
|
||||||
|
if pad_h > 0 or pad_w > 0:
|
||||||
|
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
|
||||||
|
Hp, Wp = H + pad_h, W + pad_w
|
||||||
|
|
||||||
|
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
|
||||||
|
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
||||||
|
return windows, (Hp, Wp)
|
||||||
|
|
||||||
|
|
||||||
|
def window_unpartition(
|
||||||
|
windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Window unpartition into original sequences and removing padding.
|
||||||
|
Args:
|
||||||
|
windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
|
||||||
|
window_size (int): window size.
|
||||||
|
pad_hw (Tuple): padded height and width (Hp, Wp).
|
||||||
|
hw (Tuple): original height and width (H, W) before padding.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
x: unpartitioned sequences with [B, H, W, C].
|
||||||
|
"""
|
||||||
|
Hp, Wp = pad_hw
|
||||||
|
H, W = hw
|
||||||
|
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
|
||||||
|
x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
|
||||||
|
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
|
||||||
|
|
||||||
|
if Hp > H or Wp > W:
|
||||||
|
x = x[:, :H, :W, :].contiguous()
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Get relative positional embeddings according to the relative positions of
|
||||||
|
query and key sizes.
|
||||||
|
Args:
|
||||||
|
q_size (int): size of query q.
|
||||||
|
k_size (int): size of key k.
|
||||||
|
rel_pos (Tensor): relative position embeddings (L, C).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Extracted positional embeddings according to relative positions.
|
||||||
|
"""
|
||||||
|
max_rel_dist = int(2 * max(q_size, k_size) - 1)
|
||||||
|
# Interpolate rel pos if needed.
|
||||||
|
if rel_pos.shape[0] != max_rel_dist:
|
||||||
|
# Interpolate rel pos.
|
||||||
|
dtype = rel_pos.dtype
|
||||||
|
rel_pos = rel_pos.to(torch.float32)
|
||||||
|
rel_pos_resized = F.interpolate(
|
||||||
|
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
|
||||||
|
size=max_rel_dist,
|
||||||
|
mode="linear",
|
||||||
|
).to(dtype)
|
||||||
|
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
|
||||||
|
else:
|
||||||
|
rel_pos_resized = rel_pos
|
||||||
|
|
||||||
|
# Scale the coords with short length if shapes for q and k are different.
|
||||||
|
q_coords = torch.arange(q_size, device=rel_pos.device)[:, None] * max(k_size / q_size, 1.0)
|
||||||
|
k_coords = torch.arange(k_size, device=rel_pos.device)[None, :] * max(q_size / k_size, 1.0)
|
||||||
|
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
|
||||||
|
|
||||||
|
return rel_pos_resized[relative_coords.long()]
|
||||||
|
|
||||||
|
|
||||||
|
def add_decomposed_rel_pos(
|
||||||
|
q: torch.Tensor,
|
||||||
|
rel_pos_h: torch.Tensor,
|
||||||
|
rel_pos_w: torch.Tensor,
|
||||||
|
q_size: Tuple[int, int],
|
||||||
|
k_size: Tuple[int, int],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
|
||||||
|
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
|
||||||
|
Args:
|
||||||
|
q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
|
||||||
|
rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
|
||||||
|
rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
|
||||||
|
q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
|
||||||
|
k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
attn (Tensor): attention map with added relative positional embeddings.
|
||||||
|
"""
|
||||||
|
q_h, q_w = q_size
|
||||||
|
k_h, k_w = k_size
|
||||||
|
Rh = get_rel_pos(q_h, k_h, rel_pos_h)
|
||||||
|
Rw = get_rel_pos(q_w, k_w, rel_pos_w)
|
||||||
|
|
||||||
|
B, _, dim = q.shape
|
||||||
|
r_q = q.reshape(B, q_h, q_w, dim)
|
||||||
|
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
|
||||||
|
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
|
||||||
|
rel_h = rel_h.unsqueeze(-1)
|
||||||
|
rel_w = rel_w.unsqueeze(-2)
|
||||||
|
rel_h = rel_h.reshape(B, q_h * q_w, k_h, 1)
|
||||||
|
rel_w = rel_w.reshape(B, q_h * q_w, 1, k_w)
|
||||||
|
|
||||||
|
return rel_h, rel_w
|
||||||
|
|
||||||
|
|
||||||
|
class PatchEmbed(nn.Module):
|
||||||
|
"""
|
||||||
|
Image to Patch Embedding.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
kernel_size: Tuple[int, int] = (16, 16),
|
||||||
|
stride: Tuple[int, int] = (16, 16),
|
||||||
|
padding: Tuple[int, int] = (0, 0),
|
||||||
|
in_chans: int = 3,
|
||||||
|
embed_dim: int = 768,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
kernel_size (Tuple): kernel size of the projection layer.
|
||||||
|
stride (Tuple): stride of the projection layer.
|
||||||
|
padding (Tuple): padding size of the projection layer.
|
||||||
|
in_chans (int): Number of input image channels.
|
||||||
|
embed_dim (int): Patch embedding dimension.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.proj = nn.Conv2d(
|
||||||
|
in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
x = self.proj(x)
|
||||||
|
# B C H W -> B H W C
|
||||||
|
x = x.permute(0, 2, 3, 1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def build_sam_vit_b(checkpoint=None):
|
||||||
|
return _build_sam(
|
||||||
|
encoder_embed_dim=768,
|
||||||
|
encoder_depth=12,
|
||||||
|
encoder_num_heads=12,
|
||||||
|
encoder_global_attn_indexes=[2, 5, 8, 11],
|
||||||
|
checkpoint=checkpoint,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_sam(
|
||||||
|
encoder_embed_dim,
|
||||||
|
encoder_depth,
|
||||||
|
encoder_num_heads,
|
||||||
|
encoder_global_attn_indexes,
|
||||||
|
checkpoint=None,
|
||||||
|
):
|
||||||
|
prompt_embed_dim = 256
|
||||||
|
image_size = 1024
|
||||||
|
vit_patch_size = 16
|
||||||
|
image_embedding_size = image_size // vit_patch_size
|
||||||
|
image_encoder=ImageEncoderViT(
|
||||||
|
depth=encoder_depth,
|
||||||
|
embed_dim=encoder_embed_dim,
|
||||||
|
img_size=image_size,
|
||||||
|
mlp_ratio=4,
|
||||||
|
norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
|
||||||
|
num_heads=encoder_num_heads,
|
||||||
|
patch_size=vit_patch_size,
|
||||||
|
qkv_bias=True,
|
||||||
|
use_rel_pos=True,
|
||||||
|
global_attn_indexes=encoder_global_attn_indexes,
|
||||||
|
window_size=14,
|
||||||
|
out_chans=prompt_embed_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
if checkpoint is not None:
|
||||||
|
# with open(checkpoint, "rb") as f:
|
||||||
|
state_dict = torch.load(checkpoint)
|
||||||
|
# print(state_dict.keys())
|
||||||
|
# for key in state_dict:
|
||||||
|
# image_encoder.load_state_dict({k[14:]: v for k, v in state_dict.items() if 'image_encoder' in k}, strict=False)
|
||||||
|
# ocr-anyting
|
||||||
|
# image_encoder.load_state_dict(state_dict, strict=True)
|
||||||
|
# tob
|
||||||
|
image_encoder.load_state_dict({k[30:]: v for k, v in state_dict.items() if 'vision_tower_high' in k}, strict=True)
|
||||||
|
print(checkpoint)
|
||||||
|
return image_encoder
|
||||||
582
DeepSeek-OCR-master/DeepSeek-OCR-vllm/deepseek_ocr.py
Normal file
@ -0,0 +1,582 @@
|
|||||||
|
|
||||||
|
"""Inference-only Deepseek-OCR model compatible with HuggingFace weights."""
|
||||||
|
import math
|
||||||
|
from collections.abc import Iterable, Mapping, Sequence
|
||||||
|
from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
from transformers import BatchFeature
|
||||||
|
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.model_executor import SamplingMetadata
|
||||||
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
|
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
||||||
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
|
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||||
|
MultiModalKwargs, NestedTensors)
|
||||||
|
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
|
||||||
|
ImageSize, MultiModalDataItems)
|
||||||
|
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||||
|
BaseProcessingInfo, PromptReplacement,
|
||||||
|
PromptUpdate)
|
||||||
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||||
|
from vllm.sequence import IntermediateTensors
|
||||||
|
from vllm.transformers_utils.configs.deepseek_vl2 import (DeepseekVLV2Config,
|
||||||
|
MlpProjectorConfig,
|
||||||
|
VisionEncoderConfig)
|
||||||
|
from process.image_process import (
|
||||||
|
DeepseekOCRProcessor, count_tiles)
|
||||||
|
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
|
||||||
|
# from vllm.utils import is_list_of
|
||||||
|
|
||||||
|
from vllm.model_executor.models.interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||||
|
from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||||
|
init_vllm_registered_model, maybe_prefix,
|
||||||
|
merge_multimodal_embeddings)
|
||||||
|
|
||||||
|
from deepencoder.sam_vary_sdpa import build_sam_vit_b
|
||||||
|
from deepencoder.clip_sdpa import build_clip_l
|
||||||
|
from deepencoder.build_linear import MlpProjector
|
||||||
|
from addict import Dict
|
||||||
|
# import time
|
||||||
|
from config import IMAGE_SIZE, BASE_SIZE, CROP_MODE, PRINT_NUM_VIS_TOKENS, PROMPT
|
||||||
|
# The image token id may be various
|
||||||
|
_IMAGE_TOKEN = "<image>"
|
||||||
|
|
||||||
|
|
||||||
|
class DeepseekOCRProcessingInfo(BaseProcessingInfo):
|
||||||
|
|
||||||
|
def get_hf_config(self):
|
||||||
|
return self.ctx.get_hf_config(DeepseekVLV2Config)
|
||||||
|
|
||||||
|
def get_hf_processor(self, **kwargs: object):
|
||||||
|
return self.ctx.get_hf_processor(DeepseekOCRProcessor, **kwargs)
|
||||||
|
|
||||||
|
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||||
|
return {"image": None}
|
||||||
|
|
||||||
|
def get_num_image_tokens(self,
|
||||||
|
*,
|
||||||
|
image_width: int,
|
||||||
|
image_height: int,
|
||||||
|
cropping: bool = True) -> int:
|
||||||
|
hf_processor = self.get_hf_processor()
|
||||||
|
|
||||||
|
|
||||||
|
# image_size = hf_processor.image_size
|
||||||
|
# patch_size = hf_processor.patch_size
|
||||||
|
# downsample_ratio = hf_processor.downsample_ratio
|
||||||
|
|
||||||
|
image_size = IMAGE_SIZE
|
||||||
|
base_size = BASE_SIZE
|
||||||
|
patch_size = 16
|
||||||
|
downsample_ratio = 4
|
||||||
|
|
||||||
|
if CROP_MODE:
|
||||||
|
if image_width <= 640 and image_height <= 640:
|
||||||
|
crop_ratio = [1, 1]
|
||||||
|
else:
|
||||||
|
# images_crop_raw, crop_ratio = hf_processor.dynamic_preprocess(image)
|
||||||
|
|
||||||
|
# find the closest aspect ratio to the target
|
||||||
|
crop_ratio = count_tiles(image_width, image_height, image_size=IMAGE_SIZE)
|
||||||
|
|
||||||
|
# print('===========')
|
||||||
|
# print('crop_ratio ', crop_ratio)
|
||||||
|
# print('============')
|
||||||
|
|
||||||
|
num_width_tiles, num_height_tiles = crop_ratio
|
||||||
|
else:
|
||||||
|
num_width_tiles = num_height_tiles = 1
|
||||||
|
|
||||||
|
h = w = math.ceil((base_size // patch_size) / downsample_ratio)
|
||||||
|
|
||||||
|
h2 = w2 = math.ceil((image_size // patch_size) / downsample_ratio)
|
||||||
|
|
||||||
|
global_views_tokens = h * (w + 1)
|
||||||
|
if num_width_tiles >1 or num_height_tiles>1:
|
||||||
|
local_views_tokens = (num_height_tiles * h2) * (num_width_tiles * w2 + 1)
|
||||||
|
else:
|
||||||
|
local_views_tokens = 0
|
||||||
|
|
||||||
|
|
||||||
|
return global_views_tokens + local_views_tokens + 1
|
||||||
|
|
||||||
|
def get_image_size_with_most_features(self) -> ImageSize:
|
||||||
|
|
||||||
|
if IMAGE_SIZE == 1024 and BASE_SIZE == 1280:
|
||||||
|
return ImageSize(width=1024*2, height=1024*2)
|
||||||
|
return ImageSize(width=640*2, height=640*2)
|
||||||
|
|
||||||
|
|
||||||
|
class DeepseekOCRDummyInputsBuilder(
|
||||||
|
BaseDummyInputsBuilder[DeepseekOCRProcessingInfo]):
|
||||||
|
|
||||||
|
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
||||||
|
num_images = mm_counts.get("image", 0)
|
||||||
|
|
||||||
|
processor = self.info.get_hf_processor()
|
||||||
|
image_token = processor.image_token
|
||||||
|
|
||||||
|
return image_token * num_images
|
||||||
|
|
||||||
|
def get_dummy_mm_data(
|
||||||
|
self,
|
||||||
|
seq_len: int,
|
||||||
|
mm_counts: Mapping[str, int],
|
||||||
|
) -> MultiModalDataDict:
|
||||||
|
num_images = mm_counts.get("image", 0)
|
||||||
|
|
||||||
|
max_image_size = self.info.get_image_size_with_most_features()
|
||||||
|
|
||||||
|
if '<image>' in PROMPT:
|
||||||
|
return {
|
||||||
|
"image":
|
||||||
|
DeepseekOCRProcessor().tokenize_with_images(images = self._get_dummy_images(width=max_image_size.width,
|
||||||
|
height=max_image_size.height,
|
||||||
|
num_images=num_images), bos=True, eos=True, cropping=CROP_MODE)
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"image": []
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class DeepseekOCRMultiModalProcessor(
|
||||||
|
BaseMultiModalProcessor[DeepseekOCRProcessingInfo]):
|
||||||
|
|
||||||
|
|
||||||
|
def _call_hf_processor(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
mm_data: Mapping[str, object],
|
||||||
|
mm_kwargs: Mapping[str, object],
|
||||||
|
) -> BatchFeature:
|
||||||
|
|
||||||
|
|
||||||
|
# print(mm_data)
|
||||||
|
if mm_data:
|
||||||
|
processed_outputs = self.info.ctx.call_hf_processor(
|
||||||
|
self.info.get_hf_processor(**mm_kwargs),
|
||||||
|
dict(prompt=prompt, **mm_data),
|
||||||
|
mm_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
tokenizer = self.info.get_tokenizer()
|
||||||
|
processed_outputs = tokenizer(prompt,
|
||||||
|
add_special_tokens=True,
|
||||||
|
return_tensors="pt")
|
||||||
|
|
||||||
|
return processed_outputs
|
||||||
|
|
||||||
|
def _get_mm_fields_config(
|
||||||
|
self,
|
||||||
|
hf_inputs: BatchFeature,
|
||||||
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
|
) -> Mapping[str, MultiModalFieldConfig]:
|
||||||
|
return dict(
|
||||||
|
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||||
|
images_spatial_crop=MultiModalFieldConfig.batched("image"),
|
||||||
|
# image_embeds=MultiModalFieldConfig.batched("image2"),
|
||||||
|
images_crop=MultiModalFieldConfig.batched("image"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_prompt_updates(
|
||||||
|
self,
|
||||||
|
mm_items: MultiModalDataItems,
|
||||||
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
|
out_mm_kwargs: MultiModalKwargs,
|
||||||
|
) -> Sequence[PromptUpdate]:
|
||||||
|
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||||
|
|
||||||
|
image_token_id = hf_processor.image_token_id
|
||||||
|
assert isinstance(image_token_id, int)
|
||||||
|
|
||||||
|
def get_replacement_deepseek_vl2(item_idx: int):
|
||||||
|
images = mm_items.get_items(
|
||||||
|
"image", (ImageEmbeddingItems, ImageProcessorItems))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if isinstance(images, ImageEmbeddingItems):
|
||||||
|
num_image_tokens = images.get_feature_size(item_idx)
|
||||||
|
else:
|
||||||
|
|
||||||
|
|
||||||
|
width = images[0][-1][0][0]
|
||||||
|
height = images[0][-1][0][1]
|
||||||
|
|
||||||
|
num_image_tokens = self.info.get_num_image_tokens(
|
||||||
|
image_width=width,
|
||||||
|
image_height=height,
|
||||||
|
# flag = True,
|
||||||
|
cropping=CROP_MODE,
|
||||||
|
)
|
||||||
|
return [image_token_id] * num_image_tokens
|
||||||
|
|
||||||
|
return [
|
||||||
|
PromptReplacement(
|
||||||
|
modality="image",
|
||||||
|
target=[image_token_id],
|
||||||
|
replacement=get_replacement_deepseek_vl2,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
def _cached_apply_hf_processor(
|
||||||
|
self,
|
||||||
|
prompt: Union[str, list[int]],
|
||||||
|
mm_data_items: MultiModalDataItems,
|
||||||
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
|
) -> tuple[list[int], MultiModalKwargs, bool]:
|
||||||
|
# The processor logic is different for len(images) <= 2 vs > 2
|
||||||
|
# Since the processing cache assumes that the processor output is
|
||||||
|
# invariant of how many images are passed per prompt, we only
|
||||||
|
# perform caching for the most common case
|
||||||
|
if mm_data_items.get_count("image", strict=False) > 2:
|
||||||
|
# This code path corresponds to the cache being disabled
|
||||||
|
return self._apply_hf_processor_main(
|
||||||
|
prompt=prompt,
|
||||||
|
mm_items=mm_data_items,
|
||||||
|
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||||
|
enable_hf_prompt_update=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return super()._cached_apply_hf_processor(
|
||||||
|
prompt=prompt,
|
||||||
|
mm_data_items=mm_data_items,
|
||||||
|
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@MULTIMODAL_REGISTRY.register_processor(
|
||||||
|
DeepseekOCRMultiModalProcessor,
|
||||||
|
info=DeepseekOCRProcessingInfo,
|
||||||
|
dummy_inputs=DeepseekOCRDummyInputsBuilder)
|
||||||
|
class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||||
|
|
||||||
|
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={
|
||||||
|
"language.": "language_model.",
|
||||||
|
})
|
||||||
|
|
||||||
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
config: DeepseekVLV2Config = vllm_config.model_config.hf_config
|
||||||
|
quant_config = vllm_config.quant_config
|
||||||
|
multimodal_config = vllm_config.model_config.multimodal_config
|
||||||
|
|
||||||
|
# config.model_type ='deepseek_vl_v2'
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
self.multimodal_config = multimodal_config
|
||||||
|
|
||||||
|
|
||||||
|
self.vision_config = config.vision_config
|
||||||
|
self.projector_config = config.projector_config
|
||||||
|
self.text_config = config.text_config
|
||||||
|
|
||||||
|
model_config = vllm_config.model_config
|
||||||
|
tokenizer = cached_tokenizer_from_config(model_config)
|
||||||
|
self.image_token_id = tokenizer.vocab[_IMAGE_TOKEN]
|
||||||
|
|
||||||
|
self.sam_model = build_sam_vit_b()
|
||||||
|
self.vision_model = build_clip_l()
|
||||||
|
|
||||||
|
n_embed = 1280
|
||||||
|
self.projector = MlpProjector(Dict(projector_type="linear", input_dim=2048, n_embed=n_embed))
|
||||||
|
self.tile_tag = config.tile_tag
|
||||||
|
self.global_view_pos = config.global_view_pos
|
||||||
|
|
||||||
|
# self.sam_model = torch.compile(self.sam_model, mode="reduce-overhead")
|
||||||
|
# self.vision_model = torch.compile(self.vision_model, mode="reduce-overhead")
|
||||||
|
# self.projector = torch.compile(self.projector, mode="max-autotune")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# special token for image token sequence format
|
||||||
|
embed_std = 1 / torch.sqrt(torch.tensor(n_embed, dtype=torch.float32))
|
||||||
|
if self.tile_tag == "2D":
|
||||||
|
# <|view_separator|>, <|\n|>
|
||||||
|
self.image_newline = nn.Parameter(torch.randn(n_embed) * embed_std)
|
||||||
|
self.view_seperator = nn.Parameter(torch.randn(n_embed) * embed_std)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Only 2D tile_tag is supported currently, got: {self.tile_tag}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.text_config.topk_method == "noaux_tc":
|
||||||
|
architectures = ["DeepseekV3ForCausalLM"]
|
||||||
|
elif not self.text_config.use_mla:
|
||||||
|
architectures = ["DeepseekForCausalLM"]
|
||||||
|
else:
|
||||||
|
architectures = ["DeepseekV2ForCausalLM"]
|
||||||
|
|
||||||
|
self.language_model = init_vllm_registered_model(
|
||||||
|
vllm_config=vllm_config,
|
||||||
|
hf_config=self.text_config,
|
||||||
|
prefix=maybe_prefix(prefix, "language"),
|
||||||
|
architectures=architectures,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.make_empty_intermediate_tensors = (
|
||||||
|
self.language_model.make_empty_intermediate_tensors)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_and_validate_image_input(
|
||||||
|
self, **kwargs: object):
|
||||||
|
|
||||||
|
pixel_values = kwargs.pop("pixel_values", None)
|
||||||
|
images_spatial_crop = kwargs.pop("images_spatial_crop", None)
|
||||||
|
images_crop = kwargs.pop("images_crop", None)
|
||||||
|
|
||||||
|
|
||||||
|
if pixel_values is None or torch.sum(pixel_values).item() == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if pixel_values is not None:
|
||||||
|
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||||
|
raise ValueError("Incorrect type of pixel values. "
|
||||||
|
f"Got type: {type(pixel_values)}")
|
||||||
|
|
||||||
|
if not isinstance(images_spatial_crop, (torch.Tensor, list)):
|
||||||
|
raise ValueError("Incorrect type of image sizes. "
|
||||||
|
f"Got type: {type(images_spatial_crop)}")
|
||||||
|
|
||||||
|
if not isinstance(images_crop, (torch.Tensor, list)):
|
||||||
|
raise ValueError("Incorrect type of image crop. "
|
||||||
|
f"Got type: {type(images_crop)}")
|
||||||
|
|
||||||
|
return [pixel_values, images_crop, images_spatial_crop]
|
||||||
|
|
||||||
|
|
||||||
|
raise AssertionError("This line should be unreachable.")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def _pixel_values_to_embedding(
|
||||||
|
self,
|
||||||
|
pixel_values: torch.Tensor,
|
||||||
|
images_crop: torch.Tensor,
|
||||||
|
images_spatial_crop: torch.Tensor,
|
||||||
|
) -> NestedTensors:
|
||||||
|
|
||||||
|
# Pixel_values (global view): [n_image, batch_size, 3, height, width]
|
||||||
|
# images_spatial_crop: [n_image, batch_size, [num_tiles_w, num_tiles_h]]
|
||||||
|
# images_crop (local view): [n_image, batch_size, num_pathes, 3, h, w]
|
||||||
|
# split the pixel and image_crop, all batch_size = 1
|
||||||
|
|
||||||
|
images_in_this_batch = []
|
||||||
|
|
||||||
|
|
||||||
|
# print(type(images_crop))
|
||||||
|
|
||||||
|
# print(pixel_values.shape)
|
||||||
|
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for jdx in range(images_spatial_crop.size(0)):
|
||||||
|
# with torch.set_grad_enabled(False):
|
||||||
|
patches = images_crop[jdx][0].to(torch.bfloat16) # batch_size = 1
|
||||||
|
image_ori = pixel_values[jdx]
|
||||||
|
crop_shape = images_spatial_crop[jdx][0]
|
||||||
|
|
||||||
|
if torch.sum(patches).item() != 0: # if all values = 0, no crop
|
||||||
|
# P, C, H, W = patches.shape
|
||||||
|
# crop_flag = 1
|
||||||
|
local_features_1 = self.sam_model(patches)
|
||||||
|
#TODO del patches
|
||||||
|
# torch.compiler.cudagraph_mark_step_begin()
|
||||||
|
local_features_2 = self.vision_model(patches, local_features_1)
|
||||||
|
|
||||||
|
|
||||||
|
local_features = torch.cat((local_features_2[:, 1:], local_features_1.flatten(2).permute(0, 2, 1)), dim=-1)
|
||||||
|
local_features = self.projector(local_features)
|
||||||
|
|
||||||
|
|
||||||
|
global_features_1 = self.sam_model(image_ori)
|
||||||
|
global_features_2 = self.vision_model(image_ori, global_features_1)
|
||||||
|
global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1)
|
||||||
|
global_features = self.projector(global_features)
|
||||||
|
|
||||||
|
if PRINT_NUM_VIS_TOKENS:
|
||||||
|
print('=====================')
|
||||||
|
print('BASE: ', global_features.shape)
|
||||||
|
print('PATCHES: ', local_features.shape)
|
||||||
|
print('=====================')
|
||||||
|
|
||||||
|
_, hw, n_dim = global_features.shape
|
||||||
|
h = w = int(hw ** 0.5)
|
||||||
|
|
||||||
|
_2, hw2, n_dim2 = local_features.shape
|
||||||
|
h2 = w2 = int(hw2 ** 0.5)
|
||||||
|
|
||||||
|
width_crop_num, height_crop_num = crop_shape[0], crop_shape[1]
|
||||||
|
|
||||||
|
global_features = global_features.view(h, w, n_dim)
|
||||||
|
|
||||||
|
global_features = torch.cat(
|
||||||
|
[global_features, self.image_newline[None, None, :].expand(h, 1, n_dim)], dim=1
|
||||||
|
)
|
||||||
|
|
||||||
|
global_features = global_features.view(-1, n_dim)
|
||||||
|
|
||||||
|
|
||||||
|
local_features = local_features.view(height_crop_num, width_crop_num, h2, w2, n_dim2).permute(0, 2, 1, 3, 4).reshape(height_crop_num*h2, width_crop_num*w2, n_dim2)
|
||||||
|
local_features = torch.cat(
|
||||||
|
[local_features, self.image_newline[None, None, :].expand(height_crop_num * h2, 1, n_dim2)], dim=1
|
||||||
|
)
|
||||||
|
local_features = local_features.view(-1, n_dim2)
|
||||||
|
|
||||||
|
global_local_features = torch.cat([local_features, global_features, self.view_seperator[None, :]], dim=0)
|
||||||
|
|
||||||
|
else:
|
||||||
|
global_features_1 = self.sam_model(image_ori)
|
||||||
|
global_features_2 = self.vision_model(image_ori, global_features_1)
|
||||||
|
global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1)
|
||||||
|
global_features = self.projector(global_features)
|
||||||
|
|
||||||
|
if PRINT_NUM_VIS_TOKENS:
|
||||||
|
print('=====================')
|
||||||
|
print('BASE: ', global_features.shape)
|
||||||
|
print('NO PATCHES')
|
||||||
|
print('=====================')
|
||||||
|
|
||||||
|
_, hw, n_dim = global_features.shape
|
||||||
|
h = w = int(hw ** 0.5)
|
||||||
|
|
||||||
|
global_features = global_features.view(h, w, n_dim)
|
||||||
|
|
||||||
|
global_features = torch.cat(
|
||||||
|
[global_features, self.image_newline[None, None, :].expand(h, 1, n_dim)], dim=1
|
||||||
|
)
|
||||||
|
|
||||||
|
global_features = global_features.view(-1, n_dim)
|
||||||
|
|
||||||
|
global_local_features = torch.cat([global_features, self.view_seperator[None, :]], dim=0)
|
||||||
|
|
||||||
|
images_in_this_batch.append(global_local_features)
|
||||||
|
|
||||||
|
return images_in_this_batch
|
||||||
|
|
||||||
|
def _process_image_input(
|
||||||
|
self, image_input) -> torch.Tensor:
|
||||||
|
|
||||||
|
|
||||||
|
# image_input: [pixel_values, images_crop, images_spatial_crop]
|
||||||
|
|
||||||
|
pixel_values = image_input[0].to(torch.bfloat16)
|
||||||
|
# print(image_input[1][0].shape)
|
||||||
|
# print(type(image_input[1]))
|
||||||
|
# exit()
|
||||||
|
|
||||||
|
# images_crop = image_input[1].to(torch.bfloat16)
|
||||||
|
images_crop = image_input[1]
|
||||||
|
# images_crop = image_input[1]
|
||||||
|
images_spatial_crop = image_input[2].to(dtype=torch.long)
|
||||||
|
|
||||||
|
# local_start = time.time()
|
||||||
|
vision_features = self._pixel_values_to_embedding(
|
||||||
|
pixel_values=pixel_values, images_crop = images_crop, images_spatial_crop=images_spatial_crop)
|
||||||
|
|
||||||
|
# local_total_time = time.time() - local_start
|
||||||
|
|
||||||
|
# print('encoder_time: ', local_total_time)
|
||||||
|
# exit()
|
||||||
|
return vision_features
|
||||||
|
|
||||||
|
def get_language_model(self) -> torch.nn.Module:
|
||||||
|
return self.language_model
|
||||||
|
|
||||||
|
def get_multimodal_embeddings(
|
||||||
|
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
||||||
|
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||||
|
if image_input is None:
|
||||||
|
return None
|
||||||
|
vision_embeddings = self._process_image_input(image_input)
|
||||||
|
return vision_embeddings
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def get_input_embeddings(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||||
|
|
||||||
|
|
||||||
|
if multimodal_embeddings is not None:
|
||||||
|
inputs_embeds = merge_multimodal_embeddings(
|
||||||
|
input_ids, inputs_embeds, multimodal_embeddings,
|
||||||
|
self.image_token_id)
|
||||||
|
# print(len(multimodal_embeddings))
|
||||||
|
# print(input_ids.shape)
|
||||||
|
# print(type(inputs_embeds))
|
||||||
|
# print(inputs_embeds.shape)
|
||||||
|
|
||||||
|
return inputs_embeds
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs: object):
|
||||||
|
|
||||||
|
if intermediate_tensors is not None:
|
||||||
|
inputs_embeds = None
|
||||||
|
|
||||||
|
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
||||||
|
# condition is for v0 compatibility
|
||||||
|
elif inputs_embeds is None:
|
||||||
|
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||||
|
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||||
|
vision_embeddings)
|
||||||
|
input_ids = None
|
||||||
|
|
||||||
|
hidden_states = self.language_model(input_ids,
|
||||||
|
positions,
|
||||||
|
intermediate_tensors,
|
||||||
|
inputs_embeds=inputs_embeds)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def compute_logits(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> Optional[torch.Tensor]:
|
||||||
|
return self.language_model.compute_logits(hidden_states,
|
||||||
|
sampling_metadata)
|
||||||
|
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]:
|
||||||
|
processed_weights = []
|
||||||
|
|
||||||
|
for name, tensor in weights:
|
||||||
|
if 'sam_model' in name or 'vision_model' in name or 'projector' in name or 'image_newline' in name or 'view_seperator' in name:
|
||||||
|
new_name = name.replace('model.', '', 1)
|
||||||
|
else:
|
||||||
|
new_name = 'language.' + name
|
||||||
|
|
||||||
|
processed_weights.append((new_name, tensor))
|
||||||
|
|
||||||
|
loader = AutoWeightsLoader(self)
|
||||||
|
autoloaded_weights = loader.load_weights(processed_weights, mapper=self.hf_to_vllm_mapper)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
return autoloaded_weights
|
||||||
502
DeepSeek-OCR-master/DeepSeek-OCR-vllm/process/image_process.py
Normal file
@ -0,0 +1,502 @@
|
|||||||
|
import math
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torchvision.transforms as T
|
||||||
|
from PIL import Image, ImageOps
|
||||||
|
from transformers import AutoProcessor, BatchFeature, LlamaTokenizerFast
|
||||||
|
from transformers.processing_utils import ProcessorMixin
|
||||||
|
from config import IMAGE_SIZE, BASE_SIZE, CROP_MODE, MIN_CROPS, MAX_CROPS, PROMPT, TOKENIZER
|
||||||
|
|
||||||
|
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
|
||||||
|
best_ratio_diff = float('inf')
|
||||||
|
best_ratio = (1, 1)
|
||||||
|
area = width * height
|
||||||
|
for ratio in target_ratios:
|
||||||
|
target_aspect_ratio = ratio[0] / ratio[1]
|
||||||
|
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
|
||||||
|
if ratio_diff < best_ratio_diff:
|
||||||
|
best_ratio_diff = ratio_diff
|
||||||
|
best_ratio = ratio
|
||||||
|
elif ratio_diff == best_ratio_diff:
|
||||||
|
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
|
||||||
|
best_ratio = ratio
|
||||||
|
# print(f'width: {width}, height: {height}, best_ratio: {best_ratio}')
|
||||||
|
return best_ratio
|
||||||
|
|
||||||
|
|
||||||
|
def count_tiles(orig_width, orig_height, min_num=MIN_CROPS, max_num=MAX_CROPS, image_size=640, use_thumbnail=False):
|
||||||
|
aspect_ratio = orig_width / orig_height
|
||||||
|
|
||||||
|
# calculate the existing image aspect ratio
|
||||||
|
target_ratios = set(
|
||||||
|
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
|
||||||
|
i * j <= max_num and i * j >= min_num)
|
||||||
|
# print(target_ratios)
|
||||||
|
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
||||||
|
|
||||||
|
# find the closest aspect ratio to the target
|
||||||
|
target_aspect_ratio = find_closest_aspect_ratio(
|
||||||
|
aspect_ratio, target_ratios, orig_width, orig_height, image_size)
|
||||||
|
|
||||||
|
return target_aspect_ratio
|
||||||
|
|
||||||
|
|
||||||
|
def dynamic_preprocess(image, min_num=MIN_CROPS, max_num=MAX_CROPS, image_size=640, use_thumbnail=False):
|
||||||
|
orig_width, orig_height = image.size
|
||||||
|
aspect_ratio = orig_width / orig_height
|
||||||
|
|
||||||
|
# calculate the existing image aspect ratio
|
||||||
|
target_ratios = set(
|
||||||
|
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
|
||||||
|
i * j <= max_num and i * j >= min_num)
|
||||||
|
# print(target_ratios)
|
||||||
|
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
||||||
|
|
||||||
|
# find the closest aspect ratio to the target
|
||||||
|
target_aspect_ratio = find_closest_aspect_ratio(
|
||||||
|
aspect_ratio, target_ratios, orig_width, orig_height, image_size)
|
||||||
|
|
||||||
|
# print(target_aspect_ratio)
|
||||||
|
# calculate the target width and height
|
||||||
|
target_width = image_size * target_aspect_ratio[0]
|
||||||
|
target_height = image_size * target_aspect_ratio[1]
|
||||||
|
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
|
||||||
|
|
||||||
|
# resize the image
|
||||||
|
resized_img = image.resize((target_width, target_height))
|
||||||
|
processed_images = []
|
||||||
|
for i in range(blocks):
|
||||||
|
box = (
|
||||||
|
(i % (target_width // image_size)) * image_size,
|
||||||
|
(i // (target_width // image_size)) * image_size,
|
||||||
|
((i % (target_width // image_size)) + 1) * image_size,
|
||||||
|
((i // (target_width // image_size)) + 1) * image_size
|
||||||
|
)
|
||||||
|
# split the image
|
||||||
|
split_img = resized_img.crop(box)
|
||||||
|
processed_images.append(split_img)
|
||||||
|
assert len(processed_images) == blocks
|
||||||
|
if use_thumbnail and len(processed_images) != 1:
|
||||||
|
thumbnail_img = image.resize((image_size, image_size))
|
||||||
|
processed_images.append(thumbnail_img)
|
||||||
|
return processed_images, target_aspect_ratio
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class ImageTransform:
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
mean: Tuple[float, float, float] = (0.5, 0.5, 0.5),
|
||||||
|
std: Tuple[float, float, float] = (0.5, 0.5, 0.5),
|
||||||
|
normalize: bool = True):
|
||||||
|
self.mean = mean
|
||||||
|
self.std = std
|
||||||
|
self.normalize = normalize
|
||||||
|
|
||||||
|
transform_pipelines = [T.ToTensor()]
|
||||||
|
|
||||||
|
if normalize:
|
||||||
|
transform_pipelines.append(T.Normalize(mean, std))
|
||||||
|
|
||||||
|
self.transform = T.Compose(transform_pipelines)
|
||||||
|
|
||||||
|
def __call__(self, pil_img: Image.Image):
|
||||||
|
x = self.transform(pil_img)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class DeepseekOCRProcessor(ProcessorMixin):
|
||||||
|
tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
|
||||||
|
attributes = ["tokenizer"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
tokenizer: LlamaTokenizerFast = TOKENIZER,
|
||||||
|
candidate_resolutions: Tuple[Tuple[int, int]] = [[1024, 1024]],
|
||||||
|
patch_size: int = 16,
|
||||||
|
downsample_ratio: int = 4,
|
||||||
|
image_mean: Tuple[float, float, float] = (0.5, 0.5, 0.5),
|
||||||
|
image_std: Tuple[float, float, float] = (0.5, 0.5, 0.5),
|
||||||
|
normalize: bool = True,
|
||||||
|
image_token: str = "<image>",
|
||||||
|
pad_token: str = "<|▁pad▁|>",
|
||||||
|
add_special_token: bool = False,
|
||||||
|
sft_format: str = "deepseek",
|
||||||
|
mask_prompt: bool = True,
|
||||||
|
ignore_id: int = -100,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
|
||||||
|
# self.candidate_resolutions = candidate_resolutions # placeholder no use
|
||||||
|
self.image_size = IMAGE_SIZE
|
||||||
|
self.base_size = BASE_SIZE
|
||||||
|
# self.patch_size = patch_size
|
||||||
|
self.patch_size = 16
|
||||||
|
self.image_mean = image_mean
|
||||||
|
self.image_std = image_std
|
||||||
|
self.normalize = normalize
|
||||||
|
# self.downsample_ratio = downsample_ratio
|
||||||
|
self.downsample_ratio = 4
|
||||||
|
|
||||||
|
self.image_transform = ImageTransform(mean=image_mean, std=image_std, normalize=normalize)
|
||||||
|
|
||||||
|
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
# self.tokenizer = add_special_token(tokenizer)
|
||||||
|
self.tokenizer.padding_side = 'left' # must set this,padding side with make a difference in batch inference
|
||||||
|
|
||||||
|
# add the pad_token as special token to use 'tokenizer.pad_token' and 'tokenizer.pad_token_id'
|
||||||
|
if self.tokenizer.pad_token is None:
|
||||||
|
self.tokenizer.add_special_tokens({'pad_token': pad_token})
|
||||||
|
|
||||||
|
# add image token
|
||||||
|
# image_token_id = self.tokenizer.vocab.get(image_token)
|
||||||
|
# if image_token_id is None:
|
||||||
|
# special_tokens = [image_token]
|
||||||
|
# special_tokens_dict = {"additional_special_tokens": special_tokens}
|
||||||
|
# self.tokenizer.add_special_tokens(special_tokens_dict)
|
||||||
|
self.image_token_id = self.tokenizer.vocab.get(image_token)
|
||||||
|
|
||||||
|
# add five special tokens for grounding-related tasks
|
||||||
|
# <|ref|>, <|/ref|>, <|det|>, <|/det|>, <|grounding|>
|
||||||
|
# special_tokens = ['<|ref|>', '<|/ref|>', '<|det|>', '<|/det|>', '<|grounding|>']
|
||||||
|
# special_tokens_dict = {"additional_special_tokens": special_tokens}
|
||||||
|
|
||||||
|
# special_tokens = ['<image>','<|ref|>', '<|/ref|>', '<|det|>', '<|/det|>', '<|grounding|>', '<td>', '</td>', '<tr>', '</tr>']
|
||||||
|
# special_tokens_dict = {"additional_special_tokens": special_tokens}
|
||||||
|
# self.tokenizer.add_special_tokens(special_tokens_dict)
|
||||||
|
|
||||||
|
# # add special tokens for SFT data
|
||||||
|
# special_tokens = ["<|User|>", "<|Assistant|>"]
|
||||||
|
# special_tokens_dict = {"additional_special_tokens": special_tokens}
|
||||||
|
# self.tokenizer.add_special_tokens(special_tokens_dict)
|
||||||
|
|
||||||
|
self.image_token = image_token
|
||||||
|
self.pad_token = pad_token
|
||||||
|
self.add_special_token = add_special_token
|
||||||
|
self.sft_format = sft_format
|
||||||
|
self.mask_prompt = mask_prompt
|
||||||
|
self.ignore_id = ignore_id
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
tokenizer,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# def select_best_resolution(self, image_size):
|
||||||
|
# # used for cropping
|
||||||
|
# original_width, original_height = image_size
|
||||||
|
# best_fit = None
|
||||||
|
# max_effective_resolution = 0
|
||||||
|
# min_wasted_resolution = float("inf")
|
||||||
|
|
||||||
|
# for width, height in self.candidate_resolutions:
|
||||||
|
# scale = min(width / original_width, height / original_height)
|
||||||
|
# downscaled_width, downscaled_height = int(
|
||||||
|
# original_width * scale), int(original_height * scale)
|
||||||
|
# effective_resolution = min(downscaled_width * downscaled_height,
|
||||||
|
# original_width * original_height)
|
||||||
|
# wasted_resolution = (width * height) - effective_resolution
|
||||||
|
|
||||||
|
# if effective_resolution > max_effective_resolution or (
|
||||||
|
# effective_resolution == max_effective_resolution
|
||||||
|
# and wasted_resolution < min_wasted_resolution):
|
||||||
|
# max_effective_resolution = effective_resolution
|
||||||
|
# min_wasted_resolution = wasted_resolution
|
||||||
|
# best_fit = (width, height)
|
||||||
|
|
||||||
|
# return best_fit
|
||||||
|
|
||||||
|
@property
|
||||||
|
def bos_id(self):
|
||||||
|
return self.tokenizer.bos_token_id
|
||||||
|
|
||||||
|
@property
|
||||||
|
def eos_id(self):
|
||||||
|
return self.tokenizer.eos_token_id
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pad_id(self):
|
||||||
|
return self.tokenizer.pad_token_id
|
||||||
|
|
||||||
|
def encode(self, text: str, bos: bool = True, eos: bool = False):
|
||||||
|
t = self.tokenizer.encode(text, add_special_tokens=False)
|
||||||
|
|
||||||
|
if bos:
|
||||||
|
t = [self.bos_id] + t
|
||||||
|
if eos:
|
||||||
|
t = t + [self.eos_id]
|
||||||
|
|
||||||
|
return t
|
||||||
|
|
||||||
|
def decode(self, t: List[int], **kwargs) -> str:
|
||||||
|
return self.tokenizer.decode(t, **kwargs)
|
||||||
|
|
||||||
|
def process_one(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
images: List,
|
||||||
|
inference_mode: bool = True,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (str): the formatted prompt;
|
||||||
|
conversations (List[Dict]): conversations with a list of messages;
|
||||||
|
images (List[ImageType]): the list of images;
|
||||||
|
inference_mode (bool): if True, then remove the last eos token;
|
||||||
|
system_prompt (str): the system prompt;
|
||||||
|
**kwargs:
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
outputs (BaseProcessorOutput): the output of the processor,
|
||||||
|
- input_ids (torch.LongTensor): [N + image tokens]
|
||||||
|
- target_ids (torch.LongTensor): [N + image tokens]
|
||||||
|
- pixel_values (torch.FloatTensor): [n_patches, 3, H, W]
|
||||||
|
- image_id (int): the id of the image token
|
||||||
|
- num_image_tokens (List[int]): the number of image tokens
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert (prompt is not None and images is not None
|
||||||
|
), "prompt and images must be used at the same time."
|
||||||
|
|
||||||
|
sft_format = prompt
|
||||||
|
|
||||||
|
input_ids, pixel_values, images_crop, images_seq_mask, images_spatial_crop, num_image_tokens, _ = images[0]
|
||||||
|
|
||||||
|
|
||||||
|
return {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"pixel_values": pixel_values,
|
||||||
|
"images_crop": images_crop,
|
||||||
|
"images_seq_mask": images_seq_mask,
|
||||||
|
"images_spatial_crop": images_spatial_crop,
|
||||||
|
"num_image_tokens": num_image_tokens,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# prepare = BatchFeature(
|
||||||
|
# data=dict(
|
||||||
|
# input_ids=input_ids,
|
||||||
|
# pixel_values=pixel_values,
|
||||||
|
# images_crop = images_crop,
|
||||||
|
# images_seq_mask=images_seq_mask,
|
||||||
|
# images_spatial_crop=images_spatial_crop,
|
||||||
|
# num_image_tokens=num_image_tokens,
|
||||||
|
# ),
|
||||||
|
# tensor_type="pt",
|
||||||
|
# )
|
||||||
|
# return prepare
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
prompt: str,
|
||||||
|
images: List,
|
||||||
|
inference_mode: bool = True,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (str): the formatted prompt;
|
||||||
|
images (List[ImageType]): the list of images;
|
||||||
|
inference_mode (bool): if True, then remove the last eos token;
|
||||||
|
**kwargs:
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
outputs (BaseProcessorOutput): the output of the processor,
|
||||||
|
- input_ids (torch.LongTensor): [N + image tokens]
|
||||||
|
- images (torch.FloatTensor): [n_images, 3, H, W]
|
||||||
|
- image_id (int): the id of the image token
|
||||||
|
- num_image_tokens (List[int]): the number of image tokens
|
||||||
|
"""
|
||||||
|
|
||||||
|
prepare = self.process_one(
|
||||||
|
prompt=prompt,
|
||||||
|
images=images,
|
||||||
|
inference_mode=inference_mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
return prepare
|
||||||
|
|
||||||
|
def tokenize_with_images(
|
||||||
|
self,
|
||||||
|
# conversation: str,
|
||||||
|
images: List[Image.Image],
|
||||||
|
bos: bool = True,
|
||||||
|
eos: bool = True,
|
||||||
|
cropping: bool = True,
|
||||||
|
):
|
||||||
|
"""Tokenize text with <image> tags."""
|
||||||
|
|
||||||
|
# print(conversation)
|
||||||
|
conversation = PROMPT
|
||||||
|
assert conversation.count(self.image_token) == len(images)
|
||||||
|
text_splits = conversation.split(self.image_token)
|
||||||
|
images_list, images_crop_list, images_seq_mask, images_spatial_crop = [], [], [], []
|
||||||
|
image_shapes = []
|
||||||
|
num_image_tokens = []
|
||||||
|
tokenized_str = []
|
||||||
|
# print('image: ', len(images))
|
||||||
|
for text_sep, image in zip(text_splits, images):
|
||||||
|
"""encode text_sep"""
|
||||||
|
tokenized_sep = self.encode(text_sep, bos=False, eos=False)
|
||||||
|
tokenized_str += tokenized_sep
|
||||||
|
images_seq_mask += [False] * len(tokenized_sep)
|
||||||
|
|
||||||
|
"""select best resolution for anyres"""
|
||||||
|
# if cropping:
|
||||||
|
# best_width, best_height = self.select_best_resolution(image.size)
|
||||||
|
# else:
|
||||||
|
# best_width, best_height = self.image_size, self.image_size
|
||||||
|
|
||||||
|
image_shapes.append(image.size)
|
||||||
|
|
||||||
|
if image.size[0] <= 640 and image.size[1] <= 640:
|
||||||
|
crop_ratio = [1, 1]
|
||||||
|
else:
|
||||||
|
if cropping:
|
||||||
|
# print('image-size: ', image.size)
|
||||||
|
# best_width, best_height = select_best_resolution(image.size, self.candidate_resolutions)
|
||||||
|
# print('image ', image.size)
|
||||||
|
# print('open_size:', image.size)
|
||||||
|
images_crop_raw, crop_ratio = dynamic_preprocess(image, image_size=IMAGE_SIZE)
|
||||||
|
# print('crop_ratio: ', crop_ratio)
|
||||||
|
else:
|
||||||
|
# best_width, best_height = self.image_size, self.image_size
|
||||||
|
crop_ratio = [1, 1]
|
||||||
|
# print(image.size, (best_width, best_height)) # check the select_best_resolutions func
|
||||||
|
|
||||||
|
# print(crop_ratio)
|
||||||
|
"""process the global view"""
|
||||||
|
|
||||||
|
# if cropping
|
||||||
|
if self.image_size <= 640 and not cropping:
|
||||||
|
# print('directly resize')
|
||||||
|
image = image.resize((self.image_size, self.image_size))
|
||||||
|
|
||||||
|
global_view = ImageOps.pad(image, (self.base_size, self.base_size),
|
||||||
|
color=tuple(int(x * 255) for x in self.image_transform.mean))
|
||||||
|
images_list.append(self.image_transform(global_view))
|
||||||
|
|
||||||
|
"""record height / width crop num"""
|
||||||
|
# width_crop_num, height_crop_num = best_width // self.image_size, best_height // self.image_size
|
||||||
|
num_width_tiles, num_height_tiles = crop_ratio
|
||||||
|
images_spatial_crop.append([num_width_tiles, num_height_tiles])
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if num_width_tiles > 1 or num_height_tiles > 1:
|
||||||
|
"""process the local views"""
|
||||||
|
# local_view = ImageOps.pad(image, (best_width, best_height),
|
||||||
|
# color=tuple(int(x * 255) for x in self.image_transform.mean))
|
||||||
|
# for i in range(0, best_height, self.image_size):
|
||||||
|
# for j in range(0, best_width, self.image_size):
|
||||||
|
# images_crop_list.append(
|
||||||
|
# self.image_transform(local_view.crop((j, i, j + self.image_size, i + self.image_size))))
|
||||||
|
for i in range(len(images_crop_raw)):
|
||||||
|
images_crop_list.append(self.image_transform(images_crop_raw[i]))
|
||||||
|
|
||||||
|
# """process the global view"""
|
||||||
|
# global_view = ImageOps.pad(image, (self.image_size, self.image_size),
|
||||||
|
# color=tuple(int(x * 255) for x in self.image_transform.mean))
|
||||||
|
# images_list.append(self.image_transform(global_view))
|
||||||
|
|
||||||
|
# """process the local views"""
|
||||||
|
# local_view = ImageOps.pad(image, (best_width, best_height),
|
||||||
|
# color=tuple(int(x * 255) for x in self.image_transform.mean))
|
||||||
|
# for i in range(0, best_height, self.image_size):
|
||||||
|
# for j in range(0, best_width, self.image_size):
|
||||||
|
# images_list.append(
|
||||||
|
# self.image_transform(local_view.crop((j, i, j + self.image_size, i + self.image_size))))
|
||||||
|
|
||||||
|
# """add image tokens"""
|
||||||
|
"""add image tokens"""
|
||||||
|
num_queries = math.ceil((self.image_size // self.patch_size) / self.downsample_ratio)
|
||||||
|
num_queries_base = math.ceil((self.base_size // self.patch_size) / self.downsample_ratio)
|
||||||
|
|
||||||
|
|
||||||
|
tokenized_image = ([self.image_token_id] * num_queries_base + [self.image_token_id]) * num_queries_base
|
||||||
|
tokenized_image += [self.image_token_id]
|
||||||
|
if num_width_tiles > 1 or num_height_tiles > 1:
|
||||||
|
tokenized_image += ([self.image_token_id] * (num_queries * num_width_tiles) + [self.image_token_id]) * (
|
||||||
|
num_queries * num_height_tiles)
|
||||||
|
tokenized_str += tokenized_image
|
||||||
|
images_seq_mask += [True] * len(tokenized_image)
|
||||||
|
num_image_tokens.append(len(tokenized_image))
|
||||||
|
|
||||||
|
"""process the last text split"""
|
||||||
|
tokenized_sep = self.encode(text_splits[-1], bos=False, eos=False)
|
||||||
|
tokenized_str += tokenized_sep
|
||||||
|
images_seq_mask += [False] * len(tokenized_sep)
|
||||||
|
|
||||||
|
"""add the bos and eos tokens"""
|
||||||
|
if bos:
|
||||||
|
tokenized_str = [self.bos_id] + tokenized_str
|
||||||
|
images_seq_mask = [False] + images_seq_mask
|
||||||
|
if eos:
|
||||||
|
tokenized_str = tokenized_str + [self.eos_id]
|
||||||
|
images_seq_mask = images_seq_mask + [False]
|
||||||
|
|
||||||
|
assert len(tokenized_str) == len(
|
||||||
|
images_seq_mask), f"tokenize_with_images func: tokenized_str's length {len(tokenized_str)} is not equal to imags_seq_mask's length {len(images_seq_mask)}"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
masked_tokenized_str = []
|
||||||
|
for token_index in tokenized_str:
|
||||||
|
if token_index != self.image_token_id:
|
||||||
|
masked_tokenized_str.append(token_index)
|
||||||
|
else:
|
||||||
|
masked_tokenized_str.append(self.ignore_id)
|
||||||
|
|
||||||
|
assert len(tokenized_str) == len(images_seq_mask) == len(masked_tokenized_str), \
|
||||||
|
(f"tokenized_str's length {len(tokenized_str)}, input_ids' length {len(masked_tokenized_str)}, "
|
||||||
|
f"imags_seq_mask's length {len(images_seq_mask)}, are not equal")
|
||||||
|
|
||||||
|
input_ids = torch.LongTensor(tokenized_str)
|
||||||
|
target_ids = torch.LongTensor(masked_tokenized_str)
|
||||||
|
images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool)
|
||||||
|
|
||||||
|
# set input_ids < 0 | input_ids == self.image_token_id as ignore_id
|
||||||
|
target_ids[(input_ids < 0) |
|
||||||
|
(input_ids == self.image_token_id)] = self.ignore_id
|
||||||
|
input_ids[input_ids < 0] = self.pad_id
|
||||||
|
|
||||||
|
inference_mode = True
|
||||||
|
|
||||||
|
if inference_mode:
|
||||||
|
# Remove the ending eos token
|
||||||
|
assert input_ids[-1] == self.eos_id
|
||||||
|
input_ids = input_ids[:-1]
|
||||||
|
target_ids = target_ids[:-1]
|
||||||
|
images_seq_mask = images_seq_mask[:-1]
|
||||||
|
|
||||||
|
if len(images_list) == 0:
|
||||||
|
pixel_values = torch.zeros((1, 3, self.base_size, self.base_size))
|
||||||
|
images_spatial_crop = torch.zeros((1, 1), dtype=torch.long)
|
||||||
|
images_crop = torch.zeros((1, 3, self.image_size, self.image_size)).unsqueeze(0)
|
||||||
|
else:
|
||||||
|
pixel_values = torch.stack(images_list, dim=0)
|
||||||
|
images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long)
|
||||||
|
if images_crop_list:
|
||||||
|
images_crop = torch.stack(images_crop_list, dim=0).unsqueeze(0)
|
||||||
|
else:
|
||||||
|
images_crop = torch.zeros((1, 3, self.image_size, self.image_size)).unsqueeze(0)
|
||||||
|
|
||||||
|
input_ids = input_ids.unsqueeze(0)
|
||||||
|
|
||||||
|
|
||||||
|
return [[input_ids, pixel_values, images_crop, images_seq_mask, images_spatial_crop, num_image_tokens, image_shapes]]
|
||||||
|
|
||||||
|
|
||||||
|
AutoProcessor.register("DeepseekVLV2Processor", DeepseekOCRProcessor)
|
||||||
@ -0,0 +1,40 @@
|
|||||||
|
import torch
|
||||||
|
from transformers import LogitsProcessor
|
||||||
|
from transformers.generation.logits_process import _calc_banned_ngram_tokens
|
||||||
|
from typing import List, Set
|
||||||
|
|
||||||
|
|
||||||
|
class NoRepeatNGramLogitsProcessor(LogitsProcessor):
|
||||||
|
|
||||||
|
def __init__(self, ngram_size: int, window_size: int = 100, whitelist_token_ids: set = None):
|
||||||
|
if not isinstance(ngram_size, int) or ngram_size <= 0:
|
||||||
|
raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}")
|
||||||
|
if not isinstance(window_size, int) or window_size <= 0:
|
||||||
|
raise ValueError(f"`window_size` has to be a strictly positive integer, but is {window_size}")
|
||||||
|
self.ngram_size = ngram_size
|
||||||
|
self.window_size = window_size
|
||||||
|
self.whitelist_token_ids = whitelist_token_ids or set()
|
||||||
|
|
||||||
|
def __call__(self, input_ids: List[int], scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
|
if len(input_ids) < self.ngram_size:
|
||||||
|
return scores
|
||||||
|
|
||||||
|
current_prefix = tuple(input_ids[-(self.ngram_size - 1):])
|
||||||
|
|
||||||
|
search_start = max(0, len(input_ids) - self.window_size)
|
||||||
|
search_end = len(input_ids) - self.ngram_size + 1
|
||||||
|
|
||||||
|
banned_tokens = set()
|
||||||
|
for i in range(search_start, search_end):
|
||||||
|
ngram = tuple(input_ids[i:i + self.ngram_size])
|
||||||
|
if ngram[:-1] == current_prefix:
|
||||||
|
banned_tokens.add(ngram[-1])
|
||||||
|
|
||||||
|
banned_tokens = banned_tokens - self.whitelist_token_ids
|
||||||
|
|
||||||
|
if banned_tokens:
|
||||||
|
scores = scores.clone()
|
||||||
|
for token in banned_tokens:
|
||||||
|
scores[token] = -float("inf")
|
||||||
|
|
||||||
|
return scores
|
||||||
161
DeepSeek-OCR-master/DeepSeek-OCR-vllm/run_dpsk_ocr_eval_batch.py
Normal file
@ -0,0 +1,161 @@
|
|||||||
|
import os
|
||||||
|
import re
|
||||||
|
from tqdm import tqdm
|
||||||
|
import torch
|
||||||
|
if torch.version.cuda == '11.8':
|
||||||
|
os.environ["TRITON_PTXAS_PATH"] = "/usr/local/cuda-11.8/bin/ptxas"
|
||||||
|
os.environ['VLLM_USE_V1'] = '0'
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
|
||||||
|
|
||||||
|
from config import MODEL_PATH, INPUT_PATH, OUTPUT_PATH, PROMPT, MAX_CONCURRENCY, CROP_MODE, NUM_WORKERS
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
import glob
|
||||||
|
from PIL import Image
|
||||||
|
from deepseek_ocr import DeepseekOCRForCausalLM
|
||||||
|
|
||||||
|
from vllm.model_executor.models.registry import ModelRegistry
|
||||||
|
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
from process.ngram_norepeat import NoRepeatNGramLogitsProcessor
|
||||||
|
from process.image_process import DeepseekOCRProcessor
|
||||||
|
ModelRegistry.register_model("DeepseekOCRForCausalLM", DeepseekOCRForCausalLM)
|
||||||
|
|
||||||
|
|
||||||
|
llm = LLM(
|
||||||
|
model=MODEL_PATH,
|
||||||
|
hf_overrides={"architectures": ["DeepseekOCRForCausalLM"]},
|
||||||
|
block_size=256,
|
||||||
|
enforce_eager=False,
|
||||||
|
trust_remote_code=True,
|
||||||
|
max_model_len=8192,
|
||||||
|
swap_space=0,
|
||||||
|
max_num_seqs = MAX_CONCURRENCY,
|
||||||
|
tensor_parallel_size=1,
|
||||||
|
gpu_memory_utilization=0.9,
|
||||||
|
)
|
||||||
|
|
||||||
|
logits_processors = [NoRepeatNGramLogitsProcessor(ngram_size=40, window_size=90, whitelist_token_ids= {128821, 128822})] #window for fast;whitelist_token_ids: <td>,</td>
|
||||||
|
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
temperature=0.0,
|
||||||
|
max_tokens=8192,
|
||||||
|
logits_processors=logits_processors,
|
||||||
|
skip_special_tokens=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
class Colors:
|
||||||
|
RED = '\033[31m'
|
||||||
|
GREEN = '\033[32m'
|
||||||
|
YELLOW = '\033[33m'
|
||||||
|
BLUE = '\033[34m'
|
||||||
|
RESET = '\033[0m'
|
||||||
|
|
||||||
|
def clean_formula(text):
|
||||||
|
|
||||||
|
formula_pattern = r'\\\[(.*?)\\\]'
|
||||||
|
|
||||||
|
def process_formula(match):
|
||||||
|
formula = match.group(1)
|
||||||
|
|
||||||
|
formula = re.sub(r'\\quad\s*\([^)]*\)', '', formula)
|
||||||
|
|
||||||
|
formula = formula.strip()
|
||||||
|
|
||||||
|
return r'\[' + formula + r'\]'
|
||||||
|
|
||||||
|
cleaned_text = re.sub(formula_pattern, process_formula, text)
|
||||||
|
|
||||||
|
return cleaned_text
|
||||||
|
|
||||||
|
def re_match(text):
|
||||||
|
pattern = r'(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)'
|
||||||
|
matches = re.findall(pattern, text, re.DOTALL)
|
||||||
|
|
||||||
|
|
||||||
|
# mathes_image = []
|
||||||
|
mathes_other = []
|
||||||
|
for a_match in matches:
|
||||||
|
mathes_other.append(a_match[0])
|
||||||
|
return matches, mathes_other
|
||||||
|
|
||||||
|
def process_single_image(image):
|
||||||
|
"""single image"""
|
||||||
|
prompt_in = prompt
|
||||||
|
cache_item = {
|
||||||
|
"prompt": prompt_in,
|
||||||
|
"multi_modal_data": {"image": DeepseekOCRProcessor().tokenize_with_images(images = [image], bos=True, eos=True, cropping=CROP_MODE)},
|
||||||
|
}
|
||||||
|
return cache_item
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
# INPUT_PATH = OmniDocBench images path
|
||||||
|
|
||||||
|
os.makedirs(OUTPUT_PATH, exist_ok=True)
|
||||||
|
|
||||||
|
# print('image processing until processing prompts.....')
|
||||||
|
|
||||||
|
print(f'{Colors.RED}glob images.....{Colors.RESET}')
|
||||||
|
|
||||||
|
images_path = glob.glob(f'{INPUT_PATH}/*')
|
||||||
|
|
||||||
|
images = []
|
||||||
|
|
||||||
|
for image_path in images_path:
|
||||||
|
image = Image.open(image_path).convert('RGB')
|
||||||
|
images.append(image)
|
||||||
|
|
||||||
|
prompt = PROMPT
|
||||||
|
|
||||||
|
# batch_inputs = []
|
||||||
|
|
||||||
|
|
||||||
|
# for image in tqdm(images):
|
||||||
|
|
||||||
|
# prompt_in = prompt
|
||||||
|
# cache_list = [
|
||||||
|
# {
|
||||||
|
# "prompt": prompt_in,
|
||||||
|
# "multi_modal_data": {"image": Image.open(image).convert('RGB')},
|
||||||
|
# }
|
||||||
|
# ]
|
||||||
|
# batch_inputs.extend(cache_list)
|
||||||
|
|
||||||
|
with ThreadPoolExecutor(max_workers=NUM_WORKERS) as executor:
|
||||||
|
batch_inputs = list(tqdm(
|
||||||
|
executor.map(process_single_image, images),
|
||||||
|
total=len(images),
|
||||||
|
desc="Pre-processed images"
|
||||||
|
))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
outputs_list = llm.generate(
|
||||||
|
batch_inputs,
|
||||||
|
sampling_params=sampling_params
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
output_path = OUTPUT_PATH
|
||||||
|
|
||||||
|
os.makedirs(output_path, exist_ok=True)
|
||||||
|
|
||||||
|
for output, image in zip(outputs_list, images_path):
|
||||||
|
|
||||||
|
content = output.outputs[0].text
|
||||||
|
mmd_det_path = output_path + image.split('/')[-1].replace('.jpg', '_det.md')
|
||||||
|
|
||||||
|
with open(mmd_det_path, 'w', encoding='utf-8') as afile:
|
||||||
|
afile.write(content)
|
||||||
|
|
||||||
|
content = clean_formula(content)
|
||||||
|
matches_ref, mathes_other = re_match(content)
|
||||||
|
for idx, a_match_other in enumerate(tqdm(mathes_other, desc="other")):
|
||||||
|
content = content.replace(a_match_other, '').replace('\n\n\n\n', '\n\n').replace('\n\n\n', '\n\n').replace('<center>', '').replace('</center>', '')
|
||||||
|
|
||||||
|
mmd_path = output_path + image.split('/')[-1].replace('.jpg', '.md')
|
||||||
|
|
||||||
|
with open(mmd_path, 'w', encoding='utf-8') as afile:
|
||||||
|
afile.write(content)
|
||||||
303
DeepSeek-OCR-master/DeepSeek-OCR-vllm/run_dpsk_ocr_image.py
Normal file
@ -0,0 +1,303 @@
|
|||||||
|
import asyncio
|
||||||
|
import re
|
||||||
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
if torch.version.cuda == '11.8':
|
||||||
|
os.environ["TRITON_PTXAS_PATH"] = "/usr/local/cuda-11.8/bin/ptxas"
|
||||||
|
|
||||||
|
os.environ['VLLM_USE_V1'] = '0'
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
|
||||||
|
|
||||||
|
from vllm import AsyncLLMEngine, SamplingParams
|
||||||
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
|
from vllm.model_executor.models.registry import ModelRegistry
|
||||||
|
import time
|
||||||
|
from deepseek_ocr import DeepseekOCRForCausalLM
|
||||||
|
from PIL import Image, ImageDraw, ImageFont, ImageOps
|
||||||
|
import numpy as np
|
||||||
|
from tqdm import tqdm
|
||||||
|
from process.ngram_norepeat import NoRepeatNGramLogitsProcessor
|
||||||
|
from process.image_process import DeepseekOCRProcessor
|
||||||
|
from config import MODEL_PATH, INPUT_PATH, OUTPUT_PATH, PROMPT, CROP_MODE
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
ModelRegistry.register_model("DeepseekOCRForCausalLM", DeepseekOCRForCausalLM)
|
||||||
|
|
||||||
|
def load_image(image_path):
|
||||||
|
|
||||||
|
try:
|
||||||
|
image = Image.open(image_path)
|
||||||
|
|
||||||
|
corrected_image = ImageOps.exif_transpose(image)
|
||||||
|
|
||||||
|
return corrected_image
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"error: {e}")
|
||||||
|
try:
|
||||||
|
return Image.open(image_path)
|
||||||
|
except:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def re_match(text):
|
||||||
|
pattern = r'(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)'
|
||||||
|
matches = re.findall(pattern, text, re.DOTALL)
|
||||||
|
|
||||||
|
|
||||||
|
mathes_image = []
|
||||||
|
mathes_other = []
|
||||||
|
for a_match in matches:
|
||||||
|
if '<|ref|>image<|/ref|>' in a_match[0]:
|
||||||
|
mathes_image.append(a_match[0])
|
||||||
|
else:
|
||||||
|
mathes_other.append(a_match[0])
|
||||||
|
return matches, mathes_image, mathes_other
|
||||||
|
|
||||||
|
|
||||||
|
def extract_coordinates_and_label(ref_text, image_width, image_height):
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
label_type = ref_text[1]
|
||||||
|
cor_list = eval(ref_text[2])
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
return None
|
||||||
|
|
||||||
|
return (label_type, cor_list)
|
||||||
|
|
||||||
|
|
||||||
|
def draw_bounding_boxes(image, refs):
|
||||||
|
|
||||||
|
image_width, image_height = image.size
|
||||||
|
img_draw = image.copy()
|
||||||
|
draw = ImageDraw.Draw(img_draw)
|
||||||
|
|
||||||
|
overlay = Image.new('RGBA', img_draw.size, (0, 0, 0, 0))
|
||||||
|
draw2 = ImageDraw.Draw(overlay)
|
||||||
|
|
||||||
|
# except IOError:
|
||||||
|
font = ImageFont.load_default()
|
||||||
|
|
||||||
|
img_idx = 0
|
||||||
|
|
||||||
|
for i, ref in enumerate(refs):
|
||||||
|
try:
|
||||||
|
result = extract_coordinates_and_label(ref, image_width, image_height)
|
||||||
|
if result:
|
||||||
|
label_type, points_list = result
|
||||||
|
|
||||||
|
color = (np.random.randint(0, 200), np.random.randint(0, 200), np.random.randint(0, 255))
|
||||||
|
|
||||||
|
color_a = color + (20, )
|
||||||
|
for points in points_list:
|
||||||
|
x1, y1, x2, y2 = points
|
||||||
|
|
||||||
|
x1 = int(x1 / 999 * image_width)
|
||||||
|
y1 = int(y1 / 999 * image_height)
|
||||||
|
|
||||||
|
x2 = int(x2 / 999 * image_width)
|
||||||
|
y2 = int(y2 / 999 * image_height)
|
||||||
|
|
||||||
|
if label_type == 'image':
|
||||||
|
try:
|
||||||
|
cropped = image.crop((x1, y1, x2, y2))
|
||||||
|
cropped.save(f"{OUTPUT_PATH}/images/{img_idx}.jpg")
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
pass
|
||||||
|
img_idx += 1
|
||||||
|
|
||||||
|
try:
|
||||||
|
if label_type == 'title':
|
||||||
|
draw.rectangle([x1, y1, x2, y2], outline=color, width=4)
|
||||||
|
draw2.rectangle([x1, y1, x2, y2], fill=color_a, outline=(0, 0, 0, 0), width=1)
|
||||||
|
else:
|
||||||
|
draw.rectangle([x1, y1, x2, y2], outline=color, width=2)
|
||||||
|
draw2.rectangle([x1, y1, x2, y2], fill=color_a, outline=(0, 0, 0, 0), width=1)
|
||||||
|
|
||||||
|
text_x = x1
|
||||||
|
text_y = max(0, y1 - 15)
|
||||||
|
|
||||||
|
text_bbox = draw.textbbox((0, 0), label_type, font=font)
|
||||||
|
text_width = text_bbox[2] - text_bbox[0]
|
||||||
|
text_height = text_bbox[3] - text_bbox[1]
|
||||||
|
draw.rectangle([text_x, text_y, text_x + text_width, text_y + text_height],
|
||||||
|
fill=(255, 255, 255, 30))
|
||||||
|
|
||||||
|
draw.text((text_x, text_y), label_type, font=font, fill=color)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
except:
|
||||||
|
continue
|
||||||
|
img_draw.paste(overlay, (0, 0), overlay)
|
||||||
|
return img_draw
|
||||||
|
|
||||||
|
|
||||||
|
def process_image_with_refs(image, ref_texts):
|
||||||
|
result_image = draw_bounding_boxes(image, ref_texts)
|
||||||
|
return result_image
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
async def stream_generate(image=None, prompt=''):
|
||||||
|
|
||||||
|
|
||||||
|
engine_args = AsyncEngineArgs(
|
||||||
|
model=MODEL_PATH,
|
||||||
|
hf_overrides={"architectures": ["DeepseekOCRForCausalLM"]},
|
||||||
|
block_size=256,
|
||||||
|
max_model_len=8192,
|
||||||
|
enforce_eager=False,
|
||||||
|
trust_remote_code=True,
|
||||||
|
tensor_parallel_size=1,
|
||||||
|
gpu_memory_utilization=0.75,
|
||||||
|
)
|
||||||
|
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||||
|
|
||||||
|
logits_processors = [NoRepeatNGramLogitsProcessor(ngram_size=30, window_size=90, whitelist_token_ids= {128821, 128822})] #whitelist: <td>, </td>
|
||||||
|
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
temperature=0.0,
|
||||||
|
max_tokens=8192,
|
||||||
|
logits_processors=logits_processors,
|
||||||
|
skip_special_tokens=False,
|
||||||
|
# ignore_eos=False,
|
||||||
|
|
||||||
|
)
|
||||||
|
|
||||||
|
request_id = f"request-{int(time.time())}"
|
||||||
|
|
||||||
|
printed_length = 0
|
||||||
|
|
||||||
|
if image and '<image>' in prompt:
|
||||||
|
request = {
|
||||||
|
"prompt": prompt,
|
||||||
|
"multi_modal_data": {"image": image}
|
||||||
|
}
|
||||||
|
elif prompt:
|
||||||
|
request = {
|
||||||
|
"prompt": prompt
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
assert False, f'prompt is none!!!'
|
||||||
|
async for request_output in engine.generate(
|
||||||
|
request, sampling_params, request_id
|
||||||
|
):
|
||||||
|
if request_output.outputs:
|
||||||
|
full_text = request_output.outputs[0].text
|
||||||
|
new_text = full_text[printed_length:]
|
||||||
|
print(new_text, end='', flush=True)
|
||||||
|
printed_length = len(full_text)
|
||||||
|
final_output = full_text
|
||||||
|
print('\n')
|
||||||
|
|
||||||
|
return final_output
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
os.makedirs(OUTPUT_PATH, exist_ok=True)
|
||||||
|
os.makedirs(f'{OUTPUT_PATH}/images', exist_ok=True)
|
||||||
|
|
||||||
|
image = load_image(INPUT_PATH).convert('RGB')
|
||||||
|
|
||||||
|
|
||||||
|
if '<image>' in PROMPT:
|
||||||
|
|
||||||
|
image_features = DeepseekOCRProcessor().tokenize_with_images(images = [image], bos=True, eos=True, cropping=CROP_MODE)
|
||||||
|
else:
|
||||||
|
image_features = ''
|
||||||
|
|
||||||
|
prompt = PROMPT
|
||||||
|
|
||||||
|
result_out = asyncio.run(stream_generate(image_features, prompt))
|
||||||
|
|
||||||
|
|
||||||
|
save_results = 1
|
||||||
|
|
||||||
|
if save_results and '<image>' in prompt:
|
||||||
|
print('='*15 + 'save results:' + '='*15)
|
||||||
|
|
||||||
|
image_draw = image.copy()
|
||||||
|
|
||||||
|
outputs = result_out
|
||||||
|
|
||||||
|
with open(f'{OUTPUT_PATH}/result_ori.mmd', 'w', encoding = 'utf-8') as afile:
|
||||||
|
afile.write(outputs)
|
||||||
|
|
||||||
|
matches_ref, matches_images, mathes_other = re_match(outputs)
|
||||||
|
# print(matches_ref)
|
||||||
|
result = process_image_with_refs(image_draw, matches_ref)
|
||||||
|
|
||||||
|
|
||||||
|
for idx, a_match_image in enumerate(tqdm(matches_images, desc="image")):
|
||||||
|
outputs = outputs.replace(a_match_image, f' + '.jpg)\n')
|
||||||
|
|
||||||
|
for idx, a_match_other in enumerate(tqdm(mathes_other, desc="other")):
|
||||||
|
outputs = outputs.replace(a_match_other, '').replace('\\coloneqq', ':=').replace('\\eqqcolon', '=:')
|
||||||
|
|
||||||
|
# if 'structural formula' in conversation[0]['content']:
|
||||||
|
# outputs = '<smiles>' + outputs + '</smiles>'
|
||||||
|
with open(f'{OUTPUT_PATH}/result.mmd', 'w', encoding = 'utf-8') as afile:
|
||||||
|
afile.write(outputs)
|
||||||
|
|
||||||
|
if 'line_type' in outputs:
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from matplotlib.patches import Circle
|
||||||
|
lines = eval(outputs)['Line']['line']
|
||||||
|
|
||||||
|
line_type = eval(outputs)['Line']['line_type']
|
||||||
|
# print(lines)
|
||||||
|
|
||||||
|
endpoints = eval(outputs)['Line']['line_endpoint']
|
||||||
|
|
||||||
|
fig, ax = plt.subplots(figsize=(3,3), dpi=200)
|
||||||
|
ax.set_xlim(-15, 15)
|
||||||
|
ax.set_ylim(-15, 15)
|
||||||
|
|
||||||
|
for idx, line in enumerate(lines):
|
||||||
|
try:
|
||||||
|
p0 = eval(line.split(' -- ')[0])
|
||||||
|
p1 = eval(line.split(' -- ')[-1])
|
||||||
|
|
||||||
|
if line_type[idx] == '--':
|
||||||
|
ax.plot([p0[0], p1[0]], [p0[1], p1[1]], linewidth=0.8, color='k')
|
||||||
|
else:
|
||||||
|
ax.plot([p0[0], p1[0]], [p0[1], p1[1]], linewidth = 0.8, color = 'k')
|
||||||
|
|
||||||
|
ax.scatter(p0[0], p0[1], s=5, color = 'k')
|
||||||
|
ax.scatter(p1[0], p1[1], s=5, color = 'k')
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
for endpoint in endpoints:
|
||||||
|
|
||||||
|
label = endpoint.split(': ')[0]
|
||||||
|
(x, y) = eval(endpoint.split(': ')[1])
|
||||||
|
ax.annotate(label, (x, y), xytext=(1, 1), textcoords='offset points',
|
||||||
|
fontsize=5, fontweight='light')
|
||||||
|
|
||||||
|
try:
|
||||||
|
if 'Circle' in eval(outputs).keys():
|
||||||
|
circle_centers = eval(outputs)['Circle']['circle_center']
|
||||||
|
radius = eval(outputs)['Circle']['radius']
|
||||||
|
|
||||||
|
for center, r in zip(circle_centers, radius):
|
||||||
|
center = eval(center.split(': ')[1])
|
||||||
|
circle = Circle(center, radius=r, fill=False, edgecolor='black', linewidth=0.8)
|
||||||
|
ax.add_patch(circle)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
plt.savefig(f'{OUTPUT_PATH}/geo.jpg')
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
result.save(f'{OUTPUT_PATH}/result_with_boxes.jpg')
|
||||||
330
DeepSeek-OCR-master/DeepSeek-OCR-vllm/run_dpsk_ocr_pdf.py
Normal file
@ -0,0 +1,330 @@
|
|||||||
|
import os
|
||||||
|
import fitz
|
||||||
|
import img2pdf
|
||||||
|
import io
|
||||||
|
import re
|
||||||
|
from tqdm import tqdm
|
||||||
|
import torch
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
|
|
||||||
|
if torch.version.cuda == '11.8':
|
||||||
|
os.environ["TRITON_PTXAS_PATH"] = "/usr/local/cuda-11.8/bin/ptxas"
|
||||||
|
os.environ['VLLM_USE_V1'] = '0'
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
|
||||||
|
|
||||||
|
|
||||||
|
from config import MODEL_PATH, INPUT_PATH, OUTPUT_PATH, PROMPT, SKIP_REPEAT, MAX_CONCURRENCY, NUM_WORKERS, CROP_MODE
|
||||||
|
|
||||||
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
|
import numpy as np
|
||||||
|
from deepseek_ocr import DeepseekOCRForCausalLM
|
||||||
|
|
||||||
|
from vllm.model_executor.models.registry import ModelRegistry
|
||||||
|
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
from process.ngram_norepeat import NoRepeatNGramLogitsProcessor
|
||||||
|
from process.image_process import DeepseekOCRProcessor
|
||||||
|
|
||||||
|
ModelRegistry.register_model("DeepseekOCRForCausalLM", DeepseekOCRForCausalLM)
|
||||||
|
|
||||||
|
|
||||||
|
llm = LLM(
|
||||||
|
model=MODEL_PATH,
|
||||||
|
hf_overrides={"architectures": ["DeepseekOCRForCausalLM"]},
|
||||||
|
block_size=256,
|
||||||
|
enforce_eager=False,
|
||||||
|
trust_remote_code=True,
|
||||||
|
max_model_len=8192,
|
||||||
|
swap_space=0,
|
||||||
|
max_num_seqs=MAX_CONCURRENCY,
|
||||||
|
tensor_parallel_size=1,
|
||||||
|
gpu_memory_utilization=0.9,
|
||||||
|
disable_mm_preprocessor_cache=True
|
||||||
|
)
|
||||||
|
|
||||||
|
logits_processors = [NoRepeatNGramLogitsProcessor(ngram_size=20, window_size=50, whitelist_token_ids= {128821, 128822})] #window for fast;whitelist_token_ids: <td>,</td>
|
||||||
|
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
temperature=0.0,
|
||||||
|
max_tokens=8192,
|
||||||
|
logits_processors=logits_processors,
|
||||||
|
skip_special_tokens=False,
|
||||||
|
include_stop_str_in_output=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Colors:
|
||||||
|
RED = '\033[31m'
|
||||||
|
GREEN = '\033[32m'
|
||||||
|
YELLOW = '\033[33m'
|
||||||
|
BLUE = '\033[34m'
|
||||||
|
RESET = '\033[0m'
|
||||||
|
|
||||||
|
def pdf_to_images_high_quality(pdf_path, dpi=144, image_format="PNG"):
|
||||||
|
"""
|
||||||
|
pdf2images
|
||||||
|
"""
|
||||||
|
images = []
|
||||||
|
|
||||||
|
pdf_document = fitz.open(pdf_path)
|
||||||
|
|
||||||
|
zoom = dpi / 72.0
|
||||||
|
matrix = fitz.Matrix(zoom, zoom)
|
||||||
|
|
||||||
|
for page_num in range(pdf_document.page_count):
|
||||||
|
page = pdf_document[page_num]
|
||||||
|
|
||||||
|
pixmap = page.get_pixmap(matrix=matrix, alpha=False)
|
||||||
|
Image.MAX_IMAGE_PIXELS = None
|
||||||
|
|
||||||
|
if image_format.upper() == "PNG":
|
||||||
|
img_data = pixmap.tobytes("png")
|
||||||
|
img = Image.open(io.BytesIO(img_data))
|
||||||
|
else:
|
||||||
|
img_data = pixmap.tobytes("png")
|
||||||
|
img = Image.open(io.BytesIO(img_data))
|
||||||
|
if img.mode in ('RGBA', 'LA'):
|
||||||
|
background = Image.new('RGB', img.size, (255, 255, 255))
|
||||||
|
background.paste(img, mask=img.split()[-1] if img.mode == 'RGBA' else None)
|
||||||
|
img = background
|
||||||
|
|
||||||
|
images.append(img)
|
||||||
|
|
||||||
|
pdf_document.close()
|
||||||
|
return images
|
||||||
|
|
||||||
|
def pil_to_pdf_img2pdf(pil_images, output_path):
|
||||||
|
|
||||||
|
if not pil_images:
|
||||||
|
return
|
||||||
|
|
||||||
|
image_bytes_list = []
|
||||||
|
|
||||||
|
for img in pil_images:
|
||||||
|
if img.mode != 'RGB':
|
||||||
|
img = img.convert('RGB')
|
||||||
|
|
||||||
|
img_buffer = io.BytesIO()
|
||||||
|
img.save(img_buffer, format='JPEG', quality=95)
|
||||||
|
img_bytes = img_buffer.getvalue()
|
||||||
|
image_bytes_list.append(img_bytes)
|
||||||
|
|
||||||
|
try:
|
||||||
|
pdf_bytes = img2pdf.convert(image_bytes_list)
|
||||||
|
with open(output_path, "wb") as f:
|
||||||
|
f.write(pdf_bytes)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"error: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def re_match(text):
|
||||||
|
pattern = r'(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)'
|
||||||
|
matches = re.findall(pattern, text, re.DOTALL)
|
||||||
|
|
||||||
|
|
||||||
|
mathes_image = []
|
||||||
|
mathes_other = []
|
||||||
|
for a_match in matches:
|
||||||
|
if '<|ref|>image<|/ref|>' in a_match[0]:
|
||||||
|
mathes_image.append(a_match[0])
|
||||||
|
else:
|
||||||
|
mathes_other.append(a_match[0])
|
||||||
|
return matches, mathes_image, mathes_other
|
||||||
|
|
||||||
|
|
||||||
|
def extract_coordinates_and_label(ref_text, image_width, image_height):
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
label_type = ref_text[1]
|
||||||
|
cor_list = eval(ref_text[2])
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
return None
|
||||||
|
|
||||||
|
return (label_type, cor_list)
|
||||||
|
|
||||||
|
|
||||||
|
def draw_bounding_boxes(image, refs, jdx):
|
||||||
|
|
||||||
|
image_width, image_height = image.size
|
||||||
|
img_draw = image.copy()
|
||||||
|
draw = ImageDraw.Draw(img_draw)
|
||||||
|
|
||||||
|
overlay = Image.new('RGBA', img_draw.size, (0, 0, 0, 0))
|
||||||
|
draw2 = ImageDraw.Draw(overlay)
|
||||||
|
|
||||||
|
# except IOError:
|
||||||
|
font = ImageFont.load_default()
|
||||||
|
|
||||||
|
img_idx = 0
|
||||||
|
|
||||||
|
for i, ref in enumerate(refs):
|
||||||
|
try:
|
||||||
|
result = extract_coordinates_and_label(ref, image_width, image_height)
|
||||||
|
if result:
|
||||||
|
label_type, points_list = result
|
||||||
|
|
||||||
|
color = (np.random.randint(0, 200), np.random.randint(0, 200), np.random.randint(0, 255))
|
||||||
|
|
||||||
|
color_a = color + (20, )
|
||||||
|
for points in points_list:
|
||||||
|
x1, y1, x2, y2 = points
|
||||||
|
|
||||||
|
x1 = int(x1 / 999 * image_width)
|
||||||
|
y1 = int(y1 / 999 * image_height)
|
||||||
|
|
||||||
|
x2 = int(x2 / 999 * image_width)
|
||||||
|
y2 = int(y2 / 999 * image_height)
|
||||||
|
|
||||||
|
if label_type == 'image':
|
||||||
|
try:
|
||||||
|
cropped = image.crop((x1, y1, x2, y2))
|
||||||
|
cropped.save(f"{OUTPUT_PATH}/images/{jdx}_{img_idx}.jpg")
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
pass
|
||||||
|
img_idx += 1
|
||||||
|
|
||||||
|
try:
|
||||||
|
if label_type == 'title':
|
||||||
|
draw.rectangle([x1, y1, x2, y2], outline=color, width=4)
|
||||||
|
draw2.rectangle([x1, y1, x2, y2], fill=color_a, outline=(0, 0, 0, 0), width=1)
|
||||||
|
else:
|
||||||
|
draw.rectangle([x1, y1, x2, y2], outline=color, width=2)
|
||||||
|
draw2.rectangle([x1, y1, x2, y2], fill=color_a, outline=(0, 0, 0, 0), width=1)
|
||||||
|
|
||||||
|
text_x = x1
|
||||||
|
text_y = max(0, y1 - 15)
|
||||||
|
|
||||||
|
text_bbox = draw.textbbox((0, 0), label_type, font=font)
|
||||||
|
text_width = text_bbox[2] - text_bbox[0]
|
||||||
|
text_height = text_bbox[3] - text_bbox[1]
|
||||||
|
draw.rectangle([text_x, text_y, text_x + text_width, text_y + text_height],
|
||||||
|
fill=(255, 255, 255, 30))
|
||||||
|
|
||||||
|
draw.text((text_x, text_y), label_type, font=font, fill=color)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
except:
|
||||||
|
continue
|
||||||
|
img_draw.paste(overlay, (0, 0), overlay)
|
||||||
|
return img_draw
|
||||||
|
|
||||||
|
|
||||||
|
def process_image_with_refs(image, ref_texts, jdx):
|
||||||
|
result_image = draw_bounding_boxes(image, ref_texts, jdx)
|
||||||
|
return result_image
|
||||||
|
|
||||||
|
|
||||||
|
def process_single_image(image):
|
||||||
|
"""single image"""
|
||||||
|
prompt_in = prompt
|
||||||
|
cache_item = {
|
||||||
|
"prompt": prompt_in,
|
||||||
|
"multi_modal_data": {"image": DeepseekOCRProcessor().tokenize_with_images(images = [image], bos=True, eos=True, cropping=CROP_MODE)},
|
||||||
|
}
|
||||||
|
return cache_item
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
os.makedirs(OUTPUT_PATH, exist_ok=True)
|
||||||
|
os.makedirs(f'{OUTPUT_PATH}/images', exist_ok=True)
|
||||||
|
|
||||||
|
print(f'{Colors.RED}PDF loading .....{Colors.RESET}')
|
||||||
|
|
||||||
|
|
||||||
|
images = pdf_to_images_high_quality(INPUT_PATH)
|
||||||
|
|
||||||
|
|
||||||
|
prompt = PROMPT
|
||||||
|
|
||||||
|
# batch_inputs = []
|
||||||
|
|
||||||
|
with ThreadPoolExecutor(max_workers=NUM_WORKERS) as executor:
|
||||||
|
batch_inputs = list(tqdm(
|
||||||
|
executor.map(process_single_image, images),
|
||||||
|
total=len(images),
|
||||||
|
desc="Pre-processed images"
|
||||||
|
))
|
||||||
|
|
||||||
|
|
||||||
|
# for image in tqdm(images):
|
||||||
|
|
||||||
|
# prompt_in = prompt
|
||||||
|
# cache_list = [
|
||||||
|
# {
|
||||||
|
# "prompt": prompt_in,
|
||||||
|
# "multi_modal_data": {"image": DeepseekOCRProcessor().tokenize_with_images(images = [image], bos=True, eos=True, cropping=CROP_MODE)},
|
||||||
|
# }
|
||||||
|
# ]
|
||||||
|
# batch_inputs.extend(cache_list)
|
||||||
|
|
||||||
|
|
||||||
|
outputs_list = llm.generate(
|
||||||
|
batch_inputs,
|
||||||
|
sampling_params=sampling_params
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
output_path = OUTPUT_PATH
|
||||||
|
|
||||||
|
os.makedirs(output_path, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
mmd_det_path = output_path + '/' + INPUT_PATH.split('/')[-1].replace('.pdf', '_det.mmd')
|
||||||
|
mmd_path = output_path + '/' + INPUT_PATH.split('/')[-1].replace('pdf', 'mmd')
|
||||||
|
pdf_out_path = output_path + '/' + INPUT_PATH.split('/')[-1].replace('.pdf', '_layouts.pdf')
|
||||||
|
contents_det = ''
|
||||||
|
contents = ''
|
||||||
|
draw_images = []
|
||||||
|
jdx = 0
|
||||||
|
for output, img in zip(outputs_list, images):
|
||||||
|
content = output.outputs[0].text
|
||||||
|
|
||||||
|
if '<|end▁of▁sentence|>' in content: # repeat no eos
|
||||||
|
content = content.replace('<|end▁of▁sentence|>', '')
|
||||||
|
else:
|
||||||
|
if SKIP_REPEAT:
|
||||||
|
continue
|
||||||
|
|
||||||
|
|
||||||
|
page_num = f'\n<--- Page Split --->'
|
||||||
|
|
||||||
|
contents_det += content + f'\n{page_num}\n'
|
||||||
|
|
||||||
|
image_draw = img.copy()
|
||||||
|
|
||||||
|
matches_ref, matches_images, mathes_other = re_match(content)
|
||||||
|
# print(matches_ref)
|
||||||
|
result_image = process_image_with_refs(image_draw, matches_ref, jdx)
|
||||||
|
|
||||||
|
|
||||||
|
draw_images.append(result_image)
|
||||||
|
|
||||||
|
|
||||||
|
for idx, a_match_image in enumerate(matches_images):
|
||||||
|
content = content.replace(a_match_image, f' + '_' + str(idx) + '.jpg)\n')
|
||||||
|
|
||||||
|
for idx, a_match_other in enumerate(mathes_other):
|
||||||
|
content = content.replace(a_match_other, '').replace('\\coloneqq', ':=').replace('\\eqqcolon', '=:').replace('\n\n\n\n', '\n\n').replace('\n\n\n', '\n\n')
|
||||||
|
|
||||||
|
|
||||||
|
contents += content + f'\n{page_num}\n'
|
||||||
|
|
||||||
|
|
||||||
|
jdx += 1
|
||||||
|
|
||||||
|
with open(mmd_det_path, 'w', encoding='utf-8') as afile:
|
||||||
|
afile.write(contents_det)
|
||||||
|
|
||||||
|
with open(mmd_path, 'w', encoding='utf-8') as afile:
|
||||||
|
afile.write(contents)
|
||||||
|
|
||||||
|
|
||||||
|
pil_to_pdf_img2pdf(draw_images, pdf_out_path)
|
||||||
|
|
||||||
BIN
DeepSeek_OCR_paper.pdf
Normal file
21
LICENSE
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2025 DeepSeek
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
||||||
181
README.md
Normal file
@ -0,0 +1,181 @@
|
|||||||
|
<!-- markdownlint-disable first-line-h1 -->
|
||||||
|
<!-- markdownlint-disable html -->
|
||||||
|
<!-- markdownlint-disable no-duplicate-header -->
|
||||||
|
|
||||||
|
|
||||||
|
<div align="center">
|
||||||
|
<img src="assets/logo.svg" width="60%" alt="DeepSeek AI" />
|
||||||
|
</div>
|
||||||
|
|
||||||
|
|
||||||
|
<hr>
|
||||||
|
<div align="center">
|
||||||
|
<a href="https://www.deepseek.com/" target="_blank">
|
||||||
|
<img alt="Homepage" src="assets/badge.svg" />
|
||||||
|
</a>
|
||||||
|
<a href="https://huggingface.co/deepseek-ai/DeepSeek-OCR" target="_blank">
|
||||||
|
<img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-DeepSeek%20AI-ffc107?color=ffc107&logoColor=white" />
|
||||||
|
</a>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div align="center">
|
||||||
|
|
||||||
|
<a href="https://discord.gg/Tc7c45Zzu5" target="_blank">
|
||||||
|
<img alt="Discord" src="https://img.shields.io/badge/Discord-DeepSeek%20AI-7289da?logo=discord&logoColor=white&color=7289da" />
|
||||||
|
</a>
|
||||||
|
<a href="https://twitter.com/deepseek_ai" target="_blank">
|
||||||
|
<img alt="Twitter Follow" src="https://img.shields.io/badge/Twitter-deepseek_ai-white?logo=x&logoColor=white" />
|
||||||
|
</a>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
<p align="center">
|
||||||
|
<a href="https://huggingface.co/deepseek-ai/DeepSeek-OCR"><b>📥 Model Download</b></a> |
|
||||||
|
<a href="https://github.com/deepseek-ai/DeepSeek-OCR/blob/main/DeepSeek_OCR_paper.pdf"><b>📄 Paper Link</b></a> |
|
||||||
|
<a href="./DeepSeek_OCR_paper.pdf"><b>📄 Arxiv Paper Link</b></a> |
|
||||||
|
</p>
|
||||||
|
|
||||||
|
<h2>
|
||||||
|
<p align="center">
|
||||||
|
<a href="">DeepSeek-OCR: Contexts Optical Compression</a>
|
||||||
|
</p>
|
||||||
|
</h2>
|
||||||
|
|
||||||
|
<p align="center">
|
||||||
|
<img src="assets/fig1.png" style="width: 1000px" align=center>
|
||||||
|
</p>
|
||||||
|
<p align="center">
|
||||||
|
<a href="">Explore the boundaries of visual-text compression.</a>
|
||||||
|
</p>
|
||||||
|
|
||||||
|
## Release
|
||||||
|
- [2025/x/x]🚀🚀🚀 We release DeepSeek-OCR, a model to investigate the role of vision encoders from an LLM-centric viewpoint.
|
||||||
|
|
||||||
|
## Contents
|
||||||
|
- [Install](#install)
|
||||||
|
- [vLLM Inference](#vllm-inference)
|
||||||
|
- [Transformers Inference](#transformers-inference)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## Install
|
||||||
|
>Our environment is cuda11.8+torch2.6.0.
|
||||||
|
1. Clone this repository and navigate to the DeepSeek-OCR folder
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/deepseek-ai/DeepSeek-OCR.git
|
||||||
|
```
|
||||||
|
2. Conda
|
||||||
|
```Shell
|
||||||
|
conda create -n deepseek-ocr python=3.12.9 -y
|
||||||
|
conda activate deepseek-ocr
|
||||||
|
```
|
||||||
|
3. Packages
|
||||||
|
|
||||||
|
- download the vllm-0.8.5 [whl](https://github.com/vllm-project/vllm/releases/tag/v0.8.5)
|
||||||
|
```Shell
|
||||||
|
pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu118
|
||||||
|
pip install vllm-0.8.5+cu118-cp38-abi3-manylinux1_x86_64.whl
|
||||||
|
pip install -r requirements.txt
|
||||||
|
pip install flash-attn==2.7.3 --no-build-isolation
|
||||||
|
```
|
||||||
|
**Note:** if you want vLLM and transformers codes to run in the same environment, you don't need to worry about this installation error like: vllm 0.8.5+cu118 requires transformers>=4.51.1
|
||||||
|
|
||||||
|
## vLLM-Inference
|
||||||
|
- VLLM:
|
||||||
|
>**Note:** change the INPUT_PATH/OUTPUT_PATH and other settings in the DeepSeek-OCR-master/DeepSeek-OCR-vllm/config.py
|
||||||
|
```Shell
|
||||||
|
cd DeepSeek-OCR-master/DeepSeek-OCR-vllm
|
||||||
|
```
|
||||||
|
1. image: streaming output
|
||||||
|
```Shell
|
||||||
|
python run_dpsk_ocr_image.py
|
||||||
|
```
|
||||||
|
2. pdf: concurrency ~2500tokens/s(an A100-40G)
|
||||||
|
```Shell
|
||||||
|
python run_dpsk_ocr_pdf.py
|
||||||
|
```
|
||||||
|
3. batch eval for benchmarks
|
||||||
|
```Shell
|
||||||
|
python run_dpsk_ocr_eval_batch.py
|
||||||
|
```
|
||||||
|
## Transformers-Inference
|
||||||
|
- Transformers
|
||||||
|
```python
|
||||||
|
from transformers import AutoModel, AutoTokenizer
|
||||||
|
import torch
|
||||||
|
import os
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
|
||||||
|
model_name = 'deepseek-ai/DeepSeek-OCR'
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
||||||
|
model = AutoModel.from_pretrained(model_name, _attn_implementation='flash_attention_2', trust_remote_code=True, use_safetensors=True)
|
||||||
|
model = model.eval().cuda().to(torch.bfloat16)
|
||||||
|
|
||||||
|
# prompt = "<image>\nFree OCR. "
|
||||||
|
prompt = "<image>\n<|grounding|>Convert the document to markdown. "
|
||||||
|
image_file = 'your_image.jpg'
|
||||||
|
output_path = 'your/output/dir'
|
||||||
|
|
||||||
|
res = model.infer(tokenizer, prompt=prompt, image_file=image_file, output_path = output_path, base_size = 1024, image_size = 640, crop_mode=True, save_results = True, test_compress = True)
|
||||||
|
```
|
||||||
|
or you can
|
||||||
|
```Shell
|
||||||
|
cd DeepSeek-OCR-master/DeepSeek-OCR-hf
|
||||||
|
python run_dpsk_ocr.py
|
||||||
|
```
|
||||||
|
## Support-Modes
|
||||||
|
The current open-source model supports the following modes:
|
||||||
|
- Native resolution:
|
||||||
|
- Tiny: 512×512 (64 vision tokens)✅
|
||||||
|
- Small: 640×640 (100 vision tokens)✅
|
||||||
|
- Base: 1024×1024 (256 vision tokens)✅
|
||||||
|
- Large: 1280×1280 (400 vision tokens)✅
|
||||||
|
- Dynamic resolution
|
||||||
|
- Gundam: n×640×640 + 1×1024×1024 ✅
|
||||||
|
|
||||||
|
## Prompts examples
|
||||||
|
```python
|
||||||
|
# document: <image>\n<|grounding|>Convert the document to markdown.
|
||||||
|
# other image: <image>\n<|grounding|>OCR this image.
|
||||||
|
# without layouts: <image>\nFree OCR.
|
||||||
|
# figures in document: <image>\nParse the figure.
|
||||||
|
# general: <image>\nDescribe this image in detail.
|
||||||
|
# rec: <image>\nLocate <|ref|>xxxx<|/ref|> in the image.
|
||||||
|
# '先天下之忧而忧'
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Visualizations
|
||||||
|
<table>
|
||||||
|
<tr>
|
||||||
|
<td><img src="assets/show1.jpg" style="width: 500px"></td>
|
||||||
|
<td><img src="assets/show2.jpg" style="width: 500px"></td>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<td><img src="assets/show3.jpg" style="width: 500px"></td>
|
||||||
|
<td><img src="assets/show4.jpg" style="width: 500px"></td>
|
||||||
|
</tr>
|
||||||
|
</table>
|
||||||
|
|
||||||
|
|
||||||
|
## Acknowledgement
|
||||||
|
|
||||||
|
We would like to thank [Vary](https://github.com/Ucas-HaoranWei/Vary/), [GOT-OCR2.0](https://github.com/Ucas-HaoranWei/GOT-OCR2.0/), [MinerU](https://github.com/opendatalab/MinerU), [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR), [OneChart](https://github.com/LingyvKong/OneChart), [Slow Perception](https://github.com/Ucas-HaoranWei/Slow-Perception) for their valuable models and ideas.
|
||||||
|
|
||||||
|
We also appreciate the benchmarks: [Fox](https://github.com/ucaslcl/Fox), [OminiDocBench](https://github.com/opendatalab/OmniDocBench).
|
||||||
|
|
||||||
|
## Citation
|
||||||
|
|
||||||
|
coming soon!
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
1
assets/badge.svg
Normal file
|
After Width: | Height: | Size: 6.0 KiB |
BIN
assets/fig1.png
Normal file
|
After Width: | Height: | Size: 387 KiB |
22
assets/logo.svg
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
<svg width="195.000000" height="41.359375" viewBox="0 0 195 41.3594" fill="none" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
|
||||||
|
<desc>
|
||||||
|
Created with Pixso.
|
||||||
|
</desc>
|
||||||
|
<defs>
|
||||||
|
<clipPath id="clip30_2029">
|
||||||
|
<rect id="_图层_1" width="134.577469" height="25.511124" transform="translate(60.422485 10.022217)" fill="white"/>
|
||||||
|
</clipPath>
|
||||||
|
</defs>
|
||||||
|
<g clip-path="url(#clip30_2029)">
|
||||||
|
<path id="path" d="M119.508 30.113L117.562 30.113L117.562 27.0967L119.508 27.0967C120.713 27.0967 121.931 26.7961 122.715 25.9614C123.5 25.1265 123.796 23.8464 123.796 22.5664C123.796 21.2864 123.512 20.0063 122.715 19.1716C121.919 18.3369 120.713 18.0364 119.508 18.0364C118.302 18.0364 117.085 18.3369 116.3 19.1716C115.515 20.0063 115.219 21.2864 115.219 22.5664L115.219 34.9551L111.806 34.9551L111.806 15.031L115.219 15.031L115.219 16.2998L115.845 16.2998C115.913 16.2219 115.981 16.1553 116.049 16.0884C116.903 15.3093 118.211 15.031 119.496 15.031C121.51 15.031 123.523 15.532 124.843 16.9233C126.162 18.3145 126.629 20.4517 126.629 22.5776C126.629 24.7036 126.151 26.8296 124.843 28.2319C123.535 29.6345 121.51 30.113 119.508 30.113Z" fill-rule="nonzero" fill="#4D6BFE"/>
|
||||||
|
<path id="path" d="M67.5664 15.5654L69.5117 15.5654L69.5117 18.5818L67.5664 18.5818C66.3606 18.5818 65.1434 18.8823 64.3585 19.717C63.5736 20.552 63.2778 21.832 63.2778 23.1121C63.2778 24.3921 63.5623 25.6721 64.3585 26.5068C65.1548 27.3418 66.3606 27.6423 67.5664 27.6423C68.7722 27.6423 69.9895 27.3418 70.7744 26.5068C71.5593 25.6721 71.8551 24.3921 71.8551 23.1121L71.8551 10.7124L75.2677 10.7124L75.2677 30.6475L71.8551 30.6475L71.8551 29.3787L71.2294 29.3787C71.1611 29.4565 71.0929 29.5234 71.0247 29.5901C70.1715 30.3691 68.8633 30.6475 67.5779 30.6475C65.5643 30.6475 63.5509 30.1467 62.2313 28.7554C60.9117 27.364 60.4453 25.2268 60.4453 23.1008C60.4453 20.9749 60.9231 18.8489 62.2313 17.4465C63.5509 16.0552 65.5643 15.5654 67.5664 15.5654Z" fill-rule="nonzero" fill="#4D6BFE"/>
|
||||||
|
<path id="path" d="M92.3881 22.845L92.3881 24.0581L83.299 24.0581L83.299 21.6428L89.328 21.6428C89.1914 20.7634 88.8729 19.9397 88.3042 19.3386C87.4851 18.4705 86.2224 18.1589 84.9711 18.1589C83.7198 18.1589 82.4572 18.4705 81.6381 19.3386C80.819 20.2068 80.5232 21.5315 80.5232 22.845C80.5232 24.1582 80.819 25.4939 81.6381 26.3511C82.4572 27.208 83.7198 27.531 84.9711 27.531C86.2224 27.531 87.4851 27.2192 88.3042 26.3511C88.418 26.2285 88.5203 26.095 88.6227 25.9614L91.9899 25.9614C91.6941 27.0078 91.2277 27.9539 90.5225 28.6885C89.1573 30.1243 87.0529 30.6475 84.9711 30.6475C82.8894 30.6475 80.7849 30.1355 79.4198 28.6885C78.0547 27.2415 77.5542 25.0376 77.5542 22.845C77.5542 20.6521 78.0433 18.437 79.4198 17.0012C80.7963 15.5654 82.8894 15.0422 84.9711 15.0422C87.0529 15.0422 89.1573 15.5542 90.5225 17.0012C91.8988 18.4482 92.3881 20.6521 92.3881 22.845Z" fill-rule="nonzero" fill="#4D6BFE"/>
|
||||||
|
<path id="path" d="M109.52 22.845L109.52 24.0581L100.431 24.0581L100.431 21.6428L106.46 21.6428C106.323 20.7634 106.005 19.9397 105.436 19.3386C104.617 18.4705 103.354 18.1589 102.103 18.1589C100.852 18.1589 99.5889 18.4705 98.7698 19.3386C97.9507 20.2068 97.6549 21.5315 97.6549 22.845C97.6549 24.1582 97.9507 25.4939 98.7698 26.3511C99.5889 27.208 100.852 27.531 102.103 27.531C103.354 27.531 104.617 27.2192 105.436 26.3511C105.55 26.2285 105.652 26.095 105.754 25.9614L109.122 25.9614C108.826 27.0078 108.359 27.9539 107.654 28.6885C106.289 30.1243 104.185 30.6475 102.103 30.6475C100.021 30.6475 97.9166 30.1355 96.5515 28.6885C95.1864 27.2415 94.6859 25.0376 94.6859 22.845C94.6859 20.6521 95.175 18.437 96.5515 17.0012C97.928 15.5654 100.021 15.0422 102.103 15.0422C104.185 15.0422 106.289 15.5542 107.654 17.0012C109.031 18.4482 109.52 20.6521 109.52 22.845Z" fill-rule="nonzero" fill="#4D6BFE"/>
|
||||||
|
<path id="path" d="M136.355 30.6475C138.437 30.6475 140.541 30.3469 141.906 29.49C143.271 28.6328 143.772 27.3306 143.772 26.0393C143.772 24.7483 143.282 23.4348 141.906 22.5889C140.541 21.7429 138.437 21.4312 136.355 21.4312C135.467 21.4312 134.648 21.3088 134.068 20.9861C133.488 20.6521 133.272 20.1511 133.272 19.6504C133.272 19.1494 133.477 18.6375 134.068 18.3147C134.648 17.9807 135.547 17.8694 136.434 17.8694C137.322 17.8694 138.22 17.9919 138.801 18.3147C139.381 18.6487 139.597 19.1494 139.597 19.6504L143.066 19.6504C143.066 18.3591 142.623 17.0457 141.383 16.2C140.143 15.354 138.243 15.0422 136.355 15.0422C134.466 15.0422 132.567 15.3428 131.327 16.2C130.087 17.0569 129.643 18.3591 129.643 19.6504C129.643 20.9414 130.087 22.2549 131.327 23.1008C132.567 23.9468 134.466 24.2585 136.355 24.2585C137.333 24.2585 138.414 24.3809 139.062 24.7036C139.711 25.0266 139.938 25.5386 139.938 26.0393C139.938 26.5403 139.711 27.0522 139.062 27.375C138.414 27.6978 137.424 27.8203 136.446 27.8203C135.467 27.8203 134.466 27.6978 133.829 27.375C133.192 27.0522 132.953 26.5403 132.953 26.0393L128.949 26.0393C128.949 27.3306 129.438 28.644 130.815 29.49C132.191 30.3359 134.273 30.6475 136.355 30.6475Z" fill-rule="nonzero" fill="#4D6BFE"/>
|
||||||
|
<path id="path" d="M160.903 22.845L160.903 24.0581L151.814 24.0581L151.814 21.6428L157.843 21.6428C157.707 20.7634 157.388 19.9397 156.82 19.3386C156 18.4705 154.738 18.1589 153.486 18.1589C152.235 18.1589 150.972 18.4705 150.153 19.3386C149.334 20.2068 149.039 21.5315 149.039 22.845C149.039 24.1582 149.334 25.4939 150.153 26.3511C150.972 27.208 152.235 27.531 153.486 27.531C154.738 27.531 156 27.2192 156.82 26.3511C156.933 26.2285 157.036 26.095 157.138 25.9614L160.505 25.9614C160.209 27.0078 159.743 27.9539 159.038 28.6885C157.673 30.1243 155.568 30.6475 153.486 30.6475C151.405 30.6475 149.3 30.1355 147.935 28.6885C146.57 27.2415 146.07 25.0376 146.07 22.845C146.07 20.6521 146.559 18.437 147.935 17.0012C149.312 15.5654 151.405 15.0422 153.486 15.0422C155.568 15.0422 157.673 15.5542 159.038 17.0012C160.414 18.4482 160.903 20.6521 160.903 22.845Z" fill-rule="nonzero" fill="#4D6BFE"/>
|
||||||
|
<path id="path" d="M178.035 22.845L178.035 24.0581L168.946 24.0581L168.946 21.6428L174.975 21.6428C174.839 20.7634 174.52 19.9397 173.951 19.3386C173.132 18.4705 171.87 18.1589 170.618 18.1589C169.367 18.1589 168.104 18.4705 167.285 19.3386C166.466 20.2068 166.17 21.5315 166.17 22.845C166.17 24.1582 166.466 25.4939 167.285 26.3511C168.104 27.208 169.367 27.531 170.618 27.531C171.87 27.531 173.132 27.2192 173.951 26.3511C174.065 26.2285 174.167 26.095 174.27 25.9614L177.637 25.9614C177.341 27.0078 176.875 27.9539 176.17 28.6885C174.804 30.1243 172.7 30.6475 170.618 30.6475C168.536 30.6475 166.432 30.1355 165.067 28.6885C163.702 27.2415 163.201 25.0376 163.201 22.845C163.201 20.6521 163.69 18.437 165.067 17.0012C166.443 15.5654 168.536 15.0422 170.618 15.0422C172.7 15.0422 174.804 15.5542 176.17 17.0012C177.546 18.4482 178.035 20.6521 178.035 22.845Z" fill-rule="nonzero" fill="#4D6BFE"/>
|
||||||
|
<rect id="rect" x="180.321533" y="10.022217" width="3.412687" height="20.625223" fill="#4D6BFE"/>
|
||||||
|
<path id="polygon" d="M189.559 22.3772L195.155 30.6475L190.935 30.6475L185.338 22.3772L190.935 15.7322L195.155 15.7322L189.559 22.3772Z" fill-rule="nonzero" fill="#4D6BFE"/>
|
||||||
|
</g>
|
||||||
|
<path id="path" d="M55.6128 3.47119C55.0175 3.17944 54.7611 3.73535 54.413 4.01782C54.2939 4.10889 54.1932 4.22729 54.0924 4.33667C53.2223 5.26587 52.2057 5.87646 50.8776 5.80347C48.9359 5.69409 47.2781 6.30469 45.8126 7.78979C45.5012 5.9585 44.4663 4.86499 42.8909 4.16357C42.0667 3.79907 41.2332 3.43457 40.6561 2.64185C40.2532 2.07715 40.1432 1.44849 39.9418 0.828857C39.8135 0.455322 39.6853 0.0725098 39.2548 0.00878906C38.7877 -0.0639648 38.6045 0.327637 38.4213 0.655762C37.6886 1.99512 37.4047 3.47119 37.4321 4.96533C37.4962 8.32739 38.9159 11.0059 41.7369 12.9102C42.0575 13.1289 42.1399 13.3474 42.0392 13.6665C41.8468 14.3225 41.6178 14.9602 41.4164 15.6162C41.2881 16.0354 41.0957 16.1265 40.647 15.9441C39.0991 15.2974 37.7618 14.3406 36.5803 13.1836C34.5745 11.2429 32.761 9.10181 30.4988 7.42529C29.9675 7.03345 29.4363 6.66919 28.8867 6.32275C26.5786 4.08154 29.189 2.24097 29.7935 2.02246C30.4254 1.79468 30.0133 1.01099 27.9708 1.02026C25.9283 1.0293 24.0599 1.71265 21.6786 2.62378C21.3306 2.7605 20.9641 2.8606 20.5886 2.94263C18.4271 2.53271 16.1831 2.44141 13.8384 2.70581C9.42371 3.19775 5.89758 5.28418 3.30554 8.84668C0.191406 13.1289 -0.54126 17.9941 0.356323 23.0691C1.29968 28.4172 4.02905 32.8452 8.22388 36.3076C12.5745 39.8972 17.5845 41.6558 23.2997 41.3186C26.771 41.1182 30.6361 40.6536 34.9958 36.9636C36.0948 37.5103 37.2489 37.7288 39.1632 37.8928C40.6378 38.0295 42.0575 37.8201 43.1565 37.5923C44.8784 37.2278 44.7594 35.6333 44.1366 35.3418C39.09 32.9912 40.1981 33.9478 39.1907 33.1733C41.7552 30.1394 45.6204 26.9868 47.1316 16.7732C47.2506 15.9624 47.1499 15.4521 47.1316 14.7961C47.1224 14.3953 47.214 14.2405 47.672 14.1948C48.9359 14.0491 50.1632 13.7029 51.2898 13.0833C54.5596 11.2976 55.8784 8.36377 56.1898 4.84692C56.2357 4.30933 56.1807 3.75342 55.6128 3.47119ZM27.119 35.123C22.2281 31.2783 19.856 30.0117 18.8759 30.0664C17.96 30.1211 18.1249 31.1689 18.3263 31.8523C18.537 32.5264 18.8118 32.9912 19.1964 33.5833C19.462 33.9751 19.6453 34.5581 18.9309 34.9956C17.3555 35.9705 14.6169 34.6675 14.4886 34.6038C11.3014 32.7268 8.63611 30.2485 6.75842 26.8594C4.94495 23.5974 3.89172 20.0989 3.71765 16.3633C3.67188 15.4614 3.9375 15.1423 4.83508 14.9785C6.0166 14.7598 7.23474 14.7141 8.41626 14.8872C13.408 15.6162 17.6577 17.8484 21.2206 21.3835C23.2539 23.397 24.7926 25.8025 26.3772 28.1531C28.0624 30.6494 29.8759 33.0276 32.184 34.9773C32.9991 35.6606 33.6494 36.1799 34.2722 36.5627C32.3947 36.7722 29.2622 36.8179 27.119 35.123ZM29.4637 20.0442C29.4637 19.6433 29.7843 19.3245 30.1874 19.3245C30.2789 19.3245 30.3613 19.3425 30.4346 19.3699C30.5354 19.4065 30.627 19.4612 30.7002 19.543C30.8285 19.6707 30.9017 19.8528 30.9017 20.0442C30.9017 20.4451 30.5812 20.7639 30.1782 20.7639C29.7751 20.7639 29.4637 20.4451 29.4637 20.0442ZM36.7452 23.7798C36.2781 23.9712 35.811 24.135 35.3622 24.1533C34.6661 24.1897 33.9059 23.9072 33.4938 23.561C32.8527 23.0234 32.3947 22.7229 32.2023 21.7844C32.1199 21.3835 32.1656 20.7639 32.239 20.4087C32.4038 19.6433 32.2206 19.1514 31.6803 18.7048C31.2406 18.3403 30.6819 18.2402 30.0682 18.2402C29.8392 18.2402 29.6287 18.1399 29.4729 18.0579C29.2164 17.9304 29.0059 17.6116 29.2073 17.2197C29.2714 17.0923 29.5829 16.7825 29.6561 16.7278C30.4896 16.2539 31.4513 16.4089 32.3397 16.7642C33.1641 17.1013 33.7869 17.7209 34.6844 18.5955C35.6003 19.6523 35.7651 19.9441 36.2872 20.7366C36.6995 21.3562 37.075 21.9939 37.3314 22.7229C37.4871 23.1785 37.2856 23.552 36.7452 23.7798Z" fill-rule="nonzero" fill="#4D6BFE"/>
|
||||||
|
</svg>
|
||||||
|
After Width: | Height: | Size: 10 KiB |
BIN
assets/show1.jpg
Normal file
|
After Width: | Height: | Size: 114 KiB |
BIN
assets/show2.jpg
Normal file
|
After Width: | Height: | Size: 211 KiB |
BIN
assets/show3.jpg
Normal file
|
After Width: | Height: | Size: 241 KiB |
BIN
assets/show4.jpg
Normal file
|
After Width: | Height: | Size: 262 KiB |
9
requirements.txt
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
transformers==4.46.3
|
||||||
|
tokenizers==0.20.3
|
||||||
|
PyMuPDF
|
||||||
|
img2pdf
|
||||||
|
einops
|
||||||
|
easydict
|
||||||
|
addict
|
||||||
|
Pillow
|
||||||
|
numpy
|
||||||