ai-photo-gallery / mmdet /models /necks /dilated_encoder.py
KyanChen's picture
init
f549064
raw
history blame
3.97 kB
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmcv.cnn import ConvModule, is_norm
from mmengine.model import caffe2_xavier_init, constant_init, normal_init
from torch.nn import BatchNorm2d
from mmdet.registry import MODELS
class Bottleneck(nn.Module):
"""Bottleneck block for DilatedEncoder used in `YOLOF.
<https://arxiv.org/abs/2103.09460>`.
The Bottleneck contains three ConvLayers and one residual connection.
Args:
in_channels (int): The number of input channels.
mid_channels (int): The number of middle output channels.
dilation (int): Dilation rate.
norm_cfg (dict): Dictionary to construct and config norm layer.
"""
def __init__(self,
in_channels,
mid_channels,
dilation,
norm_cfg=dict(type='BN', requires_grad=True)):
super(Bottleneck, self).__init__()
self.conv1 = ConvModule(
in_channels, mid_channels, 1, norm_cfg=norm_cfg)
self.conv2 = ConvModule(
mid_channels,
mid_channels,
3,
padding=dilation,
dilation=dilation,
norm_cfg=norm_cfg)
self.conv3 = ConvModule(
mid_channels, in_channels, 1, norm_cfg=norm_cfg)
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.conv2(out)
out = self.conv3(out)
out = out + identity
return out
@MODELS.register_module()
class DilatedEncoder(nn.Module):
"""Dilated Encoder for YOLOF <https://arxiv.org/abs/2103.09460>`.
This module contains two types of components:
- the original FPN lateral convolution layer and fpn convolution layer,
which are 1x1 conv + 3x3 conv
- the dilated residual block
Args:
in_channels (int): The number of input channels.
out_channels (int): The number of output channels.
block_mid_channels (int): The number of middle block output channels
num_residual_blocks (int): The number of residual blocks.
block_dilations (list): The list of residual blocks dilation.
"""
def __init__(self, in_channels, out_channels, block_mid_channels,
num_residual_blocks, block_dilations):
super(DilatedEncoder, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.block_mid_channels = block_mid_channels
self.num_residual_blocks = num_residual_blocks
self.block_dilations = block_dilations
self._init_layers()
def _init_layers(self):
self.lateral_conv = nn.Conv2d(
self.in_channels, self.out_channels, kernel_size=1)
self.lateral_norm = BatchNorm2d(self.out_channels)
self.fpn_conv = nn.Conv2d(
self.out_channels, self.out_channels, kernel_size=3, padding=1)
self.fpn_norm = BatchNorm2d(self.out_channels)
encoder_blocks = []
for i in range(self.num_residual_blocks):
dilation = self.block_dilations[i]
encoder_blocks.append(
Bottleneck(
self.out_channels,
self.block_mid_channels,
dilation=dilation))
self.dilated_encoder_blocks = nn.Sequential(*encoder_blocks)
def init_weights(self):
caffe2_xavier_init(self.lateral_conv)
caffe2_xavier_init(self.fpn_conv)
for m in [self.lateral_norm, self.fpn_norm]:
constant_init(m, 1)
for m in self.dilated_encoder_blocks.modules():
if isinstance(m, nn.Conv2d):
normal_init(m, mean=0, std=0.01)
if is_norm(m):
constant_init(m, 1)
def forward(self, feature):
out = self.lateral_norm(self.lateral_conv(feature[-1]))
out = self.fpn_norm(self.fpn_conv(out))
return self.dilated_encoder_blocks(out),