ai-photo-gallery / mmcls /models /utils /position_encoding.py
KyanChen's picture
init
f549064
raw
history blame contribute delete
No virus
3.78 kB
# Copyright (c) OpenMMLab. All rights reserved.
import math
from functools import partial
import torch
import torch.nn as nn
from mmengine.model import BaseModule
from mmengine.utils import digit_version
class ConditionalPositionEncoding(BaseModule):
"""The Conditional Position Encoding (CPE) module.
The CPE is the implementation of 'Conditional Positional Encodings
for Vision Transformers <https://arxiv.org/abs/2102.10882>'_.
Args:
in_channels (int): Number of input channels.
embed_dims (int): The feature dimension. Default: 768.
stride (int): Stride of conv layer. Default: 1.
"""
def __init__(self, in_channels, embed_dims=768, stride=1, init_cfg=None):
super(ConditionalPositionEncoding, self).__init__(init_cfg=init_cfg)
self.proj = nn.Conv2d(
in_channels,
embed_dims,
kernel_size=3,
stride=stride,
padding=1,
bias=True,
groups=embed_dims)
self.stride = stride
def forward(self, x, hw_shape):
B, N, C = x.shape
H, W = hw_shape
feat_token = x
# convert (B, N, C) to (B, C, H, W)
cnn_feat = feat_token.transpose(1, 2).view(B, C, H, W).contiguous()
if self.stride == 1:
x = self.proj(cnn_feat) + cnn_feat
else:
x = self.proj(cnn_feat)
x = x.flatten(2).transpose(1, 2)
return x
class PositionEncodingFourier(BaseModule):
"""The Position Encoding Fourier (PEF) module.
The PEF is adopted from EdgeNeXt <https://arxiv.org/abs/2206.10589>'_.
Args:
in_channels (int): Number of input channels.
Default: 32
embed_dims (int): The feature dimension.
Default: 768.
temperature (int): Temperature.
Default: 10000.
dtype (torch.dtype): The data type.
Default: torch.float32.
init_cfg (dict): The config dict for initializing the module.
Default: None.
"""
def __init__(self,
in_channels=32,
embed_dims=768,
temperature=10000,
dtype=torch.float32,
init_cfg=None):
super(PositionEncodingFourier, self).__init__(init_cfg=init_cfg)
self.proj = nn.Conv2d(in_channels * 2, embed_dims, kernel_size=1)
self.scale = 2 * math.pi
self.in_channels = in_channels
self.embed_dims = embed_dims
self.dtype = dtype
if digit_version(torch.__version__) < digit_version('1.8.0'):
floor_div = torch.floor_divide
else:
floor_div = partial(torch.div, rounding_mode='floor')
dim_t = torch.arange(in_channels, dtype=self.dtype)
self.dim_t = temperature**(2 * floor_div(dim_t, 2) / in_channels)
def forward(self, bhw_shape):
B, H, W = bhw_shape
mask = torch.zeros(B, H, W).bool().to(self.proj.weight.device)
not_mask = ~mask
eps = 1e-6
y_embed = not_mask.cumsum(1, dtype=self.dtype)
x_embed = not_mask.cumsum(2, dtype=self.dtype)
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = self.dim_t.to(mask.device)
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
pos_x = torch.stack(
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()),
dim=4).flatten(3)
pos_y = torch.stack(
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()),
dim=4).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
pos = self.proj(pos)
return pos