KyanChen's picture
init
f549064
raw
history blame
No virus
4.46 kB
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Tuple
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmengine.model import BaseModule
from torch import Tensor
from mmdet.models.layers import ResLayer, SimplifiedBasicBlock
from mmdet.registry import MODELS
from mmdet.utils import MultiConfig, OptConfigType
@MODELS.register_module()
class GlobalContextHead(BaseModule):
"""Global context head used in `SCNet <https://arxiv.org/abs/2012.10150>`_.
Args:
num_convs (int, optional): number of convolutional layer in GlbCtxHead.
Defaults to 4.
in_channels (int, optional): number of input channels. Defaults to 256.
conv_out_channels (int, optional): number of output channels before
classification layer. Defaults to 256.
num_classes (int, optional): number of classes. Defaults to 80.
loss_weight (float, optional): global context loss weight.
Defaults to 1.
conv_cfg (dict, optional): config to init conv layer. Defaults to None.
norm_cfg (dict, optional): config to init norm layer. Defaults to None.
conv_to_res (bool, optional): if True, 2 convs will be grouped into
1 `SimplifiedBasicBlock` using a skip connection.
Defaults to False.
init_cfg (:obj:`ConfigDict` or dict or list[dict] or
list[:obj:`ConfigDict`]): Initialization config dict. Defaults to
dict(type='Normal', std=0.01, override=dict(name='fc')).
"""
def __init__(
self,
num_convs: int = 4,
in_channels: int = 256,
conv_out_channels: int = 256,
num_classes: int = 80,
loss_weight: float = 1.0,
conv_cfg: OptConfigType = None,
norm_cfg: OptConfigType = None,
conv_to_res: bool = False,
init_cfg: MultiConfig = dict(
type='Normal', std=0.01, override=dict(name='fc'))
) -> None:
super().__init__(init_cfg=init_cfg)
self.num_convs = num_convs
self.in_channels = in_channels
self.conv_out_channels = conv_out_channels
self.num_classes = num_classes
self.loss_weight = loss_weight
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.conv_to_res = conv_to_res
self.fp16_enabled = False
if self.conv_to_res:
num_res_blocks = num_convs // 2
self.convs = ResLayer(
SimplifiedBasicBlock,
in_channels,
self.conv_out_channels,
num_res_blocks,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg)
self.num_convs = num_res_blocks
else:
self.convs = nn.ModuleList()
for i in range(self.num_convs):
in_channels = self.in_channels if i == 0 else conv_out_channels
self.convs.append(
ConvModule(
in_channels,
conv_out_channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg))
self.pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(conv_out_channels, num_classes)
self.criterion = nn.BCEWithLogitsLoss()
def forward(self, feats: Tuple[Tensor]) -> Tuple[Tensor]:
"""Forward function.
Args:
feats (Tuple[Tensor]): Multi-scale feature maps.
Returns:
Tuple[Tensor]:
- mc_pred (Tensor): Multi-class prediction.
- x (Tensor): Global context feature.
"""
x = feats[-1]
for i in range(self.num_convs):
x = self.convs[i](x)
x = self.pool(x)
# multi-class prediction
mc_pred = x.reshape(x.size(0), -1)
mc_pred = self.fc(mc_pred)
return mc_pred, x
def loss(self, pred: Tensor, labels: List[Tensor]) -> Tensor:
"""Loss function.
Args:
pred (Tensor): Logits.
labels (list[Tensor]): Grouth truths.
Returns:
Tensor: Loss.
"""
labels = [lbl.unique() for lbl in labels]
targets = pred.new_zeros(pred.size())
for i, label in enumerate(labels):
targets[i, label] = 1.0
loss = self.loss_weight * self.criterion(pred, targets)
return loss