Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
from mmcv.cnn import ConvModule, normal_init | |
from mmseg.ops import resize | |
class BaseDecodeHead(nn.Module): | |
"""Base class for BaseDecodeHead. | |
Args: | |
in_channels (int|Sequence[int]): Input channels. | |
channels (int): Channels after modules, before conv_seg. | |
num_classes (int): Number of classes. | |
dropout_ratio (float): Ratio of dropout layer. Default: 0.1. | |
conv_cfg (dict|None): Config of conv layers. Default: None. | |
norm_cfg (dict|None): Config of norm layers. Default: None. | |
act_cfg (dict): Config of activation layers. | |
Default: dict(type='ReLU') | |
in_index (int|Sequence[int]): Input feature index. Default: -1 | |
input_transform (str|None): Transformation type of input features. | |
Options: 'resize_concat', 'multiple_select', None. | |
'resize_concat': Multiple feature maps will be resize to the | |
same size as first one and than concat together. | |
Usually used in FCN head of HRNet. | |
'multiple_select': Multiple feature maps will be bundle into | |
a list and passed into decode head. | |
None: Only one select feature map is allowed. | |
Default: None. | |
loss_decode (dict): Config of decode loss. | |
Default: dict(type='CrossEntropyLoss'). | |
ignore_index (int | None): The label index to be ignored. When using | |
masked BCE loss, ignore_index should be set to None. Default: 255 | |
sampler (dict|None): The config of segmentation map sampler. | |
Default: None. | |
align_corners (bool): align_corners argument of F.interpolate. | |
Default: False. | |
""" | |
def __init__(self, | |
in_channels, | |
channels, | |
*, | |
num_classes, | |
dropout_ratio=0.1, | |
conv_cfg=None, | |
norm_cfg=dict(type='BN'), | |
act_cfg=dict(type='ReLU'), | |
in_index=-1, | |
input_transform=None, | |
ignore_index=255, | |
align_corners=False): | |
super(BaseDecodeHead, self).__init__() | |
self._init_inputs(in_channels, in_index, input_transform) | |
self.channels = channels | |
self.num_classes = num_classes | |
self.dropout_ratio = dropout_ratio | |
self.conv_cfg = conv_cfg | |
self.norm_cfg = norm_cfg | |
self.act_cfg = act_cfg | |
self.in_index = in_index | |
self.ignore_index = ignore_index | |
self.align_corners = align_corners | |
self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1) | |
if dropout_ratio > 0: | |
self.dropout = nn.Dropout2d(dropout_ratio) | |
else: | |
self.dropout = None | |
def extra_repr(self): | |
"""Extra repr.""" | |
s = f'input_transform={self.input_transform}, ' \ | |
f'ignore_index={self.ignore_index}, ' \ | |
f'align_corners={self.align_corners}' | |
return s | |
def _init_inputs(self, in_channels, in_index, input_transform): | |
"""Check and initialize input transforms. | |
The in_channels, in_index and input_transform must match. | |
Specifically, when input_transform is None, only single feature map | |
will be selected. So in_channels and in_index must be of type int. | |
When input_transform | |
Args: | |
in_channels (int|Sequence[int]): Input channels. | |
in_index (int|Sequence[int]): Input feature index. | |
input_transform (str|None): Transformation type of input features. | |
Options: 'resize_concat', 'multiple_select', None. | |
'resize_concat': Multiple feature maps will be resize to the | |
same size as first one and than concat together. | |
Usually used in FCN head of HRNet. | |
'multiple_select': Multiple feature maps will be bundle into | |
a list and passed into decode head. | |
None: Only one select feature map is allowed. | |
""" | |
if input_transform is not None: | |
assert input_transform in ['resize_concat', 'multiple_select'] | |
self.input_transform = input_transform | |
self.in_index = in_index | |
if input_transform is not None: | |
assert isinstance(in_channels, (list, tuple)) | |
assert isinstance(in_index, (list, tuple)) | |
assert len(in_channels) == len(in_index) | |
if input_transform == 'resize_concat': | |
self.in_channels = sum(in_channels) | |
else: | |
self.in_channels = in_channels | |
else: | |
assert isinstance(in_channels, int) | |
assert isinstance(in_index, int) | |
self.in_channels = in_channels | |
def init_weights(self): | |
"""Initialize weights of classification layer.""" | |
normal_init(self.conv_seg, mean=0, std=0.01) | |
def _transform_inputs(self, inputs): | |
"""Transform inputs for decoder. | |
Args: | |
inputs (list[Tensor]): List of multi-level img features. | |
Returns: | |
Tensor: The transformed inputs | |
""" | |
if self.input_transform == 'resize_concat': | |
inputs = [inputs[i] for i in self.in_index] | |
upsampled_inputs = [ | |
resize( | |
input=x, | |
size=inputs[0].shape[2:], | |
mode='bilinear', | |
align_corners=self.align_corners) for x in inputs | |
] | |
inputs = torch.cat(upsampled_inputs, dim=1) | |
elif self.input_transform == 'multiple_select': | |
inputs = [inputs[i] for i in self.in_index] | |
else: | |
inputs = inputs[self.in_index] | |
return inputs | |
def forward(self, inputs): | |
"""Placeholder of forward function.""" | |
pass | |
def cls_seg(self, feat): | |
"""Classify each pixel.""" | |
if self.dropout is not None: | |
feat = self.dropout(feat) | |
output = self.conv_seg(feat) | |
return output | |
class FCNHead(BaseDecodeHead): | |
"""Fully Convolution Networks for Semantic Segmentation. | |
This head is implemented of `FCNNet <https://arxiv.org/abs/1411.4038>`_. | |
Args: | |
num_convs (int): Number of convs in the head. Default: 2. | |
kernel_size (int): The kernel size for convs in the head. Default: 3. | |
concat_input (bool): Whether concat the input and output of convs | |
before classification layer. | |
""" | |
def __init__(self, | |
num_convs=2, | |
kernel_size=3, | |
concat_input=True, | |
**kwargs): | |
assert num_convs >= 0 | |
self.num_convs = num_convs | |
self.concat_input = concat_input | |
self.kernel_size = kernel_size | |
super(FCNHead, self).__init__(**kwargs) | |
if num_convs == 0: | |
assert self.in_channels == self.channels | |
convs = [] | |
convs.append( | |
ConvModule( | |
self.in_channels, | |
self.channels, | |
kernel_size=kernel_size, | |
padding=kernel_size // 2, | |
conv_cfg=self.conv_cfg, | |
norm_cfg=self.norm_cfg, | |
act_cfg=self.act_cfg)) | |
for i in range(num_convs - 1): | |
convs.append( | |
ConvModule( | |
self.channels, | |
self.channels, | |
kernel_size=kernel_size, | |
padding=kernel_size // 2, | |
conv_cfg=self.conv_cfg, | |
norm_cfg=self.norm_cfg, | |
act_cfg=self.act_cfg)) | |
if num_convs == 0: | |
self.convs = nn.Identity() | |
else: | |
self.convs = nn.Sequential(*convs) | |
if self.concat_input: | |
self.conv_cat = ConvModule( | |
self.in_channels + self.channels, | |
self.channels, | |
kernel_size=kernel_size, | |
padding=kernel_size // 2, | |
conv_cfg=self.conv_cfg, | |
norm_cfg=self.norm_cfg, | |
act_cfg=self.act_cfg) | |
def forward(self, inputs): | |
"""Forward function.""" | |
x = self._transform_inputs(inputs) | |
output = self.convs(x) | |
if self.concat_input: | |
output = self.conv_cat(torch.cat([x, output], dim=1)) | |
output = self.cls_seg(output) | |
return output | |
class MultiHeadFCNHead(nn.Module): | |
"""Fully Convolution Networks for Semantic Segmentation. | |
This head is implemented of `FCNNet <https://arxiv.org/abs/1411.4038>`_. | |
Args: | |
num_convs (int): Number of convs in the head. Default: 2. | |
kernel_size (int): The kernel size for convs in the head. Default: 3. | |
concat_input (bool): Whether concat the input and output of convs | |
before classification layer. | |
""" | |
def __init__(self, | |
in_channels, | |
channels, | |
*, | |
num_classes, | |
dropout_ratio=0.1, | |
conv_cfg=None, | |
norm_cfg=dict(type='BN'), | |
act_cfg=dict(type='ReLU'), | |
in_index=-1, | |
input_transform=None, | |
ignore_index=255, | |
align_corners=False, | |
num_convs=2, | |
kernel_size=3, | |
concat_input=True, | |
num_head=18, | |
**kwargs): | |
super(MultiHeadFCNHead, self).__init__() | |
assert num_convs >= 0 | |
self.num_convs = num_convs | |
self.concat_input = concat_input | |
self.kernel_size = kernel_size | |
self._init_inputs(in_channels, in_index, input_transform) | |
self.channels = channels | |
self.num_classes = num_classes | |
self.dropout_ratio = dropout_ratio | |
self.conv_cfg = conv_cfg | |
self.norm_cfg = norm_cfg | |
self.act_cfg = act_cfg | |
self.in_index = in_index | |
self.num_head = num_head | |
self.ignore_index = ignore_index | |
self.align_corners = align_corners | |
if dropout_ratio > 0: | |
self.dropout = nn.Dropout2d(dropout_ratio) | |
conv_seg_head_list = [] | |
for _ in range(self.num_head): | |
conv_seg_head_list.append( | |
nn.Conv2d(channels, num_classes, kernel_size=1)) | |
self.conv_seg_head_list = nn.ModuleList(conv_seg_head_list) | |
self.init_weights() | |
if num_convs == 0: | |
assert self.in_channels == self.channels | |
convs_list = [] | |
conv_cat_list = [] | |
for _ in range(self.num_head): | |
convs = [] | |
convs.append( | |
ConvModule( | |
self.in_channels, | |
self.channels, | |
kernel_size=kernel_size, | |
padding=kernel_size // 2, | |
conv_cfg=self.conv_cfg, | |
norm_cfg=self.norm_cfg, | |
act_cfg=self.act_cfg)) | |
for _ in range(num_convs - 1): | |
convs.append( | |
ConvModule( | |
self.channels, | |
self.channels, | |
kernel_size=kernel_size, | |
padding=kernel_size // 2, | |
conv_cfg=self.conv_cfg, | |
norm_cfg=self.norm_cfg, | |
act_cfg=self.act_cfg)) | |
if num_convs == 0: | |
convs_list.append(nn.Identity()) | |
else: | |
convs_list.append(nn.Sequential(*convs)) | |
if self.concat_input: | |
conv_cat_list.append( | |
ConvModule( | |
self.in_channels + self.channels, | |
self.channels, | |
kernel_size=kernel_size, | |
padding=kernel_size // 2, | |
conv_cfg=self.conv_cfg, | |
norm_cfg=self.norm_cfg, | |
act_cfg=self.act_cfg)) | |
self.convs_list = nn.ModuleList(convs_list) | |
self.conv_cat_list = nn.ModuleList(conv_cat_list) | |
def forward(self, inputs): | |
"""Forward function.""" | |
x = self._transform_inputs(inputs) | |
output_list = [] | |
for head_idx in range(self.num_head): | |
output = self.convs_list[head_idx](x) | |
if self.concat_input: | |
output = self.conv_cat_list[head_idx]( | |
torch.cat([x, output], dim=1)) | |
if self.dropout is not None: | |
output = self.dropout(output) | |
output = self.conv_seg_head_list[head_idx](output) | |
output_list.append(output) | |
return output_list | |
def _init_inputs(self, in_channels, in_index, input_transform): | |
"""Check and initialize input transforms. | |
The in_channels, in_index and input_transform must match. | |
Specifically, when input_transform is None, only single feature map | |
will be selected. So in_channels and in_index must be of type int. | |
When input_transform | |
Args: | |
in_channels (int|Sequence[int]): Input channels. | |
in_index (int|Sequence[int]): Input feature index. | |
input_transform (str|None): Transformation type of input features. | |
Options: 'resize_concat', 'multiple_select', None. | |
'resize_concat': Multiple feature maps will be resize to the | |
same size as first one and than concat together. | |
Usually used in FCN head of HRNet. | |
'multiple_select': Multiple feature maps will be bundle into | |
a list and passed into decode head. | |
None: Only one select feature map is allowed. | |
""" | |
if input_transform is not None: | |
assert input_transform in ['resize_concat', 'multiple_select'] | |
self.input_transform = input_transform | |
self.in_index = in_index | |
if input_transform is not None: | |
assert isinstance(in_channels, (list, tuple)) | |
assert isinstance(in_index, (list, tuple)) | |
assert len(in_channels) == len(in_index) | |
if input_transform == 'resize_concat': | |
self.in_channels = sum(in_channels) | |
else: | |
self.in_channels = in_channels | |
else: | |
assert isinstance(in_channels, int) | |
assert isinstance(in_index, int) | |
self.in_channels = in_channels | |
def init_weights(self): | |
"""Initialize weights of classification layer.""" | |
for conv_seg_head in self.conv_seg_head_list: | |
normal_init(conv_seg_head, mean=0, std=0.01) | |
def _transform_inputs(self, inputs): | |
"""Transform inputs for decoder. | |
Args: | |
inputs (list[Tensor]): List of multi-level img features. | |
Returns: | |
Tensor: The transformed inputs | |
""" | |
if self.input_transform == 'resize_concat': | |
inputs = [inputs[i] for i in self.in_index] | |
upsampled_inputs = [ | |
resize( | |
input=x, | |
size=inputs[0].shape[2:], | |
mode='bilinear', | |
align_corners=self.align_corners) for x in inputs | |
] | |
inputs = torch.cat(upsampled_inputs, dim=1) | |
elif self.input_transform == 'multiple_select': | |
inputs = [inputs[i] for i in self.in_index] | |
else: | |
inputs = inputs[self.in_index] | |
return inputs | |