File size: 6,801 Bytes
6ce7d82 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
"""
Defines a Loader class to load data from a file or file wildcard
"""
import argparse
import h5py
import torch
import numpy as np
import glob
from typing import Tuple
class Loader:
"""
Data loader class
"""
def __init__(self, **kwargs):
parser = Loader.add_argparse_args()
for action in parser._actions:
if action.dest in kwargs:
action.default = kwargs[action.dest]
args = parser.parse_args([])
self.__dict__.update(vars(args))
if type(self.label_vars) is str:
self.label_vars = [self.label_vars]
@staticmethod
def add_argparse_args(parent_parser=None):
"""
Add argeparse argument for the data loader
"""
parser = argparse.ArgumentParser(
prog='Loader',
usage=Loader.__doc__,
parents=[parent_parser] if parent_parser is not None else [],
add_help=False)
parser.add_argument('--input_var', default='p_f5.0_o0', help='Variable name for the label data')
parser.add_argument('--label_vars', nargs='*', default='c0', help='Variable name(s) for the label data')
parser.add_argument('--inputs_crop', type=int, default=[0, 1, 32, 96, 42, 2090], nargs='*',
help='Crop input data on load [layer_min layer_max x_min x_max y_min y_max]')
parser.add_argument('--labels_crop', type=int, default=[322, 830, 60, 1076], nargs='*', help='Crop label data on load [x_min x_max y_min y_max]')
parser.add_argument('--labels_resize', type=float, default=256.0 / 1016.0, help='scaling factor for labels image')
parser.add_argument('--data_scale', type=float, default=1.0, help='Data scaling factor')
parser.add_argument('--data_gain', type=float, default=1.8, help='Data gain factor in dB/20 at farthest point in data.')
return parser
def load_data(self, test_file_pattern: str, train_file_pattern: str = None) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
"""Loads training/testing data from file(s)
Arguments:
test_file_pattern {str} -- testing dataset(s) pattern
train_file_pattern {str} -- training dataset(s) pattern
Returns:
(test_inputs, test_labels, train_inputs, train_labels) -- None for values that are not loaded
"""
test_inputs, test_labels = self._load_data_files(test_file_pattern)
train_inputs, train_labels = self._load_data_files(train_file_pattern)
if train_file_pattern is not None and train_inputs is None:
raise ValueError('Failed to load train set')
if test_file_pattern is not None and test_inputs is None:
raise ValueError('Failed to load train set')
return test_inputs, test_labels, train_inputs, train_labels
def _load_data_files(self, file_pattern: str) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
""" Perform actual data loading
Args:
file_pattern: file name pattern
Returns:
inputs and labels tensors
"""
inputs, labels = None, None
if file_pattern is None:
return inputs, labels
files = glob.glob(file_pattern)
if len(files) == 0:
raise ValueError(f'{file_pattern=} comes up empty')
# Load first file to get output dimensions
with h5py.File(files[0], 'r') as f:
if self.input_var not in f:
raise ValueError(f'input data key not in file: {self.input_var=}')
shape = list(f[self.input_var].shape)
if self.inputs_crop is not None:
for i in range(len(self.inputs_crop) // 2):
shape[-i - 1] = self.inputs_crop[-i * 2 - 1] - self.inputs_crop[-i * 2 - 2]
shape[0] *= len(files)
inputs = np.empty(shape, np.single)
if len(self.label_vars):
if not all([v in f for v in self.label_vars]):
raise ValueError(f'labels data key(s) not in file: {self.label_vars=}')
shape = list(f[self.label_vars[0]].shape)
shape[1] *= len(self.label_vars)
if self.labels_crop is not None:
for i in range(len(self.labels_crop) // 2):
shape[-i - 1] = self.labels_crop[-i * 2 - 1] - self.labels_crop[-i * 2 - 2]
shape[-1] = int(shape[-1] * self.labels_resize)
shape[-2] = int(shape[-2] * self.labels_resize)
shape[0] *= len(files)
labels = np.empty(shape, np.single)
# Load data from files
pos = 0
for file in files:
with h5py.File(files[0], 'r') as f:
tmp_inputs = np.array(f[self.input_var])
if self.inputs_crop is not None:
slc = [slice(None)] * 4
for i in range(len(self.inputs_crop) // 2):
slc[-i - 1] = slice(self.inputs_crop[-i * 2 - 2], self.inputs_crop[-i * 2 - 1])
tmp_inputs = tmp_inputs[tuple(slc)]
inputs[pos:pos + tmp_inputs.shape[0], ...] = tmp_inputs
if len(self.label_vars):
tmp_labels = []
for v in self.label_vars:
tmp_labels.append(np.array(f[v]))
tmp_labels = np.concatenate(tmp_labels, axis=1)
if self.labels_crop is not None and self.labels_crop:
slc = [slice(None)] * 4
for i in range(len(self.labels_crop) // 2):
slc[-i - 1] = slice(self.labels_crop[-i * 2 - 2], self.labels_crop[-i * 2 - 1])
tmp_labels = tmp_labels[tuple(slc)]
if self.labels_resize != 1.0:
tmp_labels = torch.nn.Upsample(scale_factor=self.labels_resize, mode='nearest')(torch.from_numpy(tmp_labels)).numpy()
labels[pos:pos + tmp_labels.shape[0], ...] = tmp_labels
pos += tmp_inputs.shape[0]
inputs = inputs[:pos, ...]
if len(self.label_vars):
labels = labels[:pos, ...]
if self.data_scale != 1.0:
inputs *= self.data_scale
if self.data_gain != 0.0:
gain = 10.0 ** np.linspace(0, self.data_gain, inputs.shape[-1], np.single).reshape((1, 1, 1, -1))
inputs *= gain
# Required when inputs is non-continuous due to transpose
# TODO: Could probably use a check on strides and do a conditional copy.
inputs = torch.from_numpy(inputs.copy())
if len(self.label_vars):
labels = torch.from_numpy(labels)
return inputs, labels
|