KyanChen's picture
init
f549064
raw
history blame
3.52 kB
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmengine.model import BaseModule
from torch.utils.checkpoint import checkpoint
from mmdet.registry import MODELS
@MODELS.register_module()
class HRFPN(BaseModule):
"""HRFPN (High Resolution Feature Pyramids)
paper: `High-Resolution Representations for Labeling Pixels and Regions
<https://arxiv.org/abs/1904.04514>`_.
Args:
in_channels (list): number of channels for each branch.
out_channels (int): output channels of feature pyramids.
num_outs (int): number of output stages.
pooling_type (str): pooling for generating feature pyramids
from {MAX, AVG}.
conv_cfg (dict): dictionary to construct and config conv layer.
norm_cfg (dict): dictionary to construct and config norm layer.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed.
stride (int): stride of 3x3 convolutional layers
init_cfg (dict or list[dict], optional): Initialization config dict.
"""
def __init__(self,
in_channels,
out_channels,
num_outs=5,
pooling_type='AVG',
conv_cfg=None,
norm_cfg=None,
with_cp=False,
stride=1,
init_cfg=dict(type='Caffe2Xavier', layer='Conv2d')):
super(HRFPN, self).__init__(init_cfg)
assert isinstance(in_channels, list)
self.in_channels = in_channels
self.out_channels = out_channels
self.num_ins = len(in_channels)
self.num_outs = num_outs
self.with_cp = with_cp
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.reduction_conv = ConvModule(
sum(in_channels),
out_channels,
kernel_size=1,
conv_cfg=self.conv_cfg,
act_cfg=None)
self.fpn_convs = nn.ModuleList()
for i in range(self.num_outs):
self.fpn_convs.append(
ConvModule(
out_channels,
out_channels,
kernel_size=3,
padding=1,
stride=stride,
conv_cfg=self.conv_cfg,
act_cfg=None))
if pooling_type == 'MAX':
self.pooling = F.max_pool2d
else:
self.pooling = F.avg_pool2d
def forward(self, inputs):
"""Forward function."""
assert len(inputs) == self.num_ins
outs = [inputs[0]]
for i in range(1, self.num_ins):
outs.append(
F.interpolate(inputs[i], scale_factor=2**i, mode='bilinear'))
out = torch.cat(outs, dim=1)
if out.requires_grad and self.with_cp:
out = checkpoint(self.reduction_conv, out)
else:
out = self.reduction_conv(out)
outs = [out]
for i in range(1, self.num_outs):
outs.append(self.pooling(out, kernel_size=2**i, stride=2**i))
outputs = []
for i in range(self.num_outs):
if outs[i].requires_grad and self.with_cp:
tmp_out = checkpoint(self.fpn_convs[i], outs[i])
else:
tmp_out = self.fpn_convs[i](outs[i])
outputs.append(tmp_out)
return tuple(outputs)