KyanChen's picture
init
f549064
raw
history blame contribute delete
No virus
6.19 kB
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Sequence
import torch
import torch.nn as nn
from mmcv.cnn.bricks import (Conv2dAdaptivePadding, build_activation_layer,
build_norm_layer)
from mmengine.utils import digit_version
from mmcls.registry import MODELS
from .base_backbone import BaseBackbone
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x):
return self.fn(x) + x
@MODELS.register_module()
class ConvMixer(BaseBackbone):
"""ConvMixer. .
A PyTorch implementation of : `Patches Are All You Need?
<https://arxiv.org/pdf/2201.09792.pdf>`_
Modified from the `official repo
<https://github.com/locuslab/convmixer/blob/main/convmixer.py>`_
and `timm
<https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/convmixer.py>`_.
Args:
arch (str | dict): The model's architecture. If string, it should be
one of architecture in ``ConvMixer.arch_settings``. And if dict, it
should include the following two keys:
- embed_dims (int): The dimensions of patch embedding.
- depth (int): Number of repetitions of ConvMixer Layer.
- patch_size (int): The patch size.
- kernel_size (int): The kernel size of depthwise conv layers.
Defaults to '768/32'.
in_channels (int): Number of input image channels. Defaults to 3.
patch_size (int): The size of one patch in the patch embed layer.
Defaults to 7.
norm_cfg (dict): The config dict for norm layers.
Defaults to ``dict(type='BN')``.
act_cfg (dict): The config dict for activation after each convolution.
Defaults to ``dict(type='GELU')``.
out_indices (Sequence | int): Output from which stages.
Defaults to -1, means the last stage.
frozen_stages (int): Stages to be frozen (all param fixed).
Defaults to 0, which means not freezing any parameters.
init_cfg (dict, optional): Initialization config dict.
"""
arch_settings = {
'768/32': {
'embed_dims': 768,
'depth': 32,
'patch_size': 7,
'kernel_size': 7
},
'1024/20': {
'embed_dims': 1024,
'depth': 20,
'patch_size': 14,
'kernel_size': 9
},
'1536/20': {
'embed_dims': 1536,
'depth': 20,
'patch_size': 7,
'kernel_size': 9
},
}
def __init__(self,
arch='768/32',
in_channels=3,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='GELU'),
out_indices=-1,
frozen_stages=0,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
if isinstance(arch, str):
assert arch in self.arch_settings, \
f'Unavailable arch, please choose from ' \
f'({set(self.arch_settings)}) or pass a dict.'
arch = self.arch_settings[arch]
elif isinstance(arch, dict):
essential_keys = {
'embed_dims', 'depth', 'patch_size', 'kernel_size'
}
assert isinstance(arch, dict) and essential_keys <= set(arch), \
f'Custom arch needs a dict with keys {essential_keys}'
self.embed_dims = arch['embed_dims']
self.depth = arch['depth']
self.patch_size = arch['patch_size']
self.kernel_size = arch['kernel_size']
self.act = build_activation_layer(act_cfg)
# check out indices and frozen stages
if isinstance(out_indices, int):
out_indices = [out_indices]
assert isinstance(out_indices, Sequence), \
f'"out_indices" must by a sequence or int, ' \
f'get {type(out_indices)} instead.'
for i, index in enumerate(out_indices):
if index < 0:
out_indices[i] = self.depth + index
assert out_indices[i] >= 0, f'Invalid out_indices {index}'
self.out_indices = out_indices
self.frozen_stages = frozen_stages
# Set stem layers
self.stem = nn.Sequential(
nn.Conv2d(
in_channels,
self.embed_dims,
kernel_size=self.patch_size,
stride=self.patch_size), self.act,
build_norm_layer(norm_cfg, self.embed_dims)[1])
# Set conv2d according to torch version
convfunc = nn.Conv2d
if digit_version(torch.__version__) < digit_version('1.9.0'):
convfunc = Conv2dAdaptivePadding
# Repetitions of ConvMixer Layer
self.stages = nn.Sequential(*[
nn.Sequential(
Residual(
nn.Sequential(
convfunc(
self.embed_dims,
self.embed_dims,
self.kernel_size,
groups=self.embed_dims,
padding='same'), self.act,
build_norm_layer(norm_cfg, self.embed_dims)[1])),
nn.Conv2d(self.embed_dims, self.embed_dims, kernel_size=1),
self.act,
build_norm_layer(norm_cfg, self.embed_dims)[1])
for _ in range(self.depth)
])
self._freeze_stages()
def forward(self, x):
x = self.stem(x)
outs = []
for i, stage in enumerate(self.stages):
x = stage(x)
if i in self.out_indices:
outs.append(x)
# x = self.pooling(x).flatten(1)
return tuple(outs)
def train(self, mode=True):
super(ConvMixer, self).train(mode)
self._freeze_stages()
def _freeze_stages(self):
for i in range(self.frozen_stages):
stage = self.stages[i]
stage.eval()
for param in stage.parameters():
param.requires_grad = False