laughingrice
commited on
Commit
·
6ce7d82
1
Parent(s):
d10cbbd
Upload 11 files
Browse files- README.md +45 -3
- __main__.py +170 -0
- environment.yml +286 -0
- loader.py +176 -0
- models/embc_sos.pt +3 -0
- models/tbme2_attn.pt +3 -0
- models/tbme2_phase_sos.pt +3 -0
- models/tbme2_sos.pt +3 -0
- models/tbme_sos.pt +3 -0
- net.py +494 -0
- run_logger.py +120 -0
README.md
CHANGED
@@ -1,3 +1,45 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Deep learning for speed of sound inversion in ultrasound imaging
|
2 |
+
|
3 |
+
This repository contains the code and models for the following papers:
|
4 |
+
|
5 |
+
|
6 |
+
1. Feigin M, Freedman D, Anthony B. W. A Deep Learning Framework for Single-Sided Sound Speed Inversion in Medical Ultrasound. IEEE Trans Biomed Eng. 2020;67(4):1142-1151. doi:10.1109/TBME.2019.2931195
|
7 |
+
2. Feigin M, Zwecker M, Freedman D, Anthony BW. Detecting muscle activation using ultrasound speed of sound inversion with deep learning. In: 2020 42nd Annual International Conference of the IEEE Engineering in Medicine & Biology Society (EMBC). IEEE; 2020:2092-2095. doi:10.1109/EMBC44109.2020.9175237
|
8 |
+
3. Feigin M, Freedman D, Anthony BW. Computing Speed-of-Sound from ultrasound: user-agnostic recovery and a new benchmark. IEEE Trans Biomed Eng. 2023; doi:TBF
|
9 |
+
|
10 |
+
This repository contain the network code and models for the algorithms and results contained in the paper.
|
11 |
+
|
12 |
+
The code was tested under python 3.9. The anaconda environment is defined in environment.yml (setup environment with the `command conda env create -f environment.yml`)
|
13 |
+
|
14 |
+
## Data
|
15 |
+
|
16 |
+
The dataset used is available on huggingface at https://huggingface.co/datasets/laughingrice/Ultrasound_planewave_sos_inversion
|
17 |
+
|
18 |
+
Variables in the files are `[sample, layer, x/channel, y/sample]` order
|
19 |
+
|
20 |
+
* `alpha_coeff` -- Alpha coefficient used for simulations, full resolution
|
21 |
+
* `c0` -- Speed-of-sound used for simulations, full resolution
|
22 |
+
* `data` -- Channel data (first 2048 samples, 64 active channels, first layer with flat plane wave, to
|
23 |
+
match existing physical hardware were used for the results in the paper)
|
24 |
+
* `dx` -- spatial dx value of `c0` and `alpha_coef`
|
25 |
+
* `f` -- temporal sampling frequency of channel data (40MHz)
|
26 |
+
|
27 |
+
## Models
|
28 |
+
|
29 |
+
Model files appearing under the `models` directory for results presented in the paper with teh matching
|
30 |
+
execution parameters are as follows:
|
31 |
+
|
32 |
+
* `tbme_sos.pt` -- network weights for the network presented in [1]
|
33 |
+
* `python . --test_files data/supplamentary_sample.mat --test_fname tbme_sos.h5 --load_ver models/tbme_sos.pt --net_type tbme`
|
34 |
+
* `embc_sos.pt` -- network weights for the network presented in [2]
|
35 |
+
* `python . --test_files data/supplamentary_sample.mat --test_fname embc_sos.h5 --load_ver models/embc_sos.pt --net_type embc`
|
36 |
+
* `tbme2_sos.pt` -- network weights for the network presented in [3]
|
37 |
+
* `python . --test_files data/supplamentary_sample.mat --test_fname tbme2_sos.h5 --load_ver models/tbme2_sos.pt`
|
38 |
+
* `tbme2_sos_rand_gain.pt` -- [3] trained to recover the speed-of-sound map with random gain profile and scaling
|
39 |
+
* `python . --test_files data/supplamentary_sample.mat --test_fname tbme2_sos_gain.h5 --load_ver models/tbme2_sos_rand_gain.pt`
|
40 |
+
* `tbme2_attn.pt` -- [3] trained to recover the attenuation coefficient
|
41 |
+
* `python . --test_files data/supplamentary_sample.mat --test_fname tbme2_attn.h5 --load_ver models/tbme2_attn.pt --label_vars alpha_coeff`
|
42 |
+
* `tbme2_sos_attn.pt` -- [3] trained to recover both the speed-of-sound map and attenuation coefficient
|
43 |
+
* `python . --test_files data/supplamentary_sample.mat --test_fname tbme2_sos_attn.h5 --load_ver models/tbme2_sos_attn.pt --label_vars c0 alpha_coeff`
|
44 |
+
* `tbme2_phase_sos.pt` -- [3] trained to recover the speed-of-sound map using the IQ phase component
|
45 |
+
* `python . --test_files data/supplamentary_sample.mat --test_fname tbme2_phase_sos.h5 --load_ver models/tbme2_phase_sos.pt --phase_inv 1`
|
__main__.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Deep learning framework for sound speed inversion
|
3 |
+
"""
|
4 |
+
|
5 |
+
import json
|
6 |
+
import git
|
7 |
+
import argparse
|
8 |
+
import pathlib
|
9 |
+
import glob
|
10 |
+
import os
|
11 |
+
import h5py
|
12 |
+
|
13 |
+
import loader
|
14 |
+
import run_logger
|
15 |
+
import net
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import torch.utils.data as td
|
19 |
+
import pytorch_lightning as pl
|
20 |
+
|
21 |
+
|
22 |
+
# ----------------------------
|
23 |
+
# Setup command line arguments
|
24 |
+
# ----------------------------
|
25 |
+
|
26 |
+
parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
27 |
+
|
28 |
+
parser.add_argument('--test_files', nargs='?', help='Test data (file pattern) to process / data to evaluate ')
|
29 |
+
parser.add_argument('--train_files', nargs='?', help='Train data (file pattern) to process, only evaluate test if empty')
|
30 |
+
parser.add_argument('--test_fname', default='output.h5', help='Filename into which to write testing output -- will be overwritten')
|
31 |
+
|
32 |
+
parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
|
33 |
+
|
34 |
+
parser.add_argument('--experiment', default='DeepLearning US', help='experiment name')
|
35 |
+
parser.add_argument('--tags', nargs='?', help='Optional run tags, should evaluate to dictionary via json.loads')
|
36 |
+
|
37 |
+
parser.add_argument('--load_ver', type=str, help='Network weights to load')
|
38 |
+
|
39 |
+
parser.add_argument('--conf', type=str, action='append', help='Config file(s) to import (overridden by command line arguments)')
|
40 |
+
parser.add_argument('--conf_export', type=str, help='Filename where to store settings')
|
41 |
+
|
42 |
+
|
43 |
+
parser = pl.Trainer.add_argparse_args(parser)
|
44 |
+
parser = loader.Loader.add_argparse_args(parser)
|
45 |
+
parser = net.Net.add_model_specific_args(parser)
|
46 |
+
parser = run_logger.ImgCB.add_argparse_args(parser)
|
47 |
+
|
48 |
+
args = parser.parse_args()
|
49 |
+
|
50 |
+
if args.conf is not None:
|
51 |
+
for conf_fname in args.conf:
|
52 |
+
with open(conf_fname, 'r') as f:
|
53 |
+
parser.set_defaults(**json.load(f))
|
54 |
+
|
55 |
+
# Reload arguments to override config file values with command line values
|
56 |
+
args = parser.parse_args()
|
57 |
+
|
58 |
+
if args.conf_export is not None:
|
59 |
+
with open(args.conf_export, 'w') as f:
|
60 |
+
json.dump(vars(args), f, indent=4, sort_keys=True)
|
61 |
+
|
62 |
+
if args.test_files is None and args.train_files is None:
|
63 |
+
raise ValueError('At least one of train files or test files is required')
|
64 |
+
|
65 |
+
# ----------------------------
|
66 |
+
# Load data
|
67 |
+
# ----------------------------
|
68 |
+
|
69 |
+
ld = loader.Loader(**vars(args))
|
70 |
+
test_input, test_label, train_input, train_label = ld.load_data(test_file_pattern=args.test_files, train_file_pattern=args.train_files)
|
71 |
+
|
72 |
+
for name, tensor in (
|
73 |
+
('test_input', test_input),
|
74 |
+
('test_label', test_label),
|
75 |
+
('train_input', train_input),
|
76 |
+
('train_label', train_label)):
|
77 |
+
print(f'{name}: {tensor.shape if tensor is not None else None} -- {tensor.dtype if tensor is not None else None}')
|
78 |
+
|
79 |
+
loaders = []
|
80 |
+
|
81 |
+
if args.train_files is not None:
|
82 |
+
if train_input is None or train_label is None or (test_input is not None and test_label is None):
|
83 |
+
raise ValueError('Training requires labeled data')
|
84 |
+
|
85 |
+
train_ds = td.TensorDataset(train_input, train_label)
|
86 |
+
loaders.append(td.DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, pin_memory=True))
|
87 |
+
|
88 |
+
if args.test_files is not None:
|
89 |
+
ds = [test_input]
|
90 |
+
if test_label is not None:
|
91 |
+
ds.append(test_label)
|
92 |
+
|
93 |
+
test_ds = td.TensorDataset(*ds)
|
94 |
+
loaders.append(td.DataLoader(test_ds, args.batch_size, shuffle=False, pin_memory=True))
|
95 |
+
|
96 |
+
# ----------------------------
|
97 |
+
# Run
|
98 |
+
# ----------------------------
|
99 |
+
|
100 |
+
if args.train_files is not None:
|
101 |
+
if args.tags is None:
|
102 |
+
args.tags = {}
|
103 |
+
elif type(args.tags) == str:
|
104 |
+
args.tags = json.loads(args.tags)
|
105 |
+
|
106 |
+
try:
|
107 |
+
repo = git.Repo(search_parent_directories=True)
|
108 |
+
sha = repo.head.object.hexsha
|
109 |
+
args.tags.update({'commit': sha})
|
110 |
+
except:
|
111 |
+
print('Not a git repo, not logging commit ID')
|
112 |
+
|
113 |
+
mfl = pl.loggers.MLFlowLogger(experiment_name=args.experiment, tags=args.tags)
|
114 |
+
mfl.log_hyperparams(args)
|
115 |
+
|
116 |
+
path = pathlib.Path(__file__).parent.absolute()
|
117 |
+
files = glob.glob(str(path) + os.sep + '*.py')
|
118 |
+
for f in files:
|
119 |
+
mfl.experiment.log_artifact(mfl.run_id, f, 'source')
|
120 |
+
|
121 |
+
chkpnt_cb = pl.callbacks.ModelCheckpoint(
|
122 |
+
monitor='validate_mean',
|
123 |
+
verbose=True,
|
124 |
+
save_top_k=1,
|
125 |
+
save_weights_only=True,
|
126 |
+
mode='min',
|
127 |
+
every_n_train_steps=1,
|
128 |
+
filename='{epoch}-{validate_mean}-{train_mean}',
|
129 |
+
)
|
130 |
+
|
131 |
+
img_cb = run_logger.ImgCB(**vars(args))
|
132 |
+
lr_logger = pl.callbacks.LearningRateMonitor()
|
133 |
+
|
134 |
+
args.__dict__.update({'logger': mfl, 'callbacks': [chkpnt_cb, img_cb, lr_logger]})
|
135 |
+
else:
|
136 |
+
if os.path.exists(args.test_fname):
|
137 |
+
os.remove(args.test_fname)
|
138 |
+
|
139 |
+
args.__dict__.update({'callbacks': [run_logger.TestLogger(args.test_fname)]})
|
140 |
+
|
141 |
+
|
142 |
+
if test_label is not None:
|
143 |
+
args.n_outputs = test_label.shape[1]
|
144 |
+
elif train_label is not None:
|
145 |
+
args.n_outputs = train_label.shape[1]
|
146 |
+
|
147 |
+
if test_input is not None:
|
148 |
+
args.n_inputs = test_input.shape[1]
|
149 |
+
elif train_input is not None:
|
150 |
+
args.n_inputs = train_input.shape[1]
|
151 |
+
|
152 |
+
|
153 |
+
n = net.Net(**vars(args))
|
154 |
+
if args.load_ver is not None:
|
155 |
+
t = torch.load(args.load_ver, map_location='cpu')['state_dict']
|
156 |
+
n.load_state_dict(t)
|
157 |
+
|
158 |
+
trainer = pl.Trainer.from_argparse_args(args)
|
159 |
+
|
160 |
+
if args.train_files is not None:
|
161 |
+
trainer.fit(n, *loaders)
|
162 |
+
|
163 |
+
print(chkpnt_cb.best_model_path)
|
164 |
+
elif args.label_vars:
|
165 |
+
trainer.test(n, *loaders)
|
166 |
+
else:
|
167 |
+
predictions = trainer.predict(n, *loaders)
|
168 |
+
with h5py.File(args.test_fname, "w") as F:
|
169 |
+
F["predictions"] = torch.cat(predictions).numpy()
|
170 |
+
|
environment.yml
ADDED
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: DL-US-inversion
|
2 |
+
channels:
|
3 |
+
- conda-forge
|
4 |
+
- defaults
|
5 |
+
- pytorch
|
6 |
+
- gimli
|
7 |
+
dependencies:
|
8 |
+
- _libgcc_mutex=0.1=conda_forge
|
9 |
+
- _openmp_mutex=4.5=2_kmp_llvm
|
10 |
+
- abseil-cpp=20211102.0=h27087fc_1
|
11 |
+
- absl-py=1.1.0=pyhd8ed1ab_0
|
12 |
+
- aiohttp=3.8.1=py310h5764c6d_1
|
13 |
+
- aiosignal=1.2.0=pyhd8ed1ab_0
|
14 |
+
- alembic=1.8.1=pyhd8ed1ab_0
|
15 |
+
- aom=3.4.0=h27087fc_1
|
16 |
+
- appdirs=1.4.4=pyh9f0ad1d_0
|
17 |
+
- arrow-cpp=8.0.0=py310h893e394_4_cpu
|
18 |
+
- asn1crypto=1.5.1=pyhd8ed1ab_0
|
19 |
+
- asttokens=2.0.5=pyhd8ed1ab_0
|
20 |
+
- async-timeout=4.0.2=pyhd8ed1ab_0
|
21 |
+
- attrs=21.4.0=pyhd8ed1ab_0
|
22 |
+
- aws-c-cal=0.5.11=h95a6274_0
|
23 |
+
- aws-c-common=0.6.2=h7f98852_0
|
24 |
+
- aws-c-event-stream=0.2.7=h3541f99_13
|
25 |
+
- aws-c-io=0.10.5=hfb6a706_0
|
26 |
+
- aws-checksums=0.1.11=ha31a3da_7
|
27 |
+
- aws-sdk-cpp=1.8.186=hb4091e7_3
|
28 |
+
- backcall=0.2.0=pyh9f0ad1d_0
|
29 |
+
- backports=1.1=pyhd3eb1b0_0
|
30 |
+
- backports.functools_lru_cache=1.6.4=pyhd8ed1ab_0
|
31 |
+
- blas=2.115=mkl
|
32 |
+
- blas-devel=3.9.0=15_linux64_mkl
|
33 |
+
- blinker=1.4=py_1
|
34 |
+
- boost-cpp=1.79.0=h75c5d50_0
|
35 |
+
- brotli=1.0.9=h166bdaf_7
|
36 |
+
- brotli-bin=1.0.9=h166bdaf_7
|
37 |
+
- brotlipy=0.7.0=py310h5764c6d_1004
|
38 |
+
- bzip2=1.0.8=h7f98852_4
|
39 |
+
- c-ares=1.18.1=h7f98852_0
|
40 |
+
- ca-certificates=2022.6.15=ha878542_0
|
41 |
+
- cached-property=1.5.2=hd8ed1ab_1
|
42 |
+
- cached_property=1.5.2=pyha770c72_1
|
43 |
+
- cachetools=5.0.0=pyhd8ed1ab_0
|
44 |
+
- certifi=2022.6.15=py310hff52083_0
|
45 |
+
- cffi=1.15.1=py310h255011f_0
|
46 |
+
- charset-normalizer=2.1.0=pyhd8ed1ab_0
|
47 |
+
- click=8.1.3=py310hff52083_0
|
48 |
+
- cloudpickle=2.1.0=pyhd8ed1ab_0
|
49 |
+
- colorama=0.4.5=pyhd8ed1ab_0
|
50 |
+
- configparser=5.2.0=pyhd8ed1ab_0
|
51 |
+
- cryptography=37.0.1=py310h9ce1e76_0
|
52 |
+
- cudatoolkit=11.6.0=hecad31d_10
|
53 |
+
- cudnn=8.4.1.50=hed8a83a_0
|
54 |
+
- cycler=0.11.0=pyhd8ed1ab_0
|
55 |
+
- databricks-cli=0.17.0=pyhd8ed1ab_0
|
56 |
+
- decorator=5.1.1=pyhd8ed1ab_0
|
57 |
+
- docker-py=5.0.3=py310hff52083_2
|
58 |
+
- docker-pycreds=0.4.0=py_0
|
59 |
+
- entrypoints=0.4=pyhd8ed1ab_0
|
60 |
+
- executing=0.8.3=pyhd8ed1ab_0
|
61 |
+
- expat=2.4.8=h27087fc_0
|
62 |
+
- ffmpeg=5.0.1=gpl_h512afef_107
|
63 |
+
- flask=2.1.3=pyhd8ed1ab_0
|
64 |
+
- font-ttf-dejavu-sans-mono=2.37=hab24e00_0
|
65 |
+
- font-ttf-inconsolata=3.000=h77eed37_0
|
66 |
+
- font-ttf-source-code-pro=2.038=h77eed37_0
|
67 |
+
- font-ttf-ubuntu=0.83=hab24e00_0
|
68 |
+
- fontconfig=2.14.0=h8e229c2_0
|
69 |
+
- fonts-conda-ecosystem=1=0
|
70 |
+
- fonts-conda-forge=1=0
|
71 |
+
- fonttools=4.34.4=py310h5764c6d_0
|
72 |
+
- freetype=2.11.0=h70c0345_0
|
73 |
+
- frozenlist=1.3.0=py310h5764c6d_1
|
74 |
+
- fsspec=2022.5.0=pyhd8ed1ab_0
|
75 |
+
- future=0.18.2=py310hff52083_5
|
76 |
+
- gettext=0.21.0=hf68c758_0
|
77 |
+
- gflags=2.2.2=he1b5a44_1004
|
78 |
+
- giflib=5.2.1=h36c2ea0_2
|
79 |
+
- gitdb=4.0.9=pyhd8ed1ab_0
|
80 |
+
- gitpython=3.1.27=pyhd8ed1ab_0
|
81 |
+
- glog=0.6.0=h6f12383_0
|
82 |
+
- gmp=6.2.1=h58526e2_0
|
83 |
+
- gnutls=3.7.6=hf3e180e_5
|
84 |
+
- google-auth=2.9.1=pyh6c4a22f_0
|
85 |
+
- google-auth-oauthlib=0.4.1=py_2
|
86 |
+
- greenlet=1.1.2=py310hd8f1fbe_2
|
87 |
+
- grpc-cpp=1.46.3=hbd84cd8_2
|
88 |
+
- grpcio=1.46.3=py310ha0b7d45_2
|
89 |
+
- gunicorn=20.1.0=py310hff52083_2
|
90 |
+
- h5py=3.7.0=nompi_py310h06dffec_100
|
91 |
+
- hdf5=1.12.1=nompi_h2386368_104
|
92 |
+
- htmlmin=0.1.12=py_1
|
93 |
+
- icu=70.1=h27087fc_0
|
94 |
+
- idna=3.3=pyhd8ed1ab_0
|
95 |
+
- imagehash=4.2.1=pyhd8ed1ab_0
|
96 |
+
- importlib-metadata=4.11.4=py310hff52083_0
|
97 |
+
- importlib_resources=5.8.0=pyhd8ed1ab_0
|
98 |
+
- ipython=8.4.0=py310hff52083_0
|
99 |
+
- itsdangerous=2.1.2=pyhd8ed1ab_0
|
100 |
+
- jedi=0.18.1=py310hff52083_1
|
101 |
+
- jinja2=3.1.2=pyhd8ed1ab_1
|
102 |
+
- joblib=1.1.0=pyhd8ed1ab_0
|
103 |
+
- jpeg=9e=h166bdaf_2
|
104 |
+
- keyutils=1.6.1=h166bdaf_0
|
105 |
+
- kiwisolver=1.4.4=py310hbf28c38_0
|
106 |
+
- krb5=1.19.3=h3790be6_0
|
107 |
+
- lame=3.100=h7f98852_1001
|
108 |
+
- lcms2=2.12=hddcbb42_0
|
109 |
+
- ld_impl_linux-64=2.38=h1181459_1
|
110 |
+
- lerc=3.0=h9c3ff4c_0
|
111 |
+
- libblas=3.9.0=15_linux64_mkl
|
112 |
+
- libbrotlicommon=1.0.9=h166bdaf_7
|
113 |
+
- libbrotlidec=1.0.9=h166bdaf_7
|
114 |
+
- libbrotlienc=1.0.9=h166bdaf_7
|
115 |
+
- libcblas=3.9.0=15_linux64_mkl
|
116 |
+
- libcrc32c=1.1.2=h9c3ff4c_0
|
117 |
+
- libcurl=7.83.1=h7bff187_0
|
118 |
+
- libdeflate=1.12=h166bdaf_0
|
119 |
+
- libdrm=2.4.112=h166bdaf_0
|
120 |
+
- libedit=3.1.20210910=h7f8727e_0
|
121 |
+
- libev=4.33=h516909a_1
|
122 |
+
- libevent=2.1.10=h9b69904_4
|
123 |
+
- libffi=3.4.2=h7f98852_5
|
124 |
+
- libgcc-ng=12.1.0=h8d9b700_16
|
125 |
+
- libgfortran-ng=12.1.0=h69a702a_16
|
126 |
+
- libgfortran5=12.1.0=hdcd56e2_16
|
127 |
+
- libgomp=12.1.0=h8d9b700_16
|
128 |
+
- libgoogle-cloud=1.40.2=hefc27d0_0
|
129 |
+
- libiconv=1.16=h516909a_0
|
130 |
+
- libidn2=2.3.3=h166bdaf_0
|
131 |
+
- liblapack=3.9.0=15_linux64_mkl
|
132 |
+
- liblapacke=3.9.0=15_linux64_mkl
|
133 |
+
- libllvm11=11.1.0=hf817b99_3
|
134 |
+
- libnghttp2=1.47.0=h727a467_0
|
135 |
+
- libnsl=2.0.0=h7f98852_0
|
136 |
+
- libpciaccess=0.16=h516909a_0
|
137 |
+
- libpng=1.6.37=h753d276_3
|
138 |
+
- libprotobuf=3.20.1=h6239696_0
|
139 |
+
- libssh2=1.10.0=ha56f1ee_2
|
140 |
+
- libstdcxx-ng=12.1.0=ha89aaad_16
|
141 |
+
- libtasn1=4.18.0=h166bdaf_1
|
142 |
+
- libthrift=0.16.0=h519c5ea_1
|
143 |
+
- libtiff=4.4.0=hc85c160_1
|
144 |
+
- libunistring=0.9.10=h7f98852_0
|
145 |
+
- libutf8proc=2.7.0=h7f98852_0
|
146 |
+
- libuuid=2.32.1=h7f98852_1000
|
147 |
+
- libva=2.15.0=h166bdaf_0
|
148 |
+
- libvpx=1.11.0=h9c3ff4c_3
|
149 |
+
- libwebp=1.2.2=h3452ae3_0
|
150 |
+
- libwebp-base=1.2.2=h7f98852_1
|
151 |
+
- libxcb=1.13=h7f98852_1004
|
152 |
+
- libxml2=2.9.14=h22db469_3
|
153 |
+
- libzlib=1.2.12=h166bdaf_2
|
154 |
+
- llvm-openmp=14.0.4=he0ac6c6_0
|
155 |
+
- llvmlite=0.38.1=py310h58363a5_0
|
156 |
+
- lz4-c=1.9.3=h9c3ff4c_1
|
157 |
+
- magma=2.5.4=h6103c52_2
|
158 |
+
- mako=1.2.1=pyhd8ed1ab_0
|
159 |
+
- markdown=3.4.1=pyhd8ed1ab_0
|
160 |
+
- markupsafe=2.1.1=py310h5764c6d_1
|
161 |
+
- matplotlib-base=3.5.2=py310h5701ce4_0
|
162 |
+
- matplotlib-inline=0.1.3=pyhd8ed1ab_0
|
163 |
+
- missingno=0.4.2=py_1
|
164 |
+
- mkl=2022.1.0=h84fe81f_915
|
165 |
+
- mkl-devel=2022.1.0=ha770c72_916
|
166 |
+
- mkl-include=2022.1.0=h84fe81f_915
|
167 |
+
- mlflow=1.27.0=py310ha13cd29_0
|
168 |
+
- multidict=6.0.2=py310h5764c6d_1
|
169 |
+
- multimethod=1.4=py_0
|
170 |
+
- munkres=1.1.4=pyh9f0ad1d_0
|
171 |
+
- nccl=2.12.12.1=h0800d71_0
|
172 |
+
- ncurses=6.3=h27087fc_1
|
173 |
+
- nettle=3.8=hc379101_0
|
174 |
+
- networkx=2.8.4=pyhd8ed1ab_0
|
175 |
+
- ninja=1.11.0=h924138e_0
|
176 |
+
- numba=0.55.0=py310h00e6091_0
|
177 |
+
- numpy=1.23.1=py310h53a5b5f_0
|
178 |
+
- oauthlib=3.2.0=pyhd8ed1ab_0
|
179 |
+
- openh264=2.2.0=h6239696_1
|
180 |
+
- openjpeg=2.4.0=hb52868f_1
|
181 |
+
- openssl=1.1.1q=h166bdaf_0
|
182 |
+
- orc=1.7.5=h6c59b99_0
|
183 |
+
- p11-kit=0.24.1=hc5aa10d_0
|
184 |
+
- packaging=21.3=pyhd8ed1ab_0
|
185 |
+
- pandas=1.4.3=py310h769672d_0
|
186 |
+
- pandas-profiling=3.2.0=pyhd8ed1ab_0
|
187 |
+
- parso=0.8.3=pyhd8ed1ab_0
|
188 |
+
- patsy=0.5.2=pyhd8ed1ab_0
|
189 |
+
- pexpect=4.8.0=pyh9f0ad1d_2
|
190 |
+
- phik=0.12.2=py310h7c64c84_0
|
191 |
+
- pickleshare=0.7.5=py_1003
|
192 |
+
- pillow=9.2.0=py310he619898_0
|
193 |
+
- pip=22.1.2=pyhd8ed1ab_0
|
194 |
+
- prometheus_client=0.14.1=pyhd8ed1ab_0
|
195 |
+
- prometheus_flask_exporter=0.20.2=pyhd8ed1ab_0
|
196 |
+
- prompt-toolkit=3.0.30=pyha770c72_0
|
197 |
+
- protobuf=3.20.1=py310hd8f1fbe_0
|
198 |
+
- pthread-stubs=0.4=h36c2ea0_1001
|
199 |
+
- ptyprocess=0.7.0=pyhd3deb0d_0
|
200 |
+
- pure_eval=0.2.2=pyhd8ed1ab_0
|
201 |
+
- pyarrow=8.0.0=py310h468efa6_0
|
202 |
+
- pyasn1=0.4.8=py_0
|
203 |
+
- pyasn1-modules=0.2.8=py_0
|
204 |
+
- pybind11-abi=4=hd8ed1ab_3
|
205 |
+
- pycparser=2.21=pyhd8ed1ab_0
|
206 |
+
- pydantic=1.9.1=py310h5764c6d_0
|
207 |
+
- pydeprecate=0.3.2=pyhd8ed1ab_0
|
208 |
+
- pygments=2.12.0=pyhd8ed1ab_0
|
209 |
+
- pyjwt=2.4.0=pyhd8ed1ab_0
|
210 |
+
- pyopenssl=22.0.0=pyhd8ed1ab_0
|
211 |
+
- pyparsing=3.0.9=pyhd8ed1ab_0
|
212 |
+
- pysocks=1.7.1=py310hff52083_5
|
213 |
+
- python=3.10.5=h582c2e5_0_cpython
|
214 |
+
- python-dateutil=2.8.2=pyhd8ed1ab_0
|
215 |
+
- python_abi=3.10=2_cp310
|
216 |
+
- pytorch=1.12.0=py3.10_cuda11.6_cudnn8.3.2_0
|
217 |
+
- pytorch-lightning=1.6.5=pyhd8ed1ab_0
|
218 |
+
- pytorch-mutex=1.0=cuda
|
219 |
+
- pytz=2022.1=pyhd8ed1ab_0
|
220 |
+
- pyu2f=0.1.5=pyhd8ed1ab_0
|
221 |
+
- pywavelets=1.3.0=py310hde88566_1
|
222 |
+
- pyyaml=6.0=py310h5764c6d_4
|
223 |
+
- querystring_parser=1.2.4=py_0
|
224 |
+
- re2=2022.06.01=h27087fc_0
|
225 |
+
- readline=8.1.2=h0f457ee_0
|
226 |
+
- requests=2.28.1=pyhd8ed1ab_0
|
227 |
+
- requests-oauthlib=1.3.1=pyhd8ed1ab_0
|
228 |
+
- rsa=4.8=pyhd8ed1ab_0
|
229 |
+
- s2n=1.0.10=h9b69904_0
|
230 |
+
- scikit-learn=1.1.1=py310hffb9edd_0
|
231 |
+
- scipy=1.8.1=py310h7612f91_0
|
232 |
+
- seaborn=0.11.2=hd8ed1ab_0
|
233 |
+
- seaborn-base=0.11.2=pyhd8ed1ab_0
|
234 |
+
- setuptools=59.5.0=py310hff52083_0
|
235 |
+
- shap=0.41.0=py310h769672d_0
|
236 |
+
- six=1.16.0=pyh6c4a22f_0
|
237 |
+
- sleef=3.5.1=h9b69904_2
|
238 |
+
- slicer=0.0.7=pyhd8ed1ab_0
|
239 |
+
- smmap=3.0.5=pyh44b312d_0
|
240 |
+
- snappy=1.1.9=hbd366e4_1
|
241 |
+
- sqlalchemy=1.4.39=py310h5764c6d_0
|
242 |
+
- sqlite=3.39.1=h4ff8645_0
|
243 |
+
- sqlparse=0.4.2=pyhd8ed1ab_0
|
244 |
+
- stack_data=0.3.0=pyhd8ed1ab_0
|
245 |
+
- statsmodels=0.13.2=py310hde88566_0
|
246 |
+
- svt-av1=1.1.0=h27087fc_1
|
247 |
+
- tabulate=0.8.10=pyhd8ed1ab_0
|
248 |
+
- tangled-up-in-unicode=0.2.0=pyhd8ed1ab_0
|
249 |
+
- tbb=2021.5.0=h924138e_1
|
250 |
+
- tensorboard=2.6.0=py_0
|
251 |
+
- tensorboard-plugin-wit=1.8.1=pyhd8ed1ab_0
|
252 |
+
- threadpoolctl=3.1.0=pyh8a188c0_0
|
253 |
+
- tk=8.6.12=h27826a3_0
|
254 |
+
- torchaudio=0.12.0=py310_cu116
|
255 |
+
- torchmetrics=0.9.2=pyhd8ed1ab_0
|
256 |
+
- torchvision=0.13.0=py310_cu116
|
257 |
+
- tqdm=4.64.0=pyhd8ed1ab_0
|
258 |
+
- traitlets=5.3.0=pyhd8ed1ab_0
|
259 |
+
- typing-extensions=4.3.0=hd8ed1ab_0
|
260 |
+
- typing_extensions=4.3.0=pyha770c72_0
|
261 |
+
- tzdata=2022a=h191b570_0
|
262 |
+
- unicodedata2=14.0.0=py310h5764c6d_1
|
263 |
+
- urllib3=1.26.10=pyhd8ed1ab_0
|
264 |
+
- visions=0.7.4=pyhd8ed1ab_0
|
265 |
+
- wcwidth=0.2.5=pyh9f0ad1d_2
|
266 |
+
- websocket-client=1.3.3=pyhd8ed1ab_0
|
267 |
+
- werkzeug=2.1.2=pyhd8ed1ab_1
|
268 |
+
- wheel=0.37.1=pyhd8ed1ab_0
|
269 |
+
- x264=1!161.3030=h7f98852_1
|
270 |
+
- x265=3.5=h924138e_3
|
271 |
+
- xorg-fixesproto=5.0=h7f98852_1002
|
272 |
+
- xorg-kbproto=1.0.7=h7f98852_1002
|
273 |
+
- xorg-libx11=1.7.2=h7f98852_0
|
274 |
+
- xorg-libxau=1.0.9=h7f98852_0
|
275 |
+
- xorg-libxdmcp=1.1.3=h7f98852_0
|
276 |
+
- xorg-libxext=1.3.4=h7f98852_1
|
277 |
+
- xorg-libxfixes=5.0.3=h7f98852_1004
|
278 |
+
- xorg-xextproto=7.3.0=h7f98852_1002
|
279 |
+
- xorg-xproto=7.0.31=h7f98852_1007
|
280 |
+
- xz=5.2.5=h516909a_1
|
281 |
+
- yaml=0.2.5=h7f98852_2
|
282 |
+
- yarl=1.7.2=py310h5764c6d_2
|
283 |
+
- zipp=3.8.0=pyhd8ed1ab_0
|
284 |
+
- zlib=1.2.12=h166bdaf_2
|
285 |
+
- zstd=1.5.2=h8a70e8d_2
|
286 |
+
prefix: /home/micha/.conda/envs/DL-US-inversion
|
loader.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Defines a Loader class to load data from a file or file wildcard
|
3 |
+
"""
|
4 |
+
|
5 |
+
import argparse
|
6 |
+
|
7 |
+
import h5py
|
8 |
+
import torch
|
9 |
+
import numpy as np
|
10 |
+
import glob
|
11 |
+
from typing import Tuple
|
12 |
+
|
13 |
+
|
14 |
+
class Loader:
|
15 |
+
"""
|
16 |
+
Data loader class
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(self, **kwargs):
|
20 |
+
parser = Loader.add_argparse_args()
|
21 |
+
for action in parser._actions:
|
22 |
+
if action.dest in kwargs:
|
23 |
+
action.default = kwargs[action.dest]
|
24 |
+
args = parser.parse_args([])
|
25 |
+
self.__dict__.update(vars(args))
|
26 |
+
|
27 |
+
if type(self.label_vars) is str:
|
28 |
+
self.label_vars = [self.label_vars]
|
29 |
+
|
30 |
+
@staticmethod
|
31 |
+
def add_argparse_args(parent_parser=None):
|
32 |
+
"""
|
33 |
+
Add argeparse argument for the data loader
|
34 |
+
"""
|
35 |
+
parser = argparse.ArgumentParser(
|
36 |
+
prog='Loader',
|
37 |
+
usage=Loader.__doc__,
|
38 |
+
parents=[parent_parser] if parent_parser is not None else [],
|
39 |
+
add_help=False)
|
40 |
+
|
41 |
+
parser.add_argument('--input_var', default='p_f5.0_o0', help='Variable name for the label data')
|
42 |
+
parser.add_argument('--label_vars', nargs='*', default='c0', help='Variable name(s) for the label data')
|
43 |
+
|
44 |
+
parser.add_argument('--inputs_crop', type=int, default=[0, 1, 32, 96, 42, 2090], nargs='*',
|
45 |
+
help='Crop input data on load [layer_min layer_max x_min x_max y_min y_max]')
|
46 |
+
parser.add_argument('--labels_crop', type=int, default=[322, 830, 60, 1076], nargs='*', help='Crop label data on load [x_min x_max y_min y_max]')
|
47 |
+
parser.add_argument('--labels_resize', type=float, default=256.0 / 1016.0, help='scaling factor for labels image')
|
48 |
+
|
49 |
+
parser.add_argument('--data_scale', type=float, default=1.0, help='Data scaling factor')
|
50 |
+
parser.add_argument('--data_gain', type=float, default=1.8, help='Data gain factor in dB/20 at farthest point in data.')
|
51 |
+
|
52 |
+
return parser
|
53 |
+
|
54 |
+
def load_data(self, test_file_pattern: str, train_file_pattern: str = None) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
55 |
+
"""Loads training/testing data from file(s)
|
56 |
+
|
57 |
+
Arguments:
|
58 |
+
test_file_pattern {str} -- testing dataset(s) pattern
|
59 |
+
train_file_pattern {str} -- training dataset(s) pattern
|
60 |
+
|
61 |
+
Returns:
|
62 |
+
(test_inputs, test_labels, train_inputs, train_labels) -- None for values that are not loaded
|
63 |
+
"""
|
64 |
+
|
65 |
+
test_inputs, test_labels = self._load_data_files(test_file_pattern)
|
66 |
+
train_inputs, train_labels = self._load_data_files(train_file_pattern)
|
67 |
+
|
68 |
+
if train_file_pattern is not None and train_inputs is None:
|
69 |
+
raise ValueError('Failed to load train set')
|
70 |
+
|
71 |
+
if test_file_pattern is not None and test_inputs is None:
|
72 |
+
raise ValueError('Failed to load train set')
|
73 |
+
|
74 |
+
return test_inputs, test_labels, train_inputs, train_labels
|
75 |
+
|
76 |
+
def _load_data_files(self, file_pattern: str) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
77 |
+
""" Perform actual data loading
|
78 |
+
|
79 |
+
Args:
|
80 |
+
file_pattern: file name pattern
|
81 |
+
|
82 |
+
Returns:
|
83 |
+
inputs and labels tensors
|
84 |
+
"""
|
85 |
+
|
86 |
+
inputs, labels = None, None
|
87 |
+
|
88 |
+
if file_pattern is None:
|
89 |
+
return inputs, labels
|
90 |
+
|
91 |
+
files = glob.glob(file_pattern)
|
92 |
+
|
93 |
+
if len(files) == 0:
|
94 |
+
raise ValueError(f'{file_pattern=} comes up empty')
|
95 |
+
|
96 |
+
# Load first file to get output dimensions
|
97 |
+
with h5py.File(files[0], 'r') as f:
|
98 |
+
if self.input_var not in f:
|
99 |
+
raise ValueError(f'input data key not in file: {self.input_var=}')
|
100 |
+
|
101 |
+
shape = list(f[self.input_var].shape)
|
102 |
+
if self.inputs_crop is not None:
|
103 |
+
for i in range(len(self.inputs_crop) // 2):
|
104 |
+
shape[-i - 1] = self.inputs_crop[-i * 2 - 1] - self.inputs_crop[-i * 2 - 2]
|
105 |
+
|
106 |
+
shape[0] *= len(files)
|
107 |
+
|
108 |
+
inputs = np.empty(shape, np.single)
|
109 |
+
|
110 |
+
if len(self.label_vars):
|
111 |
+
if not all([v in f for v in self.label_vars]):
|
112 |
+
raise ValueError(f'labels data key(s) not in file: {self.label_vars=}')
|
113 |
+
|
114 |
+
shape = list(f[self.label_vars[0]].shape)
|
115 |
+
shape[1] *= len(self.label_vars)
|
116 |
+
if self.labels_crop is not None:
|
117 |
+
for i in range(len(self.labels_crop) // 2):
|
118 |
+
shape[-i - 1] = self.labels_crop[-i * 2 - 1] - self.labels_crop[-i * 2 - 2]
|
119 |
+
|
120 |
+
shape[-1] = int(shape[-1] * self.labels_resize)
|
121 |
+
shape[-2] = int(shape[-2] * self.labels_resize)
|
122 |
+
shape[0] *= len(files)
|
123 |
+
|
124 |
+
labels = np.empty(shape, np.single)
|
125 |
+
|
126 |
+
# Load data from files
|
127 |
+
pos = 0
|
128 |
+
for file in files:
|
129 |
+
with h5py.File(files[0], 'r') as f:
|
130 |
+
tmp_inputs = np.array(f[self.input_var])
|
131 |
+
|
132 |
+
if self.inputs_crop is not None:
|
133 |
+
slc = [slice(None)] * 4
|
134 |
+
for i in range(len(self.inputs_crop) // 2):
|
135 |
+
slc[-i - 1] = slice(self.inputs_crop[-i * 2 - 2], self.inputs_crop[-i * 2 - 1])
|
136 |
+
tmp_inputs = tmp_inputs[tuple(slc)]
|
137 |
+
|
138 |
+
inputs[pos:pos + tmp_inputs.shape[0], ...] = tmp_inputs
|
139 |
+
|
140 |
+
if len(self.label_vars):
|
141 |
+
tmp_labels = []
|
142 |
+
for v in self.label_vars:
|
143 |
+
tmp_labels.append(np.array(f[v]))
|
144 |
+
tmp_labels = np.concatenate(tmp_labels, axis=1)
|
145 |
+
|
146 |
+
if self.labels_crop is not None and self.labels_crop:
|
147 |
+
slc = [slice(None)] * 4
|
148 |
+
for i in range(len(self.labels_crop) // 2):
|
149 |
+
slc[-i - 1] = slice(self.labels_crop[-i * 2 - 2], self.labels_crop[-i * 2 - 1])
|
150 |
+
tmp_labels = tmp_labels[tuple(slc)]
|
151 |
+
|
152 |
+
if self.labels_resize != 1.0:
|
153 |
+
tmp_labels = torch.nn.Upsample(scale_factor=self.labels_resize, mode='nearest')(torch.from_numpy(tmp_labels)).numpy()
|
154 |
+
|
155 |
+
labels[pos:pos + tmp_labels.shape[0], ...] = tmp_labels
|
156 |
+
|
157 |
+
pos += tmp_inputs.shape[0]
|
158 |
+
|
159 |
+
inputs = inputs[:pos, ...]
|
160 |
+
if len(self.label_vars):
|
161 |
+
labels = labels[:pos, ...]
|
162 |
+
|
163 |
+
if self.data_scale != 1.0:
|
164 |
+
inputs *= self.data_scale
|
165 |
+
|
166 |
+
if self.data_gain != 0.0:
|
167 |
+
gain = 10.0 ** np.linspace(0, self.data_gain, inputs.shape[-1], np.single).reshape((1, 1, 1, -1))
|
168 |
+
inputs *= gain
|
169 |
+
|
170 |
+
# Required when inputs is non-continuous due to transpose
|
171 |
+
# TODO: Could probably use a check on strides and do a conditional copy.
|
172 |
+
inputs = torch.from_numpy(inputs.copy())
|
173 |
+
if len(self.label_vars):
|
174 |
+
labels = torch.from_numpy(labels)
|
175 |
+
|
176 |
+
return inputs, labels
|
models/embc_sos.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f16952eb255caf2c845fd8d0cbdc17e5ea8d8a63cc7669699f8003cc388b22e6
|
3 |
+
size 17301623
|
models/tbme2_attn.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a6888856204f0137b3c396fd554b1e5a30b97dcc066638f6297267804f0fd44c
|
3 |
+
size 9633441
|
models/tbme2_phase_sos.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5582ded80ec25c459a2307f698fccd0ea7d2c2537914fc65aa0ef0777ee27759
|
3 |
+
size 9639457
|
models/tbme2_sos.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9636c47b19f376e7236149584b715477249190f996f9f0f44a11692b02cf28c4
|
3 |
+
size 9633441
|
models/tbme_sos.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:de4663eeb6b655d021fd69381703db2e0e90ba20ceb033586a6097f33d7883d4
|
3 |
+
size 6100085
|
net.py
ADDED
@@ -0,0 +1,494 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Network definition file
|
3 |
+
"""
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from torchaudio.functional import lfilter
|
9 |
+
|
10 |
+
from pytorch_lightning import LightningModule
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
from scipy.signal import butter, gaussian
|
14 |
+
from copy import deepcopy
|
15 |
+
import argparse
|
16 |
+
|
17 |
+
|
18 |
+
class Net(LightningModule):
|
19 |
+
def __init__(self, **kwargs):
|
20 |
+
super().__init__()
|
21 |
+
|
22 |
+
parser = Net.add_model_specific_args()
|
23 |
+
for action in parser._actions:
|
24 |
+
if action.dest in kwargs:
|
25 |
+
action.default = kwargs[action.dest]
|
26 |
+
|
27 |
+
args = parser.parse_args([])
|
28 |
+
self.hparams.update(vars(args))
|
29 |
+
|
30 |
+
if not hasattr(self, f"_init_{self.hparams.net_type}_net"):
|
31 |
+
raise ValueError(f"Unknown net type {self.hparams.net_type}")
|
32 |
+
|
33 |
+
self._net = eval(f"self._init_{self.hparams.net_type}_net(n_inputs={self.hparams.n_inputs}, n_outputs={self.hparams.n_outputs})")
|
34 |
+
|
35 |
+
if self.hparams.bias is not None:
|
36 |
+
if hasattr(self.hparams.bias, "__iter__"):
|
37 |
+
for i in range(len(self.hparams.bias)):
|
38 |
+
self._net[-1].c.bias[i].data.fill_(self.hparams.bias[i])
|
39 |
+
else:
|
40 |
+
self._net[-1].c.bias.data.fill_(self.hparams.bias)
|
41 |
+
|
42 |
+
@staticmethod
|
43 |
+
def _init_tbme2_net(n_inputs: int = 1, n_outputs: int = 1):
|
44 |
+
return nn.Sequential(
|
45 |
+
# Encoder
|
46 |
+
DownBlock(n_inputs, 32, 32, 3, stride=[1, 2], pool=None, push=False, layers=3),
|
47 |
+
DownBlock(32, 32, 32, 3, stride=[1, 2], pool=None, push=False, layers=3),
|
48 |
+
DownBlock(32, 32, 32, 3, stride=[1, 2], pool=None, push=False, layers=3),
|
49 |
+
DownBlock(32, 32, 32, 3, stride=[1, 2], pool=None, push=True, layers=3),
|
50 |
+
DownBlock(32, 32, 64, 3, stride=1, pool=[2, 2], push=True, layers=3),
|
51 |
+
DownBlock(64, 64, 128, 3, stride=1, pool=[2, 2], push=True, layers=3),
|
52 |
+
DownBlock(128, 128, 512, 3, stride=1, pool=[2, 2], push=False, layers=3),
|
53 |
+
# Decoder
|
54 |
+
UpBlock(512, 128, 3, scale_factor=2, pop=False, layers=3),
|
55 |
+
UpBlock(256, 64, 3, scale_factor=2, pop=True, layers=3),
|
56 |
+
UpBlock(128, 32, 3, scale_factor=2, pop=True, layers=3),
|
57 |
+
UpBlock(64, 32, 3, scale_factor=2, pop=True, layers=3),
|
58 |
+
UpStep(32, 32, 3, scale_factor=1),
|
59 |
+
Compress(32, n_outputs))
|
60 |
+
|
61 |
+
@staticmethod
|
62 |
+
def _init_embc_net(n_inputs: int = 1, n_outputs: int = 1):
|
63 |
+
return nn.Sequential(
|
64 |
+
# Encoder
|
65 |
+
DownBlock(n_inputs, 32, 32, 15, [1, 2], None, layers=1),
|
66 |
+
DownBlock(32, 32, 32, 13, [1, 2], None, layers=1),
|
67 |
+
DownBlock(32, 32, 32, 11, [1, 2], None, layers=1),
|
68 |
+
DownBlock(32, 32, 32, 9, [1, 2], None, True, layers=1),
|
69 |
+
DownBlock(32, 32, 64, 7, 1, [2, 2], True, layers=1),
|
70 |
+
DownBlock(64, 64, 128, 5, 1, [2, 2], True, layers=1),
|
71 |
+
DownBlock(128, 128, 512, 3, 1, [2, 2], layers=1),
|
72 |
+
# Decoder
|
73 |
+
UpBlock(512, 128, 5, 2, layers=1),
|
74 |
+
UpBlock(256, 64, 7, 2, True, layers=1),
|
75 |
+
UpBlock(128, 32, 9, 2, True, layers=1),
|
76 |
+
UpBlock(64, 32, 11, 2, True, layers=1),
|
77 |
+
UpStep(32, 32, 3, 1),
|
78 |
+
Compress(32, n_outputs))
|
79 |
+
|
80 |
+
@staticmethod
|
81 |
+
def _init_tbme_net(n_inputs: int = 1, n_outputs: int = 1):
|
82 |
+
return nn.Sequential(
|
83 |
+
# Encoder
|
84 |
+
DownBlock(n_inputs, 32, 32, 3, [1, 2], None, layers=1),
|
85 |
+
DownBlock(32, 32, 32, 3, [1, 2], None, layers=1),
|
86 |
+
DownBlock(32, 32, 32, 3, [1, 2], None, layers=1),
|
87 |
+
DownBlock(32, 32, 32, 3, [1, 2], None, True, layers=1),
|
88 |
+
DownBlock(32, 32, 64, 3, 1, [2, 2], True, layers=1),
|
89 |
+
DownBlock(64, 64, 128, 3, 1, [2, 2], True, layers=1),
|
90 |
+
DownBlock(128, 128, 512, 3, 1, [2, 2], layers=1),
|
91 |
+
# Decoder
|
92 |
+
UpBlock(512, 128, 3, 2, layers=1),
|
93 |
+
UpBlock(256, 64, 3, 2, True, layers=1),
|
94 |
+
UpBlock(128, 32, 3, 2, True, layers=1),
|
95 |
+
UpBlock(64, 32, 3, 2, True, layers=1),
|
96 |
+
UpStep(32, 32, 3, 1),
|
97 |
+
Compress(32, n_outputs))
|
98 |
+
|
99 |
+
@staticmethod
|
100 |
+
def add_model_specific_args(parent_parser=None):
|
101 |
+
parser = argparse.ArgumentParser(
|
102 |
+
prog="Net",
|
103 |
+
usage=Net.__doc__,
|
104 |
+
parents=[parent_parser] if parent_parser is not None else [],
|
105 |
+
add_help=False)
|
106 |
+
|
107 |
+
parser.add_argument("--random_mirror", type=int, nargs="?", default=1, help="Randomly mirror data to increase diversity when using flat plate wave")
|
108 |
+
parser.add_argument("--noise_std", type=float, nargs="*", help="range of std of random noise to add to the input signal [0 val] or [min max]")
|
109 |
+
parser.add_argument("--quantization", type=float, nargs="?", help="Quantization noise")
|
110 |
+
parser.add_argument("--rand_drop", type=int, nargs="*", help="Random drop lines, between 0 and value lines if single value, or between two values")
|
111 |
+
parser.add_argument("--normalize_net", type=float, default=0.0, help="Coefficient for normalizing network weights")
|
112 |
+
|
113 |
+
parser.add_argument("--learning_rate", type=float, default=5e-3, help="Learning rate to use for optimizer")
|
114 |
+
parser.add_argument("--lr_sched_step", type=int, default=15, help="Learning decay, update step size")
|
115 |
+
parser.add_argument("--lr_sched_gamma", type=float, default=0.65, help="Learning decay gamma")
|
116 |
+
|
117 |
+
parser.add_argument("--net_type", default="tbme2", help="The network to use [tbme2/embc/tbme]")
|
118 |
+
parser.add_argument("--bias", type=float, nargs="*", help="Set bias on last layer, set to 1500 when training from scratch on SoS output")
|
119 |
+
parser.add_argument("--decimation", type=int, help="Subsample phase signal")
|
120 |
+
parser.add_argument("--phase_inv", type=int, default=0, help="Use phase for inversion")
|
121 |
+
|
122 |
+
parser.add_argument("--center_freq", type=float, default=5e6, help="Matched filter and IQ demodulation frequency")
|
123 |
+
parser.add_argument("--n_periods", type=float, default=5, help="Matched filter length")
|
124 |
+
parser.add_argument("--matched_filter", type=int, nargs="?", default=0, help="Apply matched filter, set to 1 to run during forward pass, 2 to run during preprocessing phase (before adding noise)")
|
125 |
+
|
126 |
+
parser.add_argument("--rand_output_crop", type=int, help="Subsample phase signal")
|
127 |
+
parser.add_argument("--rand_scale", type=float, nargs="*", help="Random scaling range [min max] -- (10 ** rand_scale)")
|
128 |
+
parser.add_argument("--rand_gain", type=float, nargs="*", help="Random gain coefficient range [min max] -- (10 ** rand_gain)")
|
129 |
+
|
130 |
+
parser.add_argument("--n_inputs", type=int, default=1, help="Number of input layers")
|
131 |
+
parser.add_argument("--n_outputs", type=int, default=1, help="Number of output layers")
|
132 |
+
parser.add_argument("--scale_losses", type=float, nargs="*", help="Scale each layer of the loss function by given value")
|
133 |
+
|
134 |
+
return parser
|
135 |
+
|
136 |
+
def forward(self, x) -> torch.Tensor:
|
137 |
+
# Matched filter
|
138 |
+
if self.hparams.matched_filter == 1:
|
139 |
+
x = self._matched_filter(x)
|
140 |
+
|
141 |
+
# compute IQ phase if in phase_inv mode
|
142 |
+
if self.hparams.phase_inv:
|
143 |
+
x = self._phase(x)
|
144 |
+
|
145 |
+
# Decimation
|
146 |
+
if self.hparams.decimation != 1:
|
147 |
+
x = x[..., ::self.hparams.decimation]
|
148 |
+
|
149 |
+
# Apply network
|
150 |
+
x = self._net((x, []))
|
151 |
+
|
152 |
+
return x
|
153 |
+
|
154 |
+
def _matched_filter(self, x):
|
155 |
+
sampling_freq = 40e6
|
156 |
+
|
157 |
+
samples_per_cycle = sampling_freq / self.hparams.center_freq
|
158 |
+
n_samples = np.ceil(samples_per_cycle * self.hparams.n_periods + 1)
|
159 |
+
|
160 |
+
signal = torch.sin(torch.arange(n_samples, device=x.device) / samples_per_cycle * 2 * np.pi) * torch.from_numpy(gaussian(n_samples, (n_samples - 1) / 6).astype(np.single)).to(x.device)
|
161 |
+
|
162 |
+
return torch.nn.functional.conv1d(x.reshape(x.shape[:2] + (-1,)), signal.reshape(1, 1, -1), padding="same").reshape(x.shape)
|
163 |
+
|
164 |
+
def _phase(self, x):
|
165 |
+
f = self.hparams.center_freq
|
166 |
+
F = 40e6
|
167 |
+
N = x.shape[-1]
|
168 |
+
|
169 |
+
n = int(round(f * N / F))
|
170 |
+
|
171 |
+
X = torch.fft.fft(x, dim=-1)
|
172 |
+
X[..., (2 * n + 1):] = 0
|
173 |
+
X[..., :(2 * n + 1)] *= torch.from_numpy(gaussian(2 * n + 1, 2 * n / 6).astype(np.single)).to(x.device)
|
174 |
+
X = X.roll(-n, dims=-1)
|
175 |
+
x = torch.fft.ifft(X, dim=-1)
|
176 |
+
|
177 |
+
return x.angle()
|
178 |
+
|
179 |
+
def _preprocess(self, x):
|
180 |
+
# Matched filter
|
181 |
+
if self.hparams.matched_filter == 2:
|
182 |
+
x = self._matched_filter(x)
|
183 |
+
|
184 |
+
# Gaussian (normal) noise - random scaling, normalized to signal STD
|
185 |
+
if (ns := self.hparams.noise_std) and len(ns):
|
186 |
+
scl = ns[0] if len(ns) == 1 else torch.rand([x.shape[0]] + [1] * 3).to(x.device) * (ns[-1] - ns[-2]) + ns[-2]
|
187 |
+
scl *= x.std()
|
188 |
+
x += torch.empty_like(x).normal_() * scl
|
189 |
+
|
190 |
+
# Random multiplicative scaling
|
191 |
+
if (rs := self.hparams.rand_scale) and len(rs):
|
192 |
+
x *= 10 ** (torch.rand([x.shape[0]] + [1] * 3).to(x.device) * (rs[-1] - rs[-2]) + rs[-2])
|
193 |
+
|
194 |
+
# Random exponential gain
|
195 |
+
if (gs := self.hparams.rand_gain) and len(gs):
|
196 |
+
gain = torch.FloatTensor([10.0]).to(x.device) ** \
|
197 |
+
(torch.rand([x.shape[0]] + [1] * 3).to(x.device) * ((gs[-1] - gs[-2]) + gs[-2]) *
|
198 |
+
torch.linspace(0, 1, x.shape[-1]).to(x.device).view(1, 1, 1, -1))
|
199 |
+
x *= gain
|
200 |
+
|
201 |
+
# Quantization noise, to emulated ADC
|
202 |
+
if (quantization := self.hparams.quantization) is not None:
|
203 |
+
x = (x * quantization).round() * (1.0 / quantization)
|
204 |
+
|
205 |
+
# Randomly zero out some of the channels
|
206 |
+
if (rand_drop := self.hparams.rand_drop) and len(rand_drop):
|
207 |
+
if len(rand_drop) == 1:
|
208 |
+
rand_drop = [0, ] + rand_drop
|
209 |
+
|
210 |
+
for i in range(x.shape[0]):
|
211 |
+
lines = np.random.randint(0, x.shape[2], np.random.randint(rand_drop[0], rand_drop[1] + 1))
|
212 |
+
x[i, :, lines, :] = 0.
|
213 |
+
|
214 |
+
return x
|
215 |
+
|
216 |
+
def _log_losses(self, outputs: torch.Tensor, labels: torch.Tensor, prefix: str = ""):
|
217 |
+
diff = torch.abs(labels.detach() - outputs.detach())
|
218 |
+
|
219 |
+
s1 = int(diff.shape[-1] * (1.0 / 3.0))
|
220 |
+
s2 = int(diff.shape[-1] * (2.0 / 3.0))
|
221 |
+
|
222 |
+
for i in range(diff.shape[1]):
|
223 |
+
tag = f"{i}_" if diff.shape[1] > 1 else ""
|
224 |
+
|
225 |
+
losses = {
|
226 |
+
f"{prefix + tag}rmse": torch.sqrt(torch.mean(diff[:, i, ...] * diff[:, i, ...])).item(),
|
227 |
+
f"{prefix + tag}mean": torch.mean(diff[:, i, ...]).item(),
|
228 |
+
f"{prefix + tag}short": torch.mean(diff[:, i, :, :s1]).item(),
|
229 |
+
f"{prefix + tag}med": torch.mean(diff[:, i, :, s1:s2]).item(),
|
230 |
+
f"{prefix + tag}long": torch.mean(diff[:, i, :, s2:]).item()}
|
231 |
+
|
232 |
+
self.log_dict(losses, prog_bar=True)
|
233 |
+
|
234 |
+
def training_step(self, batch, batch_idx):
|
235 |
+
if self.hparams.random_mirror:
|
236 |
+
mirror = np.random.randint(0, 2, batch[0].shape[0])
|
237 |
+
|
238 |
+
for b in batch:
|
239 |
+
for i, m in enumerate(mirror):
|
240 |
+
if not m:
|
241 |
+
continue
|
242 |
+
|
243 |
+
b[i, ...] = b[i, :, range(b.shape[-2] - 1, -1, -1), :] # Pytorch does not handle negative steps
|
244 |
+
|
245 |
+
loss = self._common_step(batch, batch_idx, "train_")
|
246 |
+
|
247 |
+
if self.hparams.normalize_net:
|
248 |
+
for W in self.parameters():
|
249 |
+
loss += self.hparams.normalize_net * W.norm(2)
|
250 |
+
|
251 |
+
return loss
|
252 |
+
|
253 |
+
def validation_step(self, batch, batch_idx):
|
254 |
+
return self._common_step(batch, batch_idx, "validate_")
|
255 |
+
|
256 |
+
def test_step(self, batch, batch_idx):
|
257 |
+
return self._common_step(batch, batch_idx, "test_")
|
258 |
+
|
259 |
+
def predict_step(self, batch, batch_idx):
|
260 |
+
x = batch[0]
|
261 |
+
|
262 |
+
x = self._preprocess(x)
|
263 |
+
z = self(x)
|
264 |
+
|
265 |
+
if isinstance(z, tuple):
|
266 |
+
z = z[0]
|
267 |
+
|
268 |
+
return z
|
269 |
+
|
270 |
+
def _common_step(self, batch, batch_idx, prefix):
|
271 |
+
x, y = batch
|
272 |
+
|
273 |
+
if self.hparams.rand_output_crop:
|
274 |
+
crop = np.random.randint(0, self.hparams.rand_output_crop, batch[0].shape[0])
|
275 |
+
|
276 |
+
for i, c in enumerate(crop):
|
277 |
+
if not c:
|
278 |
+
continue
|
279 |
+
|
280 |
+
x[i, :, :-c, :] = x[i, :, c:, :].clone()
|
281 |
+
y[i, :, :-c*2, :] = \
|
282 |
+
y[i, :, c*2-1:-1, :].clone() if np.random.randint(2) else \
|
283 |
+
y[i, :, c*2:, :].clone()
|
284 |
+
|
285 |
+
x = x[..., :-self.hparams.rand_output_crop, :]
|
286 |
+
y = y[..., :-self.hparams.rand_output_crop*2, :]
|
287 |
+
|
288 |
+
x = self._preprocess(x)
|
289 |
+
z = self(x)
|
290 |
+
|
291 |
+
outputs = z[0] if isinstance(z, tuple) or isinstance(z, list) else z
|
292 |
+
self._log_losses(outputs, y, prefix)
|
293 |
+
|
294 |
+
if (self.hparams.scale_losses) and len(self.hparams.scale_losses):
|
295 |
+
s = torch.FloatTensor(self.hparams.scale_losses).to(y.device).view(1, -1, 1, 1)
|
296 |
+
loss = F.mse_loss(s * z, s * y)
|
297 |
+
else:
|
298 |
+
loss = F.mse_loss(y, outputs)
|
299 |
+
|
300 |
+
self.log(prefix + "loss", np.sqrt(loss.item()))
|
301 |
+
|
302 |
+
return loss
|
303 |
+
|
304 |
+
def configure_optimizers(self):
|
305 |
+
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
|
306 |
+
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, self.hparams.lr_sched_step, self.hparams.lr_sched_gamma)
|
307 |
+
|
308 |
+
return [optimizer], [scheduler]
|
309 |
+
|
310 |
+
|
311 |
+
class DownStep(nn.Module):
|
312 |
+
"""
|
313 |
+
Down scaling step in the encoder decoder network
|
314 |
+
"""
|
315 |
+
def __init__(self, in_channels: int, out_channels: int, kernel_size: tuple, stride: int = 1, pool: tuple = None) -> None:
|
316 |
+
"""Constructor
|
317 |
+
|
318 |
+
Arguments:
|
319 |
+
in_channels {int} -- Number of input channels for 2D convolution
|
320 |
+
out_channels {int} -- Number of output channels for 2D convolution
|
321 |
+
kernel_size {tuple} -- Convolution kernel size
|
322 |
+
|
323 |
+
Keyword Arguments:
|
324 |
+
stride {int} -- Stride of convolution, set to 1 to disable (default: {1})
|
325 |
+
pool {tuple} -- max pulling size, set to None to disable (default: {None})
|
326 |
+
"""
|
327 |
+
super(DownStep, self).__init__()
|
328 |
+
|
329 |
+
self.c = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=kernel_size // 2)
|
330 |
+
self.n = nn.BatchNorm2d(out_channels)
|
331 |
+
self.pool = pool
|
332 |
+
|
333 |
+
def forward(self, x: torch.tensor) -> torch.tensor:
|
334 |
+
"""Run the forward step
|
335 |
+
|
336 |
+
Arguments:
|
337 |
+
x {torch.tensor} -- input tensor
|
338 |
+
|
339 |
+
Returns:
|
340 |
+
torch.tensor -- output tensor
|
341 |
+
"""
|
342 |
+
x = self.c(x)
|
343 |
+
x = F.relu(x)
|
344 |
+
if self.pool is not None:
|
345 |
+
x = F.max_pool2d(x, self.pool)
|
346 |
+
x = self.n(x)
|
347 |
+
|
348 |
+
return x
|
349 |
+
|
350 |
+
|
351 |
+
class UpStep(nn.Module):
|
352 |
+
"""
|
353 |
+
Up scaling step in the encoder decoder network
|
354 |
+
"""
|
355 |
+
def __init__(self, in_channels: int, out_channels: int, kernel_size: int, scale_factor: int = 2) -> None:
|
356 |
+
"""Constructor
|
357 |
+
|
358 |
+
Arguments:
|
359 |
+
in_channels {int} -- Number of input channels for 2D convolution
|
360 |
+
out_channels {int} -- Number of output channels for 2D convolution
|
361 |
+
kernel_size {int} -- Convolution kernel size
|
362 |
+
|
363 |
+
Keyword Arguments:
|
364 |
+
scale_factor {int} -- Upsampling scaling factor (default: {2})
|
365 |
+
"""
|
366 |
+
super(UpStep, self).__init__()
|
367 |
+
|
368 |
+
self.c = nn.Conv2d(in_channels, out_channels, kernel_size, padding=kernel_size // 2)
|
369 |
+
self.n = nn.BatchNorm2d(out_channels)
|
370 |
+
self.scale_factor = scale_factor
|
371 |
+
|
372 |
+
def forward(self, x: torch.tensor) -> torch.tensor:
|
373 |
+
"""Run the forward step
|
374 |
+
|
375 |
+
Arguments:
|
376 |
+
x {torch.tensor} -- input tensor
|
377 |
+
|
378 |
+
Returns:
|
379 |
+
torch.tensor -- output tensor
|
380 |
+
"""
|
381 |
+
if isinstance(x, tuple):
|
382 |
+
x = x[0]
|
383 |
+
|
384 |
+
if self.scale_factor != 1:
|
385 |
+
x = F.interpolate(x, scale_factor=self.scale_factor)
|
386 |
+
|
387 |
+
x = self.c(x)
|
388 |
+
x = F.relu(x)
|
389 |
+
x = self.n(x)
|
390 |
+
|
391 |
+
return x
|
392 |
+
|
393 |
+
|
394 |
+
class Compress(nn.Module):
|
395 |
+
"""
|
396 |
+
Up scaling step in the encoder decoder network
|
397 |
+
"""
|
398 |
+
def __init__(self, in_channels: int, out_channels: int = 1, kernel_size: int = 1, scale_factor: int = 1) -> None:
|
399 |
+
"""Constructor
|
400 |
+
|
401 |
+
Arguments:
|
402 |
+
in_channels {int} -- [description]
|
403 |
+
|
404 |
+
Keyword Arguments:
|
405 |
+
out_channels {int} -- [description] (default: {1})
|
406 |
+
kernel_size {int} -- [description] (default: {1})
|
407 |
+
"""
|
408 |
+
super(Compress, self).__init__()
|
409 |
+
|
410 |
+
self.scale_factor = scale_factor
|
411 |
+
|
412 |
+
self.c = nn.Conv2d(in_channels, out_channels, kernel_size, padding=kernel_size // 2)
|
413 |
+
|
414 |
+
def forward(self, x: torch.tensor) -> torch.tensor:
|
415 |
+
"""Run the forward step
|
416 |
+
|
417 |
+
Arguments:
|
418 |
+
x {torch.tensor} -- input tensor
|
419 |
+
|
420 |
+
Returns:
|
421 |
+
torch.tensor -- output tensor
|
422 |
+
"""
|
423 |
+
if isinstance(x, tuple) or isinstance(x, list):
|
424 |
+
x = x[0]
|
425 |
+
|
426 |
+
x = self.c(x)
|
427 |
+
|
428 |
+
if self.scale_factor != 1:
|
429 |
+
x = F.interpolate(x, scale_factor=self.scale_factor)
|
430 |
+
|
431 |
+
return x
|
432 |
+
|
433 |
+
|
434 |
+
class DownBlock(nn.Module):
|
435 |
+
def __init__(
|
436 |
+
self,
|
437 |
+
in_chan: int, inter_chan: int, out_chan: int,
|
438 |
+
kernel_size: int = 3, stride: int = 1, pool: tuple = None,
|
439 |
+
push: bool = False,
|
440 |
+
layers: int = 3):
|
441 |
+
super().__init__()
|
442 |
+
|
443 |
+
self.s = []
|
444 |
+
for i in range(layers):
|
445 |
+
self.s.append(deepcopy(DownStep(
|
446 |
+
in_chan if i == 0 else inter_chan,
|
447 |
+
inter_chan if i < layers - 1 else out_chan,
|
448 |
+
kernel_size,
|
449 |
+
1 if i < layers - 1 else stride,
|
450 |
+
None if i < layers - 1 else pool)))
|
451 |
+
self.s = nn.Sequential(*self.s)
|
452 |
+
|
453 |
+
self.push = push
|
454 |
+
|
455 |
+
def forward(self, x: torch.tensor) -> torch.tensor:
|
456 |
+
i, s = x
|
457 |
+
|
458 |
+
i = self.s(i)
|
459 |
+
|
460 |
+
if self.push:
|
461 |
+
s.append(i)
|
462 |
+
|
463 |
+
return i, s
|
464 |
+
|
465 |
+
|
466 |
+
class UpBlock(nn.Module):
|
467 |
+
def __init__(
|
468 |
+
self,
|
469 |
+
in_chan: int, out_chan: int,
|
470 |
+
kernel_size: int, scale_factor: int = 2,
|
471 |
+
pop: bool = False,
|
472 |
+
layers: int = 3):
|
473 |
+
super().__init__()
|
474 |
+
|
475 |
+
self.s = []
|
476 |
+
for i in range(layers):
|
477 |
+
self.s.append(deepcopy(UpStep(
|
478 |
+
in_chan if i == 0 else out_chan,
|
479 |
+
out_chan,
|
480 |
+
kernel_size,
|
481 |
+
1 if i < layers - 1 else scale_factor)))
|
482 |
+
self.s = nn.Sequential(*self.s)
|
483 |
+
|
484 |
+
self.pop = pop
|
485 |
+
|
486 |
+
def forward(self, x: torch.tensor) -> torch.tensor:
|
487 |
+
i, s = x
|
488 |
+
|
489 |
+
if self.pop:
|
490 |
+
i = torch.cat((i, s.pop()), dim=1)
|
491 |
+
|
492 |
+
i = self.s(i)
|
493 |
+
|
494 |
+
return i, s
|
run_logger.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Support log functions
|
3 |
+
|
4 |
+
TODO: log model using mlflow.pytorch in parallel / addition to checkpointing
|
5 |
+
"""
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import h5py
|
9 |
+
import os
|
10 |
+
import argparse
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torchvision.utils as vutils
|
14 |
+
import pytorch_lightning as pl
|
15 |
+
|
16 |
+
|
17 |
+
class ImgCB(pl.Callback):
|
18 |
+
def __init__(self, **kwargs):
|
19 |
+
parser = ImgCB.add_argparse_args()
|
20 |
+
for action in parser._actions:
|
21 |
+
if action.dest in kwargs:
|
22 |
+
action.default = kwargs[action.dest]
|
23 |
+
args = parser.parse_args([])
|
24 |
+
self.__dict__.update(vars(args))
|
25 |
+
|
26 |
+
@staticmethod
|
27 |
+
def add_argparse_args(parent_parser=None):
|
28 |
+
parser = argparse.ArgumentParser(
|
29 |
+
prog='ImgCB',
|
30 |
+
usage=ImgCB.__doc__,
|
31 |
+
parents=[parent_parser] if parent_parser is not None else [],
|
32 |
+
add_help=False)
|
33 |
+
|
34 |
+
parser.add_argument('--img_ranges', default=[1300, 1800], nargs='*', help='Scaling range on output image, either pair, or set of pairs')
|
35 |
+
parser.add_argument('--err_ranges', default=[0, 50], nargs='*', help='Scaling range on error images, either pair, or set of pairs')
|
36 |
+
|
37 |
+
return parser
|
38 |
+
|
39 |
+
def log_images(self, mfl_logger, y, z, prefix):
|
40 |
+
img_ranges = tuple(self.img_ranges)
|
41 |
+
err_ranges = tuple(self.err_ranges)
|
42 |
+
#
|
43 |
+
for i in range(y.shape[1]):
|
44 |
+
if y.shape[1] > 1:
|
45 |
+
tag = f'_{i}_'
|
46 |
+
|
47 |
+
if len(self.img_ranges) > 2:
|
48 |
+
img_ranges = tuple(self.img_ranges[2*i, 2*i + 1])
|
49 |
+
if len(self.err_ranges) > 2:
|
50 |
+
err_ranges = tuple(self.err_ranges[2*i, 2*i + 1])
|
51 |
+
else:
|
52 |
+
tag = ''
|
53 |
+
|
54 |
+
mfl_logger.experiment.log_image(
|
55 |
+
mfl_logger.run_id,
|
56 |
+
(np.array(vutils.make_grid(
|
57 |
+
y[:, [i], ...].detach(),
|
58 |
+
normalize=True, value_range=img_ranges, nrow=6).cpu())[0, ...] * 255.).astype(np.int),
|
59 |
+
prefix + tag + '_labels.png')
|
60 |
+
|
61 |
+
mfl_logger.experiment.log_image(
|
62 |
+
mfl_logger.run_id,
|
63 |
+
(np.array(vutils.make_grid(
|
64 |
+
z[:, [i], ...].detach(),
|
65 |
+
normalize=True, value_range=img_ranges, nrow=6).cpu())[0, ...] * 255.).astype(np.int),
|
66 |
+
prefix + tag + '_outputs.png')
|
67 |
+
|
68 |
+
mfl_logger.experiment.log_image(
|
69 |
+
mfl_logger.run_id,
|
70 |
+
(np.array(vutils.make_grid(
|
71 |
+
torch.abs(y[:, [i], ...].detach() - z[:, [i], ...].detach()),
|
72 |
+
normalize=True, value_range=err_ranges, nrow=6).cpu())[0, ...] * 255.).astype(np.int),
|
73 |
+
prefix + tag + '_errors.png')
|
74 |
+
|
75 |
+
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
76 |
+
if batch_idx == 0:
|
77 |
+
with torch.no_grad():
|
78 |
+
x, y = batch
|
79 |
+
|
80 |
+
if pl_module.hparams.rand_output_crop:
|
81 |
+
x = x[..., :-pl_module.hparams.rand_output_crop, :]
|
82 |
+
y = y[..., :-pl_module.hparams.rand_output_crop * 2, :]
|
83 |
+
|
84 |
+
z = pl_module(x.to(pl_module.device))
|
85 |
+
|
86 |
+
if isinstance(z, tuple) or isinstance(z, list):
|
87 |
+
z = z[0]
|
88 |
+
|
89 |
+
self.log_images(pl_module.logger, y.to(pl_module.device), z, 'train_')
|
90 |
+
|
91 |
+
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
|
92 |
+
if batch_idx == 0:
|
93 |
+
with torch.no_grad():
|
94 |
+
x, y = batch
|
95 |
+
|
96 |
+
if pl_module.hparams.rand_output_crop:
|
97 |
+
x = x[..., :-pl_module.hparams.rand_output_crop, :]
|
98 |
+
y = y[..., :-pl_module.hparams.rand_output_crop * 2, :]
|
99 |
+
|
100 |
+
z = pl_module(x.to(pl_module.device))
|
101 |
+
|
102 |
+
if isinstance(z, tuple) or isinstance(z, list):
|
103 |
+
z = z[0]
|
104 |
+
|
105 |
+
self.log_images(pl_module.logger, y.to(pl_module.device), z, 'validate_')
|
106 |
+
|
107 |
+
|
108 |
+
class TestLogger(pl.Callback):
|
109 |
+
"""
|
110 |
+
pytorch_lightning Data saving logger for testing output
|
111 |
+
Warning !!! : this function is not multi GPU / multi device safe -- only run on a single gpu / device
|
112 |
+
"""
|
113 |
+
def __init__(self, fname: str = 'output.h5'):
|
114 |
+
self.fname = fname
|
115 |
+
|
116 |
+
def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
|
117 |
+
with h5py.File(self.fname, 'a') as f:
|
118 |
+
f[f'batch_{batch_idx:05}'] = outputs.to('cpu').numpy()
|
119 |
+
if len(batch) > 1:
|
120 |
+
f[f'labels_{batch_idx:05}'] = batch[1].to('cpu').numpy()
|