Files
DeepSeek-OCR/DeepSeek-OCR-master/DeepSeek-OCR-vllm/deepencoder/clip_sdpa.py
Ucas-Haoranwei 093b215067 Initial commit
2025-10-20 10:48:36 +08:00

505 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from contextlib import nullcontext
import math
from typing import Optional, Tuple
# from megatron.model import LayerNorm
from easydict import EasyDict as adict
import torch
from torch.nn import functional as F
from torch import nn
from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
# from optimus import flash_attn_func
# from megatron.core import tensor_parallel
# from megatron.core import parallel_state as mpu
# from megatron.core.utils import make_viewless_tensor, divide
# from megatron.model.fused_rms_norm import RMSNorm
# from megatron.model.transformer import (
# FlashSelfAttention,
# NoopTransformerLayer,
# _cfg_to_kwargs,
# )
# from megatron.model.enums import AttnMaskType, AttnType
# from megatron.model.fused_softmax import FusedScaleMaskSoftmax
# from megatron.model.utils import attention_mask_func
# from megatron.model.module import MegatronModule
# try:
# from einops import rearrange
# except ImportError:
# rearrange = None
# from flash_attn import flash_attn_varlen_func as flash_attn_unpadded_func
# try:
# # flash attention 2.x
# from flash_attn import flash_attn_varlen_func as flash_attn_unpadded_func
# except ImportError:
# try:
# # flash attention 1.x
# from flash_attn.flash_attn_interface import flash_attn_unpadded_func
# except ImportError:
# flash_attn_unpadded_func = None
# try:
# from flash_attn.flash_attn_interface import flash_attn_unpadded_relative_attention_bias_func
# except ImportError:
# flash_attn_unpadded_relative_attention_bias_func = None
# try:
# from flash_attn.flash_attn_interface import mask_flash_attn_unpadded_func
# except ImportError:
# mask_flash_attn_unpadded_func = None
class LayerNormfp32(torch.nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = super().forward(x.type(torch.float32))
return ret.type(orig_type)
def get_abs_pos(abs_pos, tgt_size):
# abs_pos: L, C
# tgt_size: M
# return: M, C
# print(tgt_size)
# print(abs_pos.shape)
# exit()
dim = abs_pos.size(-1)
# print(dim)
abs_pos_new = abs_pos.squeeze(0)
cls_token, old_pos_embed = abs_pos_new[:1], abs_pos_new[1:]
src_size = int(math.sqrt(abs_pos_new.shape[0] - 1))
tgt_size = int(math.sqrt(tgt_size))
dtype = abs_pos.dtype
if src_size != tgt_size:
old_pos_embed = old_pos_embed.view(1, src_size, src_size, dim).permute(0, 3, 1,
2).contiguous()
old_pos_embed = old_pos_embed.to(torch.float32)
new_pos_embed = F.interpolate(
old_pos_embed,
size=(tgt_size, tgt_size),
mode='bicubic',
antialias=True,
align_corners=False,
).to(dtype)
new_pos_embed = new_pos_embed.permute(0, 2, 3, 1)
new_pos_embed = new_pos_embed.view(tgt_size * tgt_size, dim)
vision_pos_embed = torch.cat([cls_token, new_pos_embed], dim=0)
vision_pos_embed = vision_pos_embed.view(1, tgt_size * tgt_size + 1, dim)
return vision_pos_embed
else:
return abs_pos
@torch.jit.script
def quick_gelu(x):
return x * torch.sigmoid(1.702 * x)
class CLIPVisionEmbeddings(nn.Module):
def __init__(self, hidden_size=1024, image_size=224, patch_size=14, num_channels=3):
super().__init__()
self.embed_dim = hidden_size
self.image_size = image_size
self.patch_size = patch_size
self.class_embedding = torch.nn.Parameter(torch.randn(self.embed_dim))
self.patch_embedding = torch.nn.Conv2d(
in_channels=num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
bias=False,
)
self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches + 1
self.position_embedding = torch.nn.Embedding(self.num_positions, self.embed_dim)
self.register_buffer(
"position_ids", torch.arange(self.num_positions).expand((1, -1))
)
def forward(self, pixel_values, patch_embeds):
batch_size = pixel_values.shape[0]
# patch_embeds = self.patch_embedding(
# pixel_values
# ) # shape = [*, width, grid, grid]
if patch_embeds is not None:
patch_embeds = patch_embeds
# print(patch_embeds.shape)
else:
patch_embeds = self.patch_embedding(pixel_values)
# print(111111)
# shape = [*, width, grid, grid]
# patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
# x = torch.cat([cls_token, x], dim=1)
embeddings = embeddings + get_abs_pos(self.position_embedding(self.position_ids), embeddings.size(1))
# embeddings = embeddings + self.position_embedding(self.position_ids)
return embeddings
class NoTPFeedForward(nn.Module):
def __init__(
self,
cfg,
dim: int,
hidden_dim: int,
):
super().__init__()
self.fc1 = torch.nn.Linear(dim, hidden_dim, bias=True)
self.fc2 = torch.nn.Linear(hidden_dim, dim, bias=True)
def forward(self, x):
output = self.fc2(quick_gelu(self.fc1(x)))
return output
# from optimus.flash_attn_interface import flash_attn_qkvpacked_func
# class NoTPAttention(nn.Module):
# def __init__(self, cfg):
# super().__init__()
# self.num_heads = cfg.num_attention_heads
# self.n_local_heads = cfg.num_attention_heads
# self.head_dim = cfg.hidden_size // cfg.num_attention_heads
# self.max_seq_len = cfg.seq_length
# self.use_flash_attention = cfg.use_flash_attn
# self.qkv_proj = torch.nn.Linear(cfg.hidden_size, cfg.hidden_size * 3, bias=True)
# self.out_proj = torch.nn.Linear(cfg.hidden_size, cfg.hidden_size, bias=True)
# # self.core_attention = CoreAttention(cfg, AttnType.self_attn)
# self.attn_drop = cfg.attention_dropout
# def forward(
# self,
# x: torch.Tensor,
# ):
# bsz, seqlen, _ = x.shape
# xqkv = self.qkv_proj(x)
# xqkv = xqkv.view(bsz, seqlen, 3, self.num_heads, self.head_dim)
# if self.use_flash_attention:
# output = flash_attn_qkvpacked_func(xqkv)
# output = output.view(bsz, seqlen, -1)
# else:
# xq, xk, xv = torch.split(xqkv, 1, dim=2)
# xq = xq.squeeze(2)
# xk = xk.squeeze(2)
# xv = xv.squeeze(2)
# # xq, xk, xv = xqkv[:, :, 0, ...], xqkv[:, :, 1, ...], xqkv[:, :, 2, ...]
# # B, num_head, S, head_size)
# xq = xq.permute(0, 2, 1, 3)
# xk = xk.permute(0, 2, 1, 3)
# xv = xv.permute(0, 2, 1, 3)
# output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None)
# utput = output.permute(0, 2, 1, 3).view(bsz, seqlen, -1)
# output = self.out_proj(output)
# return output
# from optimus.flash_attn_interface import flash_attn_qkvpacked_func
class NoTPAttention(torch.nn.Module):
def __init__(self, cfg):
super().__init__()
self.num_heads = cfg.num_attention_heads
self.n_local_heads = cfg.num_attention_heads
self.head_dim = cfg.hidden_size // cfg.num_attention_heads
self.max_seq_len = cfg.seq_length
self.use_flash_attention = cfg.use_flash_attn
self.qkv_proj = torch.nn.Linear(cfg.hidden_size, cfg.hidden_size * 3, bias=True)
self.out_proj = torch.nn.Linear(cfg.hidden_size, cfg.hidden_size, bias=True)
# self.core_attention = CoreAttention(cfg, AttnType.self_attn)
self.attn_drop = cfg.attention_dropout
def forward(
self,
x: torch.Tensor,
):
bsz, seqlen, _ = x.shape
xqkv = self.qkv_proj(x)
xqkv = xqkv.view(bsz, seqlen, 3, self.num_heads, self.head_dim)
if self.use_flash_attention:
output = flash_attn_qkvpacked_func(xqkv)
output = output.view(bsz, seqlen, -1)
# xq, xk, xv = torch.split(xqkv, 1, dim=2)
# xq = xq.squeeze(2)
# xk = xk.squeeze(2)
# xv = xv.squeeze(2)
# # xq, xk, xv = xqkv[:, :, 0, ...], xqkv[:, :, 1, ...], xqkv[:, :, 2, ...]
# # B, num_head, S, head_size)
# xq = xq.permute(0, 2, 1, 3)
# xk = xk.permute(0, 2, 1, 3)
# xv = xv.permute(0, 2, 1, 3)
# # with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
# output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None)
# output = output.permute(0, 2, 1, 3).reshape(bsz, seqlen, -1)
# output = output.permute(0, 2, 1, 3).contiguous().view(bsz, seqlen, -1)
else:
# output = flash_attn_qkvpacked_func(xqkv)
xq, xk, xv = torch.split(xqkv, 1, dim=2)
xq = xq.squeeze(2)
xk = xk.squeeze(2)
xv = xv.squeeze(2)
# xq, xk, xv = xqkv[:, :, 0, ...], xqkv[:, :, 1, ...], xqkv[:, :, 2, ...]
# B, num_head, S, head_size)
xq = xq.permute(0, 2, 1, 3)
xk = xk.permute(0, 2, 1, 3)
xv = xv.permute(0, 2, 1, 3)
# with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None)
output = output.permute(0, 2, 1, 3).reshape(bsz, seqlen, -1)
output = self.out_proj(output)
return output
class NoTPTransformerBlock(nn.Module):
def __init__(self, cfg, layer_id: int, multiple_of=256):
super().__init__()
self.n_heads = cfg.num_attention_heads
self.dim = cfg.hidden_size
self.head_dim = cfg.hidden_size // cfg.num_attention_heads
self.self_attn = NoTPAttention(cfg)
self.mlp = NoTPFeedForward(
cfg, dim=cfg.hidden_size, hidden_dim=cfg.ffn_hidden_size
)
self.layer_id = layer_id
self.layer_norm1 = torch.nn.LayerNorm(
cfg.hidden_size, eps=cfg.layernorm_epsilon
)
self.layer_norm2 = torch.nn.LayerNorm(
cfg.hidden_size, eps=cfg.layernorm_epsilon
)
def forward(self, x: torch.Tensor):
residual = self.self_attn.forward(self.layer_norm1(x))
h = x + residual
out = h + self.mlp.forward(self.layer_norm2(h))
return out
class NoTPTransformer(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
# self.recompute_list = self.cfg.get("recompute_list", [])
self.num_layers = cfg.num_layers # _get_num_layers(cfg)
self.layers = torch.nn.ModuleList()
for layer_id in range(self.num_layers):
self.layers.append(
NoTPTransformerBlock(
cfg,
layer_id + 1,
)
)
def forward(
self,
hidden_states,
):
for lid, layer in enumerate(self.layers):
# if lid in self.recompute_list:
# def custom(layer_id):
# def custom_forward(*args, **kwargs):
# x_ = self.layers[layer_id](*args, **kwargs)
# return x_
# return custom_forward
# assert hidden_states.requires_grad == True, logger.warning(
# "When using recalculation, the input must have grad fn"
# )
# hidden_states = tensor_parallel.checkpoint(
# custom(lid),
# False,
# hidden_states.contiguous()
# )
# else:
hidden_states = layer(hidden_states)
return hidden_states
# from megatron.core.tensor_parallel.layers import non_tensor_paralleled, local_dp_reduce, local_dp_scatter
class VitModel(nn.Module):
def __init__(
self,
cfg,
freeze_embed=False,
freeze_pre_norm=False
) -> None:
super().__init__()
self.embeddings = CLIPVisionEmbeddings(hidden_size=cfg.hidden_size, image_size=cfg.image_size, patch_size=cfg.patch_size)
if freeze_embed:
for name, param in self.embeddings.named_parameters():
param.requires_grad = False
self.transformer = NoTPTransformer(cfg=cfg)
if cfg.get("fp32norm", False):
logger.info("Load fp32 layernorm for ViT.")
self.pre_layrnorm = LayerNormfp32(
cfg.hidden_size,
eps=cfg.get("pre_layernorm_epsilon", 1e-5),
)
else:
self.pre_layrnorm = torch.nn.LayerNorm(
cfg.hidden_size,
eps=cfg.get("pre_layernorm_epsilon", 1e-5),
)
# self.pre_layrnorm = RMSNorm(
# cfg.hidden_size,
# eps=cfg.get("pre_layernorm_epsilon", 1e-5),
# sequence_parallel=False,
# use_fp32=True,
# use_optimus=True,
# )
if freeze_pre_norm:
for name, param in self.pre_layrnorm.named_parameters():
param.requires_grad = False
for p in self.parameters():
p.micro_dp = True
def set_input_tensor(self, input_tensor):
if not isinstance(input_tensor, list):
input_tensor = [input_tensor]
self.transformer.set_input_tensor(input_tensor[0])
def __str__(self) -> str:
return "open_clip"
def forward(
self,
x,
patch_embeds
):
x = self.embeddings(x, patch_embeds)
hidden_states = self.pre_layrnorm(x)
# hidden_states, dis = local_dp_scatter(hidden_states)
output = self.transformer(hidden_states)
# output = local_dp_reduce(output, dis)
return output
vit_model_cfg = adict(
num_layers=24,
hidden_size=1024,
num_heads = 16,
num_attention_heads=16,
ffn_hidden_size=4096,
seq_length=256,
max_position_embeddings=256,
use_flash_attn=False,
understand_projector_stride=2,
hidden_dropout = 0.0,
attention_dropout = 0.0,
no_persist_layer_norm = False,
layernorm_epsilon = 1e-5,
pre_layernorm_epsilon = 1e-5,
image_size = 224,
patch_size = 14,
recompute_list = []
)
def build_clip_l():
return VitModel(
cfg=vit_model_cfg,
freeze_embed=False,
freeze_pre_norm=False,
)
if __name__ == '__main__':
from mmgpt.model.vision_encoder.sam_b import build_sam_vit_b
vit_model_cfg = adict(
num_layers=24,
hidden_size=1024,
num_attention_heads=16,
ffn_hidden_size=4096,
seq_length=256,
max_position_embeddings=256,
use_flash_attn=False,
understand_projector_stride=2,
hidden_dropout = 0.0,
attention_dropout = 0.0,
no_persist_layer_norm = False,
layernorm_epsilon = 1e-5,
pre_layernorm_epsilon = 1e-5,
image_size = 224,
patch_size = 14,
recompute_list = []
)
sam_model = build_sam_vit_b()
vision_model = VitModel(
cfg=vit_model_cfg,
freeze_embed=False,
freeze_pre_norm=False,
)
# model = VitModel(1344)
# x = torch.zeros(2, 3, 224, 224)
x = torch.zeros(2, 3, 1024, 1024)
with torch.no_grad():
# y = vision_model(x)
patch_embed = sam_model(x)
print(patch_embed.shape)
y = vision_model(x, patch_embed)
print(y.shape)
image_feature = torch.add(y[:, 1:], patch_embed.flatten(2).permute(0, 2, 1))
print(image_feature.shape)