AiOS / mmcv /tests /test_cnn /test_weight_init.py
ttxskk
update
d7e58f0
raw
history blame
22.9 kB
# Copyright (c) OpenMMLab. All rights reserved.
import random
from tempfile import TemporaryDirectory
import numpy as np
import pytest
import torch
from scipy import stats
from torch import nn
from mmcv.cnn import (Caffe2XavierInit, ConstantInit, KaimingInit, NormalInit,
PretrainedInit, TruncNormalInit, UniformInit, XavierInit,
bias_init_with_prob, caffe2_xavier_init, constant_init,
initialize, kaiming_init, normal_init, trunc_normal_init,
uniform_init, xavier_init)
if torch.__version__ == 'parrots':
pytest.skip('not supported in parrots now', allow_module_level=True)
def test_constant_init():
conv_module = nn.Conv2d(3, 16, 3)
constant_init(conv_module, 0.1)
assert conv_module.weight.allclose(
torch.full_like(conv_module.weight, 0.1))
assert conv_module.bias.allclose(torch.zeros_like(conv_module.bias))
conv_module_no_bias = nn.Conv2d(3, 16, 3, bias=False)
constant_init(conv_module_no_bias, 0.1)
assert conv_module.weight.allclose(
torch.full_like(conv_module.weight, 0.1))
def test_xavier_init():
conv_module = nn.Conv2d(3, 16, 3)
xavier_init(conv_module, bias=0.1)
assert conv_module.bias.allclose(torch.full_like(conv_module.bias, 0.1))
xavier_init(conv_module, distribution='uniform')
# TODO: sanity check of weight distribution, e.g. mean, std
with pytest.raises(AssertionError):
xavier_init(conv_module, distribution='student-t')
conv_module_no_bias = nn.Conv2d(3, 16, 3, bias=False)
xavier_init(conv_module_no_bias)
def test_normal_init():
conv_module = nn.Conv2d(3, 16, 3)
normal_init(conv_module, bias=0.1)
# TODO: sanity check of weight distribution, e.g. mean, std
assert conv_module.bias.allclose(torch.full_like(conv_module.bias, 0.1))
conv_module_no_bias = nn.Conv2d(3, 16, 3, bias=False)
normal_init(conv_module_no_bias)
# TODO: sanity check distribution, e.g. mean, std
def test_trunc_normal_init():
def _random_float(a, b):
return (b - a) * random.random() + a
def _is_trunc_normal(tensor, mean, std, a, b):
# scipy's trunc norm is suited for data drawn from N(0, 1),
# so we need to transform our data to test it using scipy.
z_samples = (tensor.view(-1) - mean) / std
z_samples = z_samples.tolist()
a0 = (a - mean) / std
b0 = (b - mean) / std
p_value = stats.kstest(z_samples, 'truncnorm', args=(a0, b0))[1]
return p_value > 0.0001
conv_module = nn.Conv2d(3, 16, 3)
mean = _random_float(-3, 3)
std = _random_float(.01, 1)
a = _random_float(mean - 2 * std, mean)
b = _random_float(mean, mean + 2 * std)
trunc_normal_init(conv_module, mean, std, a, b, bias=0.1)
assert _is_trunc_normal(conv_module.weight, mean, std, a, b)
assert conv_module.bias.allclose(torch.full_like(conv_module.bias, 0.1))
conv_module_no_bias = nn.Conv2d(3, 16, 3, bias=False)
trunc_normal_init(conv_module_no_bias)
# TODO: sanity check distribution, e.g. mean, std
def test_uniform_init():
conv_module = nn.Conv2d(3, 16, 3)
uniform_init(conv_module, bias=0.1)
# TODO: sanity check of weight distribution, e.g. mean, std
assert conv_module.bias.allclose(torch.full_like(conv_module.bias, 0.1))
conv_module_no_bias = nn.Conv2d(3, 16, 3, bias=False)
uniform_init(conv_module_no_bias)
def test_kaiming_init():
conv_module = nn.Conv2d(3, 16, 3)
kaiming_init(conv_module, bias=0.1)
# TODO: sanity check of weight distribution, e.g. mean, std
assert conv_module.bias.allclose(torch.full_like(conv_module.bias, 0.1))
kaiming_init(conv_module, distribution='uniform')
with pytest.raises(AssertionError):
kaiming_init(conv_module, distribution='student-t')
conv_module_no_bias = nn.Conv2d(3, 16, 3, bias=False)
kaiming_init(conv_module_no_bias)
def test_caffe_xavier_init():
conv_module = nn.Conv2d(3, 16, 3)
caffe2_xavier_init(conv_module)
def test_bias_init_with_prob():
conv_module = nn.Conv2d(3, 16, 3)
prior_prob = 0.1
normal_init(conv_module, bias=bias_init_with_prob(0.1))
# TODO: sanity check of weight distribution, e.g. mean, std
bias = float(-np.log((1 - prior_prob) / prior_prob))
assert conv_module.bias.allclose(torch.full_like(conv_module.bias, bias))
def test_constaninit():
"""test ConstantInit class."""
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2))
func = ConstantInit(val=1, bias=2, layer='Conv2d')
func(model)
assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 1.))
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 2.))
assert not torch.equal(model[2].weight,
torch.full(model[2].weight.shape, 1.))
assert not torch.equal(model[2].bias, torch.full(model[2].bias.shape, 2.))
func = ConstantInit(val=3, bias_prob=0.01, layer='Linear')
func(model)
res = bias_init_with_prob(0.01)
assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 1.))
assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, 3.))
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 2.))
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, res))
# test layer key with base class name
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Conv1d(1, 2, 1))
func = ConstantInit(val=4., bias=5., layer='_ConvNd')
func(model)
assert torch.all(model[0].weight == 4.)
assert torch.all(model[2].weight == 4.)
assert torch.all(model[0].bias == 5.)
assert torch.all(model[2].bias == 5.)
# test bias input type
with pytest.raises(TypeError):
func = ConstantInit(val=1, bias='1')
# test bias_prob type
with pytest.raises(TypeError):
func = ConstantInit(val=1, bias_prob='1')
# test layer input type
with pytest.raises(TypeError):
func = ConstantInit(val=1, layer=1)
def test_xavierinit():
"""test XavierInit class."""
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2))
func = XavierInit(bias=0.1, layer='Conv2d')
func(model)
assert model[0].bias.allclose(torch.full_like(model[2].bias, 0.1))
assert not model[2].bias.allclose(torch.full_like(model[0].bias, 0.1))
constant_func = ConstantInit(val=0, bias=0, layer=['Conv2d', 'Linear'])
func = XavierInit(gain=100, bias_prob=0.01, layer=['Conv2d', 'Linear'])
model.apply(constant_func)
assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 0.))
assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, 0.))
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 0.))
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 0.))
res = bias_init_with_prob(0.01)
func(model)
assert not torch.equal(model[0].weight,
torch.full(model[0].weight.shape, 0.))
assert not torch.equal(model[2].weight,
torch.full(model[2].weight.shape, 0.))
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, res))
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, res))
# test layer key with base class name
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Conv1d(1, 2, 1))
func = ConstantInit(val=4., bias=5., layer='_ConvNd')
func(model)
assert torch.all(model[0].weight == 4.)
assert torch.all(model[2].weight == 4.)
assert torch.all(model[0].bias == 5.)
assert torch.all(model[2].bias == 5.)
func = XavierInit(gain=100, bias_prob=0.01, layer='_ConvNd')
func(model)
assert not torch.all(model[0].weight == 4.)
assert not torch.all(model[2].weight == 4.)
assert torch.all(model[0].bias == res)
assert torch.all(model[2].bias == res)
# test bias input type
with pytest.raises(TypeError):
func = XavierInit(bias='0.1', layer='Conv2d')
# test layer inpur type
with pytest.raises(TypeError):
func = XavierInit(bias=0.1, layer=1)
def test_normalinit():
"""test Normalinit class."""
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2))
func = NormalInit(mean=100, std=1e-5, bias=200, layer=['Conv2d', 'Linear'])
func(model)
assert model[0].weight.allclose(torch.tensor(100.))
assert model[2].weight.allclose(torch.tensor(100.))
assert model[0].bias.allclose(torch.tensor(200.))
assert model[2].bias.allclose(torch.tensor(200.))
func = NormalInit(
mean=300, std=1e-5, bias_prob=0.01, layer=['Conv2d', 'Linear'])
res = bias_init_with_prob(0.01)
func(model)
assert model[0].weight.allclose(torch.tensor(300.))
assert model[2].weight.allclose(torch.tensor(300.))
assert model[0].bias.allclose(torch.tensor(res))
assert model[2].bias.allclose(torch.tensor(res))
# test layer key with base class name
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Conv1d(1, 2, 1))
func = NormalInit(mean=300, std=1e-5, bias_prob=0.01, layer='_ConvNd')
func(model)
assert model[0].weight.allclose(torch.tensor(300.))
assert model[2].weight.allclose(torch.tensor(300.))
assert torch.all(model[0].bias == res)
assert torch.all(model[2].bias == res)
def test_truncnormalinit():
"""test TruncNormalInit class."""
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2))
func = TruncNormalInit(
mean=100, std=1e-5, bias=200, a=0, b=200, layer=['Conv2d', 'Linear'])
func(model)
assert model[0].weight.allclose(torch.tensor(100.))
assert model[2].weight.allclose(torch.tensor(100.))
assert model[0].bias.allclose(torch.tensor(200.))
assert model[2].bias.allclose(torch.tensor(200.))
func = TruncNormalInit(
mean=300,
std=1e-5,
a=100,
b=400,
bias_prob=0.01,
layer=['Conv2d', 'Linear'])
res = bias_init_with_prob(0.01)
func(model)
assert model[0].weight.allclose(torch.tensor(300.))
assert model[2].weight.allclose(torch.tensor(300.))
assert model[0].bias.allclose(torch.tensor(res))
assert model[2].bias.allclose(torch.tensor(res))
# test layer key with base class name
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Conv1d(1, 2, 1))
func = TruncNormalInit(
mean=300, std=1e-5, a=100, b=400, bias_prob=0.01, layer='_ConvNd')
func(model)
assert model[0].weight.allclose(torch.tensor(300.))
assert model[2].weight.allclose(torch.tensor(300.))
assert torch.all(model[0].bias == res)
assert torch.all(model[2].bias == res)
def test_uniforminit():
""""test UniformInit class."""
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2))
func = UniformInit(a=1, b=1, bias=2, layer=['Conv2d', 'Linear'])
func(model)
assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 1.))
assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, 1.))
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 2.))
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 2.))
func = UniformInit(a=100, b=100, layer=['Conv2d', 'Linear'], bias=10)
func(model)
assert torch.equal(model[0].weight, torch.full(model[0].weight.shape,
100.))
assert torch.equal(model[2].weight, torch.full(model[2].weight.shape,
100.))
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 10.))
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 10.))
# test layer key with base class name
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Conv1d(1, 2, 1))
func = UniformInit(a=100, b=100, bias_prob=0.01, layer='_ConvNd')
res = bias_init_with_prob(0.01)
func(model)
assert torch.all(model[0].weight == 100.)
assert torch.all(model[2].weight == 100.)
assert torch.all(model[0].bias == res)
assert torch.all(model[2].bias == res)
def test_kaiminginit():
"""test KaimingInit class."""
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2))
func = KaimingInit(bias=0.1, layer='Conv2d')
func(model)
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 0.1))
assert not torch.equal(model[2].bias, torch.full(model[2].bias.shape, 0.1))
func = KaimingInit(a=100, bias=10, layer=['Conv2d', 'Linear'])
constant_func = ConstantInit(val=0, bias=0, layer=['Conv2d', 'Linear'])
model.apply(constant_func)
assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 0.))
assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, 0.))
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 0.))
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 0.))
func(model)
assert not torch.equal(model[0].weight,
torch.full(model[0].weight.shape, 0.))
assert not torch.equal(model[2].weight,
torch.full(model[2].weight.shape, 0.))
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 10.))
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 10.))
# test layer key with base class name
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Conv1d(1, 2, 1))
func = KaimingInit(bias=0.1, layer='_ConvNd')
func(model)
assert torch.all(model[0].bias == 0.1)
assert torch.all(model[2].bias == 0.1)
func = KaimingInit(a=100, bias=10, layer='_ConvNd')
constant_func = ConstantInit(val=0, bias=0, layer='_ConvNd')
model.apply(constant_func)
assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 0.))
assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, 0.))
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 0.))
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 0.))
func(model)
assert not torch.equal(model[0].weight,
torch.full(model[0].weight.shape, 0.))
assert not torch.equal(model[2].weight,
torch.full(model[2].weight.shape, 0.))
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 10.))
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 10.))
def test_caffe2xavierinit():
"""test Caffe2XavierInit."""
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2))
func = Caffe2XavierInit(bias=0.1, layer='Conv2d')
func(model)
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 0.1))
assert not torch.equal(model[2].bias, torch.full(model[2].bias.shape, 0.1))
class FooModule(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(1, 2)
self.conv2d = nn.Conv2d(3, 1, 3)
self.conv2d_2 = nn.Conv2d(3, 2, 3)
def test_pretrainedinit():
"""test PretrainedInit class."""
modelA = FooModule()
constant_func = ConstantInit(val=1, bias=2, layer=['Conv2d', 'Linear'])
modelA.apply(constant_func)
modelB = FooModule()
funcB = PretrainedInit(checkpoint='modelA.pth')
modelC = nn.Linear(1, 2)
funcC = PretrainedInit(checkpoint='modelA.pth', prefix='linear.')
with TemporaryDirectory():
torch.save(modelA.state_dict(), 'modelA.pth')
funcB(modelB)
assert torch.equal(modelB.linear.weight,
torch.full(modelB.linear.weight.shape, 1.))
assert torch.equal(modelB.linear.bias,
torch.full(modelB.linear.bias.shape, 2.))
assert torch.equal(modelB.conv2d.weight,
torch.full(modelB.conv2d.weight.shape, 1.))
assert torch.equal(modelB.conv2d.bias,
torch.full(modelB.conv2d.bias.shape, 2.))
assert torch.equal(modelB.conv2d_2.weight,
torch.full(modelB.conv2d_2.weight.shape, 1.))
assert torch.equal(modelB.conv2d_2.bias,
torch.full(modelB.conv2d_2.bias.shape, 2.))
funcC(modelC)
assert torch.equal(modelC.weight, torch.full(modelC.weight.shape, 1.))
assert torch.equal(modelC.bias, torch.full(modelC.bias.shape, 2.))
def test_initialize():
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2))
foonet = FooModule()
# test layer key
init_cfg = dict(type='Constant', layer=['Conv2d', 'Linear'], val=1, bias=2)
initialize(model, init_cfg)
assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 1.))
assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, 1.))
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 2.))
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 2.))
assert init_cfg == dict(
type='Constant', layer=['Conv2d', 'Linear'], val=1, bias=2)
# test init_cfg with list type
init_cfg = [
dict(type='Constant', layer='Conv2d', val=1, bias=2),
dict(type='Constant', layer='Linear', val=3, bias=4)
]
initialize(model, init_cfg)
assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 1.))
assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, 3.))
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 2.))
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 4.))
assert init_cfg == [
dict(type='Constant', layer='Conv2d', val=1, bias=2),
dict(type='Constant', layer='Linear', val=3, bias=4)
]
# test layer key and override key
init_cfg = dict(
type='Constant',
val=1,
bias=2,
layer=['Conv2d', 'Linear'],
override=dict(type='Constant', name='conv2d_2', val=3, bias=4))
initialize(foonet, init_cfg)
assert torch.equal(foonet.linear.weight,
torch.full(foonet.linear.weight.shape, 1.))
assert torch.equal(foonet.linear.bias,
torch.full(foonet.linear.bias.shape, 2.))
assert torch.equal(foonet.conv2d.weight,
torch.full(foonet.conv2d.weight.shape, 1.))
assert torch.equal(foonet.conv2d.bias,
torch.full(foonet.conv2d.bias.shape, 2.))
assert torch.equal(foonet.conv2d_2.weight,
torch.full(foonet.conv2d_2.weight.shape, 3.))
assert torch.equal(foonet.conv2d_2.bias,
torch.full(foonet.conv2d_2.bias.shape, 4.))
assert init_cfg == dict(
type='Constant',
val=1,
bias=2,
layer=['Conv2d', 'Linear'],
override=dict(type='Constant', name='conv2d_2', val=3, bias=4))
# test override key
init_cfg = dict(
type='Constant', val=5, bias=6, override=dict(name='conv2d_2'))
initialize(foonet, init_cfg)
assert not torch.equal(foonet.linear.weight,
torch.full(foonet.linear.weight.shape, 5.))
assert not torch.equal(foonet.linear.bias,
torch.full(foonet.linear.bias.shape, 6.))
assert not torch.equal(foonet.conv2d.weight,
torch.full(foonet.conv2d.weight.shape, 5.))
assert not torch.equal(foonet.conv2d.bias,
torch.full(foonet.conv2d.bias.shape, 6.))
assert torch.equal(foonet.conv2d_2.weight,
torch.full(foonet.conv2d_2.weight.shape, 5.))
assert torch.equal(foonet.conv2d_2.bias,
torch.full(foonet.conv2d_2.bias.shape, 6.))
assert init_cfg == dict(
type='Constant', val=5, bias=6, override=dict(name='conv2d_2'))
init_cfg = dict(
type='Pretrained',
checkpoint='modelA.pth',
override=dict(type='Constant', name='conv2d_2', val=3, bias=4))
modelA = FooModule()
constant_func = ConstantInit(val=1, bias=2, layer=['Conv2d', 'Linear'])
modelA.apply(constant_func)
with TemporaryDirectory():
torch.save(modelA.state_dict(), 'modelA.pth')
initialize(foonet, init_cfg)
assert torch.equal(foonet.linear.weight,
torch.full(foonet.linear.weight.shape, 1.))
assert torch.equal(foonet.linear.bias,
torch.full(foonet.linear.bias.shape, 2.))
assert torch.equal(foonet.conv2d.weight,
torch.full(foonet.conv2d.weight.shape, 1.))
assert torch.equal(foonet.conv2d.bias,
torch.full(foonet.conv2d.bias.shape, 2.))
assert torch.equal(foonet.conv2d_2.weight,
torch.full(foonet.conv2d_2.weight.shape, 3.))
assert torch.equal(foonet.conv2d_2.bias,
torch.full(foonet.conv2d_2.bias.shape, 4.))
assert init_cfg == dict(
type='Pretrained',
checkpoint='modelA.pth',
override=dict(type='Constant', name='conv2d_2', val=3, bias=4))
# test init_cfg type
with pytest.raises(TypeError):
init_cfg = 'init_cfg'
initialize(foonet, init_cfg)
# test override value type
with pytest.raises(TypeError):
init_cfg = dict(
type='Constant',
val=1,
bias=2,
layer=['Conv2d', 'Linear'],
override='conv')
initialize(foonet, init_cfg)
# test override name
with pytest.raises(RuntimeError):
init_cfg = dict(
type='Constant',
val=1,
bias=2,
layer=['Conv2d', 'Linear'],
override=dict(type='Constant', name='conv2d_3', val=3, bias=4))
initialize(foonet, init_cfg)
# test list override name
with pytest.raises(RuntimeError):
init_cfg = dict(
type='Constant',
val=1,
bias=2,
layer=['Conv2d', 'Linear'],
override=[
dict(type='Constant', name='conv2d', val=3, bias=4),
dict(type='Constant', name='conv2d_3', val=5, bias=6)
])
initialize(foonet, init_cfg)
# test override with args except type key
with pytest.raises(ValueError):
init_cfg = dict(
type='Constant',
val=1,
bias=2,
override=dict(name='conv2d_2', val=3, bias=4))
initialize(foonet, init_cfg)
# test override without name
with pytest.raises(ValueError):
init_cfg = dict(
type='Constant',
val=1,
bias=2,
override=dict(type='Constant', val=3, bias=4))
initialize(foonet, init_cfg)