KyanChen's picture
init
f549064
raw
history blame contribute delete
No virus
2.19 kB
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Union
from mmcv.cnn import ConvModule
from torch import Tensor
from mmdet.registry import MODELS
from .fcn_mask_head import FCNMaskHead
@MODELS.register_module()
class HTCMaskHead(FCNMaskHead):
"""Mask head for HTC.
Args:
with_conv_res (bool): Whether add conv layer for ``res_feat``.
Defaults to True.
"""
def __init__(self, with_conv_res: bool = True, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.with_conv_res = with_conv_res
if self.with_conv_res:
self.conv_res = ConvModule(
self.conv_out_channels,
self.conv_out_channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg)
def forward(self,
x: Tensor,
res_feat: Optional[Tensor] = None,
return_logits: bool = True,
return_feat: bool = True) -> Union[Tensor, List[Tensor]]:
"""
Args:
x (Tensor): Feature map.
res_feat (Tensor, optional): Feature for residual connection.
Defaults to None.
return_logits (bool): Whether return mask logits. Defaults to True.
return_feat (bool): Whether return feature map. Defaults to True.
Returns:
Union[Tensor, List[Tensor]]: The return result is one of three
results: res_feat, logits, or [logits, res_feat].
"""
assert not (not return_logits and not return_feat)
if res_feat is not None:
assert self.with_conv_res
res_feat = self.conv_res(res_feat)
x = x + res_feat
for conv in self.convs:
x = conv(x)
res_feat = x
outs = []
if return_logits:
x = self.upsample(x)
if self.upsample_method == 'deconv':
x = self.relu(x)
mask_preds = self.conv_logits(x)
outs.append(mask_preds)
if return_feat:
outs.append(res_feat)
return outs if len(outs) > 1 else outs[0]