Spaces:
Starting
on
L40S
Starting
on
L40S
# Copyright (c) OpenMMLab. All rights reserved. | |
import tempfile | |
import pytest | |
import torch | |
from torch import nn | |
import mmcv | |
from mmcv.cnn.utils.weight_init import update_init_info | |
from mmcv.runner import BaseModule, ModuleDict, ModuleList, Sequential | |
from mmcv.utils import Registry, build_from_cfg | |
COMPONENTS = Registry('component') | |
FOOMODELS = Registry('model') | |
class FooConv1d(BaseModule): | |
def __init__(self, init_cfg=None): | |
super().__init__(init_cfg) | |
self.conv1d = nn.Conv1d(4, 1, 4) | |
def forward(self, x): | |
return self.conv1d(x) | |
class FooConv2d(BaseModule): | |
def __init__(self, init_cfg=None): | |
super().__init__(init_cfg) | |
self.conv2d = nn.Conv2d(3, 1, 3) | |
def forward(self, x): | |
return self.conv2d(x) | |
class FooLinear(BaseModule): | |
def __init__(self, init_cfg=None): | |
super().__init__(init_cfg) | |
self.linear = nn.Linear(3, 4) | |
def forward(self, x): | |
return self.linear(x) | |
class FooLinearConv1d(BaseModule): | |
def __init__(self, linear=None, conv1d=None, init_cfg=None): | |
super().__init__(init_cfg) | |
if linear is not None: | |
self.linear = build_from_cfg(linear, COMPONENTS) | |
if conv1d is not None: | |
self.conv1d = build_from_cfg(conv1d, COMPONENTS) | |
def forward(self, x): | |
x = self.linear(x) | |
return self.conv1d(x) | |
class FooModel(BaseModule): | |
def __init__(self, | |
component1=None, | |
component2=None, | |
component3=None, | |
component4=None, | |
init_cfg=None) -> None: | |
super().__init__(init_cfg) | |
if component1 is not None: | |
self.component1 = build_from_cfg(component1, COMPONENTS) | |
if component2 is not None: | |
self.component2 = build_from_cfg(component2, COMPONENTS) | |
if component3 is not None: | |
self.component3 = build_from_cfg(component3, COMPONENTS) | |
if component4 is not None: | |
self.component4 = build_from_cfg(component4, COMPONENTS) | |
# its type is not BaseModule, it can be initialized | |
# with "override" key. | |
self.reg = nn.Linear(3, 4) | |
def test_initilization_info_logger(): | |
# 'override' has higher priority | |
import os | |
import torch.nn as nn | |
from mmcv.utils.logging import get_logger | |
class OverloadInitConv(nn.Conv2d, BaseModule): | |
def init_weights(self): | |
for p in self.parameters(): | |
with torch.no_grad(): | |
p.fill_(1) | |
class CheckLoggerModel(BaseModule): | |
def __init__(self, init_cfg=None): | |
super().__init__(init_cfg) | |
self.conv1 = nn.Conv2d(1, 1, 1, 1) | |
self.conv2 = OverloadInitConv(1, 1, 1, 1) | |
self.conv3 = nn.Conv2d(1, 1, 1, 1) | |
self.fc1 = nn.Linear(1, 1) | |
init_cfg = [ | |
dict( | |
type='Normal', | |
layer='Conv2d', | |
std=0.01, | |
override=dict( | |
type='Normal', name='conv3', std=0.01, bias_prob=0.01)), | |
dict(type='Constant', layer='Linear', val=0., bias=1.) | |
] | |
model = CheckLoggerModel(init_cfg=init_cfg) | |
train_log = '20210720_132454.log' | |
workdir = tempfile.mkdtemp() | |
log_file = os.path.join(workdir, train_log) | |
# create a logger | |
get_logger('init_logger', log_file=log_file) | |
assert not hasattr(model, '_params_init_info') | |
model.init_weights() | |
# assert `_params_init_info` would be deleted after `init_weights` | |
assert not hasattr(model, '_params_init_info') | |
# assert initialization information has been dumped | |
assert os.path.exists(log_file) | |
lines = mmcv.list_from_file(log_file) | |
# check initialization information is right | |
for i, line in enumerate(lines): | |
if 'conv1.weight' in line: | |
assert 'NormalInit' in lines[i + 1] | |
if 'conv2.weight' in line: | |
assert 'OverloadInitConv' in lines[i + 1] | |
if 'fc1.weight' in line: | |
assert 'ConstantInit' in lines[i + 1] | |
# test corner case | |
class OverloadInitConvFc(nn.Conv2d, BaseModule): | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.conv1 = nn.Linear(1, 1) | |
def init_weights(self): | |
for p in self.parameters(): | |
with torch.no_grad(): | |
p.fill_(1) | |
class CheckLoggerModel(BaseModule): | |
def __init__(self, init_cfg=None): | |
super().__init__(init_cfg) | |
self.conv1 = nn.Conv2d(1, 1, 1, 1) | |
self.conv2 = OverloadInitConvFc(1, 1, 1, 1) | |
self.conv3 = nn.Conv2d(1, 1, 1, 1) | |
self.fc1 = nn.Linear(1, 1) | |
class TopLevelModule(BaseModule): | |
def __init__(self, init_cfg=None, checklog_init_cfg=None): | |
super().__init__(init_cfg) | |
self.module1 = CheckLoggerModel(checklog_init_cfg) | |
self.module2 = OverloadInitConvFc(1, 1, 1, 1) | |
checklog_init_cfg = [ | |
dict( | |
type='Normal', | |
layer='Conv2d', | |
std=0.01, | |
override=dict( | |
type='Normal', name='conv3', std=0.01, bias_prob=0.01)), | |
dict(type='Constant', layer='Linear', val=0., bias=1.) | |
] | |
top_level_init_cfg = [ | |
dict( | |
type='Normal', | |
layer='Conv2d', | |
std=0.01, | |
override=dict( | |
type='Normal', name='module2', std=0.01, bias_prob=0.01)) | |
] | |
model = TopLevelModule( | |
init_cfg=top_level_init_cfg, checklog_init_cfg=checklog_init_cfg) | |
model.module1.init_weights() | |
model.module2.init_weights() | |
model.init_weights() | |
model.module1.init_weights() | |
model.module2.init_weights() | |
assert not hasattr(model, '_params_init_info') | |
model.init_weights() | |
# assert `_params_init_info` would be deleted after `init_weights` | |
assert not hasattr(model, '_params_init_info') | |
# assert initialization information has been dumped | |
assert os.path.exists(log_file) | |
lines = mmcv.list_from_file(log_file) | |
# check initialization information is right | |
for i, line in enumerate(lines): | |
if 'TopLevelModule' in line and 'init_cfg' not in line: | |
# have been set init_flag | |
assert 'the same' in line | |
def test_update_init_info(): | |
class DummyModel(BaseModule): | |
def __init__(self, init_cfg=None): | |
super().__init__(init_cfg) | |
self.conv1 = nn.Conv2d(1, 1, 1, 1) | |
self.conv3 = nn.Conv2d(1, 1, 1, 1) | |
self.fc1 = nn.Linear(1, 1) | |
model = DummyModel() | |
from collections import defaultdict | |
model._params_init_info = defaultdict(dict) | |
for name, param in model.named_parameters(): | |
model._params_init_info[param]['init_info'] = 'init' | |
model._params_init_info[param]['tmp_mean_value'] = param.data.mean() | |
with torch.no_grad(): | |
for p in model.parameters(): | |
p.fill_(1) | |
update_init_info(model, init_info='fill_1') | |
for item in model._params_init_info.values(): | |
assert item['init_info'] == 'fill_1' | |
assert item['tmp_mean_value'] == 1 | |
# test assert for new parameters | |
model.conv1.bias = nn.Parameter(torch.ones_like(model.conv1.bias)) | |
with pytest.raises(AssertionError): | |
update_init_info(model, init_info=' ') | |
def test_model_weight_init(): | |
""" | |
Config | |
model (FooModel, Linear: weight=1, bias=2, Conv1d: weight=3, bias=4, | |
Conv2d: weight=5, bias=6) | |
├──component1 (FooConv1d) | |
├──component2 (FooConv2d) | |
├──component3 (FooLinear) | |
├──component4 (FooLinearConv1d) | |
├──linear (FooLinear) | |
├──conv1d (FooConv1d) | |
├──reg (nn.Linear) | |
Parameters after initialization | |
model (FooModel) | |
├──component1 (FooConv1d, weight=3, bias=4) | |
├──component2 (FooConv2d, weight=5, bias=6) | |
├──component3 (FooLinear, weight=1, bias=2) | |
├──component4 (FooLinearConv1d) | |
├──linear (FooLinear, weight=1, bias=2) | |
├──conv1d (FooConv1d, weight=3, bias=4) | |
├──reg (nn.Linear, weight=1, bias=2) | |
""" | |
model_cfg = dict( | |
type='FooModel', | |
init_cfg=[ | |
dict(type='Constant', val=1, bias=2, layer='Linear'), | |
dict(type='Constant', val=3, bias=4, layer='Conv1d'), | |
dict(type='Constant', val=5, bias=6, layer='Conv2d') | |
], | |
component1=dict(type='FooConv1d'), | |
component2=dict(type='FooConv2d'), | |
component3=dict(type='FooLinear'), | |
component4=dict( | |
type='FooLinearConv1d', | |
linear=dict(type='FooLinear'), | |
conv1d=dict(type='FooConv1d'))) | |
model = build_from_cfg(model_cfg, FOOMODELS) | |
model.init_weights() | |
assert torch.equal(model.component1.conv1d.weight, | |
torch.full(model.component1.conv1d.weight.shape, 3.0)) | |
assert torch.equal(model.component1.conv1d.bias, | |
torch.full(model.component1.conv1d.bias.shape, 4.0)) | |
assert torch.equal(model.component2.conv2d.weight, | |
torch.full(model.component2.conv2d.weight.shape, 5.0)) | |
assert torch.equal(model.component2.conv2d.bias, | |
torch.full(model.component2.conv2d.bias.shape, 6.0)) | |
assert torch.equal(model.component3.linear.weight, | |
torch.full(model.component3.linear.weight.shape, 1.0)) | |
assert torch.equal(model.component3.linear.bias, | |
torch.full(model.component3.linear.bias.shape, 2.0)) | |
assert torch.equal( | |
model.component4.linear.linear.weight, | |
torch.full(model.component4.linear.linear.weight.shape, 1.0)) | |
assert torch.equal( | |
model.component4.linear.linear.bias, | |
torch.full(model.component4.linear.linear.bias.shape, 2.0)) | |
assert torch.equal( | |
model.component4.conv1d.conv1d.weight, | |
torch.full(model.component4.conv1d.conv1d.weight.shape, 3.0)) | |
assert torch.equal( | |
model.component4.conv1d.conv1d.bias, | |
torch.full(model.component4.conv1d.conv1d.bias.shape, 4.0)) | |
assert torch.equal(model.reg.weight, torch.full(model.reg.weight.shape, | |
1.0)) | |
assert torch.equal(model.reg.bias, torch.full(model.reg.bias.shape, 2.0)) | |
def test_nest_components_weight_init(): | |
""" | |
Config | |
model (FooModel, Linear: weight=1, bias=2, Conv1d: weight=3, bias=4, | |
Conv2d: weight=5, bias=6) | |
├──component1 (FooConv1d, Conv1d: weight=7, bias=8) | |
├──component2 (FooConv2d, Conv2d: weight=9, bias=10) | |
├──component3 (FooLinear) | |
├──component4 (FooLinearConv1d, Linear: weight=11, bias=12) | |
├──linear (FooLinear, Linear: weight=11, bias=12) | |
├──conv1d (FooConv1d) | |
├──reg (nn.Linear, weight=13, bias=14) | |
Parameters after initialization | |
model (FooModel) | |
├──component1 (FooConv1d, weight=7, bias=8) | |
├──component2 (FooConv2d, weight=9, bias=10) | |
├──component3 (FooLinear, weight=1, bias=2) | |
├──component4 (FooLinearConv1d) | |
├──linear (FooLinear, weight=1, bias=2) | |
├──conv1d (FooConv1d, weight=3, bias=4) | |
├──reg (nn.Linear, weight=13, bias=14) | |
""" | |
model_cfg = dict( | |
type='FooModel', | |
init_cfg=[ | |
dict( | |
type='Constant', | |
val=1, | |
bias=2, | |
layer='Linear', | |
override=dict(type='Constant', name='reg', val=13, bias=14)), | |
dict(type='Constant', val=3, bias=4, layer='Conv1d'), | |
dict(type='Constant', val=5, bias=6, layer='Conv2d'), | |
], | |
component1=dict( | |
type='FooConv1d', | |
init_cfg=dict(type='Constant', layer='Conv1d', val=7, bias=8)), | |
component2=dict( | |
type='FooConv2d', | |
init_cfg=dict(type='Constant', layer='Conv2d', val=9, bias=10)), | |
component3=dict(type='FooLinear'), | |
component4=dict( | |
type='FooLinearConv1d', | |
linear=dict(type='FooLinear'), | |
conv1d=dict(type='FooConv1d'))) | |
model = build_from_cfg(model_cfg, FOOMODELS) | |
model.init_weights() | |
assert torch.equal(model.component1.conv1d.weight, | |
torch.full(model.component1.conv1d.weight.shape, 7.0)) | |
assert torch.equal(model.component1.conv1d.bias, | |
torch.full(model.component1.conv1d.bias.shape, 8.0)) | |
assert torch.equal(model.component2.conv2d.weight, | |
torch.full(model.component2.conv2d.weight.shape, 9.0)) | |
assert torch.equal(model.component2.conv2d.bias, | |
torch.full(model.component2.conv2d.bias.shape, 10.0)) | |
assert torch.equal(model.component3.linear.weight, | |
torch.full(model.component3.linear.weight.shape, 1.0)) | |
assert torch.equal(model.component3.linear.bias, | |
torch.full(model.component3.linear.bias.shape, 2.0)) | |
assert torch.equal( | |
model.component4.linear.linear.weight, | |
torch.full(model.component4.linear.linear.weight.shape, 1.0)) | |
assert torch.equal( | |
model.component4.linear.linear.bias, | |
torch.full(model.component4.linear.linear.bias.shape, 2.0)) | |
assert torch.equal( | |
model.component4.conv1d.conv1d.weight, | |
torch.full(model.component4.conv1d.conv1d.weight.shape, 3.0)) | |
assert torch.equal( | |
model.component4.conv1d.conv1d.bias, | |
torch.full(model.component4.conv1d.conv1d.bias.shape, 4.0)) | |
assert torch.equal(model.reg.weight, | |
torch.full(model.reg.weight.shape, 13.0)) | |
assert torch.equal(model.reg.bias, torch.full(model.reg.bias.shape, 14.0)) | |
def test_without_layer_weight_init(): | |
model_cfg = dict( | |
type='FooModel', | |
init_cfg=[ | |
dict(type='Constant', val=1, bias=2, layer='Linear'), | |
dict(type='Constant', val=3, bias=4, layer='Conv1d'), | |
dict(type='Constant', val=5, bias=6, layer='Conv2d') | |
], | |
component1=dict( | |
type='FooConv1d', init_cfg=dict(type='Constant', val=7, bias=8)), | |
component2=dict(type='FooConv2d'), | |
component3=dict(type='FooLinear')) | |
model = build_from_cfg(model_cfg, FOOMODELS) | |
model.init_weights() | |
assert torch.equal(model.component1.conv1d.weight, | |
torch.full(model.component1.conv1d.weight.shape, 3.0)) | |
assert torch.equal(model.component1.conv1d.bias, | |
torch.full(model.component1.conv1d.bias.shape, 4.0)) | |
# init_cfg in component1 does not have layer key, so it does nothing | |
assert torch.equal(model.component2.conv2d.weight, | |
torch.full(model.component2.conv2d.weight.shape, 5.0)) | |
assert torch.equal(model.component2.conv2d.bias, | |
torch.full(model.component2.conv2d.bias.shape, 6.0)) | |
assert torch.equal(model.component3.linear.weight, | |
torch.full(model.component3.linear.weight.shape, 1.0)) | |
assert torch.equal(model.component3.linear.bias, | |
torch.full(model.component3.linear.bias.shape, 2.0)) | |
assert torch.equal(model.reg.weight, torch.full(model.reg.weight.shape, | |
1.0)) | |
assert torch.equal(model.reg.bias, torch.full(model.reg.bias.shape, 2.0)) | |
def test_override_weight_init(): | |
# only initialize 'override' | |
model_cfg = dict( | |
type='FooModel', | |
init_cfg=[ | |
dict(type='Constant', val=10, bias=20, override=dict(name='reg')) | |
], | |
component1=dict(type='FooConv1d'), | |
component3=dict(type='FooLinear')) | |
model = build_from_cfg(model_cfg, FOOMODELS) | |
model.init_weights() | |
assert torch.equal(model.reg.weight, | |
torch.full(model.reg.weight.shape, 10.0)) | |
assert torch.equal(model.reg.bias, torch.full(model.reg.bias.shape, 20.0)) | |
# do not initialize others | |
assert not torch.equal( | |
model.component1.conv1d.weight, | |
torch.full(model.component1.conv1d.weight.shape, 10.0)) | |
assert not torch.equal( | |
model.component1.conv1d.bias, | |
torch.full(model.component1.conv1d.bias.shape, 20.0)) | |
assert not torch.equal( | |
model.component3.linear.weight, | |
torch.full(model.component3.linear.weight.shape, 10.0)) | |
assert not torch.equal( | |
model.component3.linear.bias, | |
torch.full(model.component3.linear.bias.shape, 20.0)) | |
# 'override' has higher priority | |
model_cfg = dict( | |
type='FooModel', | |
init_cfg=[ | |
dict( | |
type='Constant', | |
val=1, | |
bias=2, | |
override=dict(name='reg', type='Constant', val=30, bias=40)) | |
], | |
component1=dict(type='FooConv1d'), | |
component2=dict(type='FooConv2d'), | |
component3=dict(type='FooLinear')) | |
model = build_from_cfg(model_cfg, FOOMODELS) | |
model.init_weights() | |
assert torch.equal(model.reg.weight, | |
torch.full(model.reg.weight.shape, 30.0)) | |
assert torch.equal(model.reg.bias, torch.full(model.reg.bias.shape, 40.0)) | |
def test_sequential_model_weight_init(): | |
seq_model_cfg = [ | |
dict( | |
type='FooConv1d', | |
init_cfg=dict(type='Constant', layer='Conv1d', val=0., bias=1.)), | |
dict( | |
type='FooConv2d', | |
init_cfg=dict(type='Constant', layer='Conv2d', val=2., bias=3.)), | |
] | |
layers = [build_from_cfg(cfg, COMPONENTS) for cfg in seq_model_cfg] | |
seq_model = Sequential(*layers) | |
seq_model.init_weights() | |
assert torch.equal(seq_model[0].conv1d.weight, | |
torch.full(seq_model[0].conv1d.weight.shape, 0.)) | |
assert torch.equal(seq_model[0].conv1d.bias, | |
torch.full(seq_model[0].conv1d.bias.shape, 1.)) | |
assert torch.equal(seq_model[1].conv2d.weight, | |
torch.full(seq_model[1].conv2d.weight.shape, 2.)) | |
assert torch.equal(seq_model[1].conv2d.bias, | |
torch.full(seq_model[1].conv2d.bias.shape, 3.)) | |
# inner init_cfg has higher priority | |
layers = [build_from_cfg(cfg, COMPONENTS) for cfg in seq_model_cfg] | |
seq_model = Sequential( | |
*layers, | |
init_cfg=dict( | |
type='Constant', layer=['Conv1d', 'Conv2d'], val=4., bias=5.)) | |
seq_model.init_weights() | |
assert torch.equal(seq_model[0].conv1d.weight, | |
torch.full(seq_model[0].conv1d.weight.shape, 0.)) | |
assert torch.equal(seq_model[0].conv1d.bias, | |
torch.full(seq_model[0].conv1d.bias.shape, 1.)) | |
assert torch.equal(seq_model[1].conv2d.weight, | |
torch.full(seq_model[1].conv2d.weight.shape, 2.)) | |
assert torch.equal(seq_model[1].conv2d.bias, | |
torch.full(seq_model[1].conv2d.bias.shape, 3.)) | |
def test_modulelist_weight_init(): | |
models_cfg = [ | |
dict( | |
type='FooConv1d', | |
init_cfg=dict(type='Constant', layer='Conv1d', val=0., bias=1.)), | |
dict( | |
type='FooConv2d', | |
init_cfg=dict(type='Constant', layer='Conv2d', val=2., bias=3.)), | |
] | |
layers = [build_from_cfg(cfg, COMPONENTS) for cfg in models_cfg] | |
modellist = ModuleList(layers) | |
modellist.init_weights() | |
assert torch.equal(modellist[0].conv1d.weight, | |
torch.full(modellist[0].conv1d.weight.shape, 0.)) | |
assert torch.equal(modellist[0].conv1d.bias, | |
torch.full(modellist[0].conv1d.bias.shape, 1.)) | |
assert torch.equal(modellist[1].conv2d.weight, | |
torch.full(modellist[1].conv2d.weight.shape, 2.)) | |
assert torch.equal(modellist[1].conv2d.bias, | |
torch.full(modellist[1].conv2d.bias.shape, 3.)) | |
# inner init_cfg has higher priority | |
layers = [build_from_cfg(cfg, COMPONENTS) for cfg in models_cfg] | |
modellist = ModuleList( | |
layers, | |
init_cfg=dict( | |
type='Constant', layer=['Conv1d', 'Conv2d'], val=4., bias=5.)) | |
modellist.init_weights() | |
assert torch.equal(modellist[0].conv1d.weight, | |
torch.full(modellist[0].conv1d.weight.shape, 0.)) | |
assert torch.equal(modellist[0].conv1d.bias, | |
torch.full(modellist[0].conv1d.bias.shape, 1.)) | |
assert torch.equal(modellist[1].conv2d.weight, | |
torch.full(modellist[1].conv2d.weight.shape, 2.)) | |
assert torch.equal(modellist[1].conv2d.bias, | |
torch.full(modellist[1].conv2d.bias.shape, 3.)) | |
def test_moduledict_weight_init(): | |
models_cfg = dict( | |
foo_conv_1d=dict( | |
type='FooConv1d', | |
init_cfg=dict(type='Constant', layer='Conv1d', val=0., bias=1.)), | |
foo_conv_2d=dict( | |
type='FooConv2d', | |
init_cfg=dict(type='Constant', layer='Conv2d', val=2., bias=3.)), | |
) | |
layers = { | |
name: build_from_cfg(cfg, COMPONENTS) | |
for name, cfg in models_cfg.items() | |
} | |
modeldict = ModuleDict(layers) | |
modeldict.init_weights() | |
assert torch.equal( | |
modeldict['foo_conv_1d'].conv1d.weight, | |
torch.full(modeldict['foo_conv_1d'].conv1d.weight.shape, 0.)) | |
assert torch.equal( | |
modeldict['foo_conv_1d'].conv1d.bias, | |
torch.full(modeldict['foo_conv_1d'].conv1d.bias.shape, 1.)) | |
assert torch.equal( | |
modeldict['foo_conv_2d'].conv2d.weight, | |
torch.full(modeldict['foo_conv_2d'].conv2d.weight.shape, 2.)) | |
assert torch.equal( | |
modeldict['foo_conv_2d'].conv2d.bias, | |
torch.full(modeldict['foo_conv_2d'].conv2d.bias.shape, 3.)) | |
# inner init_cfg has higher priority | |
layers = { | |
name: build_from_cfg(cfg, COMPONENTS) | |
for name, cfg in models_cfg.items() | |
} | |
modeldict = ModuleDict( | |
layers, | |
init_cfg=dict( | |
type='Constant', layer=['Conv1d', 'Conv2d'], val=4., bias=5.)) | |
modeldict.init_weights() | |
assert torch.equal( | |
modeldict['foo_conv_1d'].conv1d.weight, | |
torch.full(modeldict['foo_conv_1d'].conv1d.weight.shape, 0.)) | |
assert torch.equal( | |
modeldict['foo_conv_1d'].conv1d.bias, | |
torch.full(modeldict['foo_conv_1d'].conv1d.bias.shape, 1.)) | |
assert torch.equal( | |
modeldict['foo_conv_2d'].conv2d.weight, | |
torch.full(modeldict['foo_conv_2d'].conv2d.weight.shape, 2.)) | |
assert torch.equal( | |
modeldict['foo_conv_2d'].conv2d.bias, | |
torch.full(modeldict['foo_conv_2d'].conv2d.bias.shape, 3.)) | |