import torch import torch.nn as nn from mmcv.cnn import ConvModule, normal_init from mmseg.ops import resize class BaseDecodeHead(nn.Module): """Base class for BaseDecodeHead. Args: in_channels (int|Sequence[int]): Input channels. channels (int): Channels after modules, before conv_seg. num_classes (int): Number of classes. dropout_ratio (float): Ratio of dropout layer. Default: 0.1. conv_cfg (dict|None): Config of conv layers. Default: None. norm_cfg (dict|None): Config of norm layers. Default: None. act_cfg (dict): Config of activation layers. Default: dict(type='ReLU') in_index (int|Sequence[int]): Input feature index. Default: -1 input_transform (str|None): Transformation type of input features. Options: 'resize_concat', 'multiple_select', None. 'resize_concat': Multiple feature maps will be resize to the same size as first one and than concat together. Usually used in FCN head of HRNet. 'multiple_select': Multiple feature maps will be bundle into a list and passed into decode head. None: Only one select feature map is allowed. Default: None. loss_decode (dict): Config of decode loss. Default: dict(type='CrossEntropyLoss'). ignore_index (int | None): The label index to be ignored. When using masked BCE loss, ignore_index should be set to None. Default: 255 sampler (dict|None): The config of segmentation map sampler. Default: None. align_corners (bool): align_corners argument of F.interpolate. Default: False. """ def __init__(self, in_channels, channels, *, num_classes, dropout_ratio=0.1, conv_cfg=None, norm_cfg=dict(type='BN'), act_cfg=dict(type='ReLU'), in_index=-1, input_transform=None, ignore_index=255, align_corners=False): super(BaseDecodeHead, self).__init__() self._init_inputs(in_channels, in_index, input_transform) self.channels = channels self.num_classes = num_classes self.dropout_ratio = dropout_ratio self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg self.act_cfg = act_cfg self.in_index = in_index self.ignore_index = ignore_index self.align_corners = align_corners self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1) if dropout_ratio > 0: self.dropout = nn.Dropout2d(dropout_ratio) else: self.dropout = None def extra_repr(self): """Extra repr.""" s = f'input_transform={self.input_transform}, ' \ f'ignore_index={self.ignore_index}, ' \ f'align_corners={self.align_corners}' return s def _init_inputs(self, in_channels, in_index, input_transform): """Check and initialize input transforms. The in_channels, in_index and input_transform must match. Specifically, when input_transform is None, only single feature map will be selected. So in_channels and in_index must be of type int. When input_transform Args: in_channels (int|Sequence[int]): Input channels. in_index (int|Sequence[int]): Input feature index. input_transform (str|None): Transformation type of input features. Options: 'resize_concat', 'multiple_select', None. 'resize_concat': Multiple feature maps will be resize to the same size as first one and than concat together. Usually used in FCN head of HRNet. 'multiple_select': Multiple feature maps will be bundle into a list and passed into decode head. None: Only one select feature map is allowed. """ if input_transform is not None: assert input_transform in ['resize_concat', 'multiple_select'] self.input_transform = input_transform self.in_index = in_index if input_transform is not None: assert isinstance(in_channels, (list, tuple)) assert isinstance(in_index, (list, tuple)) assert len(in_channels) == len(in_index) if input_transform == 'resize_concat': self.in_channels = sum(in_channels) else: self.in_channels = in_channels else: assert isinstance(in_channels, int) assert isinstance(in_index, int) self.in_channels = in_channels def init_weights(self): """Initialize weights of classification layer.""" normal_init(self.conv_seg, mean=0, std=0.01) def _transform_inputs(self, inputs): """Transform inputs for decoder. Args: inputs (list[Tensor]): List of multi-level img features. Returns: Tensor: The transformed inputs """ if self.input_transform == 'resize_concat': inputs = [inputs[i] for i in self.in_index] upsampled_inputs = [ resize( input=x, size=inputs[0].shape[2:], mode='bilinear', align_corners=self.align_corners) for x in inputs ] inputs = torch.cat(upsampled_inputs, dim=1) elif self.input_transform == 'multiple_select': inputs = [inputs[i] for i in self.in_index] else: inputs = inputs[self.in_index] return inputs def forward(self, inputs): """Placeholder of forward function.""" pass def cls_seg(self, feat): """Classify each pixel.""" if self.dropout is not None: feat = self.dropout(feat) output = self.conv_seg(feat) return output class FCNHead(BaseDecodeHead): """Fully Convolution Networks for Semantic Segmentation. This head is implemented of `FCNNet `_. Args: num_convs (int): Number of convs in the head. Default: 2. kernel_size (int): The kernel size for convs in the head. Default: 3. concat_input (bool): Whether concat the input and output of convs before classification layer. """ def __init__(self, num_convs=2, kernel_size=3, concat_input=True, **kwargs): assert num_convs >= 0 self.num_convs = num_convs self.concat_input = concat_input self.kernel_size = kernel_size super(FCNHead, self).__init__(**kwargs) if num_convs == 0: assert self.in_channels == self.channels convs = [] convs.append( ConvModule( self.in_channels, self.channels, kernel_size=kernel_size, padding=kernel_size // 2, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg)) for i in range(num_convs - 1): convs.append( ConvModule( self.channels, self.channels, kernel_size=kernel_size, padding=kernel_size // 2, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg)) if num_convs == 0: self.convs = nn.Identity() else: self.convs = nn.Sequential(*convs) if self.concat_input: self.conv_cat = ConvModule( self.in_channels + self.channels, self.channels, kernel_size=kernel_size, padding=kernel_size // 2, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) def forward(self, inputs): """Forward function.""" x = self._transform_inputs(inputs) output = self.convs(x) if self.concat_input: output = self.conv_cat(torch.cat([x, output], dim=1)) output = self.cls_seg(output) return output class MultiHeadFCNHead(nn.Module): """Fully Convolution Networks for Semantic Segmentation. This head is implemented of `FCNNet `_. Args: num_convs (int): Number of convs in the head. Default: 2. kernel_size (int): The kernel size for convs in the head. Default: 3. concat_input (bool): Whether concat the input and output of convs before classification layer. """ def __init__(self, in_channels, channels, *, num_classes, dropout_ratio=0.1, conv_cfg=None, norm_cfg=dict(type='BN'), act_cfg=dict(type='ReLU'), in_index=-1, input_transform=None, ignore_index=255, align_corners=False, num_convs=2, kernel_size=3, concat_input=True, num_head=18, **kwargs): super(MultiHeadFCNHead, self).__init__() assert num_convs >= 0 self.num_convs = num_convs self.concat_input = concat_input self.kernel_size = kernel_size self._init_inputs(in_channels, in_index, input_transform) self.channels = channels self.num_classes = num_classes self.dropout_ratio = dropout_ratio self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg self.act_cfg = act_cfg self.in_index = in_index self.num_head = num_head self.ignore_index = ignore_index self.align_corners = align_corners if dropout_ratio > 0: self.dropout = nn.Dropout2d(dropout_ratio) conv_seg_head_list = [] for _ in range(self.num_head): conv_seg_head_list.append( nn.Conv2d(channels, num_classes, kernel_size=1)) self.conv_seg_head_list = nn.ModuleList(conv_seg_head_list) self.init_weights() if num_convs == 0: assert self.in_channels == self.channels convs_list = [] conv_cat_list = [] for _ in range(self.num_head): convs = [] convs.append( ConvModule( self.in_channels, self.channels, kernel_size=kernel_size, padding=kernel_size // 2, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg)) for _ in range(num_convs - 1): convs.append( ConvModule( self.channels, self.channels, kernel_size=kernel_size, padding=kernel_size // 2, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg)) if num_convs == 0: convs_list.append(nn.Identity()) else: convs_list.append(nn.Sequential(*convs)) if self.concat_input: conv_cat_list.append( ConvModule( self.in_channels + self.channels, self.channels, kernel_size=kernel_size, padding=kernel_size // 2, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg)) self.convs_list = nn.ModuleList(convs_list) self.conv_cat_list = nn.ModuleList(conv_cat_list) def forward(self, inputs): """Forward function.""" x = self._transform_inputs(inputs) output_list = [] for head_idx in range(self.num_head): output = self.convs_list[head_idx](x) if self.concat_input: output = self.conv_cat_list[head_idx]( torch.cat([x, output], dim=1)) if self.dropout is not None: output = self.dropout(output) output = self.conv_seg_head_list[head_idx](output) output_list.append(output) return output_list def _init_inputs(self, in_channels, in_index, input_transform): """Check and initialize input transforms. The in_channels, in_index and input_transform must match. Specifically, when input_transform is None, only single feature map will be selected. So in_channels and in_index must be of type int. When input_transform Args: in_channels (int|Sequence[int]): Input channels. in_index (int|Sequence[int]): Input feature index. input_transform (str|None): Transformation type of input features. Options: 'resize_concat', 'multiple_select', None. 'resize_concat': Multiple feature maps will be resize to the same size as first one and than concat together. Usually used in FCN head of HRNet. 'multiple_select': Multiple feature maps will be bundle into a list and passed into decode head. None: Only one select feature map is allowed. """ if input_transform is not None: assert input_transform in ['resize_concat', 'multiple_select'] self.input_transform = input_transform self.in_index = in_index if input_transform is not None: assert isinstance(in_channels, (list, tuple)) assert isinstance(in_index, (list, tuple)) assert len(in_channels) == len(in_index) if input_transform == 'resize_concat': self.in_channels = sum(in_channels) else: self.in_channels = in_channels else: assert isinstance(in_channels, int) assert isinstance(in_index, int) self.in_channels = in_channels def init_weights(self): """Initialize weights of classification layer.""" for conv_seg_head in self.conv_seg_head_list: normal_init(conv_seg_head, mean=0, std=0.01) def _transform_inputs(self, inputs): """Transform inputs for decoder. Args: inputs (list[Tensor]): List of multi-level img features. Returns: Tensor: The transformed inputs """ if self.input_transform == 'resize_concat': inputs = [inputs[i] for i in self.in_index] upsampled_inputs = [ resize( input=x, size=inputs[0].shape[2:], mode='bilinear', align_corners=self.align_corners) for x in inputs ] inputs = torch.cat(upsampled_inputs, dim=1) elif self.input_transform == 'multiple_select': inputs = [inputs[i] for i in self.in_index] else: inputs = inputs[self.in_index] return inputs