Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import List, Tuple, Union | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from mmcv.cnn import Conv2d, ConvModule | |
from mmengine.model import BaseModule, ModuleList, caffe2_xavier_init | |
from torch import Tensor | |
from mmdet.registry import MODELS | |
from mmdet.utils import ConfigType, OptMultiConfig | |
from .positional_encoding import SinePositionalEncoding | |
from .transformer import DetrTransformerEncoder | |
class PixelDecoder(BaseModule): | |
"""Pixel decoder with a structure like fpn. | |
Args: | |
in_channels (list[int] | tuple[int]): Number of channels in the | |
input feature maps. | |
feat_channels (int): Number channels for feature. | |
out_channels (int): Number channels for output. | |
norm_cfg (:obj:`ConfigDict` or dict): Config for normalization. | |
Defaults to dict(type='GN', num_groups=32). | |
act_cfg (:obj:`ConfigDict` or dict): Config for activation. | |
Defaults to dict(type='ReLU'). | |
encoder (:obj:`ConfigDict` or dict): Config for transorformer | |
encoder.Defaults to None. | |
positional_encoding (:obj:`ConfigDict` or dict): Config for | |
transformer encoder position encoding. Defaults to | |
dict(type='SinePositionalEncoding', num_feats=128, | |
normalize=True). | |
init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \ | |
dict], optional): Initialization config dict. Defaults to None. | |
""" | |
def __init__(self, | |
in_channels: Union[List[int], Tuple[int]], | |
feat_channels: int, | |
out_channels: int, | |
norm_cfg: ConfigType = dict(type='GN', num_groups=32), | |
act_cfg: ConfigType = dict(type='ReLU'), | |
init_cfg: OptMultiConfig = None) -> None: | |
super().__init__(init_cfg=init_cfg) | |
self.in_channels = in_channels | |
self.num_inputs = len(in_channels) | |
self.lateral_convs = ModuleList() | |
self.output_convs = ModuleList() | |
self.use_bias = norm_cfg is None | |
for i in range(0, self.num_inputs - 1): | |
lateral_conv = ConvModule( | |
in_channels[i], | |
feat_channels, | |
kernel_size=1, | |
bias=self.use_bias, | |
norm_cfg=norm_cfg, | |
act_cfg=None) | |
output_conv = ConvModule( | |
feat_channels, | |
feat_channels, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
bias=self.use_bias, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg) | |
self.lateral_convs.append(lateral_conv) | |
self.output_convs.append(output_conv) | |
self.last_feat_conv = ConvModule( | |
in_channels[-1], | |
feat_channels, | |
kernel_size=3, | |
padding=1, | |
stride=1, | |
bias=self.use_bias, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg) | |
self.mask_feature = Conv2d( | |
feat_channels, out_channels, kernel_size=3, stride=1, padding=1) | |
def init_weights(self) -> None: | |
"""Initialize weights.""" | |
for i in range(0, self.num_inputs - 2): | |
caffe2_xavier_init(self.lateral_convs[i].conv, bias=0) | |
caffe2_xavier_init(self.output_convs[i].conv, bias=0) | |
caffe2_xavier_init(self.mask_feature, bias=0) | |
caffe2_xavier_init(self.last_feat_conv, bias=0) | |
def forward(self, feats: List[Tensor], | |
batch_img_metas: List[dict]) -> Tuple[Tensor, Tensor]: | |
""" | |
Args: | |
feats (list[Tensor]): Feature maps of each level. Each has | |
shape of (batch_size, c, h, w). | |
batch_img_metas (list[dict]): List of image information. | |
Pass in for creating more accurate padding mask. Not | |
used here. | |
Returns: | |
tuple[Tensor, Tensor]: a tuple containing the following: | |
- mask_feature (Tensor): Shape (batch_size, c, h, w). | |
- memory (Tensor): Output of last stage of backbone.\ | |
Shape (batch_size, c, h, w). | |
""" | |
y = self.last_feat_conv(feats[-1]) | |
for i in range(self.num_inputs - 2, -1, -1): | |
x = feats[i] | |
cur_feat = self.lateral_convs[i](x) | |
y = cur_feat + \ | |
F.interpolate(y, size=cur_feat.shape[-2:], mode='nearest') | |
y = self.output_convs[i](y) | |
mask_feature = self.mask_feature(y) | |
memory = feats[-1] | |
return mask_feature, memory | |
class TransformerEncoderPixelDecoder(PixelDecoder): | |
"""Pixel decoder with transormer encoder inside. | |
Args: | |
in_channels (list[int] | tuple[int]): Number of channels in the | |
input feature maps. | |
feat_channels (int): Number channels for feature. | |
out_channels (int): Number channels for output. | |
norm_cfg (:obj:`ConfigDict` or dict): Config for normalization. | |
Defaults to dict(type='GN', num_groups=32). | |
act_cfg (:obj:`ConfigDict` or dict): Config for activation. | |
Defaults to dict(type='ReLU'). | |
encoder (:obj:`ConfigDict` or dict): Config for transformer encoder. | |
Defaults to None. | |
positional_encoding (:obj:`ConfigDict` or dict): Config for | |
transformer encoder position encoding. Defaults to | |
dict(num_feats=128, normalize=True). | |
init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \ | |
dict], optional): Initialization config dict. Defaults to None. | |
""" | |
def __init__(self, | |
in_channels: Union[List[int], Tuple[int]], | |
feat_channels: int, | |
out_channels: int, | |
norm_cfg: ConfigType = dict(type='GN', num_groups=32), | |
act_cfg: ConfigType = dict(type='ReLU'), | |
encoder: ConfigType = None, | |
positional_encoding: ConfigType = dict( | |
num_feats=128, normalize=True), | |
init_cfg: OptMultiConfig = None) -> None: | |
super().__init__( | |
in_channels=in_channels, | |
feat_channels=feat_channels, | |
out_channels=out_channels, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg, | |
init_cfg=init_cfg) | |
self.last_feat_conv = None | |
self.encoder = DetrTransformerEncoder(**encoder) | |
self.encoder_embed_dims = self.encoder.embed_dims | |
assert self.encoder_embed_dims == feat_channels, 'embed_dims({}) of ' \ | |
'tranformer encoder must equal to feat_channels({})'.format( | |
feat_channels, self.encoder_embed_dims) | |
self.positional_encoding = SinePositionalEncoding( | |
**positional_encoding) | |
self.encoder_in_proj = Conv2d( | |
in_channels[-1], feat_channels, kernel_size=1) | |
self.encoder_out_proj = ConvModule( | |
feat_channels, | |
feat_channels, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
bias=self.use_bias, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg) | |
def init_weights(self) -> None: | |
"""Initialize weights.""" | |
for i in range(0, self.num_inputs - 2): | |
caffe2_xavier_init(self.lateral_convs[i].conv, bias=0) | |
caffe2_xavier_init(self.output_convs[i].conv, bias=0) | |
caffe2_xavier_init(self.mask_feature, bias=0) | |
caffe2_xavier_init(self.encoder_in_proj, bias=0) | |
caffe2_xavier_init(self.encoder_out_proj.conv, bias=0) | |
for p in self.encoder.parameters(): | |
if p.dim() > 1: | |
nn.init.xavier_uniform_(p) | |
def forward(self, feats: List[Tensor], | |
batch_img_metas: List[dict]) -> Tuple[Tensor, Tensor]: | |
""" | |
Args: | |
feats (list[Tensor]): Feature maps of each level. Each has | |
shape of (batch_size, c, h, w). | |
batch_img_metas (list[dict]): List of image information. Pass in | |
for creating more accurate padding mask. | |
Returns: | |
tuple: a tuple containing the following: | |
- mask_feature (Tensor): shape (batch_size, c, h, w). | |
- memory (Tensor): shape (batch_size, c, h, w). | |
""" | |
feat_last = feats[-1] | |
bs, c, h, w = feat_last.shape | |
input_img_h, input_img_w = batch_img_metas[0]['batch_input_shape'] | |
padding_mask = feat_last.new_ones((bs, input_img_h, input_img_w), | |
dtype=torch.float32) | |
for i in range(bs): | |
img_h, img_w = batch_img_metas[i]['img_shape'] | |
padding_mask[i, :img_h, :img_w] = 0 | |
padding_mask = F.interpolate( | |
padding_mask.unsqueeze(1), | |
size=feat_last.shape[-2:], | |
mode='nearest').to(torch.bool).squeeze(1) | |
pos_embed = self.positional_encoding(padding_mask) | |
feat_last = self.encoder_in_proj(feat_last) | |
# (batch_size, c, h, w) -> (batch_size, num_queries, c) | |
feat_last = feat_last.flatten(2).permute(0, 2, 1) | |
pos_embed = pos_embed.flatten(2).permute(0, 2, 1) | |
# (batch_size, h, w) -> (batch_size, h*w) | |
padding_mask = padding_mask.flatten(1) | |
memory = self.encoder( | |
query=feat_last, | |
query_pos=pos_embed, | |
key_padding_mask=padding_mask) | |
# (batch_size, num_queries, c) -> (batch_size, c, h, w) | |
memory = memory.permute(0, 2, 1).view(bs, self.encoder_embed_dims, h, | |
w) | |
y = self.encoder_out_proj(memory) | |
for i in range(self.num_inputs - 2, -1, -1): | |
x = feats[i] | |
cur_feat = self.lateral_convs[i](x) | |
y = cur_feat + \ | |
F.interpolate(y, size=cur_feat.shape[-2:], mode='nearest') | |
y = self.output_convs[i](y) | |
mask_feature = self.mask_feature(y) | |
return mask_feature, memory | |