Spaces:
Sleeping
Sleeping
# Copyright (c) OpenMMLab. All rights reserved. | |
import sys | |
from collections import OrderedDict | |
from tempfile import TemporaryDirectory | |
from unittest.mock import MagicMock, patch | |
import pytest | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
from torch.nn.parallel import DataParallel | |
from mmcv.fileio.file_client import PetrelBackend | |
from mmcv.parallel.registry import MODULE_WRAPPERS | |
from mmcv.runner.checkpoint import (_load_checkpoint_with_prefix, | |
get_state_dict, load_checkpoint, | |
load_from_local, load_from_pavi, | |
save_checkpoint) | |
sys.modules['petrel_client'] = MagicMock() | |
sys.modules['petrel_client.client'] = MagicMock() | |
class DDPWrapper: | |
def __init__(self, module): | |
self.module = module | |
class Block(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.conv = nn.Conv2d(3, 3, 1) | |
self.norm = nn.BatchNorm2d(3) | |
class Model(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.block = Block() | |
self.conv = nn.Conv2d(3, 3, 1) | |
class Mockpavimodel: | |
def __init__(self, name='fakename'): | |
self.name = name | |
def download(self, file): | |
pass | |
def assert_tensor_equal(tensor_a, tensor_b): | |
assert tensor_a.eq(tensor_b).all() | |
def test_get_state_dict(): | |
if torch.__version__ == 'parrots': | |
state_dict_keys = { | |
'block.conv.weight', 'block.conv.bias', 'block.norm.weight', | |
'block.norm.bias', 'block.norm.running_mean', | |
'block.norm.running_var', 'conv.weight', 'conv.bias' | |
} | |
else: | |
state_dict_keys = { | |
'block.conv.weight', 'block.conv.bias', 'block.norm.weight', | |
'block.norm.bias', 'block.norm.running_mean', | |
'block.norm.running_var', 'block.norm.num_batches_tracked', | |
'conv.weight', 'conv.bias' | |
} | |
model = Model() | |
state_dict = get_state_dict(model) | |
assert isinstance(state_dict, OrderedDict) | |
assert set(state_dict.keys()) == state_dict_keys | |
assert_tensor_equal(state_dict['block.conv.weight'], | |
model.block.conv.weight) | |
assert_tensor_equal(state_dict['block.conv.bias'], model.block.conv.bias) | |
assert_tensor_equal(state_dict['block.norm.weight'], | |
model.block.norm.weight) | |
assert_tensor_equal(state_dict['block.norm.bias'], model.block.norm.bias) | |
assert_tensor_equal(state_dict['block.norm.running_mean'], | |
model.block.norm.running_mean) | |
assert_tensor_equal(state_dict['block.norm.running_var'], | |
model.block.norm.running_var) | |
if torch.__version__ != 'parrots': | |
assert_tensor_equal(state_dict['block.norm.num_batches_tracked'], | |
model.block.norm.num_batches_tracked) | |
assert_tensor_equal(state_dict['conv.weight'], model.conv.weight) | |
assert_tensor_equal(state_dict['conv.bias'], model.conv.bias) | |
wrapped_model = DDPWrapper(model) | |
state_dict = get_state_dict(wrapped_model) | |
assert isinstance(state_dict, OrderedDict) | |
assert set(state_dict.keys()) == state_dict_keys | |
assert_tensor_equal(state_dict['block.conv.weight'], | |
wrapped_model.module.block.conv.weight) | |
assert_tensor_equal(state_dict['block.conv.bias'], | |
wrapped_model.module.block.conv.bias) | |
assert_tensor_equal(state_dict['block.norm.weight'], | |
wrapped_model.module.block.norm.weight) | |
assert_tensor_equal(state_dict['block.norm.bias'], | |
wrapped_model.module.block.norm.bias) | |
assert_tensor_equal(state_dict['block.norm.running_mean'], | |
wrapped_model.module.block.norm.running_mean) | |
assert_tensor_equal(state_dict['block.norm.running_var'], | |
wrapped_model.module.block.norm.running_var) | |
if torch.__version__ != 'parrots': | |
assert_tensor_equal( | |
state_dict['block.norm.num_batches_tracked'], | |
wrapped_model.module.block.norm.num_batches_tracked) | |
assert_tensor_equal(state_dict['conv.weight'], | |
wrapped_model.module.conv.weight) | |
assert_tensor_equal(state_dict['conv.bias'], | |
wrapped_model.module.conv.bias) | |
# wrapped inner module | |
for name, module in wrapped_model.module._modules.items(): | |
module = DataParallel(module) | |
wrapped_model.module._modules[name] = module | |
state_dict = get_state_dict(wrapped_model) | |
assert isinstance(state_dict, OrderedDict) | |
assert set(state_dict.keys()) == state_dict_keys | |
assert_tensor_equal(state_dict['block.conv.weight'], | |
wrapped_model.module.block.module.conv.weight) | |
assert_tensor_equal(state_dict['block.conv.bias'], | |
wrapped_model.module.block.module.conv.bias) | |
assert_tensor_equal(state_dict['block.norm.weight'], | |
wrapped_model.module.block.module.norm.weight) | |
assert_tensor_equal(state_dict['block.norm.bias'], | |
wrapped_model.module.block.module.norm.bias) | |
assert_tensor_equal(state_dict['block.norm.running_mean'], | |
wrapped_model.module.block.module.norm.running_mean) | |
assert_tensor_equal(state_dict['block.norm.running_var'], | |
wrapped_model.module.block.module.norm.running_var) | |
if torch.__version__ != 'parrots': | |
assert_tensor_equal( | |
state_dict['block.norm.num_batches_tracked'], | |
wrapped_model.module.block.module.norm.num_batches_tracked) | |
assert_tensor_equal(state_dict['conv.weight'], | |
wrapped_model.module.conv.module.weight) | |
assert_tensor_equal(state_dict['conv.bias'], | |
wrapped_model.module.conv.module.bias) | |
def test_load_pavimodel_dist(): | |
sys.modules['pavi'] = MagicMock() | |
sys.modules['pavi.modelcloud'] = MagicMock() | |
pavimodel = Mockpavimodel() | |
import pavi | |
pavi.modelcloud.get = MagicMock(return_value=pavimodel) | |
with pytest.raises(AssertionError): | |
# test pavi prefix | |
_ = load_from_pavi('MyPaviFolder/checkpoint.pth') | |
with pytest.raises(FileNotFoundError): | |
# there is not such checkpoint for us to load | |
_ = load_from_pavi('pavi://checkpoint.pth') | |
def test_load_checkpoint_with_prefix(): | |
class FooModule(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.linear = nn.Linear(1, 2) | |
self.conv2d = nn.Conv2d(3, 1, 3) | |
self.conv2d_2 = nn.Conv2d(3, 2, 3) | |
model = FooModule() | |
nn.init.constant_(model.linear.weight, 1) | |
nn.init.constant_(model.linear.bias, 2) | |
nn.init.constant_(model.conv2d.weight, 3) | |
nn.init.constant_(model.conv2d.bias, 4) | |
nn.init.constant_(model.conv2d_2.weight, 5) | |
nn.init.constant_(model.conv2d_2.bias, 6) | |
with TemporaryDirectory(): | |
torch.save(model.state_dict(), 'model.pth') | |
prefix = 'conv2d' | |
state_dict = _load_checkpoint_with_prefix(prefix, 'model.pth') | |
assert torch.equal(model.conv2d.state_dict()['weight'], | |
state_dict['weight']) | |
assert torch.equal(model.conv2d.state_dict()['bias'], | |
state_dict['bias']) | |
# test whether prefix is in pretrained model | |
with pytest.raises(AssertionError): | |
prefix = 'back' | |
_load_checkpoint_with_prefix(prefix, 'model.pth') | |
def test_load_checkpoint(): | |
import os | |
import re | |
import tempfile | |
class PrefixModel(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.backbone = Model() | |
pmodel = PrefixModel() | |
model = Model() | |
checkpoint_path = os.path.join(tempfile.gettempdir(), 'checkpoint.pth') | |
# add prefix | |
torch.save(model.state_dict(), checkpoint_path) | |
state_dict = load_checkpoint( | |
pmodel, checkpoint_path, revise_keys=[(r'^', 'backbone.')]) | |
for key in pmodel.backbone.state_dict().keys(): | |
assert torch.equal(pmodel.backbone.state_dict()[key], state_dict[key]) | |
# strip prefix | |
torch.save(pmodel.state_dict(), checkpoint_path) | |
state_dict = load_checkpoint( | |
model, checkpoint_path, revise_keys=[(r'^backbone\.', '')]) | |
for key in state_dict.keys(): | |
key_stripped = re.sub(r'^backbone\.', '', key) | |
assert torch.equal(model.state_dict()[key_stripped], state_dict[key]) | |
os.remove(checkpoint_path) | |
def test_load_checkpoint_metadata(): | |
import os | |
import tempfile | |
from mmcv.runner import load_checkpoint, save_checkpoint | |
class ModelV1(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.block = Block() | |
self.conv1 = nn.Conv2d(3, 3, 1) | |
self.conv2 = nn.Conv2d(3, 3, 1) | |
nn.init.normal_(self.conv1.weight) | |
nn.init.normal_(self.conv2.weight) | |
class ModelV2(nn.Module): | |
_version = 2 | |
def __init__(self): | |
super().__init__() | |
self.block = Block() | |
self.conv0 = nn.Conv2d(3, 3, 1) | |
self.conv1 = nn.Conv2d(3, 3, 1) | |
nn.init.normal_(self.conv0.weight) | |
nn.init.normal_(self.conv1.weight) | |
def _load_from_state_dict(self, state_dict, prefix, local_metadata, | |
*args, **kwargs): | |
"""load checkpoints.""" | |
# Names of some parameters in has been changed. | |
version = local_metadata.get('version', None) | |
if version is None or version < 2: | |
state_dict_keys = list(state_dict.keys()) | |
convert_map = {'conv1': 'conv0', 'conv2': 'conv1'} | |
for k in state_dict_keys: | |
for ori_str, new_str in convert_map.items(): | |
if k.startswith(prefix + ori_str): | |
new_key = k.replace(ori_str, new_str) | |
state_dict[new_key] = state_dict[k] | |
del state_dict[k] | |
super()._load_from_state_dict(state_dict, prefix, local_metadata, | |
*args, **kwargs) | |
model_v1 = ModelV1() | |
model_v1_conv0_weight = model_v1.conv1.weight.detach() | |
model_v1_conv1_weight = model_v1.conv2.weight.detach() | |
model_v2 = ModelV2() | |
model_v2_conv0_weight = model_v2.conv0.weight.detach() | |
model_v2_conv1_weight = model_v2.conv1.weight.detach() | |
ckpt_v1_path = os.path.join(tempfile.gettempdir(), 'checkpoint_v1.pth') | |
ckpt_v2_path = os.path.join(tempfile.gettempdir(), 'checkpoint_v2.pth') | |
# Save checkpoint | |
save_checkpoint(model_v1, ckpt_v1_path) | |
save_checkpoint(model_v2, ckpt_v2_path) | |
# test load v1 model | |
load_checkpoint(model_v2, ckpt_v1_path) | |
assert torch.allclose(model_v2.conv0.weight, model_v1_conv0_weight) | |
assert torch.allclose(model_v2.conv1.weight, model_v1_conv1_weight) | |
# test load v2 model | |
load_checkpoint(model_v2, ckpt_v2_path) | |
assert torch.allclose(model_v2.conv0.weight, model_v2_conv0_weight) | |
assert torch.allclose(model_v2.conv1.weight, model_v2_conv1_weight) | |
def test_load_classes_name(): | |
import os | |
import tempfile | |
from mmcv.runner import load_checkpoint, save_checkpoint | |
checkpoint_path = os.path.join(tempfile.gettempdir(), 'checkpoint.pth') | |
model = Model() | |
save_checkpoint(model, checkpoint_path) | |
checkpoint = load_checkpoint(model, checkpoint_path) | |
assert 'meta' in checkpoint and 'CLASSES' not in checkpoint['meta'] | |
model.CLASSES = ('class1', 'class2') | |
save_checkpoint(model, checkpoint_path) | |
checkpoint = load_checkpoint(model, checkpoint_path) | |
assert 'meta' in checkpoint and 'CLASSES' in checkpoint['meta'] | |
assert checkpoint['meta']['CLASSES'] == ('class1', 'class2') | |
model = Model() | |
wrapped_model = DDPWrapper(model) | |
save_checkpoint(wrapped_model, checkpoint_path) | |
checkpoint = load_checkpoint(wrapped_model, checkpoint_path) | |
assert 'meta' in checkpoint and 'CLASSES' not in checkpoint['meta'] | |
wrapped_model.module.CLASSES = ('class1', 'class2') | |
save_checkpoint(wrapped_model, checkpoint_path) | |
checkpoint = load_checkpoint(wrapped_model, checkpoint_path) | |
assert 'meta' in checkpoint and 'CLASSES' in checkpoint['meta'] | |
assert checkpoint['meta']['CLASSES'] == ('class1', 'class2') | |
# remove the temp file | |
os.remove(checkpoint_path) | |
def test_checkpoint_loader(): | |
import os | |
import tempfile | |
from mmcv.runner import CheckpointLoader, _load_checkpoint, save_checkpoint | |
checkpoint_path = os.path.join(tempfile.gettempdir(), 'checkpoint.pth') | |
model = Model() | |
save_checkpoint(model, checkpoint_path) | |
checkpoint = _load_checkpoint(checkpoint_path) | |
assert 'meta' in checkpoint and 'CLASSES' not in checkpoint['meta'] | |
# remove the temp file | |
os.remove(checkpoint_path) | |
filenames = [ | |
'http://xx.xx/xx.pth', 'https://xx.xx/xx.pth', | |
'modelzoo://xx.xx/xx.pth', 'torchvision://xx.xx/xx.pth', | |
'open-mmlab://xx.xx/xx.pth', 'openmmlab://xx.xx/xx.pth', | |
'mmcls://xx.xx/xx.pth', 'pavi://xx.xx/xx.pth', 's3://xx.xx/xx.pth', | |
'ss3://xx.xx/xx.pth', ' s3://xx.xx/xx.pth', | |
'open-mmlab:s3://xx.xx/xx.pth', 'openmmlab:s3://xx.xx/xx.pth', | |
'openmmlabs3://xx.xx/xx.pth', ':s3://xx.xx/xx.path' | |
] | |
fn_names = [ | |
'load_from_http', 'load_from_http', 'load_from_torchvision', | |
'load_from_torchvision', 'load_from_openmmlab', 'load_from_openmmlab', | |
'load_from_mmcls', 'load_from_pavi', 'load_from_ceph', | |
'load_from_local', 'load_from_local', 'load_from_ceph', | |
'load_from_ceph', 'load_from_local', 'load_from_local' | |
] | |
for filename, fn_name in zip(filenames, fn_names): | |
loader = CheckpointLoader._get_checkpoint_loader(filename) | |
assert loader.__name__ == fn_name | |
def load_from_ftp(filename, map_location): | |
return dict(filename=filename) | |
# test register_loader | |
filename = 'ftp://xx.xx/xx.pth' | |
loader = CheckpointLoader._get_checkpoint_loader(filename) | |
assert loader.__name__ == 'load_from_ftp' | |
def load_from_ftp1(filename, map_location): | |
return dict(filename=filename) | |
# test duplicate registered error | |
with pytest.raises(KeyError): | |
CheckpointLoader.register_scheme('ftp://', load_from_ftp1) | |
# test force param | |
CheckpointLoader.register_scheme('ftp://', load_from_ftp1, force=True) | |
checkpoint = CheckpointLoader.load_checkpoint(filename) | |
assert checkpoint['filename'] == filename | |
# test print function name | |
loader = CheckpointLoader._get_checkpoint_loader(filename) | |
assert loader.__name__ == 'load_from_ftp1' | |
# test sort | |
def load_from_ab(filename, map_location): | |
return dict(filename=filename) | |
def load_from_abc(filename, map_location): | |
return dict(filename=filename) | |
filename = 'a/b/c/d' | |
loader = CheckpointLoader._get_checkpoint_loader(filename) | |
assert loader.__name__ == 'load_from_abc' | |
def test_save_checkpoint(tmp_path): | |
model = Model() | |
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9) | |
# meta is not a dict | |
with pytest.raises(TypeError): | |
save_checkpoint(model, '/path/of/your/filename', meta='invalid type') | |
# 1. save to disk | |
filename = str(tmp_path / 'checkpoint1.pth') | |
save_checkpoint(model, filename) | |
filename = str(tmp_path / 'checkpoint2.pth') | |
save_checkpoint(model, filename, optimizer) | |
filename = str(tmp_path / 'checkpoint3.pth') | |
save_checkpoint(model, filename, meta={'test': 'test'}) | |
filename = str(tmp_path / 'checkpoint4.pth') | |
save_checkpoint(model, filename, file_client_args={'backend': 'disk'}) | |
# 2. save to petrel oss | |
with patch.object(PetrelBackend, 'put') as mock_method: | |
filename = 's3://path/of/your/checkpoint1.pth' | |
save_checkpoint(model, filename) | |
mock_method.assert_called() | |
with patch.object(PetrelBackend, 'put') as mock_method: | |
filename = 's3://path//of/your/checkpoint2.pth' | |
save_checkpoint( | |
model, filename, file_client_args={'backend': 'petrel'}) | |
mock_method.assert_called() | |
def test_load_from_local(): | |
import os | |
home_path = os.path.expanduser('~') | |
checkpoint_path = os.path.join( | |
home_path, 'dummy_checkpoint_used_to_test_load_from_local.pth') | |
model = Model() | |
save_checkpoint(model, checkpoint_path) | |
checkpoint = load_from_local( | |
'~/dummy_checkpoint_used_to_test_load_from_local.pth', | |
map_location=None) | |
assert_tensor_equal(checkpoint['state_dict']['block.conv.weight'], | |
model.block.conv.weight) | |
os.remove(checkpoint_path) | |