Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import List, Optional, Union | |
from mmcv.cnn import ConvModule | |
from torch import Tensor | |
from mmdet.registry import MODELS | |
from .fcn_mask_head import FCNMaskHead | |
class HTCMaskHead(FCNMaskHead): | |
"""Mask head for HTC. | |
Args: | |
with_conv_res (bool): Whether add conv layer for ``res_feat``. | |
Defaults to True. | |
""" | |
def __init__(self, with_conv_res: bool = True, *args, **kwargs) -> None: | |
super().__init__(*args, **kwargs) | |
self.with_conv_res = with_conv_res | |
if self.with_conv_res: | |
self.conv_res = ConvModule( | |
self.conv_out_channels, | |
self.conv_out_channels, | |
1, | |
conv_cfg=self.conv_cfg, | |
norm_cfg=self.norm_cfg) | |
def forward(self, | |
x: Tensor, | |
res_feat: Optional[Tensor] = None, | |
return_logits: bool = True, | |
return_feat: bool = True) -> Union[Tensor, List[Tensor]]: | |
""" | |
Args: | |
x (Tensor): Feature map. | |
res_feat (Tensor, optional): Feature for residual connection. | |
Defaults to None. | |
return_logits (bool): Whether return mask logits. Defaults to True. | |
return_feat (bool): Whether return feature map. Defaults to True. | |
Returns: | |
Union[Tensor, List[Tensor]]: The return result is one of three | |
results: res_feat, logits, or [logits, res_feat]. | |
""" | |
assert not (not return_logits and not return_feat) | |
if res_feat is not None: | |
assert self.with_conv_res | |
res_feat = self.conv_res(res_feat) | |
x = x + res_feat | |
for conv in self.convs: | |
x = conv(x) | |
res_feat = x | |
outs = [] | |
if return_logits: | |
x = self.upsample(x) | |
if self.upsample_method == 'deconv': | |
x = self.relu(x) | |
mask_preds = self.conv_logits(x) | |
outs.append(mask_preds) | |
if return_feat: | |
outs.append(res_feat) | |
return outs if len(outs) > 1 else outs[0] | |