AiOS / models /aios /backbones /backbone.py
ttxskk
update
d7e58f0
raw
history blame
8.53 kB
"""Backbone modules."""
from collections import OrderedDict
import os
import torch
import torch.nn.functional as F
import torchvision
from torch import nn
from torchvision.models._utils import IntermediateLayerGetter
from typing import Dict, List
from util.misc import NestedTensor, clean_state_dict, is_main_process
from ..position_encoding import build_position_encoding
from .swin_transformer import build_swin_transformer
class FrozenBatchNorm2d(torch.nn.Module):
"""BatchNorm2d where the batch statistics and the affine parameters are
fixed.
Copy-paste from torchvision.misc.ops with added eps before rqsrt, without
which any other models than torchvision.models.resnet[18,34,50,101] produce
nans.
"""
def __init__(self, n):
super(FrozenBatchNorm2d, self).__init__()
self.register_buffer('weight', torch.ones(n))
self.register_buffer('bias', torch.zeros(n))
self.register_buffer('running_mean', torch.zeros(n))
self.register_buffer('running_var', torch.ones(n))
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
num_batches_tracked_key = prefix + 'num_batches_tracked'
if num_batches_tracked_key in state_dict:
del state_dict[num_batches_tracked_key]
super(FrozenBatchNorm2d,
self)._load_from_state_dict(state_dict, prefix, local_metadata,
strict, missing_keys,
unexpected_keys, error_msgs)
def forward(self, x):
w = self.weight.reshape(1, -1, 1, 1)
b = self.bias.reshape(1, -1, 1, 1)
rv = self.running_var.reshape(1, -1, 1, 1)
rm = self.running_mean.reshape(1, -1, 1, 1)
eps = 1e-5
scale = w * (rv + eps).rsqrt()
bias = b - rm * scale
return x * scale + bias
class BackboneBase(nn.Module):
def __init__(self, backbone: nn.Module, train_backbone: bool,
num_channels: int, return_interm_indices: list):
super().__init__()
for name, parameter in backbone.named_parameters():
if not train_backbone or 'layer0' not in name and 'layer1' not in name and 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
parameter.requires_grad_(False)
return_layers = {}
for idx, layer_index in enumerate(return_interm_indices):
return_layers.update({
'layer{}'.format(5 - len(return_interm_indices) + idx):
'{}'.format(layer_index)
})
self.body = IntermediateLayerGetter(backbone,
return_layers=return_layers)
self.num_channels = num_channels
def forward(self, tensor_list: NestedTensor):
xs = self.body(tensor_list.tensors)
out: Dict[str, NestedTensor] = {}
for name, x in xs.items():
m = tensor_list.mask
assert m is not None
mask = F.interpolate(m[None].float(),
size=x.shape[-2:]).to(torch.bool)[0]
out[name] = NestedTensor(x, mask)
return out
class Backbone(BackboneBase):
"""ResNet backbone with frozen BatchNorm."""
def __init__(
self,
name: str,
train_backbone: bool,
dilation: bool,
return_interm_indices: list,
batch_norm=FrozenBatchNorm2d,
):
if name in ['resnet18', 'resnet34', 'resnet50', 'resnet101']:
# backbone = getattr(torchvision.models, name)(
# replace_stride_with_dilation=[False, False, dilation],
# pretrained=is_main_process(), norm_layer=batch_norm)
backbone = getattr(torchvision.models, name)(
replace_stride_with_dilation=[False, False, dilation],
pretrained=False,
norm_layer=batch_norm)
else:
raise NotImplementedError(
'Why you can get here with name {}'.format(name))
assert name not in (
'resnet18',
'resnet34'), 'Only resnet50 and resnet101 are available.'
assert return_interm_indices in [[0, 1, 2, 3], [1, 2, 3], [3]]
num_channels_all = [256, 512, 1024, 2048]
num_channels = num_channels_all[4 - len(return_interm_indices):]
super().__init__(backbone, train_backbone, num_channels,
return_interm_indices)
class Joiner(nn.Sequential):
def __init__(self, backbone, position_embedding):
super().__init__(backbone, position_embedding)
def forward(self, tensor_list: NestedTensor):
xs = self[0](tensor_list)
out: List[NestedTensor] = []
pos = []
for name, x in xs.items():
out.append(x)
pos.append(self[1](x).to(x.tensors.dtype))
return out, pos
def build_backbone(args):
"""Useful args:
- backbone: backbone name
- lr_backbone:
- dilation
- return_interm_indices: available: [0,1,2,3], [1,2,3], [3]
- backbone_freeze_keywords:
- use_checkpoint: for swin only for now
"""
position_embedding = build_position_encoding(args)
train_backbone = args.lr_backbone > 0
if not train_backbone:
raise ValueError('Please set lr_backbone > 0')
return_interm_indices = args.return_interm_indices
assert return_interm_indices in [[0, 1, 2, 3], [1, 2, 3], [3]] # [1,2,3]
backbone_freeze_keywords = args.backbone_freeze_keywords # None
use_checkpoint = getattr(args, 'use_checkpoint', False) # False
if args.backbone in ['resnet50', 'resnet101']:
backbone = Backbone(args.backbone,
train_backbone,
args.dilation,
return_interm_indices,
batch_norm=FrozenBatchNorm2d)
bb_num_channels = backbone.num_channels
elif args.backbone in [
'swin_T_224_1k', 'swin_B_224_22k', 'swin_B_384_22k',
'swin_L_224_22k', 'swin_L_384_22k'
]:
pretrain_img_size = int(args.backbone.split('_')[-2])
backbone = build_swin_transformer(
args.backbone,
pretrain_img_size=pretrain_img_size,
out_indices=tuple(return_interm_indices),
dilation=args.dilation,
use_checkpoint=use_checkpoint)
# freeze some layers
if backbone_freeze_keywords is not None:
for name, parameter in backbone.named_parameters():
for keyword in backbone_freeze_keywords:
if keyword in name:
parameter.requires_grad_(False)
break
pretrained_dir = os.environ.get('pretrain_model_path')
# import pdb
# pdb.set_trace()
PTDICT = {
'swin_T_224_1k': 'swin_tiny_patch4_window7_224.pth',
'swin_B_384_22k': 'swin_base_patch4_window12_384.pth',
'swin_L_384_22k': 'swin_large_patch4_window12_384_22k.pth',
}
pretrainedpath = os.path.join(pretrained_dir, PTDICT[args.backbone])
checkpoint = torch.load(pretrainedpath, map_location='cpu')['model']
from collections import OrderedDict
def key_select_function(keyname):
if 'head' in keyname:
return False
if args.dilation and 'layers.3' in keyname:
return False
return True
_tmp_st = OrderedDict({
k: v
for k, v in clean_state_dict(checkpoint).items()
if key_select_function(k)
})
_tmp_st_output = backbone.load_state_dict(_tmp_st, strict=False)
print(str(_tmp_st_output))
bb_num_channels = backbone.num_features[4 -
len(return_interm_indices):]
else:
raise NotImplementedError('Unknown backbone {}'.format(args.backbone))
assert len(bb_num_channels) == len(
return_interm_indices
), f'len(bb_num_channels) {len(bb_num_channels)} != len(return_interm_indices) {len(return_interm_indices)}'
model = Joiner(backbone, position_embedding)
model.num_channels = bb_num_channels
assert isinstance(
bb_num_channels,
List), 'bb_num_channels is expected to be a List but {}'.format(
type(bb_num_channels))
# import pdb; pdb.set_trace()
return model