File size: 5,474 Bytes
6ce7d82 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
"""
Deep learning framework for sound speed inversion
"""
import json
import git
import argparse
import pathlib
import glob
import os
import h5py
import loader
import run_logger
import net
import torch
import torch.utils.data as td
import pytorch_lightning as pl
# ----------------------------
# Setup command line arguments
# ----------------------------
parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--test_files', nargs='?', help='Test data (file pattern) to process / data to evaluate ')
parser.add_argument('--train_files', nargs='?', help='Train data (file pattern) to process, only evaluate test if empty')
parser.add_argument('--test_fname', default='output.h5', help='Filename into which to write testing output -- will be overwritten')
parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
parser.add_argument('--experiment', default='DeepLearning US', help='experiment name')
parser.add_argument('--tags', nargs='?', help='Optional run tags, should evaluate to dictionary via json.loads')
parser.add_argument('--load_ver', type=str, help='Network weights to load')
parser.add_argument('--conf', type=str, action='append', help='Config file(s) to import (overridden by command line arguments)')
parser.add_argument('--conf_export', type=str, help='Filename where to store settings')
parser = pl.Trainer.add_argparse_args(parser)
parser = loader.Loader.add_argparse_args(parser)
parser = net.Net.add_model_specific_args(parser)
parser = run_logger.ImgCB.add_argparse_args(parser)
args = parser.parse_args()
if args.conf is not None:
for conf_fname in args.conf:
with open(conf_fname, 'r') as f:
parser.set_defaults(**json.load(f))
# Reload arguments to override config file values with command line values
args = parser.parse_args()
if args.conf_export is not None:
with open(args.conf_export, 'w') as f:
json.dump(vars(args), f, indent=4, sort_keys=True)
if args.test_files is None and args.train_files is None:
raise ValueError('At least one of train files or test files is required')
# ----------------------------
# Load data
# ----------------------------
ld = loader.Loader(**vars(args))
test_input, test_label, train_input, train_label = ld.load_data(test_file_pattern=args.test_files, train_file_pattern=args.train_files)
for name, tensor in (
('test_input', test_input),
('test_label', test_label),
('train_input', train_input),
('train_label', train_label)):
print(f'{name}: {tensor.shape if tensor is not None else None} -- {tensor.dtype if tensor is not None else None}')
loaders = []
if args.train_files is not None:
if train_input is None or train_label is None or (test_input is not None and test_label is None):
raise ValueError('Training requires labeled data')
train_ds = td.TensorDataset(train_input, train_label)
loaders.append(td.DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, pin_memory=True))
if args.test_files is not None:
ds = [test_input]
if test_label is not None:
ds.append(test_label)
test_ds = td.TensorDataset(*ds)
loaders.append(td.DataLoader(test_ds, args.batch_size, shuffle=False, pin_memory=True))
# ----------------------------
# Run
# ----------------------------
if args.train_files is not None:
if args.tags is None:
args.tags = {}
elif type(args.tags) == str:
args.tags = json.loads(args.tags)
try:
repo = git.Repo(search_parent_directories=True)
sha = repo.head.object.hexsha
args.tags.update({'commit': sha})
except:
print('Not a git repo, not logging commit ID')
mfl = pl.loggers.MLFlowLogger(experiment_name=args.experiment, tags=args.tags)
mfl.log_hyperparams(args)
path = pathlib.Path(__file__).parent.absolute()
files = glob.glob(str(path) + os.sep + '*.py')
for f in files:
mfl.experiment.log_artifact(mfl.run_id, f, 'source')
chkpnt_cb = pl.callbacks.ModelCheckpoint(
monitor='validate_mean',
verbose=True,
save_top_k=1,
save_weights_only=True,
mode='min',
every_n_train_steps=1,
filename='{epoch}-{validate_mean}-{train_mean}',
)
img_cb = run_logger.ImgCB(**vars(args))
lr_logger = pl.callbacks.LearningRateMonitor()
args.__dict__.update({'logger': mfl, 'callbacks': [chkpnt_cb, img_cb, lr_logger]})
else:
if os.path.exists(args.test_fname):
os.remove(args.test_fname)
args.__dict__.update({'callbacks': [run_logger.TestLogger(args.test_fname)]})
if test_label is not None:
args.n_outputs = test_label.shape[1]
elif train_label is not None:
args.n_outputs = train_label.shape[1]
if test_input is not None:
args.n_inputs = test_input.shape[1]
elif train_input is not None:
args.n_inputs = train_input.shape[1]
n = net.Net(**vars(args))
if args.load_ver is not None:
t = torch.load(args.load_ver, map_location='cpu')['state_dict']
n.load_state_dict(t)
trainer = pl.Trainer.from_argparse_args(args)
if args.train_files is not None:
trainer.fit(n, *loaders)
print(chkpnt_cb.best_model_path)
elif args.label_vars:
trainer.test(n, *loaders)
else:
predictions = trainer.predict(n, *loaders)
with h5py.File(args.test_fname, "w") as F:
F["predictions"] = torch.cat(predictions).numpy()
|