# Copyright (c) OpenMMLab. All rights reserved. import numpy as np import pytest import mmcv try: import torch except ImportError: torch = None else: import torch.nn as nn def test_assert_dict_contains_subset(): dict_obj = {'a': 'test1', 'b': 2, 'c': (4, 6)} # case 1 expected_subset = {'a': 'test1', 'b': 2, 'c': (4, 6)} assert mmcv.assert_dict_contains_subset(dict_obj, expected_subset) # case 2 expected_subset = {'a': 'test1', 'b': 2, 'c': (6, 4)} assert not mmcv.assert_dict_contains_subset(dict_obj, expected_subset) # case 3 expected_subset = {'a': 'test1', 'b': 2, 'c': None} assert not mmcv.assert_dict_contains_subset(dict_obj, expected_subset) # case 4 expected_subset = {'a': 'test1', 'b': 2, 'd': (4, 6)} assert not mmcv.assert_dict_contains_subset(dict_obj, expected_subset) # case 5 dict_obj = { 'a': 'test1', 'b': 2, 'c': (4, 6), 'd': np.array([[5, 3, 5], [1, 2, 3]]) } expected_subset = { 'a': 'test1', 'b': 2, 'c': (4, 6), 'd': np.array([[5, 3, 5], [6, 2, 3]]) } assert not mmcv.assert_dict_contains_subset(dict_obj, expected_subset) # case 6 dict_obj = {'a': 'test1', 'b': 2, 'c': (4, 6), 'd': np.array([[1]])} expected_subset = {'a': 'test1', 'b': 2, 'c': (4, 6), 'd': np.array([[1]])} assert mmcv.assert_dict_contains_subset(dict_obj, expected_subset) if torch is not None: dict_obj = { 'a': 'test1', 'b': 2, 'c': (4, 6), 'd': torch.tensor([5, 3, 5]) } # case 7 expected_subset = {'d': torch.tensor([5, 5, 5])} assert not mmcv.assert_dict_contains_subset(dict_obj, expected_subset) # case 8 expected_subset = {'d': torch.tensor([[5, 3, 5], [4, 1, 2]])} assert not mmcv.assert_dict_contains_subset(dict_obj, expected_subset) def test_assert_attrs_equal(): class TestExample: a, b, c = 1, ('wvi', 3), [4.5, 3.14] def test_func(self): return self.b # case 1 assert mmcv.assert_attrs_equal(TestExample, { 'a': 1, 'b': ('wvi', 3), 'c': [4.5, 3.14] }) # case 2 assert not mmcv.assert_attrs_equal(TestExample, { 'a': 1, 'b': ('wvi', 3), 'c': [4.5, 3.14, 2] }) # case 3 assert not mmcv.assert_attrs_equal(TestExample, { 'bc': 54, 'c': [4.5, 3.14] }) # case 4 assert mmcv.assert_attrs_equal(TestExample, { 'b': ('wvi', 3), 'test_func': TestExample.test_func }) if torch is not None: class TestExample: a, b = torch.tensor([1]), torch.tensor([4, 5]) # case 5 assert mmcv.assert_attrs_equal(TestExample, { 'a': torch.tensor([1]), 'b': torch.tensor([4, 5]) }) # case 6 assert not mmcv.assert_attrs_equal(TestExample, { 'a': torch.tensor([1]), 'b': torch.tensor([4, 6]) }) assert_dict_has_keys_data_1 = [({ 'res_layer': 1, 'norm_layer': 2, 'dense_layer': 3 })] assert_dict_has_keys_data_2 = [(['res_layer', 'dense_layer'], True), (['res_layer', 'conv_layer'], False)] @pytest.mark.parametrize('obj', assert_dict_has_keys_data_1) @pytest.mark.parametrize('expected_keys, ret_value', assert_dict_has_keys_data_2) def test_assert_dict_has_keys(obj, expected_keys, ret_value): assert mmcv.assert_dict_has_keys(obj, expected_keys) == ret_value assert_keys_equal_data_1 = [(['res_layer', 'norm_layer', 'dense_layer'])] assert_keys_equal_data_2 = [(['res_layer', 'norm_layer', 'dense_layer'], True), (['res_layer', 'dense_layer', 'norm_layer'], True), (['res_layer', 'norm_layer'], False), (['res_layer', 'conv_layer', 'norm_layer'], False)] @pytest.mark.parametrize('result_keys', assert_keys_equal_data_1) @pytest.mark.parametrize('target_keys, ret_value', assert_keys_equal_data_2) def test_assert_keys_equal(result_keys, target_keys, ret_value): assert mmcv.assert_keys_equal(result_keys, target_keys) == ret_value @pytest.mark.skipif(torch is None, reason='requires torch library') def test_assert_is_norm_layer(): # case 1 assert not mmcv.assert_is_norm_layer(nn.Conv3d(3, 64, 3)) # case 2 assert mmcv.assert_is_norm_layer(nn.BatchNorm3d(128)) # case 3 assert mmcv.assert_is_norm_layer(nn.GroupNorm(8, 64)) # case 4 assert not mmcv.assert_is_norm_layer(nn.Sigmoid()) @pytest.mark.skipif(torch is None, reason='requires torch library') def test_assert_params_all_zeros(): demo_module = nn.Conv2d(3, 64, 3) nn.init.constant_(demo_module.weight, 0) nn.init.constant_(demo_module.bias, 0) assert mmcv.assert_params_all_zeros(demo_module) nn.init.xavier_normal_(demo_module.weight) nn.init.constant_(demo_module.bias, 0) assert not mmcv.assert_params_all_zeros(demo_module) demo_module = nn.Linear(2048, 400, bias=False) nn.init.constant_(demo_module.weight, 0) assert mmcv.assert_params_all_zeros(demo_module) nn.init.normal_(demo_module.weight, mean=0, std=0.01) assert not mmcv.assert_params_all_zeros(demo_module) def test_check_python_script(capsys): mmcv.utils.check_python_script('./tests/data/scripts/hello.py zz') captured = capsys.readouterr().out assert captured == 'hello zz!\n' mmcv.utils.check_python_script('./tests/data/scripts/hello.py agent') captured = capsys.readouterr().out assert captured == 'hello agent!\n' # Make sure that wrong cmd raises an error with pytest.raises(SystemExit): mmcv.utils.check_python_script('./tests/data/scripts/hello.py li zz')