Spaces:
Sleeping
Sleeping
# Copyright (c) OpenMMLab. All rights reserved. | |
import numpy as np | |
import pytest | |
import mmcv | |
try: | |
import torch | |
except ImportError: | |
torch = None | |
else: | |
import torch.nn as nn | |
def test_assert_dict_contains_subset(): | |
dict_obj = {'a': 'test1', 'b': 2, 'c': (4, 6)} | |
# case 1 | |
expected_subset = {'a': 'test1', 'b': 2, 'c': (4, 6)} | |
assert mmcv.assert_dict_contains_subset(dict_obj, expected_subset) | |
# case 2 | |
expected_subset = {'a': 'test1', 'b': 2, 'c': (6, 4)} | |
assert not mmcv.assert_dict_contains_subset(dict_obj, expected_subset) | |
# case 3 | |
expected_subset = {'a': 'test1', 'b': 2, 'c': None} | |
assert not mmcv.assert_dict_contains_subset(dict_obj, expected_subset) | |
# case 4 | |
expected_subset = {'a': 'test1', 'b': 2, 'd': (4, 6)} | |
assert not mmcv.assert_dict_contains_subset(dict_obj, expected_subset) | |
# case 5 | |
dict_obj = { | |
'a': 'test1', | |
'b': 2, | |
'c': (4, 6), | |
'd': np.array([[5, 3, 5], [1, 2, 3]]) | |
} | |
expected_subset = { | |
'a': 'test1', | |
'b': 2, | |
'c': (4, 6), | |
'd': np.array([[5, 3, 5], [6, 2, 3]]) | |
} | |
assert not mmcv.assert_dict_contains_subset(dict_obj, expected_subset) | |
# case 6 | |
dict_obj = {'a': 'test1', 'b': 2, 'c': (4, 6), 'd': np.array([[1]])} | |
expected_subset = {'a': 'test1', 'b': 2, 'c': (4, 6), 'd': np.array([[1]])} | |
assert mmcv.assert_dict_contains_subset(dict_obj, expected_subset) | |
if torch is not None: | |
dict_obj = { | |
'a': 'test1', | |
'b': 2, | |
'c': (4, 6), | |
'd': torch.tensor([5, 3, 5]) | |
} | |
# case 7 | |
expected_subset = {'d': torch.tensor([5, 5, 5])} | |
assert not mmcv.assert_dict_contains_subset(dict_obj, expected_subset) | |
# case 8 | |
expected_subset = {'d': torch.tensor([[5, 3, 5], [4, 1, 2]])} | |
assert not mmcv.assert_dict_contains_subset(dict_obj, expected_subset) | |
def test_assert_attrs_equal(): | |
class TestExample: | |
a, b, c = 1, ('wvi', 3), [4.5, 3.14] | |
def test_func(self): | |
return self.b | |
# case 1 | |
assert mmcv.assert_attrs_equal(TestExample, { | |
'a': 1, | |
'b': ('wvi', 3), | |
'c': [4.5, 3.14] | |
}) | |
# case 2 | |
assert not mmcv.assert_attrs_equal(TestExample, { | |
'a': 1, | |
'b': ('wvi', 3), | |
'c': [4.5, 3.14, 2] | |
}) | |
# case 3 | |
assert not mmcv.assert_attrs_equal(TestExample, { | |
'bc': 54, | |
'c': [4.5, 3.14] | |
}) | |
# case 4 | |
assert mmcv.assert_attrs_equal(TestExample, { | |
'b': ('wvi', 3), | |
'test_func': TestExample.test_func | |
}) | |
if torch is not None: | |
class TestExample: | |
a, b = torch.tensor([1]), torch.tensor([4, 5]) | |
# case 5 | |
assert mmcv.assert_attrs_equal(TestExample, { | |
'a': torch.tensor([1]), | |
'b': torch.tensor([4, 5]) | |
}) | |
# case 6 | |
assert not mmcv.assert_attrs_equal(TestExample, { | |
'a': torch.tensor([1]), | |
'b': torch.tensor([4, 6]) | |
}) | |
assert_dict_has_keys_data_1 = [({ | |
'res_layer': 1, | |
'norm_layer': 2, | |
'dense_layer': 3 | |
})] | |
assert_dict_has_keys_data_2 = [(['res_layer', 'dense_layer'], True), | |
(['res_layer', 'conv_layer'], False)] | |
def test_assert_dict_has_keys(obj, expected_keys, ret_value): | |
assert mmcv.assert_dict_has_keys(obj, expected_keys) == ret_value | |
assert_keys_equal_data_1 = [(['res_layer', 'norm_layer', 'dense_layer'])] | |
assert_keys_equal_data_2 = [(['res_layer', 'norm_layer', 'dense_layer'], True), | |
(['res_layer', 'dense_layer', 'norm_layer'], True), | |
(['res_layer', 'norm_layer'], False), | |
(['res_layer', 'conv_layer', 'norm_layer'], False)] | |
def test_assert_keys_equal(result_keys, target_keys, ret_value): | |
assert mmcv.assert_keys_equal(result_keys, target_keys) == ret_value | |
def test_assert_is_norm_layer(): | |
# case 1 | |
assert not mmcv.assert_is_norm_layer(nn.Conv3d(3, 64, 3)) | |
# case 2 | |
assert mmcv.assert_is_norm_layer(nn.BatchNorm3d(128)) | |
# case 3 | |
assert mmcv.assert_is_norm_layer(nn.GroupNorm(8, 64)) | |
# case 4 | |
assert not mmcv.assert_is_norm_layer(nn.Sigmoid()) | |
def test_assert_params_all_zeros(): | |
demo_module = nn.Conv2d(3, 64, 3) | |
nn.init.constant_(demo_module.weight, 0) | |
nn.init.constant_(demo_module.bias, 0) | |
assert mmcv.assert_params_all_zeros(demo_module) | |
nn.init.xavier_normal_(demo_module.weight) | |
nn.init.constant_(demo_module.bias, 0) | |
assert not mmcv.assert_params_all_zeros(demo_module) | |
demo_module = nn.Linear(2048, 400, bias=False) | |
nn.init.constant_(demo_module.weight, 0) | |
assert mmcv.assert_params_all_zeros(demo_module) | |
nn.init.normal_(demo_module.weight, mean=0, std=0.01) | |
assert not mmcv.assert_params_all_zeros(demo_module) | |
def test_check_python_script(capsys): | |
mmcv.utils.check_python_script('./tests/data/scripts/hello.py zz') | |
captured = capsys.readouterr().out | |
assert captured == 'hello zz!\n' | |
mmcv.utils.check_python_script('./tests/data/scripts/hello.py agent') | |
captured = capsys.readouterr().out | |
assert captured == 'hello agent!\n' | |
# Make sure that wrong cmd raises an error | |
with pytest.raises(SystemExit): | |
mmcv.utils.check_python_script('./tests/data/scripts/hello.py li zz') | |