Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import warnings | |
from typing import Tuple | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from mmcv.cnn import ConvModule | |
from mmengine.config import ConfigDict | |
from mmengine.model import BaseModule | |
from torch import Tensor | |
from mmdet.registry import MODELS | |
from mmdet.utils import MultiConfig, OptConfigType | |
class FusedSemanticHead(BaseModule): | |
r"""Multi-level fused semantic segmentation head. | |
.. code-block:: none | |
in_1 -> 1x1 conv --- | |
| | |
in_2 -> 1x1 conv -- | | |
|| | |
in_3 -> 1x1 conv - || | |
||| /-> 1x1 conv (mask prediction) | |
in_4 -> 1x1 conv -----> 3x3 convs (*4) | |
| \-> 1x1 conv (feature) | |
in_5 -> 1x1 conv --- | |
""" # noqa: W605 | |
def __init__( | |
self, | |
num_ins: int, | |
fusion_level: int, | |
seg_scale_factor=1 / 8, | |
num_convs: int = 4, | |
in_channels: int = 256, | |
conv_out_channels: int = 256, | |
num_classes: int = 183, | |
conv_cfg: OptConfigType = None, | |
norm_cfg: OptConfigType = None, | |
ignore_label: int = None, | |
loss_weight: float = None, | |
loss_seg: ConfigDict = dict( | |
type='CrossEntropyLoss', ignore_index=255, loss_weight=0.2), | |
init_cfg: MultiConfig = dict( | |
type='Kaiming', override=dict(name='conv_logits')) | |
) -> None: | |
super().__init__(init_cfg=init_cfg) | |
self.num_ins = num_ins | |
self.fusion_level = fusion_level | |
self.seg_scale_factor = seg_scale_factor | |
self.num_convs = num_convs | |
self.in_channels = in_channels | |
self.conv_out_channels = conv_out_channels | |
self.num_classes = num_classes | |
self.conv_cfg = conv_cfg | |
self.norm_cfg = norm_cfg | |
self.fp16_enabled = False | |
self.lateral_convs = nn.ModuleList() | |
for i in range(self.num_ins): | |
self.lateral_convs.append( | |
ConvModule( | |
self.in_channels, | |
self.in_channels, | |
1, | |
conv_cfg=self.conv_cfg, | |
norm_cfg=self.norm_cfg, | |
inplace=False)) | |
self.convs = nn.ModuleList() | |
for i in range(self.num_convs): | |
in_channels = self.in_channels if i == 0 else conv_out_channels | |
self.convs.append( | |
ConvModule( | |
in_channels, | |
conv_out_channels, | |
3, | |
padding=1, | |
conv_cfg=self.conv_cfg, | |
norm_cfg=self.norm_cfg)) | |
self.conv_embedding = ConvModule( | |
conv_out_channels, | |
conv_out_channels, | |
1, | |
conv_cfg=self.conv_cfg, | |
norm_cfg=self.norm_cfg) | |
self.conv_logits = nn.Conv2d(conv_out_channels, self.num_classes, 1) | |
if ignore_label: | |
loss_seg['ignore_index'] = ignore_label | |
if loss_weight: | |
loss_seg['loss_weight'] = loss_weight | |
if ignore_label or loss_weight: | |
warnings.warn('``ignore_label`` and ``loss_weight`` would be ' | |
'deprecated soon. Please set ``ingore_index`` and ' | |
'``loss_weight`` in ``loss_seg`` instead.') | |
self.criterion = MODELS.build(loss_seg) | |
def forward(self, feats: Tuple[Tensor]) -> Tuple[Tensor]: | |
"""Forward function. | |
Args: | |
feats (tuple[Tensor]): Multi scale feature maps. | |
Returns: | |
tuple[Tensor]: | |
- mask_preds (Tensor): Predicted mask logits. | |
- x (Tensor): Fused feature. | |
""" | |
x = self.lateral_convs[self.fusion_level](feats[self.fusion_level]) | |
fused_size = tuple(x.shape[-2:]) | |
for i, feat in enumerate(feats): | |
if i != self.fusion_level: | |
feat = F.interpolate( | |
feat, size=fused_size, mode='bilinear', align_corners=True) | |
# fix runtime error of "+=" inplace operation in PyTorch 1.10 | |
x = x + self.lateral_convs[i](feat) | |
for i in range(self.num_convs): | |
x = self.convs[i](x) | |
mask_preds = self.conv_logits(x) | |
x = self.conv_embedding(x) | |
return mask_preds, x | |
def loss(self, mask_preds: Tensor, labels: Tensor) -> Tensor: | |
"""Loss function. | |
Args: | |
mask_preds (Tensor): Predicted mask logits. | |
labels (Tensor): Ground truth. | |
Returns: | |
Tensor: Semantic segmentation loss. | |
""" | |
labels = F.interpolate( | |
labels.float(), scale_factor=self.seg_scale_factor, mode='nearest') | |
labels = labels.squeeze(1).long() | |
loss_semantic_seg = self.criterion(mask_preds, labels) | |
return loss_semantic_seg | |