Spaces:
Sleeping
Sleeping
# Copyright (c) OpenMMLab. All rights reserved. | |
import logging | |
import numpy as np | |
import pytest | |
import torch | |
from mmcv.parallel.data_container import DataContainer | |
from mmcv.utils import IS_IPU_AVAILABLE | |
if IS_IPU_AVAILABLE: | |
from mmcv.device.ipu.hierarchical_data_manager import \ | |
HierarchicalDataManager | |
skip_no_ipu = pytest.mark.skipif( | |
not IS_IPU_AVAILABLE, reason='test case under ipu environment') | |
def test_HierarchicalData(): | |
# test hierarchical data | |
hierarchical_data_sample = { | |
'a': torch.rand(3, 4), | |
'b': np.random.rand(3, 4), | |
'c': DataContainer({ | |
'a': torch.rand(3, 4), | |
'b': 4, | |
'c': 'd' | |
}), | |
'd': 123, | |
'e': [1, 3, torch.rand(3, 4), | |
np.random.rand(3, 4)], | |
'f': { | |
'a': torch.rand(3, 4), | |
'b': np.random.rand(3, 4), | |
'c': [1, 'asd'] | |
} | |
} | |
all_tensors = [] | |
all_tensors.append(hierarchical_data_sample['a']) | |
all_tensors.append(hierarchical_data_sample['c'].data['a']) | |
all_tensors.append(hierarchical_data_sample['e'][2]) | |
all_tensors.append(hierarchical_data_sample['f']['a']) | |
all_tensors_id = [id(ele) for ele in all_tensors] | |
hd = HierarchicalDataManager(logging.getLogger()) | |
hd.record_hierarchical_data(hierarchical_data_sample) | |
tensors = hd.collect_all_tensors() | |
for t in tensors: | |
assert id(t) in all_tensors_id | |
tensors[0].add_(1) | |
hd.update_all_tensors(tensors) | |
data = hd.hierarchical_data | |
data['c'].data['a'].sub_(1) | |
hd.record_hierarchical_data(data) | |
tensors = hd.collect_all_tensors() | |
for t in tensors: | |
assert id(t) in all_tensors_id | |
hd.quick() | |
with pytest.raises( | |
AssertionError, | |
match='original hierarchical data is not torch.tensor'): | |
hd.record_hierarchical_data(torch.rand(3, 4)) | |
class AuxClass: | |
pass | |
with pytest.raises(NotImplementedError, match='not supported datatype:'): | |
hd.record_hierarchical_data(AuxClass()) | |
with pytest.raises(NotImplementedError, match='not supported datatype:'): | |
hierarchical_data_sample['a'] = AuxClass() | |
hd.update_all_tensors(tensors) | |
with pytest.raises(NotImplementedError, match='not supported datatype:'): | |
hierarchical_data_sample['a'] = AuxClass() | |
hd.collect_all_tensors() | |
with pytest.raises(NotImplementedError, match='not supported datatype:'): | |
hierarchical_data_sample['a'] = AuxClass() | |
hd.clean_all_tensors() | |
hd = HierarchicalDataManager(logging.getLogger()) | |
hd.record_hierarchical_data(hierarchical_data_sample) | |
hierarchical_data_sample['a'] = torch.rand(3, 4) | |
with pytest.raises(ValueError, match='all data except torch.Tensor'): | |
new_hierarchical_data_sample = { | |
**hierarchical_data_sample, 'b': np.random.rand(3, 4) | |
} | |
hd.update_hierarchical_data(new_hierarchical_data_sample) | |
hd.update_hierarchical_data(new_hierarchical_data_sample, strict=False) | |
hd.clean_all_tensors() | |
# test single tensor | |
single_tensor = torch.rand(3, 4) | |
hd = HierarchicalDataManager(logging.getLogger()) | |
hd.record_hierarchical_data(single_tensor) | |
tensors = hd.collect_all_tensors() | |
assert len(tensors) == 1 and single_tensor in tensors | |
single_tensor_to_update = [torch.rand(3, 4)] | |
hd.update_all_tensors(single_tensor_to_update) | |
new_tensors = hd.collect_all_tensors() | |
assert new_tensors == single_tensor_to_update | |