ai-photo-gallery / mmdet /models /seg_heads /panoptic_fpn_head.py
KyanChen's picture
init
f549064
raw
history blame
No virus
6.98 kB
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine.model import ModuleList
from torch import Tensor
from mmdet.registry import MODELS
from mmdet.structures import SampleList
from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig
from ..layers import ConvUpsample
from ..utils import interpolate_as
from .base_semantic_head import BaseSemanticHead
@MODELS.register_module()
class PanopticFPNHead(BaseSemanticHead):
"""PanopticFPNHead used in Panoptic FPN.
In this head, the number of output channels is ``num_stuff_classes
+ 1``, including all stuff classes and one thing class. The stuff
classes will be reset from ``0`` to ``num_stuff_classes - 1``, the
thing classes will be merged to ``num_stuff_classes``-th channel.
Arg:
num_things_classes (int): Number of thing classes. Default: 80.
num_stuff_classes (int): Number of stuff classes. Default: 53.
in_channels (int): Number of channels in the input feature
map.
inner_channels (int): Number of channels in inner features.
start_level (int): The start level of the input features
used in PanopticFPN.
end_level (int): The end level of the used features, the
``end_level``-th layer will not be used.
conv_cfg (Optional[Union[ConfigDict, dict]]): Dictionary to construct
and config conv layer.
norm_cfg (Union[ConfigDict, dict]): Dictionary to construct and config
norm layer. Use ``GN`` by default.
init_cfg (Optional[Union[ConfigDict, dict]]): Initialization config
dict.
loss_seg (Union[ConfigDict, dict]): the loss of the semantic head.
"""
def __init__(self,
num_things_classes: int = 80,
num_stuff_classes: int = 53,
in_channels: int = 256,
inner_channels: int = 128,
start_level: int = 0,
end_level: int = 4,
conv_cfg: OptConfigType = None,
norm_cfg: ConfigType = dict(
type='GN', num_groups=32, requires_grad=True),
loss_seg: ConfigType = dict(
type='CrossEntropyLoss', ignore_index=-1,
loss_weight=1.0),
init_cfg: OptMultiConfig = None) -> None:
seg_rescale_factor = 1 / 2**(start_level + 2)
super().__init__(
num_classes=num_stuff_classes + 1,
seg_rescale_factor=seg_rescale_factor,
loss_seg=loss_seg,
init_cfg=init_cfg)
self.num_things_classes = num_things_classes
self.num_stuff_classes = num_stuff_classes
# Used feature layers are [start_level, end_level)
self.start_level = start_level
self.end_level = end_level
self.num_stages = end_level - start_level
self.inner_channels = inner_channels
self.conv_upsample_layers = ModuleList()
for i in range(start_level, end_level):
self.conv_upsample_layers.append(
ConvUpsample(
in_channels,
inner_channels,
num_layers=i if i > 0 else 1,
num_upsample=i if i > 0 else 0,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
))
self.conv_logits = nn.Conv2d(inner_channels, self.num_classes, 1)
def _set_things_to_void(self, gt_semantic_seg: Tensor) -> Tensor:
"""Merge thing classes to one class.
In PanopticFPN, the background labels will be reset from `0` to
`self.num_stuff_classes-1`, the foreground labels will be merged to
`self.num_stuff_classes`-th channel.
"""
gt_semantic_seg = gt_semantic_seg.int()
fg_mask = gt_semantic_seg < self.num_things_classes
bg_mask = (gt_semantic_seg >= self.num_things_classes) * (
gt_semantic_seg < self.num_things_classes + self.num_stuff_classes)
new_gt_seg = torch.clone(gt_semantic_seg)
new_gt_seg = torch.where(bg_mask,
gt_semantic_seg - self.num_things_classes,
new_gt_seg)
new_gt_seg = torch.where(fg_mask,
fg_mask.int() * self.num_stuff_classes,
new_gt_seg)
return new_gt_seg
def loss(self, x: Union[Tensor, Tuple[Tensor]],
batch_data_samples: SampleList) -> Dict[str, Tensor]:
"""
Args:
x (Union[Tensor, Tuple[Tensor]]): Feature maps.
batch_data_samples (list[:obj:`DetDataSample`]): The batch
data samples. It usually includes information such
as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
Returns:
Dict[str, Tensor]: The loss of semantic head.
"""
seg_preds = self(x)['seg_preds']
gt_semantic_segs = [
data_sample.gt_sem_seg.sem_seg
for data_sample in batch_data_samples
]
gt_semantic_segs = torch.stack(gt_semantic_segs)
if self.seg_rescale_factor != 1.0:
gt_semantic_segs = F.interpolate(
gt_semantic_segs.float(),
scale_factor=self.seg_rescale_factor,
mode='nearest').squeeze(1)
# Things classes will be merged to one class in PanopticFPN.
gt_semantic_segs = self._set_things_to_void(gt_semantic_segs)
if seg_preds.shape[-2:] != gt_semantic_segs.shape[-2:]:
seg_preds = interpolate_as(seg_preds, gt_semantic_segs)
seg_preds = seg_preds.permute((0, 2, 3, 1))
loss_seg = self.loss_seg(
seg_preds.reshape(-1, self.num_classes), # => [NxHxW, C]
gt_semantic_segs.reshape(-1).long())
return dict(loss_seg=loss_seg)
def init_weights(self) -> None:
"""Initialize weights."""
super().init_weights()
nn.init.normal_(self.conv_logits.weight.data, 0, 0.01)
self.conv_logits.bias.data.zero_()
def forward(self, x: Tuple[Tensor]) -> Dict[str, Tensor]:
"""Forward.
Args:
x (Tuple[Tensor]): Multi scale Feature maps.
Returns:
dict[str, Tensor]: semantic segmentation predictions and
feature maps.
"""
# the number of subnets must be not more than
# the length of features.
assert self.num_stages <= len(x)
feats = []
for i, layer in enumerate(self.conv_upsample_layers):
f = layer(x[self.start_level + i])
feats.append(f)
seg_feats = torch.sum(torch.stack(feats, dim=0), dim=0)
seg_preds = self.conv_logits(seg_feats)
out = dict(seg_preds=seg_preds, seg_feats=seg_feats)
return out