175 lines
7.1 KiB
Python
175 lines
7.1 KiB
Python
import torch.nn as nn
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import copy
|
|
|
|
|
|
class MlpProjector(nn.Module):
|
|
|
|
def __init__(self, cfg):
|
|
|
|
super().__init__()
|
|
|
|
self.cfg = cfg
|
|
|
|
if cfg.projector_type == "identity":
|
|
modules = nn.Identity()
|
|
|
|
elif cfg.projector_type == "linear":
|
|
modules = nn.Linear(cfg.input_dim, cfg.n_embed)
|
|
|
|
elif cfg.projector_type == "mlp_gelu":
|
|
mlp_depth = cfg.get("depth", 1)
|
|
modules = [nn.Linear(cfg.input_dim, cfg.n_embed)]
|
|
for _ in range(1, mlp_depth):
|
|
modules.append(nn.GELU())
|
|
modules.append(nn.Linear(cfg.n_embed, cfg.n_embed))
|
|
modules = nn.Sequential(*modules)
|
|
|
|
elif cfg.projector_type == "normlayer_downsample_mlp_gelu":
|
|
mlp_depth = cfg.get("depth", 1)
|
|
mlp_ratio = cfg.get("mlp_ratio", 1)
|
|
modules = [
|
|
nn.LayerNorm(cfg.input_dim * cfg.downsample_ratio * cfg.downsample_ratio),
|
|
nn.Linear(cfg.input_dim * cfg.downsample_ratio * cfg.downsample_ratio, cfg.n_embed * mlp_ratio)
|
|
]
|
|
for _ in range(1, mlp_depth - 1):
|
|
modules.append(nn.GELU())
|
|
modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed * mlp_ratio))
|
|
modules.append(nn.GELU())
|
|
modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed))
|
|
modules = nn.Sequential(*modules)
|
|
|
|
elif cfg.projector_type == "downsample_mlp_gelu":
|
|
mlp_depth = cfg.get("depth", 1)
|
|
mlp_ratio = cfg.get("mlp_ratio", 1)
|
|
modules = [nn.Linear(cfg.input_dim * cfg.downsample_ratio * cfg.downsample_ratio, cfg.n_embed * mlp_ratio)]
|
|
for _ in range(1, mlp_depth - 1):
|
|
modules.append(nn.GELU())
|
|
modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed * mlp_ratio))
|
|
modules.append(nn.GELU())
|
|
modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed))
|
|
modules = nn.Sequential(*modules)
|
|
|
|
elif cfg.projector_type == "low_high_hybrid_split_mlp_gelu":
|
|
mlp_depth = cfg.get("depth", 1)
|
|
self.high_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2)
|
|
self.low_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2)
|
|
|
|
modules = []
|
|
for _ in range(1, mlp_depth):
|
|
modules.append(nn.GELU())
|
|
modules.append(nn.Linear(cfg.n_embed, cfg.n_embed))
|
|
modules = nn.Sequential(*modules)
|
|
|
|
elif cfg.projector_type == "hybrid_split_feature_mlp_gelu":
|
|
mlp_depth = cfg.get("depth", 1)
|
|
channel_div = cfg.get("channel_div", 0.5)
|
|
self.high_up_proj = nn.Linear(cfg.input_dim[0], int(cfg.n_embed * channel_div))
|
|
self.low_up_proj = nn.Linear(cfg.input_dim[1], cfg.n_embed - int(cfg.n_embed * channel_div))
|
|
|
|
modules = []
|
|
for _ in range(1, mlp_depth):
|
|
modules.append(nn.GELU())
|
|
modules.append(nn.Linear(cfg.n_embed, cfg.n_embed))
|
|
modules = nn.Sequential(*modules)
|
|
|
|
elif cfg.projector_type == "low_high_split_mlp_gelu":
|
|
mlp_depth = cfg.get("depth", 1)
|
|
modules = []
|
|
for _ in range(1, mlp_depth):
|
|
modules.append(nn.GELU())
|
|
modules.append(nn.Linear(cfg.n_embed // 2, cfg.n_embed // 2))
|
|
modules = nn.Sequential(*modules)
|
|
self.high_layers = nn.Sequential(*modules)
|
|
self.low_layers = copy.deepcopy(modules)
|
|
|
|
else:
|
|
raise ValueError(f"Unknown projector type: {cfg.projector_type}")
|
|
|
|
if cfg.get("token_pooling", False):
|
|
self.token_pooling_layer = nn.Linear(cfg.input_dim * 4, cfg.input_dim)
|
|
|
|
if cfg.get("conv_fusion_high_low_features", False):
|
|
self.fusion_layer = nn.Linear(cfg.input_dim, cfg.input_dim)
|
|
self.layers = modules
|
|
|
|
def forward(self, x):
|
|
if self.cfg.get("token_pooling", False):
|
|
batch_size, wxh, channels = x.shape
|
|
w = h = int(wxh**0.5)
|
|
x = x.view(batch_size, w, h, channels)
|
|
x = x.permute(0, 3, 1, 2)
|
|
# import ipdb; ipdb.set_trace()
|
|
patches = x.unfold(2, 2, 2).unfold(3, 2, 2)
|
|
batch_size, channels, h_patches, w_patches, _, _ = patches.size()
|
|
# 在通道维度上拼接
|
|
patches = patches.contiguous().view(batch_size, channels, h_patches * w_patches, -1)
|
|
|
|
# 通过线性层
|
|
patches = patches.permute(0, 2, 1, 3).contiguous()
|
|
patches = patches.view(batch_size, h_patches * w_patches, channels * 4)
|
|
|
|
x = self.token_pooling_layer(patches)
|
|
|
|
if self.cfg.get("conv_fusion_high_low_features", False):
|
|
x = self.fusion_layer(x[:, 0]) + x[:, 1]
|
|
|
|
if self.cfg.projector_type == 'low_high_hybrid_split_mlp_gelu':
|
|
high_x, low_x = x[0], x[1]
|
|
high_x = self.high_up_proj(high_x)
|
|
low_x = self.low_up_proj(low_x)
|
|
x = torch.concat([high_x, low_x], dim=-1)
|
|
|
|
if self.cfg.projector_type == 'hybrid_split_feature_mlp_gelu':
|
|
high_x = x[...,:self.cfg.input_dim[0]]
|
|
low_x = x[...,self.cfg.input_dim[0]:]
|
|
high_x = self.high_up_proj(high_x)
|
|
low_x = self.low_up_proj(low_x)
|
|
x = torch.concat([high_x, low_x], dim=-1)
|
|
|
|
if self.cfg.projector_type == 'low_high_split_mlp_gelu':
|
|
high_x, low_x = x[0], x[1]
|
|
high_x = self.high_layers(high_x)
|
|
low_x = self.low_layers(low_x)
|
|
x = torch.concat([high_x, low_x], dim=-1)
|
|
return x
|
|
|
|
if self.cfg.projector_type == 'downsample_mlp_gelu' or self.cfg.projector_type == 'normlayer_downsample_mlp_gelu':
|
|
bs, hw, input_dim = x.shape
|
|
h = w = int((hw) ** 0.5)
|
|
|
|
"""compute padding"""
|
|
if h % self.cfg.downsample_ratio:
|
|
pad = self.cfg.downsample_ratio - h % self.cfg.downsample_ratio
|
|
else:
|
|
pad = 0
|
|
x = x.reshape(bs, h, w, input_dim)
|
|
if pad > 0:
|
|
x = F.pad(x, (0, 0, 0, pad, 0, pad), "constant", 0)
|
|
|
|
"""4 to 1 concat"""
|
|
x = x.permute(0, 3, 1, 2) # B, C, H, W
|
|
x = F.unfold(x, kernel_size=self.cfg.downsample_ratio, stride=self.cfg.downsample_ratio, padding=0) # B, C*4, HW // 4
|
|
x = x.permute(0, 2, 1)
|
|
|
|
return self.layers(x)
|
|
|
|
@staticmethod
|
|
def get_flops_per_sample(cfg):
|
|
if cfg.projector_type == "linear":
|
|
fwd = 2 * cfg.input_dim * cfg.n_embed
|
|
|
|
elif "mlp_gelu" in cfg.projector_type :
|
|
mlp_depth = cfg.get("depth", 1)
|
|
downsample_ratio = cfg.get("downsample_ratio", 1)
|
|
input_dim = sum(cfg.input_dim) if isinstance(cfg.input_dim, list) else cfg.input_dim
|
|
input_dim = input_dim * downsample_ratio * downsample_ratio
|
|
fwd = 2 * input_dim * cfg.n_embed + (mlp_depth - 1) * 2 * cfg.n_embed * cfg.n_embed
|
|
else:
|
|
fwd = 0
|
|
|
|
return fwd * 3
|
|
|
|
|