# Copyright (c) OpenMMLab. All rights reserved. import warnings from typing import Sequence import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from mmcv.cnn import build_conv_layer, build_norm_layer from mmcv.cnn.bricks.transformer import AdaptivePadding from mmengine.model import BaseModule from .helpers import to_2tuple def resize_pos_embed(pos_embed, src_shape, dst_shape, mode='bicubic', num_extra_tokens=1): """Resize pos_embed weights. Args: pos_embed (torch.Tensor): Position embedding weights with shape [1, L, C]. src_shape (tuple): The resolution of downsampled origin training image, in format (H, W). dst_shape (tuple): The resolution of downsampled new training image, in format (H, W). mode (str): Algorithm used for upsampling. Choose one from 'nearest', 'linear', 'bilinear', 'bicubic' and 'trilinear'. Defaults to 'bicubic'. num_extra_tokens (int): The number of extra tokens, such as cls_token. Defaults to 1. Returns: torch.Tensor: The resized pos_embed of shape [1, L_new, C] """ if src_shape[0] == dst_shape[0] and src_shape[1] == dst_shape[1]: return pos_embed assert pos_embed.ndim == 3, 'shape of pos_embed must be [1, L, C]' _, L, C = pos_embed.shape src_h, src_w = src_shape assert L == src_h * src_w + num_extra_tokens, \ f"The length of `pos_embed` ({L}) doesn't match the expected " \ f'shape ({src_h}*{src_w}+{num_extra_tokens}). Please check the' \ '`img_size` argument.' extra_tokens = pos_embed[:, :num_extra_tokens] src_weight = pos_embed[:, num_extra_tokens:] src_weight = src_weight.reshape(1, src_h, src_w, C).permute(0, 3, 1, 2) # The cubic interpolate algorithm only accepts float32 dst_weight = F.interpolate( src_weight.float(), size=dst_shape, align_corners=False, mode=mode) dst_weight = torch.flatten(dst_weight, 2).transpose(1, 2) dst_weight = dst_weight.to(src_weight.dtype) return torch.cat((extra_tokens, dst_weight), dim=1) def resize_relative_position_bias_table(src_shape, dst_shape, table, num_head): """Resize relative position bias table. Args: src_shape (int): The resolution of downsampled origin training image, in format (H, W). dst_shape (int): The resolution of downsampled new training image, in format (H, W). table (tensor): The relative position bias of the pretrained model. num_head (int): Number of attention heads. Returns: torch.Tensor: The resized relative position bias table. """ from scipy import interpolate def geometric_progression(a, r, n): return a * (1.0 - r**n) / (1.0 - r) left, right = 1.01, 1.5 while right - left > 1e-6: q = (left + right) / 2.0 gp = geometric_progression(1, q, src_shape // 2) if gp > dst_shape // 2: right = q else: left = q dis = [] cur = 1 for i in range(src_shape // 2): dis.append(cur) cur += q**(i + 1) r_ids = [-_ for _ in reversed(dis)] x = r_ids + [0] + dis y = r_ids + [0] + dis t = dst_shape // 2.0 dx = np.arange(-t, t + 0.1, 1.0) dy = np.arange(-t, t + 0.1, 1.0) all_rel_pos_bias = [] for i in range(num_head): z = table[:, i].view(src_shape, src_shape).float().numpy() f_cubic = interpolate.interp2d(x, y, z, kind='cubic') all_rel_pos_bias.append( torch.Tensor(f_cubic(dx, dy)).contiguous().view(-1, 1).to(table.device)) new_rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1) return new_rel_pos_bias class PatchEmbed(BaseModule): """Image to Patch Embedding. We use a conv layer to implement PatchEmbed. Args: img_size (int | tuple): The size of input image. Default: 224 in_channels (int): The num of input channels. Default: 3 embed_dims (int): The dimensions of embedding. Default: 768 norm_cfg (dict, optional): Config dict for normalization layer. Default: None conv_cfg (dict, optional): The config dict for conv layers. Default: None init_cfg (`mmcv.ConfigDict`, optional): The Config for initialization. Default: None """ def __init__(self, img_size=224, in_channels=3, embed_dims=768, norm_cfg=None, conv_cfg=None, init_cfg=None): super(PatchEmbed, self).__init__(init_cfg) warnings.warn('The `PatchEmbed` in mmcls will be deprecated. ' 'Please use `mmcv.cnn.bricks.transformer.PatchEmbed`. ' "It's more general and supports dynamic input shape") if isinstance(img_size, int): img_size = to_2tuple(img_size) elif isinstance(img_size, tuple): if len(img_size) == 1: img_size = to_2tuple(img_size[0]) assert len(img_size) == 2, \ f'The size of image should have length 1 or 2, ' \ f'but got {len(img_size)}' self.img_size = img_size self.embed_dims = embed_dims # Use conv layer to embed conv_cfg = conv_cfg or dict() _conv_cfg = dict( type='Conv2d', kernel_size=16, stride=16, padding=0, dilation=1) _conv_cfg.update(conv_cfg) self.projection = build_conv_layer(_conv_cfg, in_channels, embed_dims) # Calculate how many patches a input image is splited to. h_out, w_out = [(self.img_size[i] + 2 * self.projection.padding[i] - self.projection.dilation[i] * (self.projection.kernel_size[i] - 1) - 1) // self.projection.stride[i] + 1 for i in range(2)] self.patches_resolution = (h_out, w_out) self.num_patches = h_out * w_out if norm_cfg is not None: self.norm = build_norm_layer(norm_cfg, embed_dims)[1] else: self.norm = None def forward(self, x): B, C, H, W = x.shape assert H == self.img_size[0] and W == self.img_size[1], \ f"Input image size ({H}*{W}) doesn't " \ f'match model ({self.img_size[0]}*{self.img_size[1]}).' # The output size is (B, N, D), where N=H*W/P/P, D is embid_dim x = self.projection(x).flatten(2).transpose(1, 2) if self.norm is not None: x = self.norm(x) return x # Modified from pytorch-image-models class HybridEmbed(BaseModule): """CNN Feature Map Embedding. Extract feature map from CNN, flatten, project to embedding dim. Args: backbone (nn.Module): CNN backbone img_size (int | tuple): The size of input image. Default: 224 feature_size (int | tuple, optional): Size of feature map extracted by CNN backbone. Default: None in_channels (int): The num of input channels. Default: 3 embed_dims (int): The dimensions of embedding. Default: 768 conv_cfg (dict, optional): The config dict for conv layers. Default: None. init_cfg (`mmcv.ConfigDict`, optional): The Config for initialization. Default: None. """ def __init__(self, backbone, img_size=224, feature_size=None, in_channels=3, embed_dims=768, conv_cfg=None, init_cfg=None): super(HybridEmbed, self).__init__(init_cfg) assert isinstance(backbone, nn.Module) if isinstance(img_size, int): img_size = to_2tuple(img_size) elif isinstance(img_size, tuple): if len(img_size) == 1: img_size = to_2tuple(img_size[0]) assert len(img_size) == 2, \ f'The size of image should have length 1 or 2, ' \ f'but got {len(img_size)}' self.img_size = img_size self.backbone = backbone if feature_size is None: with torch.no_grad(): # FIXME this is hacky, but most reliable way of # determining the exact dim of the output feature # map for all networks, the feature metadata has # reliable channel and stride info, but using # stride to calc feature dim requires info about padding of # each stage that isn't captured. training = backbone.training if training: backbone.eval() o = self.backbone( torch.zeros(1, in_channels, img_size[0], img_size[1])) if isinstance(o, (list, tuple)): # last feature if backbone outputs list/tuple of features o = o[-1] feature_size = o.shape[-2:] feature_dim = o.shape[1] backbone.train(training) else: feature_size = to_2tuple(feature_size) if hasattr(self.backbone, 'feature_info'): feature_dim = self.backbone.feature_info.channels()[-1] else: feature_dim = self.backbone.num_features self.num_patches = feature_size[0] * feature_size[1] # Use conv layer to embed conv_cfg = conv_cfg or dict() _conv_cfg = dict( type='Conv2d', kernel_size=1, stride=1, padding=0, dilation=1) _conv_cfg.update(conv_cfg) self.projection = build_conv_layer(_conv_cfg, feature_dim, embed_dims) def forward(self, x): x = self.backbone(x) if isinstance(x, (list, tuple)): # last feature if backbone outputs list/tuple of features x = x[-1] x = self.projection(x).flatten(2).transpose(1, 2) return x class PatchMerging(BaseModule): """Merge patch feature map. Modified from mmcv, and this module supports specifying whether to use post-norm. This layer groups feature map by kernel_size, and applies norm and linear layers to the grouped feature map ((used in Swin Transformer)). Our implementation uses :class:`torch.nn.Unfold` to merge patches, which is about 25% faster than the original implementation. However, we need to modify pretrained models for compatibility. Args: in_channels (int): The num of input channels. To gets fully covered by filter and stride you specified. out_channels (int): The num of output channels. kernel_size (int | tuple, optional): the kernel size in the unfold layer. Defaults to 2. stride (int | tuple, optional): the stride of the sliding blocks in the unfold layer. Defaults to None, which means to be set as ``kernel_size``. padding (int | tuple | string ): The padding length of embedding conv. When it is a string, it means the mode of adaptive padding, support "same" and "corner" now. Defaults to "corner". dilation (int | tuple, optional): dilation parameter in the unfold layer. Defaults to 1. bias (bool, optional): Whether to add bias in linear layer or not. Defaults to False. norm_cfg (dict, optional): Config dict for normalization layer. Defaults to ``dict(type='LN')``. use_post_norm (bool): Whether to use post normalization here. Defaults to False. init_cfg (dict, optional): The extra config for initialization. Defaults to None. """ def __init__(self, in_channels, out_channels, kernel_size=2, stride=None, padding='corner', dilation=1, bias=False, norm_cfg=dict(type='LN'), use_post_norm=False, init_cfg=None): super().__init__(init_cfg=init_cfg) self.in_channels = in_channels self.out_channels = out_channels self.use_post_norm = use_post_norm if stride: stride = stride else: stride = kernel_size kernel_size = to_2tuple(kernel_size) stride = to_2tuple(stride) dilation = to_2tuple(dilation) if isinstance(padding, str): self.adaptive_padding = AdaptivePadding( kernel_size=kernel_size, stride=stride, dilation=dilation, padding=padding) # disable the padding of unfold padding = 0 else: self.adaptive_padding = None padding = to_2tuple(padding) self.sampler = nn.Unfold( kernel_size=kernel_size, dilation=dilation, padding=padding, stride=stride) sample_dim = kernel_size[0] * kernel_size[1] * in_channels self.reduction = nn.Linear(sample_dim, out_channels, bias=bias) if norm_cfg is not None: # build pre or post norm layer based on different channels if self.use_post_norm: self.norm = build_norm_layer(norm_cfg, out_channels)[1] else: self.norm = build_norm_layer(norm_cfg, sample_dim)[1] else: self.norm = None def forward(self, x, input_size): """ Args: x (Tensor): Has shape (B, H*W, C_in). input_size (tuple[int]): The spatial shape of x, arrange as (H, W). Default: None. Returns: tuple: Contains merged results and its spatial shape. - x (Tensor): Has shape (B, Merged_H * Merged_W, C_out) - out_size (tuple[int]): Spatial shape of x, arrange as (Merged_H, Merged_W). """ B, L, C = x.shape assert isinstance(input_size, Sequence), f'Expect ' \ f'input_size is ' \ f'`Sequence` ' \ f'but get {input_size}' H, W = input_size assert L == H * W, 'input feature has wrong size' x = x.view(B, H, W, C).permute([0, 3, 1, 2]) # B, C, H, W if self.adaptive_padding: x = self.adaptive_padding(x) H, W = x.shape[-2:] # Use nn.Unfold to merge patch. About 25% faster than original method, # but need to modify pretrained model for compatibility # if kernel_size=2 and stride=2, x should has shape (B, 4*C, H/2*W/2) x = self.sampler(x) out_h = (H + 2 * self.sampler.padding[0] - self.sampler.dilation[0] * (self.sampler.kernel_size[0] - 1) - 1) // self.sampler.stride[0] + 1 out_w = (W + 2 * self.sampler.padding[1] - self.sampler.dilation[1] * (self.sampler.kernel_size[1] - 1) - 1) // self.sampler.stride[1] + 1 output_size = (out_h, out_w) x = x.transpose(1, 2) # B, H/2*W/2, 4*C if self.use_post_norm: # use post-norm here x = self.reduction(x) x = self.norm(x) if self.norm else x else: x = self.norm(x) if self.norm else x x = self.reduction(x) return x, output_size