Spaces:
Runtime error
Runtime error
File size: 6,187 Bytes
f549064 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Sequence
import torch
import torch.nn as nn
from mmcv.cnn.bricks import (Conv2dAdaptivePadding, build_activation_layer,
build_norm_layer)
from mmengine.utils import digit_version
from mmcls.registry import MODELS
from .base_backbone import BaseBackbone
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x):
return self.fn(x) + x
@MODELS.register_module()
class ConvMixer(BaseBackbone):
"""ConvMixer. .
A PyTorch implementation of : `Patches Are All You Need?
<https://arxiv.org/pdf/2201.09792.pdf>`_
Modified from the `official repo
<https://github.com/locuslab/convmixer/blob/main/convmixer.py>`_
and `timm
<https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/convmixer.py>`_.
Args:
arch (str | dict): The model's architecture. If string, it should be
one of architecture in ``ConvMixer.arch_settings``. And if dict, it
should include the following two keys:
- embed_dims (int): The dimensions of patch embedding.
- depth (int): Number of repetitions of ConvMixer Layer.
- patch_size (int): The patch size.
- kernel_size (int): The kernel size of depthwise conv layers.
Defaults to '768/32'.
in_channels (int): Number of input image channels. Defaults to 3.
patch_size (int): The size of one patch in the patch embed layer.
Defaults to 7.
norm_cfg (dict): The config dict for norm layers.
Defaults to ``dict(type='BN')``.
act_cfg (dict): The config dict for activation after each convolution.
Defaults to ``dict(type='GELU')``.
out_indices (Sequence | int): Output from which stages.
Defaults to -1, means the last stage.
frozen_stages (int): Stages to be frozen (all param fixed).
Defaults to 0, which means not freezing any parameters.
init_cfg (dict, optional): Initialization config dict.
"""
arch_settings = {
'768/32': {
'embed_dims': 768,
'depth': 32,
'patch_size': 7,
'kernel_size': 7
},
'1024/20': {
'embed_dims': 1024,
'depth': 20,
'patch_size': 14,
'kernel_size': 9
},
'1536/20': {
'embed_dims': 1536,
'depth': 20,
'patch_size': 7,
'kernel_size': 9
},
}
def __init__(self,
arch='768/32',
in_channels=3,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='GELU'),
out_indices=-1,
frozen_stages=0,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
if isinstance(arch, str):
assert arch in self.arch_settings, \
f'Unavailable arch, please choose from ' \
f'({set(self.arch_settings)}) or pass a dict.'
arch = self.arch_settings[arch]
elif isinstance(arch, dict):
essential_keys = {
'embed_dims', 'depth', 'patch_size', 'kernel_size'
}
assert isinstance(arch, dict) and essential_keys <= set(arch), \
f'Custom arch needs a dict with keys {essential_keys}'
self.embed_dims = arch['embed_dims']
self.depth = arch['depth']
self.patch_size = arch['patch_size']
self.kernel_size = arch['kernel_size']
self.act = build_activation_layer(act_cfg)
# check out indices and frozen stages
if isinstance(out_indices, int):
out_indices = [out_indices]
assert isinstance(out_indices, Sequence), \
f'"out_indices" must by a sequence or int, ' \
f'get {type(out_indices)} instead.'
for i, index in enumerate(out_indices):
if index < 0:
out_indices[i] = self.depth + index
assert out_indices[i] >= 0, f'Invalid out_indices {index}'
self.out_indices = out_indices
self.frozen_stages = frozen_stages
# Set stem layers
self.stem = nn.Sequential(
nn.Conv2d(
in_channels,
self.embed_dims,
kernel_size=self.patch_size,
stride=self.patch_size), self.act,
build_norm_layer(norm_cfg, self.embed_dims)[1])
# Set conv2d according to torch version
convfunc = nn.Conv2d
if digit_version(torch.__version__) < digit_version('1.9.0'):
convfunc = Conv2dAdaptivePadding
# Repetitions of ConvMixer Layer
self.stages = nn.Sequential(*[
nn.Sequential(
Residual(
nn.Sequential(
convfunc(
self.embed_dims,
self.embed_dims,
self.kernel_size,
groups=self.embed_dims,
padding='same'), self.act,
build_norm_layer(norm_cfg, self.embed_dims)[1])),
nn.Conv2d(self.embed_dims, self.embed_dims, kernel_size=1),
self.act,
build_norm_layer(norm_cfg, self.embed_dims)[1])
for _ in range(self.depth)
])
self._freeze_stages()
def forward(self, x):
x = self.stem(x)
outs = []
for i, stage in enumerate(self.stages):
x = stage(x)
if i in self.out_indices:
outs.append(x)
# x = self.pooling(x).flatten(1)
return tuple(outs)
def train(self, mode=True):
super(ConvMixer, self).train(mode)
self._freeze_stages()
def _freeze_stages(self):
for i in range(self.frozen_stages):
stage = self.stages[i]
stage.eval()
for param in stage.parameters():
param.requires_grad = False
|