# Copyright (c) OpenMMLab. All rights reserved. import math from functools import partial import torch import torch.nn as nn from mmengine.model import BaseModule from mmengine.utils import digit_version class ConditionalPositionEncoding(BaseModule): """The Conditional Position Encoding (CPE) module. The CPE is the implementation of 'Conditional Positional Encodings for Vision Transformers '_. Args: in_channels (int): Number of input channels. embed_dims (int): The feature dimension. Default: 768. stride (int): Stride of conv layer. Default: 1. """ def __init__(self, in_channels, embed_dims=768, stride=1, init_cfg=None): super(ConditionalPositionEncoding, self).__init__(init_cfg=init_cfg) self.proj = nn.Conv2d( in_channels, embed_dims, kernel_size=3, stride=stride, padding=1, bias=True, groups=embed_dims) self.stride = stride def forward(self, x, hw_shape): B, N, C = x.shape H, W = hw_shape feat_token = x # convert (B, N, C) to (B, C, H, W) cnn_feat = feat_token.transpose(1, 2).view(B, C, H, W).contiguous() if self.stride == 1: x = self.proj(cnn_feat) + cnn_feat else: x = self.proj(cnn_feat) x = x.flatten(2).transpose(1, 2) return x class PositionEncodingFourier(BaseModule): """The Position Encoding Fourier (PEF) module. The PEF is adopted from EdgeNeXt '_. Args: in_channels (int): Number of input channels. Default: 32 embed_dims (int): The feature dimension. Default: 768. temperature (int): Temperature. Default: 10000. dtype (torch.dtype): The data type. Default: torch.float32. init_cfg (dict): The config dict for initializing the module. Default: None. """ def __init__(self, in_channels=32, embed_dims=768, temperature=10000, dtype=torch.float32, init_cfg=None): super(PositionEncodingFourier, self).__init__(init_cfg=init_cfg) self.proj = nn.Conv2d(in_channels * 2, embed_dims, kernel_size=1) self.scale = 2 * math.pi self.in_channels = in_channels self.embed_dims = embed_dims self.dtype = dtype if digit_version(torch.__version__) < digit_version('1.8.0'): floor_div = torch.floor_divide else: floor_div = partial(torch.div, rounding_mode='floor') dim_t = torch.arange(in_channels, dtype=self.dtype) self.dim_t = temperature**(2 * floor_div(dim_t, 2) / in_channels) def forward(self, bhw_shape): B, H, W = bhw_shape mask = torch.zeros(B, H, W).bool().to(self.proj.weight.device) not_mask = ~mask eps = 1e-6 y_embed = not_mask.cumsum(1, dtype=self.dtype) x_embed = not_mask.cumsum(2, dtype=self.dtype) y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale dim_t = self.dim_t.to(mask.device) pos_x = x_embed[:, :, :, None] / dim_t pos_y = y_embed[:, :, :, None] / dim_t pos_x = torch.stack( (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) pos_y = torch.stack( (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) pos = self.proj(pos) return pos