Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import warnings | |
from typing import Sequence | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from mmcv.cnn import build_conv_layer, build_norm_layer | |
from mmcv.cnn.bricks.transformer import AdaptivePadding | |
from mmengine.model import BaseModule | |
from .helpers import to_2tuple | |
def resize_pos_embed(pos_embed, | |
src_shape, | |
dst_shape, | |
mode='bicubic', | |
num_extra_tokens=1): | |
"""Resize pos_embed weights. | |
Args: | |
pos_embed (torch.Tensor): Position embedding weights with shape | |
[1, L, C]. | |
src_shape (tuple): The resolution of downsampled origin training | |
image, in format (H, W). | |
dst_shape (tuple): The resolution of downsampled new training | |
image, in format (H, W). | |
mode (str): Algorithm used for upsampling. Choose one from 'nearest', | |
'linear', 'bilinear', 'bicubic' and 'trilinear'. | |
Defaults to 'bicubic'. | |
num_extra_tokens (int): The number of extra tokens, such as cls_token. | |
Defaults to 1. | |
Returns: | |
torch.Tensor: The resized pos_embed of shape [1, L_new, C] | |
""" | |
if src_shape[0] == dst_shape[0] and src_shape[1] == dst_shape[1]: | |
return pos_embed | |
assert pos_embed.ndim == 3, 'shape of pos_embed must be [1, L, C]' | |
_, L, C = pos_embed.shape | |
src_h, src_w = src_shape | |
assert L == src_h * src_w + num_extra_tokens, \ | |
f"The length of `pos_embed` ({L}) doesn't match the expected " \ | |
f'shape ({src_h}*{src_w}+{num_extra_tokens}). Please check the' \ | |
'`img_size` argument.' | |
extra_tokens = pos_embed[:, :num_extra_tokens] | |
src_weight = pos_embed[:, num_extra_tokens:] | |
src_weight = src_weight.reshape(1, src_h, src_w, C).permute(0, 3, 1, 2) | |
# The cubic interpolate algorithm only accepts float32 | |
dst_weight = F.interpolate( | |
src_weight.float(), size=dst_shape, align_corners=False, mode=mode) | |
dst_weight = torch.flatten(dst_weight, 2).transpose(1, 2) | |
dst_weight = dst_weight.to(src_weight.dtype) | |
return torch.cat((extra_tokens, dst_weight), dim=1) | |
def resize_relative_position_bias_table(src_shape, dst_shape, table, num_head): | |
"""Resize relative position bias table. | |
Args: | |
src_shape (int): The resolution of downsampled origin training | |
image, in format (H, W). | |
dst_shape (int): The resolution of downsampled new training | |
image, in format (H, W). | |
table (tensor): The relative position bias of the pretrained model. | |
num_head (int): Number of attention heads. | |
Returns: | |
torch.Tensor: The resized relative position bias table. | |
""" | |
from scipy import interpolate | |
def geometric_progression(a, r, n): | |
return a * (1.0 - r**n) / (1.0 - r) | |
left, right = 1.01, 1.5 | |
while right - left > 1e-6: | |
q = (left + right) / 2.0 | |
gp = geometric_progression(1, q, src_shape // 2) | |
if gp > dst_shape // 2: | |
right = q | |
else: | |
left = q | |
dis = [] | |
cur = 1 | |
for i in range(src_shape // 2): | |
dis.append(cur) | |
cur += q**(i + 1) | |
r_ids = [-_ for _ in reversed(dis)] | |
x = r_ids + [0] + dis | |
y = r_ids + [0] + dis | |
t = dst_shape // 2.0 | |
dx = np.arange(-t, t + 0.1, 1.0) | |
dy = np.arange(-t, t + 0.1, 1.0) | |
all_rel_pos_bias = [] | |
for i in range(num_head): | |
z = table[:, i].view(src_shape, src_shape).float().numpy() | |
f_cubic = interpolate.interp2d(x, y, z, kind='cubic') | |
all_rel_pos_bias.append( | |
torch.Tensor(f_cubic(dx, | |
dy)).contiguous().view(-1, | |
1).to(table.device)) | |
new_rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1) | |
return new_rel_pos_bias | |
class PatchEmbed(BaseModule): | |
"""Image to Patch Embedding. | |
We use a conv layer to implement PatchEmbed. | |
Args: | |
img_size (int | tuple): The size of input image. Default: 224 | |
in_channels (int): The num of input channels. Default: 3 | |
embed_dims (int): The dimensions of embedding. Default: 768 | |
norm_cfg (dict, optional): Config dict for normalization layer. | |
Default: None | |
conv_cfg (dict, optional): The config dict for conv layers. | |
Default: None | |
init_cfg (`mmcv.ConfigDict`, optional): The Config for initialization. | |
Default: None | |
""" | |
def __init__(self, | |
img_size=224, | |
in_channels=3, | |
embed_dims=768, | |
norm_cfg=None, | |
conv_cfg=None, | |
init_cfg=None): | |
super(PatchEmbed, self).__init__(init_cfg) | |
warnings.warn('The `PatchEmbed` in mmcls will be deprecated. ' | |
'Please use `mmcv.cnn.bricks.transformer.PatchEmbed`. ' | |
"It's more general and supports dynamic input shape") | |
if isinstance(img_size, int): | |
img_size = to_2tuple(img_size) | |
elif isinstance(img_size, tuple): | |
if len(img_size) == 1: | |
img_size = to_2tuple(img_size[0]) | |
assert len(img_size) == 2, \ | |
f'The size of image should have length 1 or 2, ' \ | |
f'but got {len(img_size)}' | |
self.img_size = img_size | |
self.embed_dims = embed_dims | |
# Use conv layer to embed | |
conv_cfg = conv_cfg or dict() | |
_conv_cfg = dict( | |
type='Conv2d', kernel_size=16, stride=16, padding=0, dilation=1) | |
_conv_cfg.update(conv_cfg) | |
self.projection = build_conv_layer(_conv_cfg, in_channels, embed_dims) | |
# Calculate how many patches a input image is splited to. | |
h_out, w_out = [(self.img_size[i] + 2 * self.projection.padding[i] - | |
self.projection.dilation[i] * | |
(self.projection.kernel_size[i] - 1) - 1) // | |
self.projection.stride[i] + 1 for i in range(2)] | |
self.patches_resolution = (h_out, w_out) | |
self.num_patches = h_out * w_out | |
if norm_cfg is not None: | |
self.norm = build_norm_layer(norm_cfg, embed_dims)[1] | |
else: | |
self.norm = None | |
def forward(self, x): | |
B, C, H, W = x.shape | |
assert H == self.img_size[0] and W == self.img_size[1], \ | |
f"Input image size ({H}*{W}) doesn't " \ | |
f'match model ({self.img_size[0]}*{self.img_size[1]}).' | |
# The output size is (B, N, D), where N=H*W/P/P, D is embid_dim | |
x = self.projection(x).flatten(2).transpose(1, 2) | |
if self.norm is not None: | |
x = self.norm(x) | |
return x | |
# Modified from pytorch-image-models | |
class HybridEmbed(BaseModule): | |
"""CNN Feature Map Embedding. | |
Extract feature map from CNN, flatten, | |
project to embedding dim. | |
Args: | |
backbone (nn.Module): CNN backbone | |
img_size (int | tuple): The size of input image. Default: 224 | |
feature_size (int | tuple, optional): Size of feature map extracted by | |
CNN backbone. Default: None | |
in_channels (int): The num of input channels. Default: 3 | |
embed_dims (int): The dimensions of embedding. Default: 768 | |
conv_cfg (dict, optional): The config dict for conv layers. | |
Default: None. | |
init_cfg (`mmcv.ConfigDict`, optional): The Config for initialization. | |
Default: None. | |
""" | |
def __init__(self, | |
backbone, | |
img_size=224, | |
feature_size=None, | |
in_channels=3, | |
embed_dims=768, | |
conv_cfg=None, | |
init_cfg=None): | |
super(HybridEmbed, self).__init__(init_cfg) | |
assert isinstance(backbone, nn.Module) | |
if isinstance(img_size, int): | |
img_size = to_2tuple(img_size) | |
elif isinstance(img_size, tuple): | |
if len(img_size) == 1: | |
img_size = to_2tuple(img_size[0]) | |
assert len(img_size) == 2, \ | |
f'The size of image should have length 1 or 2, ' \ | |
f'but got {len(img_size)}' | |
self.img_size = img_size | |
self.backbone = backbone | |
if feature_size is None: | |
with torch.no_grad(): | |
# FIXME this is hacky, but most reliable way of | |
# determining the exact dim of the output feature | |
# map for all networks, the feature metadata has | |
# reliable channel and stride info, but using | |
# stride to calc feature dim requires info about padding of | |
# each stage that isn't captured. | |
training = backbone.training | |
if training: | |
backbone.eval() | |
o = self.backbone( | |
torch.zeros(1, in_channels, img_size[0], img_size[1])) | |
if isinstance(o, (list, tuple)): | |
# last feature if backbone outputs list/tuple of features | |
o = o[-1] | |
feature_size = o.shape[-2:] | |
feature_dim = o.shape[1] | |
backbone.train(training) | |
else: | |
feature_size = to_2tuple(feature_size) | |
if hasattr(self.backbone, 'feature_info'): | |
feature_dim = self.backbone.feature_info.channels()[-1] | |
else: | |
feature_dim = self.backbone.num_features | |
self.num_patches = feature_size[0] * feature_size[1] | |
# Use conv layer to embed | |
conv_cfg = conv_cfg or dict() | |
_conv_cfg = dict( | |
type='Conv2d', kernel_size=1, stride=1, padding=0, dilation=1) | |
_conv_cfg.update(conv_cfg) | |
self.projection = build_conv_layer(_conv_cfg, feature_dim, embed_dims) | |
def forward(self, x): | |
x = self.backbone(x) | |
if isinstance(x, (list, tuple)): | |
# last feature if backbone outputs list/tuple of features | |
x = x[-1] | |
x = self.projection(x).flatten(2).transpose(1, 2) | |
return x | |
class PatchMerging(BaseModule): | |
"""Merge patch feature map. | |
Modified from mmcv, and this module supports specifying whether to use | |
post-norm. | |
This layer groups feature map by kernel_size, and applies norm and linear | |
layers to the grouped feature map ((used in Swin Transformer)). Our | |
implementation uses :class:`torch.nn.Unfold` to merge patches, which is | |
about 25% faster than the original implementation. However, we need to | |
modify pretrained models for compatibility. | |
Args: | |
in_channels (int): The num of input channels. To gets fully covered | |
by filter and stride you specified. | |
out_channels (int): The num of output channels. | |
kernel_size (int | tuple, optional): the kernel size in the unfold | |
layer. Defaults to 2. | |
stride (int | tuple, optional): the stride of the sliding blocks in the | |
unfold layer. Defaults to None, which means to be set as | |
``kernel_size``. | |
padding (int | tuple | string ): The padding length of | |
embedding conv. When it is a string, it means the mode | |
of adaptive padding, support "same" and "corner" now. | |
Defaults to "corner". | |
dilation (int | tuple, optional): dilation parameter in the unfold | |
layer. Defaults to 1. | |
bias (bool, optional): Whether to add bias in linear layer or not. | |
Defaults to False. | |
norm_cfg (dict, optional): Config dict for normalization layer. | |
Defaults to ``dict(type='LN')``. | |
use_post_norm (bool): Whether to use post normalization here. | |
Defaults to False. | |
init_cfg (dict, optional): The extra config for initialization. | |
Defaults to None. | |
""" | |
def __init__(self, | |
in_channels, | |
out_channels, | |
kernel_size=2, | |
stride=None, | |
padding='corner', | |
dilation=1, | |
bias=False, | |
norm_cfg=dict(type='LN'), | |
use_post_norm=False, | |
init_cfg=None): | |
super().__init__(init_cfg=init_cfg) | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
self.use_post_norm = use_post_norm | |
if stride: | |
stride = stride | |
else: | |
stride = kernel_size | |
kernel_size = to_2tuple(kernel_size) | |
stride = to_2tuple(stride) | |
dilation = to_2tuple(dilation) | |
if isinstance(padding, str): | |
self.adaptive_padding = AdaptivePadding( | |
kernel_size=kernel_size, | |
stride=stride, | |
dilation=dilation, | |
padding=padding) | |
# disable the padding of unfold | |
padding = 0 | |
else: | |
self.adaptive_padding = None | |
padding = to_2tuple(padding) | |
self.sampler = nn.Unfold( | |
kernel_size=kernel_size, | |
dilation=dilation, | |
padding=padding, | |
stride=stride) | |
sample_dim = kernel_size[0] * kernel_size[1] * in_channels | |
self.reduction = nn.Linear(sample_dim, out_channels, bias=bias) | |
if norm_cfg is not None: | |
# build pre or post norm layer based on different channels | |
if self.use_post_norm: | |
self.norm = build_norm_layer(norm_cfg, out_channels)[1] | |
else: | |
self.norm = build_norm_layer(norm_cfg, sample_dim)[1] | |
else: | |
self.norm = None | |
def forward(self, x, input_size): | |
""" | |
Args: | |
x (Tensor): Has shape (B, H*W, C_in). | |
input_size (tuple[int]): The spatial shape of x, arrange as (H, W). | |
Default: None. | |
Returns: | |
tuple: Contains merged results and its spatial shape. | |
- x (Tensor): Has shape (B, Merged_H * Merged_W, C_out) | |
- out_size (tuple[int]): Spatial shape of x, arrange as | |
(Merged_H, Merged_W). | |
""" | |
B, L, C = x.shape | |
assert isinstance(input_size, Sequence), f'Expect ' \ | |
f'input_size is ' \ | |
f'`Sequence` ' \ | |
f'but get {input_size}' | |
H, W = input_size | |
assert L == H * W, 'input feature has wrong size' | |
x = x.view(B, H, W, C).permute([0, 3, 1, 2]) # B, C, H, W | |
if self.adaptive_padding: | |
x = self.adaptive_padding(x) | |
H, W = x.shape[-2:] | |
# Use nn.Unfold to merge patch. About 25% faster than original method, | |
# but need to modify pretrained model for compatibility | |
# if kernel_size=2 and stride=2, x should has shape (B, 4*C, H/2*W/2) | |
x = self.sampler(x) | |
out_h = (H + 2 * self.sampler.padding[0] - self.sampler.dilation[0] * | |
(self.sampler.kernel_size[0] - 1) - | |
1) // self.sampler.stride[0] + 1 | |
out_w = (W + 2 * self.sampler.padding[1] - self.sampler.dilation[1] * | |
(self.sampler.kernel_size[1] - 1) - | |
1) // self.sampler.stride[1] + 1 | |
output_size = (out_h, out_w) | |
x = x.transpose(1, 2) # B, H/2*W/2, 4*C | |
if self.use_post_norm: | |
# use post-norm here | |
x = self.reduction(x) | |
x = self.norm(x) if self.norm else x | |
else: | |
x = self.norm(x) if self.norm else x | |
x = self.reduction(x) | |
return x, output_size | |