KyanChen's picture
init
f549064
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Sequence
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import build_activation_layer, build_norm_layer
from mmcv.cnn.bricks.drop import DropPath
from mmcv.cnn.bricks.transformer import AdaptivePadding
from mmengine.model import BaseModule
from mmengine.model.weight_init import trunc_normal_
from mmcls.registry import MODELS
from .base_backbone import BaseBackbone
from .vision_transformer import TransformerEncoderLayer
class ConvBlock(BaseModule):
"""Basic convluation block used in Conformer.
This block includes three convluation modules, and supports three new
functions:
1. Returns the output of both the final layers and the second convluation
module.
2. Fuses the input of the second convluation module with an extra input
feature map.
3. Supports to add an extra convluation module to the identity connection.
Args:
in_channels (int): The number of input channels.
out_channels (int): The number of output channels.
stride (int): The stride of the second convluation module.
Defaults to 1.
groups (int): The groups of the second convluation module.
Defaults to 1.
drop_path_rate (float): The rate of the DropPath layer. Defaults to 0.
with_residual_conv (bool): Whether to add an extra convluation module
to the identity connection. Defaults to False.
norm_cfg (dict): The config of normalization layers.
Defaults to ``dict(type='BN', eps=1e-6)``.
act_cfg (dict): The config of activative functions.
Defaults to ``dict(type='ReLU', inplace=True))``.
init_cfg (dict, optional): The extra config to initialize the module.
Defaults to None.
"""
def __init__(self,
in_channels,
out_channels,
stride=1,
groups=1,
drop_path_rate=0.,
with_residual_conv=False,
norm_cfg=dict(type='BN', eps=1e-6),
act_cfg=dict(type='ReLU', inplace=True),
init_cfg=None):
super(ConvBlock, self).__init__(init_cfg=init_cfg)
expansion = 4
mid_channels = out_channels // expansion
self.conv1 = nn.Conv2d(
in_channels,
mid_channels,
kernel_size=1,
stride=1,
padding=0,
bias=False)
self.bn1 = build_norm_layer(norm_cfg, mid_channels)[1]
self.act1 = build_activation_layer(act_cfg)
self.conv2 = nn.Conv2d(
mid_channels,
mid_channels,
kernel_size=3,
stride=stride,
groups=groups,
padding=1,
bias=False)
self.bn2 = build_norm_layer(norm_cfg, mid_channels)[1]
self.act2 = build_activation_layer(act_cfg)
self.conv3 = nn.Conv2d(
mid_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0,
bias=False)
self.bn3 = build_norm_layer(norm_cfg, out_channels)[1]
self.act3 = build_activation_layer(act_cfg)
if with_residual_conv:
self.residual_conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size=1,
stride=stride,
padding=0,
bias=False)
self.residual_bn = build_norm_layer(norm_cfg, out_channels)[1]
self.with_residual_conv = with_residual_conv
self.drop_path = DropPath(
drop_path_rate) if drop_path_rate > 0. else nn.Identity()
def zero_init_last_bn(self):
nn.init.zeros_(self.bn3.weight)
def forward(self, x, fusion_features=None, out_conv2=True):
identity = x
x = self.conv1(x)
x = self.bn1(x)
x = self.act1(x)
x = self.conv2(x) if fusion_features is None else self.conv2(
x + fusion_features)
x = self.bn2(x)
x2 = self.act2(x)
x = self.conv3(x2)
x = self.bn3(x)
if self.drop_path is not None:
x = self.drop_path(x)
if self.with_residual_conv:
identity = self.residual_conv(identity)
identity = self.residual_bn(identity)
x += identity
x = self.act3(x)
if out_conv2:
return x, x2
else:
return x
class FCUDown(BaseModule):
"""CNN feature maps -> Transformer patch embeddings."""
def __init__(self,
in_channels,
out_channels,
down_stride,
with_cls_token=True,
norm_cfg=dict(type='LN', eps=1e-6),
act_cfg=dict(type='GELU'),
init_cfg=None):
super(FCUDown, self).__init__(init_cfg=init_cfg)
self.down_stride = down_stride
self.with_cls_token = with_cls_token
self.conv_project = nn.Conv2d(
in_channels, out_channels, kernel_size=1, stride=1, padding=0)
self.sample_pooling = nn.AvgPool2d(
kernel_size=down_stride, stride=down_stride)
self.ln = build_norm_layer(norm_cfg, out_channels)[1]
self.act = build_activation_layer(act_cfg)
def forward(self, x, x_t):
x = self.conv_project(x) # [N, C, H, W]
x = self.sample_pooling(x).flatten(2).transpose(1, 2)
x = self.ln(x)
x = self.act(x)
if self.with_cls_token:
x = torch.cat([x_t[:, 0][:, None, :], x], dim=1)
return x
class FCUUp(BaseModule):
"""Transformer patch embeddings -> CNN feature maps."""
def __init__(self,
in_channels,
out_channels,
up_stride,
with_cls_token=True,
norm_cfg=dict(type='BN', eps=1e-6),
act_cfg=dict(type='ReLU', inplace=True),
init_cfg=None):
super(FCUUp, self).__init__(init_cfg=init_cfg)
self.up_stride = up_stride
self.with_cls_token = with_cls_token
self.conv_project = nn.Conv2d(
in_channels, out_channels, kernel_size=1, stride=1, padding=0)
self.bn = build_norm_layer(norm_cfg, out_channels)[1]
self.act = build_activation_layer(act_cfg)
def forward(self, x, H, W):
B, _, C = x.shape
# [N, 197, 384] -> [N, 196, 384] -> [N, 384, 196] -> [N, 384, 14, 14]
if self.with_cls_token:
x_r = x[:, 1:].transpose(1, 2).reshape(B, C, H, W)
else:
x_r = x.transpose(1, 2).reshape(B, C, H, W)
x_r = self.act(self.bn(self.conv_project(x_r)))
return F.interpolate(
x_r, size=(H * self.up_stride, W * self.up_stride))
class ConvTransBlock(BaseModule):
"""Basic module for Conformer.
This module is a fusion of CNN block transformer encoder block.
Args:
in_channels (int): The number of input channels in conv blocks.
out_channels (int): The number of output channels in conv blocks.
embed_dims (int): The embedding dimension in transformer blocks.
conv_stride (int): The stride of conv2d layers. Defaults to 1.
groups (int): The groups of conv blocks. Defaults to 1.
with_residual_conv (bool): Whether to add a conv-bn layer to the
identity connect in the conv block. Defaults to False.
down_stride (int): The stride of the downsample pooling layer.
Defaults to 4.
num_heads (int): The number of heads in transformer attention layers.
Defaults to 12.
mlp_ratio (float): The expansion ratio in transformer FFN module.
Defaults to 4.
qkv_bias (bool): Enable bias for qkv if True. Defaults to False.
with_cls_token (bool): Whether use class token or not.
Defaults to True.
drop_rate (float): The dropout rate of the output projection and
FFN in the transformer block. Defaults to 0.
attn_drop_rate (float): The dropout rate after the attention
calculation in the transformer block. Defaults to 0.
drop_path_rate (bloat): The drop path rate in both the conv block
and the transformer block. Defaults to 0.
last_fusion (bool): Whether this block is the last stage. If so,
downsample the fusion feature map.
init_cfg (dict, optional): The extra config to initialize the module.
Defaults to None.
"""
def __init__(self,
in_channels,
out_channels,
embed_dims,
conv_stride=1,
groups=1,
with_residual_conv=False,
down_stride=4,
num_heads=12,
mlp_ratio=4.,
qkv_bias=False,
with_cls_token=True,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
last_fusion=False,
init_cfg=None):
super(ConvTransBlock, self).__init__(init_cfg=init_cfg)
expansion = 4
self.cnn_block = ConvBlock(
in_channels=in_channels,
out_channels=out_channels,
with_residual_conv=with_residual_conv,
stride=conv_stride,
groups=groups)
if last_fusion:
self.fusion_block = ConvBlock(
in_channels=out_channels,
out_channels=out_channels,
stride=2,
with_residual_conv=True,
groups=groups,
drop_path_rate=drop_path_rate)
else:
self.fusion_block = ConvBlock(
in_channels=out_channels,
out_channels=out_channels,
groups=groups,
drop_path_rate=drop_path_rate)
self.squeeze_block = FCUDown(
in_channels=out_channels // expansion,
out_channels=embed_dims,
down_stride=down_stride,
with_cls_token=with_cls_token)
self.expand_block = FCUUp(
in_channels=embed_dims,
out_channels=out_channels // expansion,
up_stride=down_stride,
with_cls_token=with_cls_token)
self.trans_block = TransformerEncoderLayer(
embed_dims=embed_dims,
num_heads=num_heads,
feedforward_channels=int(embed_dims * mlp_ratio),
drop_rate=drop_rate,
drop_path_rate=drop_path_rate,
attn_drop_rate=attn_drop_rate,
qkv_bias=qkv_bias,
norm_cfg=dict(type='LN', eps=1e-6))
self.down_stride = down_stride
self.embed_dim = embed_dims
self.last_fusion = last_fusion
def forward(self, cnn_input, trans_input):
x, x_conv2 = self.cnn_block(cnn_input, out_conv2=True)
_, _, H, W = x_conv2.shape
# Convert the feature map of conv2 to transformer embedding
# and concat with class token.
conv2_embedding = self.squeeze_block(x_conv2, trans_input)
trans_output = self.trans_block(conv2_embedding + trans_input)
# Convert the transformer output embedding to feature map
trans_features = self.expand_block(trans_output, H // self.down_stride,
W // self.down_stride)
x = self.fusion_block(
x, fusion_features=trans_features, out_conv2=False)
return x, trans_output
@MODELS.register_module()
class Conformer(BaseBackbone):
"""Conformer backbone.
A PyTorch implementation of : `Conformer: Local Features Coupling Global
Representations for Visual Recognition <https://arxiv.org/abs/2105.03889>`_
Args:
arch (str | dict): Conformer architecture. Defaults to 'tiny'.
patch_size (int): The patch size. Defaults to 16.
base_channels (int): The base number of channels in CNN network.
Defaults to 64.
mlp_ratio (float): The expansion ratio of FFN network in transformer
block. Defaults to 4.
with_cls_token (bool): Whether use class token or not.
Defaults to True.
drop_path_rate (float): stochastic depth rate. Defaults to 0.
out_indices (Sequence | int): Output from which stages.
Defaults to -1, means the last stage.
init_cfg (dict, optional): Initialization config dict.
Defaults to None.
"""
arch_zoo = {
**dict.fromkeys(['t', 'tiny'],
{'embed_dims': 384,
'channel_ratio': 1,
'num_heads': 6,
'depths': 12
}),
**dict.fromkeys(['s', 'small'],
{'embed_dims': 384,
'channel_ratio': 4,
'num_heads': 6,
'depths': 12
}),
**dict.fromkeys(['b', 'base'],
{'embed_dims': 576,
'channel_ratio': 6,
'num_heads': 9,
'depths': 12
}),
} # yapf: disable
_version = 1
def __init__(self,
arch='tiny',
patch_size=16,
base_channels=64,
mlp_ratio=4.,
qkv_bias=True,
with_cls_token=True,
drop_path_rate=0.,
norm_eval=True,
frozen_stages=0,
out_indices=-1,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
if isinstance(arch, str):
arch = arch.lower()
assert arch in set(self.arch_zoo), \
f'Arch {arch} is not in default archs {set(self.arch_zoo)}'
self.arch_settings = self.arch_zoo[arch]
else:
essential_keys = {
'embed_dims', 'depths', 'num_heads', 'channel_ratio'
}
assert isinstance(arch, dict) and set(arch) == essential_keys, \
f'Custom arch needs a dict with keys {essential_keys}'
self.arch_settings = arch
self.num_features = self.embed_dims = self.arch_settings['embed_dims']
self.depths = self.arch_settings['depths']
self.num_heads = self.arch_settings['num_heads']
self.channel_ratio = self.arch_settings['channel_ratio']
if isinstance(out_indices, int):
out_indices = [out_indices]
assert isinstance(out_indices, Sequence), \
f'"out_indices" must by a sequence or int, ' \
f'get {type(out_indices)} instead.'
for i, index in enumerate(out_indices):
if index < 0:
out_indices[i] = self.depths + index + 1
assert out_indices[i] >= 0, f'Invalid out_indices {index}'
self.out_indices = out_indices
self.norm_eval = norm_eval
self.frozen_stages = frozen_stages
self.with_cls_token = with_cls_token
if self.with_cls_token:
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))
# stochastic depth decay rule
self.trans_dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, self.depths)
]
# Stem stage: get the feature maps by conv block
self.conv1 = nn.Conv2d(
3, 64, kernel_size=7, stride=2, padding=3,
bias=False) # 1 / 2 [112, 112]
self.bn1 = nn.BatchNorm2d(64)
self.act1 = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(
kernel_size=3, stride=2, padding=1) # 1 / 4 [56, 56]
assert patch_size % 16 == 0, 'The patch size of Conformer must ' \
'be divisible by 16.'
trans_down_stride = patch_size // 4
# To solve the issue #680
# Auto pad the feature map to be divisible by trans_down_stride
self.auto_pad = AdaptivePadding(trans_down_stride, trans_down_stride)
# 1 stage
stage1_channels = int(base_channels * self.channel_ratio)
self.conv_1 = ConvBlock(
in_channels=64,
out_channels=stage1_channels,
with_residual_conv=True,
stride=1)
self.trans_patch_conv = nn.Conv2d(
64,
self.embed_dims,
kernel_size=trans_down_stride,
stride=trans_down_stride,
padding=0)
self.trans_1 = TransformerEncoderLayer(
embed_dims=self.embed_dims,
num_heads=self.num_heads,
feedforward_channels=int(self.embed_dims * mlp_ratio),
drop_path_rate=self.trans_dpr[0],
qkv_bias=qkv_bias,
norm_cfg=dict(type='LN', eps=1e-6))
# 2~4 stage
init_stage = 2
fin_stage = self.depths // 3 + 1
for i in range(init_stage, fin_stage):
self.add_module(
f'conv_trans_{i}',
ConvTransBlock(
in_channels=stage1_channels,
out_channels=stage1_channels,
embed_dims=self.embed_dims,
conv_stride=1,
with_residual_conv=False,
down_stride=trans_down_stride,
num_heads=self.num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop_path_rate=self.trans_dpr[i - 1],
with_cls_token=self.with_cls_token))
stage2_channels = int(base_channels * self.channel_ratio * 2)
# 5~8 stage
init_stage = fin_stage # 5
fin_stage = fin_stage + self.depths // 3 # 9
for i in range(init_stage, fin_stage):
if i == init_stage:
conv_stride = 2
in_channels = stage1_channels
else:
conv_stride = 1
in_channels = stage2_channels
with_residual_conv = True if i == init_stage else False
self.add_module(
f'conv_trans_{i}',
ConvTransBlock(
in_channels=in_channels,
out_channels=stage2_channels,
embed_dims=self.embed_dims,
conv_stride=conv_stride,
with_residual_conv=with_residual_conv,
down_stride=trans_down_stride // 2,
num_heads=self.num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop_path_rate=self.trans_dpr[i - 1],
with_cls_token=self.with_cls_token))
stage3_channels = int(base_channels * self.channel_ratio * 2 * 2)
# 9~12 stage
init_stage = fin_stage # 9
fin_stage = fin_stage + self.depths // 3 # 13
for i in range(init_stage, fin_stage):
if i == init_stage:
conv_stride = 2
in_channels = stage2_channels
with_residual_conv = True
else:
conv_stride = 1
in_channels = stage3_channels
with_residual_conv = False
last_fusion = (i == self.depths)
self.add_module(
f'conv_trans_{i}',
ConvTransBlock(
in_channels=in_channels,
out_channels=stage3_channels,
embed_dims=self.embed_dims,
conv_stride=conv_stride,
with_residual_conv=with_residual_conv,
down_stride=trans_down_stride // 4,
num_heads=self.num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop_path_rate=self.trans_dpr[i - 1],
with_cls_token=self.with_cls_token,
last_fusion=last_fusion))
self.fin_stage = fin_stage
self.pooling = nn.AdaptiveAvgPool2d(1)
self.trans_norm = nn.LayerNorm(self.embed_dims)
if self.with_cls_token:
trunc_normal_(self.cls_token, std=.02)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(
m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1.)
nn.init.constant_(m.bias, 0.)
if hasattr(m, 'zero_init_last_bn'):
m.zero_init_last_bn()
def init_weights(self):
super(Conformer, self).init_weights()
if (isinstance(self.init_cfg, dict)
and self.init_cfg['type'] == 'Pretrained'):
# Suppress default init if use pretrained model.
return
self.apply(self._init_weights)
def forward(self, x):
output = []
B = x.shape[0]
if self.with_cls_token:
cls_tokens = self.cls_token.expand(B, -1, -1)
# stem
x_base = self.maxpool(self.act1(self.bn1(self.conv1(x))))
x_base = self.auto_pad(x_base)
# 1 stage [N, 64, 56, 56] -> [N, 128, 56, 56]
x = self.conv_1(x_base, out_conv2=False)
x_t = self.trans_patch_conv(x_base).flatten(2).transpose(1, 2)
if self.with_cls_token:
x_t = torch.cat([cls_tokens, x_t], dim=1)
x_t = self.trans_1(x_t)
# 2 ~ final
for i in range(2, self.fin_stage):
stage = getattr(self, f'conv_trans_{i}')
x, x_t = stage(x, x_t)
if i in self.out_indices:
if self.with_cls_token:
output.append([
self.pooling(x).flatten(1),
self.trans_norm(x_t)[:, 0]
])
else:
# if no class token, use the mean patch token
# as the transformer feature.
output.append([
self.pooling(x).flatten(1),
self.trans_norm(x_t).mean(dim=1)
])
return tuple(output)