# Copyright (c) OpenMMLab. All rights reserved. import torch.nn as nn import torch.nn.functional as F from mmcv.cnn import build_activation_layer, build_norm_layer from mmcv.ops.modulated_deform_conv import ModulatedDeformConv2d from mmengine.model import BaseModule, constant_init, normal_init from mmdet.registry import MODELS from ..layers import DyReLU # Reference: # https://github.com/microsoft/DynamicHead # https://github.com/jshilong/SEPC class DyDCNv2(nn.Module): """ModulatedDeformConv2d with normalization layer used in DyHead. This module cannot be configured with `conv_cfg=dict(type='DCNv2')` because DyHead calculates offset and mask from middle-level feature. Args: in_channels (int): Number of input channels. out_channels (int): Number of output channels. stride (int | tuple[int], optional): Stride of the convolution. Default: 1. norm_cfg (dict, optional): Config dict for normalization layer. Default: dict(type='GN', num_groups=16, requires_grad=True). """ def __init__(self, in_channels, out_channels, stride=1, norm_cfg=dict(type='GN', num_groups=16, requires_grad=True)): super().__init__() self.with_norm = norm_cfg is not None bias = not self.with_norm self.conv = ModulatedDeformConv2d( in_channels, out_channels, 3, stride=stride, padding=1, bias=bias) if self.with_norm: self.norm = build_norm_layer(norm_cfg, out_channels)[1] def forward(self, x, offset, mask): """Forward function.""" x = self.conv(x.contiguous(), offset, mask) if self.with_norm: x = self.norm(x) return x class DyHeadBlock(nn.Module): """DyHead Block with three types of attention. HSigmoid arguments in default act_cfg follow official code, not paper. https://github.com/microsoft/DynamicHead/blob/master/dyhead/dyrelu.py Args: in_channels (int): Number of input channels. out_channels (int): Number of output channels. zero_init_offset (bool, optional): Whether to use zero init for `spatial_conv_offset`. Default: True. act_cfg (dict, optional): Config dict for the last activation layer of scale-aware attention. Default: dict(type='HSigmoid', bias=3.0, divisor=6.0). """ def __init__(self, in_channels, out_channels, zero_init_offset=True, act_cfg=dict(type='HSigmoid', bias=3.0, divisor=6.0)): super().__init__() self.zero_init_offset = zero_init_offset # (offset_x, offset_y, mask) * kernel_size_y * kernel_size_x self.offset_and_mask_dim = 3 * 3 * 3 self.offset_dim = 2 * 3 * 3 self.spatial_conv_high = DyDCNv2(in_channels, out_channels) self.spatial_conv_mid = DyDCNv2(in_channels, out_channels) self.spatial_conv_low = DyDCNv2(in_channels, out_channels, stride=2) self.spatial_conv_offset = nn.Conv2d( in_channels, self.offset_and_mask_dim, 3, padding=1) self.scale_attn_module = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Conv2d(out_channels, 1, 1), nn.ReLU(inplace=True), build_activation_layer(act_cfg)) self.task_attn_module = DyReLU(out_channels) self._init_weights() def _init_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): normal_init(m, 0, 0.01) if self.zero_init_offset: constant_init(self.spatial_conv_offset, 0) def forward(self, x): """Forward function.""" outs = [] for level in range(len(x)): # calculate offset and mask of DCNv2 from middle-level feature offset_and_mask = self.spatial_conv_offset(x[level]) offset = offset_and_mask[:, :self.offset_dim, :, :] mask = offset_and_mask[:, self.offset_dim:, :, :].sigmoid() mid_feat = self.spatial_conv_mid(x[level], offset, mask) sum_feat = mid_feat * self.scale_attn_module(mid_feat) summed_levels = 1 if level > 0: low_feat = self.spatial_conv_low(x[level - 1], offset, mask) sum_feat += low_feat * self.scale_attn_module(low_feat) summed_levels += 1 if level < len(x) - 1: # this upsample order is weird, but faster than natural order # https://github.com/microsoft/DynamicHead/issues/25 high_feat = F.interpolate( self.spatial_conv_high(x[level + 1], offset, mask), size=x[level].shape[-2:], mode='bilinear', align_corners=True) sum_feat += high_feat * self.scale_attn_module(high_feat) summed_levels += 1 outs.append(self.task_attn_module(sum_feat / summed_levels)) return outs @MODELS.register_module() class DyHead(BaseModule): """DyHead neck consisting of multiple DyHead Blocks. See `Dynamic Head: Unifying Object Detection Heads with Attentions `_ for details. Args: in_channels (int): Number of input channels. out_channels (int): Number of output channels. num_blocks (int, optional): Number of DyHead Blocks. Default: 6. zero_init_offset (bool, optional): Whether to use zero init for `spatial_conv_offset`. Default: True. init_cfg (dict or list[dict], optional): Initialization config dict. Default: None. """ def __init__(self, in_channels, out_channels, num_blocks=6, zero_init_offset=True, init_cfg=None): assert init_cfg is None, 'To prevent abnormal initialization ' \ 'behavior, init_cfg is not allowed to be set' super().__init__(init_cfg=init_cfg) self.in_channels = in_channels self.out_channels = out_channels self.num_blocks = num_blocks self.zero_init_offset = zero_init_offset dyhead_blocks = [] for i in range(num_blocks): in_channels = self.in_channels if i == 0 else self.out_channels dyhead_blocks.append( DyHeadBlock( in_channels, self.out_channels, zero_init_offset=zero_init_offset)) self.dyhead_blocks = nn.Sequential(*dyhead_blocks) def forward(self, inputs): """Forward function.""" assert isinstance(inputs, (tuple, list)) outs = self.dyhead_blocks(inputs) return tuple(outs)