KyanChen's picture
init
f549064
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from typing import Sequence
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import build_conv_layer, build_norm_layer
from mmcv.cnn.bricks.transformer import AdaptivePadding
from mmengine.model import BaseModule
from .helpers import to_2tuple
def resize_pos_embed(pos_embed,
src_shape,
dst_shape,
mode='bicubic',
num_extra_tokens=1):
"""Resize pos_embed weights.
Args:
pos_embed (torch.Tensor): Position embedding weights with shape
[1, L, C].
src_shape (tuple): The resolution of downsampled origin training
image, in format (H, W).
dst_shape (tuple): The resolution of downsampled new training
image, in format (H, W).
mode (str): Algorithm used for upsampling. Choose one from 'nearest',
'linear', 'bilinear', 'bicubic' and 'trilinear'.
Defaults to 'bicubic'.
num_extra_tokens (int): The number of extra tokens, such as cls_token.
Defaults to 1.
Returns:
torch.Tensor: The resized pos_embed of shape [1, L_new, C]
"""
if src_shape[0] == dst_shape[0] and src_shape[1] == dst_shape[1]:
return pos_embed
assert pos_embed.ndim == 3, 'shape of pos_embed must be [1, L, C]'
_, L, C = pos_embed.shape
src_h, src_w = src_shape
assert L == src_h * src_w + num_extra_tokens, \
f"The length of `pos_embed` ({L}) doesn't match the expected " \
f'shape ({src_h}*{src_w}+{num_extra_tokens}). Please check the' \
'`img_size` argument.'
extra_tokens = pos_embed[:, :num_extra_tokens]
src_weight = pos_embed[:, num_extra_tokens:]
src_weight = src_weight.reshape(1, src_h, src_w, C).permute(0, 3, 1, 2)
# The cubic interpolate algorithm only accepts float32
dst_weight = F.interpolate(
src_weight.float(), size=dst_shape, align_corners=False, mode=mode)
dst_weight = torch.flatten(dst_weight, 2).transpose(1, 2)
dst_weight = dst_weight.to(src_weight.dtype)
return torch.cat((extra_tokens, dst_weight), dim=1)
def resize_relative_position_bias_table(src_shape, dst_shape, table, num_head):
"""Resize relative position bias table.
Args:
src_shape (int): The resolution of downsampled origin training
image, in format (H, W).
dst_shape (int): The resolution of downsampled new training
image, in format (H, W).
table (tensor): The relative position bias of the pretrained model.
num_head (int): Number of attention heads.
Returns:
torch.Tensor: The resized relative position bias table.
"""
from scipy import interpolate
def geometric_progression(a, r, n):
return a * (1.0 - r**n) / (1.0 - r)
left, right = 1.01, 1.5
while right - left > 1e-6:
q = (left + right) / 2.0
gp = geometric_progression(1, q, src_shape // 2)
if gp > dst_shape // 2:
right = q
else:
left = q
dis = []
cur = 1
for i in range(src_shape // 2):
dis.append(cur)
cur += q**(i + 1)
r_ids = [-_ for _ in reversed(dis)]
x = r_ids + [0] + dis
y = r_ids + [0] + dis
t = dst_shape // 2.0
dx = np.arange(-t, t + 0.1, 1.0)
dy = np.arange(-t, t + 0.1, 1.0)
all_rel_pos_bias = []
for i in range(num_head):
z = table[:, i].view(src_shape, src_shape).float().numpy()
f_cubic = interpolate.interp2d(x, y, z, kind='cubic')
all_rel_pos_bias.append(
torch.Tensor(f_cubic(dx,
dy)).contiguous().view(-1,
1).to(table.device))
new_rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)
return new_rel_pos_bias
class PatchEmbed(BaseModule):
"""Image to Patch Embedding.
We use a conv layer to implement PatchEmbed.
Args:
img_size (int | tuple): The size of input image. Default: 224
in_channels (int): The num of input channels. Default: 3
embed_dims (int): The dimensions of embedding. Default: 768
norm_cfg (dict, optional): Config dict for normalization layer.
Default: None
conv_cfg (dict, optional): The config dict for conv layers.
Default: None
init_cfg (`mmcv.ConfigDict`, optional): The Config for initialization.
Default: None
"""
def __init__(self,
img_size=224,
in_channels=3,
embed_dims=768,
norm_cfg=None,
conv_cfg=None,
init_cfg=None):
super(PatchEmbed, self).__init__(init_cfg)
warnings.warn('The `PatchEmbed` in mmcls will be deprecated. '
'Please use `mmcv.cnn.bricks.transformer.PatchEmbed`. '
"It's more general and supports dynamic input shape")
if isinstance(img_size, int):
img_size = to_2tuple(img_size)
elif isinstance(img_size, tuple):
if len(img_size) == 1:
img_size = to_2tuple(img_size[0])
assert len(img_size) == 2, \
f'The size of image should have length 1 or 2, ' \
f'but got {len(img_size)}'
self.img_size = img_size
self.embed_dims = embed_dims
# Use conv layer to embed
conv_cfg = conv_cfg or dict()
_conv_cfg = dict(
type='Conv2d', kernel_size=16, stride=16, padding=0, dilation=1)
_conv_cfg.update(conv_cfg)
self.projection = build_conv_layer(_conv_cfg, in_channels, embed_dims)
# Calculate how many patches a input image is splited to.
h_out, w_out = [(self.img_size[i] + 2 * self.projection.padding[i] -
self.projection.dilation[i] *
(self.projection.kernel_size[i] - 1) - 1) //
self.projection.stride[i] + 1 for i in range(2)]
self.patches_resolution = (h_out, w_out)
self.num_patches = h_out * w_out
if norm_cfg is not None:
self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
else:
self.norm = None
def forward(self, x):
B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't " \
f'match model ({self.img_size[0]}*{self.img_size[1]}).'
# The output size is (B, N, D), where N=H*W/P/P, D is embid_dim
x = self.projection(x).flatten(2).transpose(1, 2)
if self.norm is not None:
x = self.norm(x)
return x
# Modified from pytorch-image-models
class HybridEmbed(BaseModule):
"""CNN Feature Map Embedding.
Extract feature map from CNN, flatten,
project to embedding dim.
Args:
backbone (nn.Module): CNN backbone
img_size (int | tuple): The size of input image. Default: 224
feature_size (int | tuple, optional): Size of feature map extracted by
CNN backbone. Default: None
in_channels (int): The num of input channels. Default: 3
embed_dims (int): The dimensions of embedding. Default: 768
conv_cfg (dict, optional): The config dict for conv layers.
Default: None.
init_cfg (`mmcv.ConfigDict`, optional): The Config for initialization.
Default: None.
"""
def __init__(self,
backbone,
img_size=224,
feature_size=None,
in_channels=3,
embed_dims=768,
conv_cfg=None,
init_cfg=None):
super(HybridEmbed, self).__init__(init_cfg)
assert isinstance(backbone, nn.Module)
if isinstance(img_size, int):
img_size = to_2tuple(img_size)
elif isinstance(img_size, tuple):
if len(img_size) == 1:
img_size = to_2tuple(img_size[0])
assert len(img_size) == 2, \
f'The size of image should have length 1 or 2, ' \
f'but got {len(img_size)}'
self.img_size = img_size
self.backbone = backbone
if feature_size is None:
with torch.no_grad():
# FIXME this is hacky, but most reliable way of
# determining the exact dim of the output feature
# map for all networks, the feature metadata has
# reliable channel and stride info, but using
# stride to calc feature dim requires info about padding of
# each stage that isn't captured.
training = backbone.training
if training:
backbone.eval()
o = self.backbone(
torch.zeros(1, in_channels, img_size[0], img_size[1]))
if isinstance(o, (list, tuple)):
# last feature if backbone outputs list/tuple of features
o = o[-1]
feature_size = o.shape[-2:]
feature_dim = o.shape[1]
backbone.train(training)
else:
feature_size = to_2tuple(feature_size)
if hasattr(self.backbone, 'feature_info'):
feature_dim = self.backbone.feature_info.channels()[-1]
else:
feature_dim = self.backbone.num_features
self.num_patches = feature_size[0] * feature_size[1]
# Use conv layer to embed
conv_cfg = conv_cfg or dict()
_conv_cfg = dict(
type='Conv2d', kernel_size=1, stride=1, padding=0, dilation=1)
_conv_cfg.update(conv_cfg)
self.projection = build_conv_layer(_conv_cfg, feature_dim, embed_dims)
def forward(self, x):
x = self.backbone(x)
if isinstance(x, (list, tuple)):
# last feature if backbone outputs list/tuple of features
x = x[-1]
x = self.projection(x).flatten(2).transpose(1, 2)
return x
class PatchMerging(BaseModule):
"""Merge patch feature map.
Modified from mmcv, and this module supports specifying whether to use
post-norm.
This layer groups feature map by kernel_size, and applies norm and linear
layers to the grouped feature map ((used in Swin Transformer)). Our
implementation uses :class:`torch.nn.Unfold` to merge patches, which is
about 25% faster than the original implementation. However, we need to
modify pretrained models for compatibility.
Args:
in_channels (int): The num of input channels. To gets fully covered
by filter and stride you specified.
out_channels (int): The num of output channels.
kernel_size (int | tuple, optional): the kernel size in the unfold
layer. Defaults to 2.
stride (int | tuple, optional): the stride of the sliding blocks in the
unfold layer. Defaults to None, which means to be set as
``kernel_size``.
padding (int | tuple | string ): The padding length of
embedding conv. When it is a string, it means the mode
of adaptive padding, support "same" and "corner" now.
Defaults to "corner".
dilation (int | tuple, optional): dilation parameter in the unfold
layer. Defaults to 1.
bias (bool, optional): Whether to add bias in linear layer or not.
Defaults to False.
norm_cfg (dict, optional): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
use_post_norm (bool): Whether to use post normalization here.
Defaults to False.
init_cfg (dict, optional): The extra config for initialization.
Defaults to None.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size=2,
stride=None,
padding='corner',
dilation=1,
bias=False,
norm_cfg=dict(type='LN'),
use_post_norm=False,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.in_channels = in_channels
self.out_channels = out_channels
self.use_post_norm = use_post_norm
if stride:
stride = stride
else:
stride = kernel_size
kernel_size = to_2tuple(kernel_size)
stride = to_2tuple(stride)
dilation = to_2tuple(dilation)
if isinstance(padding, str):
self.adaptive_padding = AdaptivePadding(
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=padding)
# disable the padding of unfold
padding = 0
else:
self.adaptive_padding = None
padding = to_2tuple(padding)
self.sampler = nn.Unfold(
kernel_size=kernel_size,
dilation=dilation,
padding=padding,
stride=stride)
sample_dim = kernel_size[0] * kernel_size[1] * in_channels
self.reduction = nn.Linear(sample_dim, out_channels, bias=bias)
if norm_cfg is not None:
# build pre or post norm layer based on different channels
if self.use_post_norm:
self.norm = build_norm_layer(norm_cfg, out_channels)[1]
else:
self.norm = build_norm_layer(norm_cfg, sample_dim)[1]
else:
self.norm = None
def forward(self, x, input_size):
"""
Args:
x (Tensor): Has shape (B, H*W, C_in).
input_size (tuple[int]): The spatial shape of x, arrange as (H, W).
Default: None.
Returns:
tuple: Contains merged results and its spatial shape.
- x (Tensor): Has shape (B, Merged_H * Merged_W, C_out)
- out_size (tuple[int]): Spatial shape of x, arrange as
(Merged_H, Merged_W).
"""
B, L, C = x.shape
assert isinstance(input_size, Sequence), f'Expect ' \
f'input_size is ' \
f'`Sequence` ' \
f'but get {input_size}'
H, W = input_size
assert L == H * W, 'input feature has wrong size'
x = x.view(B, H, W, C).permute([0, 3, 1, 2]) # B, C, H, W
if self.adaptive_padding:
x = self.adaptive_padding(x)
H, W = x.shape[-2:]
# Use nn.Unfold to merge patch. About 25% faster than original method,
# but need to modify pretrained model for compatibility
# if kernel_size=2 and stride=2, x should has shape (B, 4*C, H/2*W/2)
x = self.sampler(x)
out_h = (H + 2 * self.sampler.padding[0] - self.sampler.dilation[0] *
(self.sampler.kernel_size[0] - 1) -
1) // self.sampler.stride[0] + 1
out_w = (W + 2 * self.sampler.padding[1] - self.sampler.dilation[1] *
(self.sampler.kernel_size[1] - 1) -
1) // self.sampler.stride[1] + 1
output_size = (out_h, out_w)
x = x.transpose(1, 2) # B, H/2*W/2, 4*C
if self.use_post_norm:
# use post-norm here
x = self.reduction(x)
x = self.norm(x) if self.norm else x
else:
x = self.norm(x) if self.norm else x
x = self.reduction(x)
return x, output_size