Spaces:
Sleeping
Sleeping
File size: 3,547 Bytes
d7e58f0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
# Copyright (c) OpenMMLab. All rights reserved.
import logging
import numpy as np
import pytest
import torch
from mmcv.parallel.data_container import DataContainer
from mmcv.utils import IS_IPU_AVAILABLE
if IS_IPU_AVAILABLE:
from mmcv.device.ipu.hierarchical_data_manager import \
HierarchicalDataManager
skip_no_ipu = pytest.mark.skipif(
not IS_IPU_AVAILABLE, reason='test case under ipu environment')
@skip_no_ipu
def test_HierarchicalData():
# test hierarchical data
hierarchical_data_sample = {
'a': torch.rand(3, 4),
'b': np.random.rand(3, 4),
'c': DataContainer({
'a': torch.rand(3, 4),
'b': 4,
'c': 'd'
}),
'd': 123,
'e': [1, 3, torch.rand(3, 4),
np.random.rand(3, 4)],
'f': {
'a': torch.rand(3, 4),
'b': np.random.rand(3, 4),
'c': [1, 'asd']
}
}
all_tensors = []
all_tensors.append(hierarchical_data_sample['a'])
all_tensors.append(hierarchical_data_sample['c'].data['a'])
all_tensors.append(hierarchical_data_sample['e'][2])
all_tensors.append(hierarchical_data_sample['f']['a'])
all_tensors_id = [id(ele) for ele in all_tensors]
hd = HierarchicalDataManager(logging.getLogger())
hd.record_hierarchical_data(hierarchical_data_sample)
tensors = hd.collect_all_tensors()
for t in tensors:
assert id(t) in all_tensors_id
tensors[0].add_(1)
hd.update_all_tensors(tensors)
data = hd.hierarchical_data
data['c'].data['a'].sub_(1)
hd.record_hierarchical_data(data)
tensors = hd.collect_all_tensors()
for t in tensors:
assert id(t) in all_tensors_id
hd.quick()
with pytest.raises(
AssertionError,
match='original hierarchical data is not torch.tensor'):
hd.record_hierarchical_data(torch.rand(3, 4))
class AuxClass:
pass
with pytest.raises(NotImplementedError, match='not supported datatype:'):
hd.record_hierarchical_data(AuxClass())
with pytest.raises(NotImplementedError, match='not supported datatype:'):
hierarchical_data_sample['a'] = AuxClass()
hd.update_all_tensors(tensors)
with pytest.raises(NotImplementedError, match='not supported datatype:'):
hierarchical_data_sample['a'] = AuxClass()
hd.collect_all_tensors()
with pytest.raises(NotImplementedError, match='not supported datatype:'):
hierarchical_data_sample['a'] = AuxClass()
hd.clean_all_tensors()
hd = HierarchicalDataManager(logging.getLogger())
hd.record_hierarchical_data(hierarchical_data_sample)
hierarchical_data_sample['a'] = torch.rand(3, 4)
with pytest.raises(ValueError, match='all data except torch.Tensor'):
new_hierarchical_data_sample = {
**hierarchical_data_sample, 'b': np.random.rand(3, 4)
}
hd.update_hierarchical_data(new_hierarchical_data_sample)
hd.update_hierarchical_data(new_hierarchical_data_sample, strict=False)
hd.clean_all_tensors()
# test single tensor
single_tensor = torch.rand(3, 4)
hd = HierarchicalDataManager(logging.getLogger())
hd.record_hierarchical_data(single_tensor)
tensors = hd.collect_all_tensors()
assert len(tensors) == 1 and single_tensor in tensors
single_tensor_to_update = [torch.rand(3, 4)]
hd.update_all_tensors(single_tensor_to_update)
new_tensors = hd.collect_all_tensors()
assert new_tensors == single_tensor_to_update
|