KyanChen's picture
init
f549064
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from typing import Tuple
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmengine.config import ConfigDict
from mmengine.model import BaseModule
from torch import Tensor
from mmdet.registry import MODELS
from mmdet.utils import MultiConfig, OptConfigType
@MODELS.register_module()
class FusedSemanticHead(BaseModule):
r"""Multi-level fused semantic segmentation head.
.. code-block:: none
in_1 -> 1x1 conv ---
|
in_2 -> 1x1 conv -- |
||
in_3 -> 1x1 conv - ||
||| /-> 1x1 conv (mask prediction)
in_4 -> 1x1 conv -----> 3x3 convs (*4)
| \-> 1x1 conv (feature)
in_5 -> 1x1 conv ---
""" # noqa: W605
def __init__(
self,
num_ins: int,
fusion_level: int,
seg_scale_factor=1 / 8,
num_convs: int = 4,
in_channels: int = 256,
conv_out_channels: int = 256,
num_classes: int = 183,
conv_cfg: OptConfigType = None,
norm_cfg: OptConfigType = None,
ignore_label: int = None,
loss_weight: float = None,
loss_seg: ConfigDict = dict(
type='CrossEntropyLoss', ignore_index=255, loss_weight=0.2),
init_cfg: MultiConfig = dict(
type='Kaiming', override=dict(name='conv_logits'))
) -> None:
super().__init__(init_cfg=init_cfg)
self.num_ins = num_ins
self.fusion_level = fusion_level
self.seg_scale_factor = seg_scale_factor
self.num_convs = num_convs
self.in_channels = in_channels
self.conv_out_channels = conv_out_channels
self.num_classes = num_classes
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.fp16_enabled = False
self.lateral_convs = nn.ModuleList()
for i in range(self.num_ins):
self.lateral_convs.append(
ConvModule(
self.in_channels,
self.in_channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
inplace=False))
self.convs = nn.ModuleList()
for i in range(self.num_convs):
in_channels = self.in_channels if i == 0 else conv_out_channels
self.convs.append(
ConvModule(
in_channels,
conv_out_channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg))
self.conv_embedding = ConvModule(
conv_out_channels,
conv_out_channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg)
self.conv_logits = nn.Conv2d(conv_out_channels, self.num_classes, 1)
if ignore_label:
loss_seg['ignore_index'] = ignore_label
if loss_weight:
loss_seg['loss_weight'] = loss_weight
if ignore_label or loss_weight:
warnings.warn('``ignore_label`` and ``loss_weight`` would be '
'deprecated soon. Please set ``ingore_index`` and '
'``loss_weight`` in ``loss_seg`` instead.')
self.criterion = MODELS.build(loss_seg)
def forward(self, feats: Tuple[Tensor]) -> Tuple[Tensor]:
"""Forward function.
Args:
feats (tuple[Tensor]): Multi scale feature maps.
Returns:
tuple[Tensor]:
- mask_preds (Tensor): Predicted mask logits.
- x (Tensor): Fused feature.
"""
x = self.lateral_convs[self.fusion_level](feats[self.fusion_level])
fused_size = tuple(x.shape[-2:])
for i, feat in enumerate(feats):
if i != self.fusion_level:
feat = F.interpolate(
feat, size=fused_size, mode='bilinear', align_corners=True)
# fix runtime error of "+=" inplace operation in PyTorch 1.10
x = x + self.lateral_convs[i](feat)
for i in range(self.num_convs):
x = self.convs[i](x)
mask_preds = self.conv_logits(x)
x = self.conv_embedding(x)
return mask_preds, x
def loss(self, mask_preds: Tensor, labels: Tensor) -> Tensor:
"""Loss function.
Args:
mask_preds (Tensor): Predicted mask logits.
labels (Tensor): Ground truth.
Returns:
Tensor: Semantic segmentation loss.
"""
labels = F.interpolate(
labels.float(), scale_factor=self.seg_scale_factor, mode='nearest')
labels = labels.squeeze(1).long()
loss_semantic_seg = self.criterion(mask_preds, labels)
return loss_semantic_seg