Spaces:
Runtime error
Runtime error
File size: 6,975 Bytes
f549064 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 |
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine.model import ModuleList
from torch import Tensor
from mmdet.registry import MODELS
from mmdet.structures import SampleList
from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig
from ..layers import ConvUpsample
from ..utils import interpolate_as
from .base_semantic_head import BaseSemanticHead
@MODELS.register_module()
class PanopticFPNHead(BaseSemanticHead):
"""PanopticFPNHead used in Panoptic FPN.
In this head, the number of output channels is ``num_stuff_classes
+ 1``, including all stuff classes and one thing class. The stuff
classes will be reset from ``0`` to ``num_stuff_classes - 1``, the
thing classes will be merged to ``num_stuff_classes``-th channel.
Arg:
num_things_classes (int): Number of thing classes. Default: 80.
num_stuff_classes (int): Number of stuff classes. Default: 53.
in_channels (int): Number of channels in the input feature
map.
inner_channels (int): Number of channels in inner features.
start_level (int): The start level of the input features
used in PanopticFPN.
end_level (int): The end level of the used features, the
``end_level``-th layer will not be used.
conv_cfg (Optional[Union[ConfigDict, dict]]): Dictionary to construct
and config conv layer.
norm_cfg (Union[ConfigDict, dict]): Dictionary to construct and config
norm layer. Use ``GN`` by default.
init_cfg (Optional[Union[ConfigDict, dict]]): Initialization config
dict.
loss_seg (Union[ConfigDict, dict]): the loss of the semantic head.
"""
def __init__(self,
num_things_classes: int = 80,
num_stuff_classes: int = 53,
in_channels: int = 256,
inner_channels: int = 128,
start_level: int = 0,
end_level: int = 4,
conv_cfg: OptConfigType = None,
norm_cfg: ConfigType = dict(
type='GN', num_groups=32, requires_grad=True),
loss_seg: ConfigType = dict(
type='CrossEntropyLoss', ignore_index=-1,
loss_weight=1.0),
init_cfg: OptMultiConfig = None) -> None:
seg_rescale_factor = 1 / 2**(start_level + 2)
super().__init__(
num_classes=num_stuff_classes + 1,
seg_rescale_factor=seg_rescale_factor,
loss_seg=loss_seg,
init_cfg=init_cfg)
self.num_things_classes = num_things_classes
self.num_stuff_classes = num_stuff_classes
# Used feature layers are [start_level, end_level)
self.start_level = start_level
self.end_level = end_level
self.num_stages = end_level - start_level
self.inner_channels = inner_channels
self.conv_upsample_layers = ModuleList()
for i in range(start_level, end_level):
self.conv_upsample_layers.append(
ConvUpsample(
in_channels,
inner_channels,
num_layers=i if i > 0 else 1,
num_upsample=i if i > 0 else 0,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
))
self.conv_logits = nn.Conv2d(inner_channels, self.num_classes, 1)
def _set_things_to_void(self, gt_semantic_seg: Tensor) -> Tensor:
"""Merge thing classes to one class.
In PanopticFPN, the background labels will be reset from `0` to
`self.num_stuff_classes-1`, the foreground labels will be merged to
`self.num_stuff_classes`-th channel.
"""
gt_semantic_seg = gt_semantic_seg.int()
fg_mask = gt_semantic_seg < self.num_things_classes
bg_mask = (gt_semantic_seg >= self.num_things_classes) * (
gt_semantic_seg < self.num_things_classes + self.num_stuff_classes)
new_gt_seg = torch.clone(gt_semantic_seg)
new_gt_seg = torch.where(bg_mask,
gt_semantic_seg - self.num_things_classes,
new_gt_seg)
new_gt_seg = torch.where(fg_mask,
fg_mask.int() * self.num_stuff_classes,
new_gt_seg)
return new_gt_seg
def loss(self, x: Union[Tensor, Tuple[Tensor]],
batch_data_samples: SampleList) -> Dict[str, Tensor]:
"""
Args:
x (Union[Tensor, Tuple[Tensor]]): Feature maps.
batch_data_samples (list[:obj:`DetDataSample`]): The batch
data samples. It usually includes information such
as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
Returns:
Dict[str, Tensor]: The loss of semantic head.
"""
seg_preds = self(x)['seg_preds']
gt_semantic_segs = [
data_sample.gt_sem_seg.sem_seg
for data_sample in batch_data_samples
]
gt_semantic_segs = torch.stack(gt_semantic_segs)
if self.seg_rescale_factor != 1.0:
gt_semantic_segs = F.interpolate(
gt_semantic_segs.float(),
scale_factor=self.seg_rescale_factor,
mode='nearest').squeeze(1)
# Things classes will be merged to one class in PanopticFPN.
gt_semantic_segs = self._set_things_to_void(gt_semantic_segs)
if seg_preds.shape[-2:] != gt_semantic_segs.shape[-2:]:
seg_preds = interpolate_as(seg_preds, gt_semantic_segs)
seg_preds = seg_preds.permute((0, 2, 3, 1))
loss_seg = self.loss_seg(
seg_preds.reshape(-1, self.num_classes), # => [NxHxW, C]
gt_semantic_segs.reshape(-1).long())
return dict(loss_seg=loss_seg)
def init_weights(self) -> None:
"""Initialize weights."""
super().init_weights()
nn.init.normal_(self.conv_logits.weight.data, 0, 0.01)
self.conv_logits.bias.data.zero_()
def forward(self, x: Tuple[Tensor]) -> Dict[str, Tensor]:
"""Forward.
Args:
x (Tuple[Tensor]): Multi scale Feature maps.
Returns:
dict[str, Tensor]: semantic segmentation predictions and
feature maps.
"""
# the number of subnets must be not more than
# the length of features.
assert self.num_stages <= len(x)
feats = []
for i, layer in enumerate(self.conv_upsample_layers):
f = layer(x[self.start_level + i])
feats.append(f)
seg_feats = torch.sum(torch.stack(feats, dim=0), dim=0)
seg_preds = self.conv_logits(seg_feats)
out = dict(seg_preds=seg_preds, seg_feats=seg_feats)
return out
|