medical imaging
ultrasound
File size: 5,474 Bytes
6ce7d82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
"""
Deep learning framework for sound speed inversion
"""

import json
import git
import argparse
import pathlib
import glob
import os
import h5py

import loader
import run_logger
import net

import torch
import torch.utils.data as td
import pytorch_lightning as pl


# ----------------------------
# Setup command line arguments
# ----------------------------

parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter)

parser.add_argument('--test_files', nargs='?', help='Test data (file pattern) to process / data to evaluate ')
parser.add_argument('--train_files', nargs='?', help='Train data (file pattern) to process, only evaluate test if empty')
parser.add_argument('--test_fname', default='output.h5', help='Filename into which to write testing output -- will be overwritten')

parser.add_argument('--batch_size', type=int, default=32, help='Batch size')

parser.add_argument('--experiment', default='DeepLearning US', help='experiment name')
parser.add_argument('--tags', nargs='?', help='Optional run tags, should evaluate to dictionary via json.loads')

parser.add_argument('--load_ver', type=str, help='Network weights to load')

parser.add_argument('--conf', type=str, action='append', help='Config file(s) to import (overridden by command line arguments)')
parser.add_argument('--conf_export', type=str, help='Filename where to store settings')


parser = pl.Trainer.add_argparse_args(parser)
parser = loader.Loader.add_argparse_args(parser)
parser = net.Net.add_model_specific_args(parser)
parser = run_logger.ImgCB.add_argparse_args(parser)

args = parser.parse_args()

if args.conf is not None:
    for conf_fname in args.conf:
        with open(conf_fname, 'r') as f:
            parser.set_defaults(**json.load(f))

    # Reload arguments to override config file values with command line values
    args = parser.parse_args()

if args.conf_export is not None:
    with open(args.conf_export, 'w') as f:
        json.dump(vars(args), f, indent=4, sort_keys=True)

if args.test_files is None and args.train_files is None:
    raise ValueError('At least one of train files or test files is required')

# ----------------------------
# Load data
# ----------------------------

ld = loader.Loader(**vars(args))
test_input, test_label, train_input, train_label = ld.load_data(test_file_pattern=args.test_files, train_file_pattern=args.train_files)

for name, tensor in (
        ('test_input', test_input),
        ('test_label', test_label),
        ('train_input', train_input),
        ('train_label', train_label)):
    print(f'{name}: {tensor.shape if tensor is not None else None} -- {tensor.dtype if tensor is not None else None}')

loaders = []

if args.train_files is not None:
    if train_input is None or train_label is None or (test_input is not None and test_label is None):
        raise ValueError('Training requires labeled data')

    train_ds = td.TensorDataset(train_input, train_label)
    loaders.append(td.DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, pin_memory=True))

if args.test_files is not None:
    ds = [test_input]
    if test_label is not None:
        ds.append(test_label)

    test_ds = td.TensorDataset(*ds)
    loaders.append(td.DataLoader(test_ds, args.batch_size, shuffle=False, pin_memory=True))

# ----------------------------
# Run
# ----------------------------

if args.train_files is not None:
    if args.tags is None:
        args.tags = {}
    elif type(args.tags) == str:
        args.tags = json.loads(args.tags)

    try:
        repo = git.Repo(search_parent_directories=True)
        sha = repo.head.object.hexsha
        args.tags.update({'commit': sha})
    except:
        print('Not a git repo, not logging commit ID')

    mfl = pl.loggers.MLFlowLogger(experiment_name=args.experiment, tags=args.tags)
    mfl.log_hyperparams(args)

    path = pathlib.Path(__file__).parent.absolute()
    files = glob.glob(str(path) + os.sep + '*.py')
    for f in files:
        mfl.experiment.log_artifact(mfl.run_id, f, 'source')

    chkpnt_cb = pl.callbacks.ModelCheckpoint(
        monitor='validate_mean',
        verbose=True,
        save_top_k=1,
        save_weights_only=True,
        mode='min',
        every_n_train_steps=1,
        filename='{epoch}-{validate_mean}-{train_mean}',
    )

    img_cb = run_logger.ImgCB(**vars(args))
    lr_logger = pl.callbacks.LearningRateMonitor()

    args.__dict__.update({'logger': mfl, 'callbacks': [chkpnt_cb, img_cb, lr_logger]})
else:
    if os.path.exists(args.test_fname):
        os.remove(args.test_fname)

    args.__dict__.update({'callbacks': [run_logger.TestLogger(args.test_fname)]})


if test_label is not None:
    args.n_outputs = test_label.shape[1]
elif train_label is not None:
    args.n_outputs = train_label.shape[1]

if test_input is not None:
    args.n_inputs = test_input.shape[1]
elif train_input is not None:
    args.n_inputs = train_input.shape[1]


n = net.Net(**vars(args))
if args.load_ver is not None:
    t = torch.load(args.load_ver, map_location='cpu')['state_dict']
    n.load_state_dict(t)

trainer = pl.Trainer.from_argparse_args(args)

if args.train_files is not None:
    trainer.fit(n, *loaders)

    print(chkpnt_cb.best_model_path)
elif args.label_vars:
     trainer.test(n, *loaders)
else:
     predictions = trainer.predict(n, *loaders)
     with h5py.File(args.test_fname, "w") as F:
         F["predictions"] = torch.cat(predictions).numpy()