medical imaging
ultrasound
File size: 6,801 Bytes
6ce7d82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
"""
Defines a Loader class to load data from a file or file wildcard
"""

import argparse

import h5py
import torch
import numpy as np
import glob
from typing import Tuple


class Loader:
    """
    Data loader class
    """

    def __init__(self, **kwargs):
        parser = Loader.add_argparse_args()
        for action in parser._actions:
            if action.dest in kwargs:
                action.default = kwargs[action.dest]
        args = parser.parse_args([])
        self.__dict__.update(vars(args))

        if type(self.label_vars) is str:
            self.label_vars = [self.label_vars]

    @staticmethod
    def add_argparse_args(parent_parser=None):
        """
        Add argeparse argument for the data loader
        """
        parser = argparse.ArgumentParser(
            prog='Loader',
            usage=Loader.__doc__,
            parents=[parent_parser] if parent_parser is not None else [],
            add_help=False)

        parser.add_argument('--input_var', default='p_f5.0_o0', help='Variable name for the label data')
        parser.add_argument('--label_vars', nargs='*', default='c0', help='Variable name(s) for the label data')

        parser.add_argument('--inputs_crop', type=int, default=[0, 1, 32, 96, 42, 2090], nargs='*',
                            help='Crop input data on load [layer_min layer_max x_min x_max y_min y_max]')
        parser.add_argument('--labels_crop', type=int, default=[322, 830, 60, 1076], nargs='*', help='Crop label data on load [x_min x_max y_min y_max]')
        parser.add_argument('--labels_resize', type=float, default=256.0 / 1016.0, help='scaling factor for labels image')

        parser.add_argument('--data_scale', type=float, default=1.0, help='Data scaling factor')
        parser.add_argument('--data_gain', type=float, default=1.8, help='Data gain factor in dB/20 at farthest point in data.')

        return parser

    def load_data(self, test_file_pattern: str, train_file_pattern: str = None) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
        """Loads training/testing data from file(s)

        Arguments:
            test_file_pattern  {str} -- testing dataset(s) pattern
            train_file_pattern  {str} -- training dataset(s) pattern

        Returns:
            (test_inputs, test_labels, train_inputs, train_labels) -- None for values that are not loaded
        """

        test_inputs, test_labels = self._load_data_files(test_file_pattern)
        train_inputs, train_labels = self._load_data_files(train_file_pattern)

        if train_file_pattern is not None and train_inputs is None:
            raise ValueError('Failed to load train set')

        if test_file_pattern is not None and test_inputs is None:
            raise ValueError('Failed to load train set')

        return test_inputs, test_labels, train_inputs, train_labels

    def _load_data_files(self, file_pattern: str) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
        """ Perform actual data loading

        Args:
            file_pattern: file name pattern

        Returns:
            inputs and labels tensors
        """

        inputs, labels = None, None

        if file_pattern is None:
            return inputs, labels

        files = glob.glob(file_pattern)

        if len(files) == 0:
            raise ValueError(f'{file_pattern=} comes up empty')

        # Load first file to get output dimensions
        with h5py.File(files[0], 'r') as f:
            if self.input_var not in f:
                raise ValueError(f'input data key not in file: {self.input_var=}')

            shape = list(f[self.input_var].shape)
            if self.inputs_crop is not None:
                for i in range(len(self.inputs_crop) // 2):
                    shape[-i - 1] = self.inputs_crop[-i * 2 - 1] - self.inputs_crop[-i * 2 - 2]

            shape[0] *= len(files)

            inputs = np.empty(shape, np.single)

            if len(self.label_vars):
                if not all([v in f for v in self.label_vars]):
                    raise ValueError(f'labels data key(s) not in file: {self.label_vars=}')

                shape = list(f[self.label_vars[0]].shape)
                shape[1] *= len(self.label_vars)
                if self.labels_crop is not None:
                    for i in range(len(self.labels_crop) // 2):
                        shape[-i - 1] = self.labels_crop[-i * 2 - 1] - self.labels_crop[-i * 2 - 2]

                shape[-1] = int(shape[-1] * self.labels_resize)
                shape[-2] = int(shape[-2] * self.labels_resize)
                shape[0] *= len(files)

                labels = np.empty(shape, np.single)

        # Load data from files
        pos = 0
        for file in files:
            with h5py.File(files[0], 'r') as f:
                tmp_inputs = np.array(f[self.input_var])

                if self.inputs_crop is not None:
                    slc = [slice(None)] * 4
                    for i in range(len(self.inputs_crop) // 2):
                        slc[-i - 1] = slice(self.inputs_crop[-i * 2 - 2], self.inputs_crop[-i * 2 - 1])
                    tmp_inputs = tmp_inputs[tuple(slc)]

                inputs[pos:pos + tmp_inputs.shape[0], ...] = tmp_inputs

                if len(self.label_vars):
                    tmp_labels = []
                    for v in self.label_vars:
                        tmp_labels.append(np.array(f[v]))
                    tmp_labels = np.concatenate(tmp_labels, axis=1)

                    if self.labels_crop is not None and self.labels_crop:
                        slc = [slice(None)] * 4
                        for i in range(len(self.labels_crop) // 2):
                            slc[-i - 1] = slice(self.labels_crop[-i * 2 - 2], self.labels_crop[-i * 2 - 1])
                        tmp_labels = tmp_labels[tuple(slc)]

                    if self.labels_resize != 1.0:
                        tmp_labels = torch.nn.Upsample(scale_factor=self.labels_resize, mode='nearest')(torch.from_numpy(tmp_labels)).numpy()

                    labels[pos:pos + tmp_labels.shape[0], ...] = tmp_labels

                pos += tmp_inputs.shape[0]

        inputs = inputs[:pos, ...]
        if len(self.label_vars):
            labels = labels[:pos, ...]

        if self.data_scale != 1.0:
            inputs *= self.data_scale

        if self.data_gain != 0.0:
            gain = 10.0 ** np.linspace(0, self.data_gain, inputs.shape[-1], np.single).reshape((1, 1, 1, -1))
            inputs *= gain

        # Required when inputs is non-continuous due to transpose
        # TODO: Could probably use a check on strides and do a conditional copy.
        inputs = torch.from_numpy(inputs.copy())
        if len(self.label_vars):
            labels = torch.from_numpy(labels)

        return inputs, labels