Spaces:
Runtime error
Runtime error
Refactored source
Browse files- app.py +1 -2
- src/training/dcc_tf.py β dcc_tf.py +0 -0
- src/__init__.py +0 -0
- src/helpers/__init__.py +0 -0
- src/helpers/utils.py +0 -205
- src/training/__init__.py +0 -0
- src/training/eval.py +0 -214
- src/training/synthetic_dataset.py +0 -168
- src/training/train.py +0 -311
app.py
CHANGED
@@ -6,7 +6,6 @@ import torch
|
|
6 |
import torchaudio
|
7 |
import gradio as gr
|
8 |
|
9 |
-
from src.helpers import utils
|
10 |
from src.training.dcc_tf import Net as Waveformer
|
11 |
|
12 |
TARGETS = [
|
@@ -34,7 +33,7 @@ if not os.path.exists('default_ckpt.pt'):
|
|
34 |
# Instantiate model
|
35 |
params = utils.Params('default_config.json')
|
36 |
model = Waveformer(**params.model_params)
|
37 |
-
|
38 |
model.eval()
|
39 |
|
40 |
def waveformer(audio, label_choices):
|
|
|
6 |
import torchaudio
|
7 |
import gradio as gr
|
8 |
|
|
|
9 |
from src.training.dcc_tf import Net as Waveformer
|
10 |
|
11 |
TARGETS = [
|
|
|
33 |
# Instantiate model
|
34 |
params = utils.Params('default_config.json')
|
35 |
model = Waveformer(**params.model_params)
|
36 |
+
model.load_state_dict(torch.load('default_ckpt.pt', map_location=torch.device('cpu')))
|
37 |
model.eval()
|
38 |
|
39 |
def waveformer(audio, label_choices):
|
src/training/dcc_tf.py β dcc_tf.py
RENAMED
File without changes
|
src/__init__.py
DELETED
File without changes
|
src/helpers/__init__.py
DELETED
File without changes
|
src/helpers/utils.py
DELETED
@@ -1,205 +0,0 @@
|
|
1 |
-
"""A collection of useful helper functions"""
|
2 |
-
|
3 |
-
import os
|
4 |
-
import logging
|
5 |
-
import json
|
6 |
-
|
7 |
-
import torch
|
8 |
-
from torch.profiler import profile, record_function, ProfilerActivity
|
9 |
-
import pandas as pd
|
10 |
-
from torchmetrics.functional import(
|
11 |
-
scale_invariant_signal_noise_ratio as si_snr,
|
12 |
-
signal_noise_ratio as snr,
|
13 |
-
signal_distortion_ratio as sdr,
|
14 |
-
scale_invariant_signal_distortion_ratio as si_sdr)
|
15 |
-
import matplotlib.pyplot as plt
|
16 |
-
|
17 |
-
class Params():
|
18 |
-
"""Class that loads hyperparameters from a json file.
|
19 |
-
Example:
|
20 |
-
```
|
21 |
-
params = Params(json_path)
|
22 |
-
print(params.learning_rate)
|
23 |
-
params.learning_rate = 0.5 # change the value of learning_rate in params
|
24 |
-
```
|
25 |
-
"""
|
26 |
-
|
27 |
-
def __init__(self, json_path):
|
28 |
-
with open(json_path) as f:
|
29 |
-
params = json.load(f)
|
30 |
-
self.__dict__.update(params)
|
31 |
-
|
32 |
-
def save(self, json_path):
|
33 |
-
with open(json_path, 'w') as f:
|
34 |
-
json.dump(self.__dict__, f, indent=4)
|
35 |
-
|
36 |
-
def update(self, json_path):
|
37 |
-
"""Loads parameters from json file"""
|
38 |
-
with open(json_path) as f:
|
39 |
-
params = json.load(f)
|
40 |
-
self.__dict__.update(params)
|
41 |
-
|
42 |
-
@property
|
43 |
-
def dict(self):
|
44 |
-
"""Gives dict-like access to Params instance by `params.dict['learning_rate']"""
|
45 |
-
return self.__dict__
|
46 |
-
|
47 |
-
def save_graph(train_metrics, test_metrics, save_dir):
|
48 |
-
metrics = [snr, si_snr]
|
49 |
-
results = {'train_loss': train_metrics['loss'],
|
50 |
-
'test_loss' : test_metrics['loss']}
|
51 |
-
|
52 |
-
for m_fn in metrics:
|
53 |
-
results["train_"+m_fn.__name__] = train_metrics[m_fn.__name__]
|
54 |
-
results["test_"+m_fn.__name__] = test_metrics[m_fn.__name__]
|
55 |
-
|
56 |
-
results_pd = pd.DataFrame(results)
|
57 |
-
|
58 |
-
results_pd.to_csv(os.path.join(save_dir, 'results.csv'))
|
59 |
-
|
60 |
-
fig, temp_ax = plt.subplots(2, 3, figsize=(15,10))
|
61 |
-
axs=[]
|
62 |
-
for i in temp_ax:
|
63 |
-
for j in i:
|
64 |
-
axs.append(j)
|
65 |
-
|
66 |
-
x = range(len(train_metrics['loss']))
|
67 |
-
axs[0].plot(x, train_metrics['loss'], label='train')
|
68 |
-
axs[0].plot(x, test_metrics['loss'], label='test')
|
69 |
-
axs[0].set(ylabel='Loss')
|
70 |
-
axs[0].set(xlabel='Epoch')
|
71 |
-
axs[0].set_title('loss',fontweight='bold')
|
72 |
-
axs[0].legend()
|
73 |
-
|
74 |
-
for i in range(len(metrics)):
|
75 |
-
axs[i+1].plot(x, train_metrics[metrics[i].__name__], label='train')
|
76 |
-
axs[i+1].plot(x, test_metrics[metrics[i].__name__], label='test')
|
77 |
-
axs[i+1].set(xlabel='Epoch')
|
78 |
-
axs[i+1].set_title(metrics[i].__name__,fontweight='bold')
|
79 |
-
axs[i+1].legend()
|
80 |
-
|
81 |
-
plt.tight_layout()
|
82 |
-
plt.savefig(os.path.join(save_dir, 'results.png'))
|
83 |
-
plt.close(fig)
|
84 |
-
|
85 |
-
def set_logger(log_path):
|
86 |
-
"""Set the logger to log info in terminal and file `log_path`.
|
87 |
-
In general, it is useful to have a logger so that every output to the terminal is saved
|
88 |
-
in a permanent file. Here we save it to `model_dir/train.log`.
|
89 |
-
Example:
|
90 |
-
```
|
91 |
-
logging.info("Starting training...")
|
92 |
-
```
|
93 |
-
Args:
|
94 |
-
log_path: (string) where to log
|
95 |
-
"""
|
96 |
-
logger = logging.getLogger()
|
97 |
-
logger.setLevel(logging.INFO)
|
98 |
-
logger.handlers.clear()
|
99 |
-
|
100 |
-
# Logging to a file
|
101 |
-
file_handler = logging.FileHandler(log_path)
|
102 |
-
file_handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s'))
|
103 |
-
logger.addHandler(file_handler)
|
104 |
-
|
105 |
-
# Logging to console
|
106 |
-
stream_handler = logging.StreamHandler()
|
107 |
-
stream_handler.setFormatter(logging.Formatter('%(message)s'))
|
108 |
-
logger.addHandler(stream_handler)
|
109 |
-
|
110 |
-
def load_checkpoint(checkpoint, model, optim=None, lr_sched=None, data_parallel=False):
|
111 |
-
"""Loads model parameters (state_dict) from file_path.
|
112 |
-
|
113 |
-
Args:
|
114 |
-
checkpoint: (string) filename which needs to be loaded
|
115 |
-
model: (torch.nn.Module) model for which the parameters are loaded
|
116 |
-
data_parallel: (bool) if the model is a data parallel model
|
117 |
-
"""
|
118 |
-
if not os.path.exists(checkpoint):
|
119 |
-
raise("File doesn't exist {}".format(checkpoint))
|
120 |
-
|
121 |
-
state_dict = torch.load(checkpoint)
|
122 |
-
|
123 |
-
if data_parallel:
|
124 |
-
state_dict['model_state_dict'] = {
|
125 |
-
'module.' + k: state_dict['model_state_dict'][k]
|
126 |
-
for k in state_dict['model_state_dict'].keys()}
|
127 |
-
model.load_state_dict(state_dict['model_state_dict'])
|
128 |
-
|
129 |
-
if optim is not None:
|
130 |
-
optim.load_state_dict(state_dict['optim_state_dict'])
|
131 |
-
|
132 |
-
if lr_sched is not None:
|
133 |
-
lr_sched.load_state_dict(state_dict['lr_sched_state_dict'])
|
134 |
-
|
135 |
-
return state_dict['epoch'], state_dict['train_metrics'], \
|
136 |
-
state_dict['val_metrics']
|
137 |
-
|
138 |
-
def save_checkpoint(checkpoint, epoch, model, optim=None, lr_sched=None,
|
139 |
-
train_metrics=None, val_metrics=None, data_parallel=False):
|
140 |
-
"""Saves model parameters (state_dict) to file_path.
|
141 |
-
|
142 |
-
Args:
|
143 |
-
checkpoint: (string) filename which needs to be loaded
|
144 |
-
model: (torch.nn.Module) model for which the parameters are loaded
|
145 |
-
data_parallel: (bool) if the model is a data parallel model
|
146 |
-
"""
|
147 |
-
if os.path.exists(checkpoint):
|
148 |
-
raise("File already exists {}".format(checkpoint))
|
149 |
-
|
150 |
-
model_state_dict = model.state_dict()
|
151 |
-
if data_parallel:
|
152 |
-
model_state_dict = {
|
153 |
-
k.partition('module.')[2]:
|
154 |
-
model_state_dict[k] for k in model_state_dict.keys()}
|
155 |
-
|
156 |
-
optim_state_dict = None if not optim else optim.state_dict()
|
157 |
-
lr_sched_state_dict = None if not lr_sched else lr_sched.state_dict()
|
158 |
-
|
159 |
-
state_dict = {
|
160 |
-
'epoch': epoch,
|
161 |
-
'model_state_dict': model_state_dict,
|
162 |
-
'optim_state_dict': optim_state_dict,
|
163 |
-
'lr_sched_state_dict': lr_sched_state_dict,
|
164 |
-
'train_metrics': train_metrics,
|
165 |
-
'val_metrics': val_metrics
|
166 |
-
}
|
167 |
-
|
168 |
-
torch.save(state_dict, checkpoint)
|
169 |
-
|
170 |
-
def model_size(model):
|
171 |
-
"""
|
172 |
-
Returns size of the `model` in millions of parameters.
|
173 |
-
"""
|
174 |
-
num_train_params = sum(
|
175 |
-
p.numel() for p in model.parameters() if p.requires_grad)
|
176 |
-
return num_train_params / 1e6
|
177 |
-
|
178 |
-
def run_time(model, inputs, profiling=False):
|
179 |
-
"""
|
180 |
-
Returns runtime of a model in ms.
|
181 |
-
"""
|
182 |
-
# Warmup
|
183 |
-
for _ in range(100):
|
184 |
-
output = model(*inputs)
|
185 |
-
|
186 |
-
with profile(activities=[ProfilerActivity.CPU],
|
187 |
-
record_shapes=True) as prof:
|
188 |
-
with record_function("model_inference"):
|
189 |
-
output = model(*inputs)
|
190 |
-
|
191 |
-
# Print profiling results
|
192 |
-
if profiling:
|
193 |
-
print(prof.key_averages().table(sort_by="self_cpu_time_total",
|
194 |
-
row_limit=20))
|
195 |
-
|
196 |
-
# Return runtime in ms
|
197 |
-
return prof.profiler.self_cpu_time_total / 1000
|
198 |
-
|
199 |
-
def format_lr_info(optimizer):
|
200 |
-
lr_info = ""
|
201 |
-
for i, pg in enumerate(optimizer.param_groups):
|
202 |
-
lr_info += " {group %d: params=%.5fM lr=%.1E}" % (
|
203 |
-
i, sum([p.numel() for p in pg['params']]) / (1024 ** 2), pg['lr'])
|
204 |
-
return lr_info
|
205 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/training/__init__.py
DELETED
File without changes
|
src/training/eval.py
DELETED
@@ -1,214 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Test script to evaluate the model.
|
3 |
-
"""
|
4 |
-
|
5 |
-
import argparse
|
6 |
-
import importlib
|
7 |
-
import multiprocessing
|
8 |
-
import os, glob
|
9 |
-
import logging
|
10 |
-
|
11 |
-
import numpy as np
|
12 |
-
import torch
|
13 |
-
import pandas as pd
|
14 |
-
import torch.nn as nn
|
15 |
-
from torch.utils.tensorboard import SummaryWriter
|
16 |
-
from torch.profiler import profile, record_function, ProfilerActivity
|
17 |
-
from tqdm import tqdm # pylint: disable=unused-import
|
18 |
-
from torchmetrics.functional import(
|
19 |
-
scale_invariant_signal_noise_ratio as si_snr,
|
20 |
-
signal_noise_ratio as snr,
|
21 |
-
signal_distortion_ratio as sdr,
|
22 |
-
scale_invariant_signal_distortion_ratio as si_sdr)
|
23 |
-
|
24 |
-
from src.helpers import utils
|
25 |
-
from src.training.synthetic_dataset import FSDSoundScapesDataset, tensorboard_add_metrics
|
26 |
-
from src.training.synthetic_dataset import tensorboard_add_sample
|
27 |
-
|
28 |
-
def test_epoch(model: nn.Module, device: torch.device,
|
29 |
-
test_loader: torch.utils.data.dataloader.DataLoader,
|
30 |
-
n_items: int, loss_fn, metrics_fn,
|
31 |
-
profiling: bool = False, epoch: int = 0,
|
32 |
-
writer: SummaryWriter = None, data_params = None) -> float:
|
33 |
-
"""
|
34 |
-
Evaluate the network.
|
35 |
-
"""
|
36 |
-
model.eval()
|
37 |
-
metrics = {}
|
38 |
-
|
39 |
-
with torch.no_grad():
|
40 |
-
for batch_idx, (mixed, label, gt) in \
|
41 |
-
enumerate(tqdm(test_loader, desc='Test', ncols=100)):
|
42 |
-
mixed = mixed.to(device)
|
43 |
-
label = label.to(device)
|
44 |
-
gt = gt.to(device)
|
45 |
-
|
46 |
-
# Run through the model
|
47 |
-
with profile(activities=[ProfilerActivity.CPU],
|
48 |
-
record_shapes=True) as prof:
|
49 |
-
with record_function("model_inference"):
|
50 |
-
output = model(mixed, label)
|
51 |
-
if profiling:
|
52 |
-
logging.info(
|
53 |
-
prof.key_averages().table(sort_by="self_cpu_time_total",
|
54 |
-
row_limit=20))
|
55 |
-
|
56 |
-
# Compute loss
|
57 |
-
loss = loss_fn(output, gt)
|
58 |
-
|
59 |
-
# Compute metrics
|
60 |
-
metrics_batch = metrics_fn(mixed, output, gt)
|
61 |
-
metrics_batch['loss'] = [loss.item()]
|
62 |
-
metrics_batch['runtime'] = [prof.profiler.self_cpu_time_total/1000]
|
63 |
-
for k in metrics_batch.keys():
|
64 |
-
if not k in metrics:
|
65 |
-
metrics[k] = metrics_batch[k]
|
66 |
-
else:
|
67 |
-
metrics[k] += metrics_batch[k]
|
68 |
-
|
69 |
-
if writer is not None:
|
70 |
-
if batch_idx == 0:
|
71 |
-
tensorboard_add_sample(
|
72 |
-
writer, tag='Test',
|
73 |
-
sample=(mixed[:8], label[:8], gt[:8], output[:8]),
|
74 |
-
step=epoch, params=data_params)
|
75 |
-
tensorboard_add_metrics(
|
76 |
-
writer, tag='Test', metrics=metrics_batch, label=label,
|
77 |
-
step=epoch)
|
78 |
-
|
79 |
-
if n_items is not None and batch_idx == (n_items - 1):
|
80 |
-
break
|
81 |
-
|
82 |
-
avg_metrics = {k: np.mean(metrics[k]) for k in metrics.keys()}
|
83 |
-
avg_metrics_str = "Test:"
|
84 |
-
for m in avg_metrics.keys():
|
85 |
-
avg_metrics_str += ' %s=%.04f' % (m, avg_metrics[m])
|
86 |
-
logging.info(avg_metrics_str)
|
87 |
-
|
88 |
-
return avg_metrics
|
89 |
-
|
90 |
-
def evaluate(network, args: argparse.Namespace):
|
91 |
-
"""
|
92 |
-
Evaluate the model on a given dataset.
|
93 |
-
"""
|
94 |
-
|
95 |
-
# Load dataset
|
96 |
-
data_test = FSDSoundScapesDataset(**args.test_data)
|
97 |
-
logging.info("Loaded test dataset at %s containing %d elements" %
|
98 |
-
(args.test_data['input_dir'], len(data_test)))
|
99 |
-
|
100 |
-
# Set up the device and workers.
|
101 |
-
use_cuda = args.use_cuda and torch.cuda.is_available()
|
102 |
-
if use_cuda:
|
103 |
-
gpu_ids = args.gpu_ids if args.gpu_ids is not None\
|
104 |
-
else range(torch.cuda.device_count())
|
105 |
-
device_ids = [_ for _ in gpu_ids]
|
106 |
-
data_parallel = len(device_ids) > 1
|
107 |
-
device = 'cuda:%d' % device_ids[0]
|
108 |
-
torch.cuda.set_device(device_ids[0])
|
109 |
-
logging.info("Using CUDA devices: %s" % str(device_ids))
|
110 |
-
else:
|
111 |
-
data_parallel = False
|
112 |
-
device = torch.device('cpu')
|
113 |
-
logging.info("Using device: CPU")
|
114 |
-
|
115 |
-
# Set multiprocessing params
|
116 |
-
num_workers = min(multiprocessing.cpu_count(), args.n_workers)
|
117 |
-
kwargs = {
|
118 |
-
'num_workers': num_workers,
|
119 |
-
'pin_memory': True
|
120 |
-
} if use_cuda else {}
|
121 |
-
|
122 |
-
# Set up data loader
|
123 |
-
test_loader = torch.utils.data.DataLoader(data_test,
|
124 |
-
batch_size=args.eval_batch_size,
|
125 |
-
**kwargs)
|
126 |
-
|
127 |
-
# Set up model
|
128 |
-
model = network.Net(**args.model_params)
|
129 |
-
if use_cuda and data_parallel:
|
130 |
-
model = nn.DataParallel(model, device_ids=device_ids)
|
131 |
-
logging.info("Using data parallel model")
|
132 |
-
model.to(device)
|
133 |
-
|
134 |
-
# Load weights
|
135 |
-
if args.pretrain_path == "best":
|
136 |
-
ckpts = glob.glob(os.path.join(args.exp_dir, '*.pt'))
|
137 |
-
ckpts.sort(
|
138 |
-
key=lambda _: int(os.path.splitext(os.path.basename(_))[0]))
|
139 |
-
val_metrics = torch.load(ckpts[-1])['val_metrics'][args.base_metric]
|
140 |
-
best_epoch = max(range(len(val_metrics)), key=val_metrics.__getitem__)
|
141 |
-
args.pretrain_path = os.path.join(args.exp_dir, '%d.pt' % best_epoch)
|
142 |
-
logging.info(
|
143 |
-
"Found 'best' validation %s=%.02f at %s" %
|
144 |
-
(args.base_metric, val_metrics[best_epoch], args.pretrain_path))
|
145 |
-
if args.pretrain_path != "":
|
146 |
-
utils.load_checkpoint(
|
147 |
-
args.pretrain_path, model, data_parallel=data_parallel)
|
148 |
-
logging.info("Loaded pretrain weights from %s" % args.pretrain_path)
|
149 |
-
|
150 |
-
# Evaluate
|
151 |
-
try:
|
152 |
-
return test_epoch(
|
153 |
-
model, device, test_loader, args.n_items, network.loss,
|
154 |
-
network.metrics, args.profiling)
|
155 |
-
except KeyboardInterrupt:
|
156 |
-
print("Interrupted")
|
157 |
-
except Exception as _: # pylint: disable=broad-except
|
158 |
-
import traceback # pylint: disable=import-outside-toplevel
|
159 |
-
traceback.print_exc()
|
160 |
-
|
161 |
-
|
162 |
-
if __name__ == '__main__':
|
163 |
-
parser = argparse.ArgumentParser()
|
164 |
-
# Data Params
|
165 |
-
parser.add_argument('experiments', nargs='+', type=str,
|
166 |
-
default=None,
|
167 |
-
help="List of experiments to evaluate. "
|
168 |
-
"Provide only one experiment when providing "
|
169 |
-
"pretrained path. If pretrianed path is not "
|
170 |
-
"provided, epoch with best validation metric "
|
171 |
-
"is used for evaluation.")
|
172 |
-
parser.add_argument('--results', type=str, default="",
|
173 |
-
help="Path to the CSV file to store results.")
|
174 |
-
|
175 |
-
# System params
|
176 |
-
parser.add_argument('--n_items', type=int, default=None,
|
177 |
-
help="Number of items to test.")
|
178 |
-
parser.add_argument('--pretrain_path', type=str, default="best",
|
179 |
-
help="Path to pretrained weights")
|
180 |
-
parser.add_argument('--profiling', dest='profiling', action='store_true',
|
181 |
-
help="Enable or disable profiling.")
|
182 |
-
parser.add_argument('--use_cuda', dest='use_cuda', action='store_true',
|
183 |
-
help="Whether to use cuda")
|
184 |
-
parser.add_argument('--gpu_ids', nargs='+', type=int, default=None,
|
185 |
-
help="List of GPU ids used for training. "
|
186 |
-
"Eg., --gpu_ids 2 4. All GPUs are used by default.")
|
187 |
-
args = parser.parse_args()
|
188 |
-
|
189 |
-
results = []
|
190 |
-
|
191 |
-
for exp_dir in args.experiments:
|
192 |
-
eval_args = argparse.Namespace(**vars(args))
|
193 |
-
eval_args.exp_dir = exp_dir
|
194 |
-
|
195 |
-
utils.set_logger(os.path.join(exp_dir, 'eval.log'))
|
196 |
-
logging.info("Evaluating %s ..." % exp_dir)
|
197 |
-
|
198 |
-
# Load model and training params
|
199 |
-
params = utils.Params(os.path.join(exp_dir, 'config.json'))
|
200 |
-
for k, v in params.__dict__.items():
|
201 |
-
vars(eval_args)[k] = v
|
202 |
-
|
203 |
-
network = importlib.import_module(eval_args.model)
|
204 |
-
logging.info("Imported the model from '%s'." % eval_args.model)
|
205 |
-
|
206 |
-
curr_res = evaluate(network, eval_args)
|
207 |
-
curr_res['experiment'] = os.path.basename(exp_dir)
|
208 |
-
results.append(curr_res)
|
209 |
-
|
210 |
-
del eval_args
|
211 |
-
|
212 |
-
if args.results != "":
|
213 |
-
print("Writing results to %s" % args.results)
|
214 |
-
pd.DataFrame(results).to_csv(args.results, index=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/training/synthetic_dataset.py
DELETED
@@ -1,168 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Torch dataset object for synthetically rendered spatial data.
|
3 |
-
"""
|
4 |
-
|
5 |
-
import os
|
6 |
-
import json
|
7 |
-
import random
|
8 |
-
from pathlib import Path
|
9 |
-
import logging
|
10 |
-
|
11 |
-
import numpy as np
|
12 |
-
import pandas as pd
|
13 |
-
import matplotlib.pyplot as plt
|
14 |
-
import scaper
|
15 |
-
import torch
|
16 |
-
import torchaudio
|
17 |
-
import torchaudio.transforms as AT
|
18 |
-
from random import randrange
|
19 |
-
|
20 |
-
class FSDSoundScapesDataset(torch.utils.data.Dataset): # type: ignore
|
21 |
-
"""
|
22 |
-
Base class for FSD Sound Scapes dataset
|
23 |
-
"""
|
24 |
-
|
25 |
-
_labels = [
|
26 |
-
"Acoustic_guitar", "Applause", "Bark", "Bass_drum",
|
27 |
-
"Burping_or_eructation", "Bus", "Cello", "Chime", "Clarinet",
|
28 |
-
"Computer_keyboard", "Cough", "Cowbell", "Double_bass",
|
29 |
-
"Drawer_open_or_close", "Electric_piano", "Fart", "Finger_snapping",
|
30 |
-
"Fireworks", "Flute", "Glockenspiel", "Gong", "Gunshot_or_gunfire",
|
31 |
-
"Harmonica", "Hi-hat", "Keys_jangling", "Knock", "Laughter", "Meow",
|
32 |
-
"Microwave_oven", "Oboe", "Saxophone", "Scissors", "Shatter",
|
33 |
-
"Snare_drum", "Squeak", "Tambourine", "Tearing", "Telephone",
|
34 |
-
"Trumpet", "Violin_or_fiddle", "Writing"]
|
35 |
-
|
36 |
-
def __init__(self, input_dir, dset='', sr=None,
|
37 |
-
resample_rate=None, max_num_targets=1):
|
38 |
-
assert dset in ['train', 'val', 'test'], \
|
39 |
-
"`dset` must be one of ['train', 'val', 'test']"
|
40 |
-
self.dset = dset
|
41 |
-
self.max_num_targets = max_num_targets
|
42 |
-
self.fg_dir = os.path.join(input_dir, 'FSDKaggle2018/%s' % dset)
|
43 |
-
if dset in ['train', 'val']:
|
44 |
-
self.bg_dir = os.path.join(
|
45 |
-
input_dir,
|
46 |
-
'TAU-acoustic-sounds/'
|
47 |
-
'TAU-urban-acoustic-scenes-2019-development')
|
48 |
-
else:
|
49 |
-
self.bg_dir = os.path.join(
|
50 |
-
input_dir,
|
51 |
-
'TAU-acoustic-sounds/'
|
52 |
-
'TAU-urban-acoustic-scenes-2019-evaluation')
|
53 |
-
logging.info("Loading %s dataset: fg_dir=%s bg_dir=%s" %
|
54 |
-
(dset, self.fg_dir, self.bg_dir))
|
55 |
-
|
56 |
-
self.samples = sorted(list(
|
57 |
-
Path(os.path.join(input_dir, 'jams', dset)).glob('[0-9]*')))
|
58 |
-
|
59 |
-
jamsfile = os.path.join(self.samples[0], 'mixture.jams')
|
60 |
-
_, jams, _, _ = scaper.generate_from_jams(
|
61 |
-
jamsfile, fg_path=self.fg_dir, bg_path=self.bg_dir)
|
62 |
-
_sr = jams['annotations'][0]['sandbox']['scaper']['sr']
|
63 |
-
assert _sr == sr, "Sampling rate provided does not match the data"
|
64 |
-
|
65 |
-
if resample_rate is not None:
|
66 |
-
self.resampler = AT.Resample(sr, resample_rate)
|
67 |
-
self.sr = resample_rate
|
68 |
-
else:
|
69 |
-
self.resampler = lambda a: a
|
70 |
-
self.sr = sr
|
71 |
-
|
72 |
-
def _get_label_vector(self, labels):
|
73 |
-
"""
|
74 |
-
Generates a multi-hot vector corresponding to `labels`.
|
75 |
-
"""
|
76 |
-
vector = torch.zeros(len(FSDSoundScapesDataset._labels))
|
77 |
-
|
78 |
-
for label in labels:
|
79 |
-
idx = FSDSoundScapesDataset._labels.index(label)
|
80 |
-
assert vector[idx] == 0, "Repeated labels"
|
81 |
-
vector[idx] = 1
|
82 |
-
|
83 |
-
return vector
|
84 |
-
|
85 |
-
def __len__(self):
|
86 |
-
return len(self.samples)
|
87 |
-
|
88 |
-
def __getitem__(self, idx):
|
89 |
-
sample_path = self.samples[idx]
|
90 |
-
jamsfile = os.path.join(sample_path, 'mixture.jams')
|
91 |
-
|
92 |
-
mixture, jams, ann_list, event_audio_list = scaper.generate_from_jams(
|
93 |
-
jamsfile, fg_path=self.fg_dir, bg_path=self.bg_dir)
|
94 |
-
isolated_events = {}
|
95 |
-
for e, a in zip(ann_list, event_audio_list[1:]):
|
96 |
-
# 0th event is background
|
97 |
-
isolated_events[e[2]] = a
|
98 |
-
gt_events = list(pd.read_csv(
|
99 |
-
os.path.join(sample_path, 'gt_events.csv'), sep='\t')['label'])
|
100 |
-
|
101 |
-
mixture = torch.from_numpy(mixture).permute(1, 0)
|
102 |
-
mixture = self.resampler(mixture.to(torch.float))
|
103 |
-
|
104 |
-
if self.dset == 'train':
|
105 |
-
labels = random.sample(gt_events, randrange(1,self.max_num_targets+1))
|
106 |
-
elif self.dset == 'val':
|
107 |
-
labels = gt_events[:idx%self.max_num_targets+1]
|
108 |
-
elif self.dset == 'test':
|
109 |
-
labels = gt_events[:self.max_num_targets]
|
110 |
-
label_vector = self._get_label_vector(labels)
|
111 |
-
|
112 |
-
gt = torch.zeros_like(
|
113 |
-
torch.from_numpy(event_audio_list[1]).permute(1, 0))
|
114 |
-
for l in labels:
|
115 |
-
gt = gt + torch.from_numpy(isolated_events[l]).permute(1, 0)
|
116 |
-
gt = self.resampler(gt.to(torch.float))
|
117 |
-
|
118 |
-
return mixture, label_vector, gt #, jams
|
119 |
-
|
120 |
-
def tensorboard_add_sample(writer, tag, sample, step, params):
|
121 |
-
"""
|
122 |
-
Adds a sample of FSDSynthDataset to tensorboard.
|
123 |
-
"""
|
124 |
-
if params['resample_rate'] is not None:
|
125 |
-
sr = params['resample_rate']
|
126 |
-
else:
|
127 |
-
sr = params['sr']
|
128 |
-
resample_rate = 16000 if sr > 16000 else sr
|
129 |
-
|
130 |
-
m, l, gt, o = sample
|
131 |
-
m, gt, o = (
|
132 |
-
torchaudio.functional.resample(_, sr, resample_rate).cpu()
|
133 |
-
for _ in (m, gt, o))
|
134 |
-
|
135 |
-
def _add_audio(a, audio_tag, axis, plt_title):
|
136 |
-
for i, ch in enumerate(a):
|
137 |
-
axis.plot(ch, label='mic %d' % i)
|
138 |
-
writer.add_audio(
|
139 |
-
'%s/mic %d' % (audio_tag, i), ch.unsqueeze(0), step, resample_rate)
|
140 |
-
axis.set_title(plt_title)
|
141 |
-
axis.legend()
|
142 |
-
|
143 |
-
for b in range(m.shape[0]):
|
144 |
-
label = []
|
145 |
-
for i in range(len(l[b, :])):
|
146 |
-
if l[b, i] == 1:
|
147 |
-
label.append(FSDSoundScapesDataset._labels[i])
|
148 |
-
|
149 |
-
# Add waveforms
|
150 |
-
rows = 3 # input, output, gt
|
151 |
-
fig = plt.figure(figsize=(10, 2 * rows))
|
152 |
-
axes = fig.subplots(rows, 1, sharex=True)
|
153 |
-
_add_audio(m[b], '%s/sample_%d/0_input' % (tag, b), axes[0], "Mixed")
|
154 |
-
_add_audio(o[b], '%s/sample_%d/1_output' % (tag, b), axes[1], "Output (%s)" % label)
|
155 |
-
_add_audio(gt[b], '%s/sample_%d/2_gt' % (tag, b), axes[2], "GT (%s)" % label)
|
156 |
-
writer.add_figure('%s/sample_%d/waveform' % (tag, b), fig, step)
|
157 |
-
|
158 |
-
def tensorboard_add_metrics(writer, tag, metrics, label, step):
|
159 |
-
"""
|
160 |
-
Add metrics to tensorboard.
|
161 |
-
"""
|
162 |
-
vals = np.asarray(metrics['scale_invariant_signal_noise_ratio'])
|
163 |
-
|
164 |
-
writer.add_histogram('%s/%s' % (tag, 'SI-SNRi'), vals, step)
|
165 |
-
|
166 |
-
label_names = [FSDSoundScapesDataset._labels[torch.argmax(_)] for _ in label]
|
167 |
-
for l, v in zip(label_names, vals):
|
168 |
-
writer.add_histogram('%s/%s' % (tag, l), v, step)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/training/train.py
DELETED
@@ -1,311 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
The main training script for training on synthetic data
|
3 |
-
"""
|
4 |
-
|
5 |
-
import argparse
|
6 |
-
import multiprocessing
|
7 |
-
import os
|
8 |
-
import logging
|
9 |
-
from pathlib import Path
|
10 |
-
import random
|
11 |
-
|
12 |
-
import numpy as np
|
13 |
-
import torch
|
14 |
-
import torch.nn as nn
|
15 |
-
import torch.nn.functional as F
|
16 |
-
import torch.optim as optim
|
17 |
-
from torch.utils.tensorboard import SummaryWriter
|
18 |
-
from tqdm import tqdm # pylint: disable=unused-import
|
19 |
-
from torchmetrics.functional import(
|
20 |
-
scale_invariant_signal_noise_ratio as si_snr,
|
21 |
-
signal_noise_ratio as snr,
|
22 |
-
signal_distortion_ratio as sdr,
|
23 |
-
scale_invariant_signal_distortion_ratio as si_sdr)
|
24 |
-
|
25 |
-
from src.helpers import utils
|
26 |
-
from src.training.eval import test_epoch
|
27 |
-
from src.training.synthetic_dataset import FSDSoundScapesDataset as Dataset
|
28 |
-
from src.training.synthetic_dataset import tensorboard_add_sample
|
29 |
-
|
30 |
-
def train_epoch(model: nn.Module, device: torch.device,
|
31 |
-
optimizer: optim.Optimizer,
|
32 |
-
train_loader: torch.utils.data.dataloader.DataLoader,
|
33 |
-
n_items: int, epoch: int = 0,
|
34 |
-
writer: SummaryWriter = None, data_params = None) -> float:
|
35 |
-
|
36 |
-
"""
|
37 |
-
Train a single epoch.
|
38 |
-
"""
|
39 |
-
# Set the model to training.
|
40 |
-
model.train()
|
41 |
-
|
42 |
-
# Training loop
|
43 |
-
losses = []
|
44 |
-
metrics = {}
|
45 |
-
|
46 |
-
with tqdm(total=len(train_loader), desc='Train', ncols=100) as t:
|
47 |
-
for batch_idx, (mixed, label, gt) in enumerate(train_loader):
|
48 |
-
mixed = mixed.to(device)
|
49 |
-
label = label.to(device)
|
50 |
-
gt = gt.to(device)
|
51 |
-
|
52 |
-
# Reset grad
|
53 |
-
optimizer.zero_grad()
|
54 |
-
|
55 |
-
# Run through the model
|
56 |
-
output = model(mixed, label)
|
57 |
-
|
58 |
-
# Compute loss
|
59 |
-
loss = network.loss(output, gt)
|
60 |
-
|
61 |
-
losses.append(loss.item())
|
62 |
-
|
63 |
-
# Backpropagation
|
64 |
-
loss.backward()
|
65 |
-
|
66 |
-
# Gradient clipping
|
67 |
-
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
|
68 |
-
|
69 |
-
# Update the weights
|
70 |
-
optimizer.step()
|
71 |
-
|
72 |
-
metrics_batch = network.metrics(mixed.detach(), output.detach(),
|
73 |
-
gt.detach())
|
74 |
-
for k in metrics_batch.keys():
|
75 |
-
if not k in metrics:
|
76 |
-
metrics[k] = metrics_batch[k]
|
77 |
-
else:
|
78 |
-
metrics[k] += metrics_batch[k]
|
79 |
-
|
80 |
-
if writer is not None and batch_idx == 0:
|
81 |
-
tensorboard_add_sample(
|
82 |
-
writer, tag='Train',
|
83 |
-
sample=(mixed.detach()[:8], label.detach()[:8],
|
84 |
-
gt.detach()[:8], output.detach()[:8]),
|
85 |
-
step=epoch, params=data_params)
|
86 |
-
|
87 |
-
# Show current loss in the progress meter
|
88 |
-
t.set_postfix(loss='%.05f'%loss.item())
|
89 |
-
t.update()
|
90 |
-
|
91 |
-
if n_items is not None and batch_idx == n_items:
|
92 |
-
break
|
93 |
-
|
94 |
-
avg_metrics = {k: np.mean(metrics[k]) for k in metrics.keys()}
|
95 |
-
avg_metrics['loss'] = np.mean(losses)
|
96 |
-
avg_metrics_str = "Train:"
|
97 |
-
for m in avg_metrics.keys():
|
98 |
-
avg_metrics_str += ' %s=%.04f' % (m, avg_metrics[m])
|
99 |
-
logging.info(avg_metrics_str)
|
100 |
-
|
101 |
-
return avg_metrics
|
102 |
-
|
103 |
-
|
104 |
-
def train(args: argparse.Namespace):
|
105 |
-
"""
|
106 |
-
Train the network.
|
107 |
-
"""
|
108 |
-
|
109 |
-
# Load dataset
|
110 |
-
data_train = Dataset(**args.train_data)
|
111 |
-
logging.info("Loaded train dataset at %s containing %d elements" %
|
112 |
-
(args.train_data['input_dir'], len(data_train)))
|
113 |
-
data_val = Dataset(**args.val_data)
|
114 |
-
logging.info("Loaded test dataset at %s containing %d elements" %
|
115 |
-
(args.val_data['input_dir'], len(data_val)))
|
116 |
-
|
117 |
-
# Set up the device and workers.
|
118 |
-
use_cuda = args.use_cuda and torch.cuda.is_available()
|
119 |
-
if use_cuda:
|
120 |
-
gpu_ids = args.gpu_ids if args.gpu_ids is not None\
|
121 |
-
else range(torch.cuda.device_count())
|
122 |
-
device_ids = [_ for _ in gpu_ids]
|
123 |
-
data_parallel = len(device_ids) > 1
|
124 |
-
device = 'cuda:%d' % device_ids[0]
|
125 |
-
torch.cuda.set_device(device_ids[0])
|
126 |
-
logging.info("Using CUDA devices: %s" % str(device_ids))
|
127 |
-
else:
|
128 |
-
data_parallel = False
|
129 |
-
device = torch.device('cpu')
|
130 |
-
logging.info("Using device: CPU")
|
131 |
-
|
132 |
-
# Set multiprocessing params
|
133 |
-
num_workers = min(multiprocessing.cpu_count(), args.n_workers)
|
134 |
-
kwargs = {
|
135 |
-
'num_workers': num_workers,
|
136 |
-
'pin_memory': True
|
137 |
-
} if use_cuda else {}
|
138 |
-
|
139 |
-
# Set up data loaders
|
140 |
-
#print(args.batch_size, args.eval_batch_size)
|
141 |
-
train_loader = torch.utils.data.DataLoader(data_train,
|
142 |
-
batch_size=args.batch_size,
|
143 |
-
shuffle=True, **kwargs)
|
144 |
-
val_loader = torch.utils.data.DataLoader(data_val,
|
145 |
-
batch_size=args.eval_batch_size,
|
146 |
-
**kwargs)
|
147 |
-
|
148 |
-
# Set up model
|
149 |
-
model = network.Net(**args.model_params)
|
150 |
-
|
151 |
-
# Add graph to tensorboard with example train samples
|
152 |
-
# _mixed, _label, _ = next(iter(val_loader))
|
153 |
-
# args.writer.add_graph(model, (_mixed, _label))
|
154 |
-
|
155 |
-
if use_cuda and data_parallel:
|
156 |
-
model = nn.DataParallel(model, device_ids=device_ids)
|
157 |
-
logging.info("Using data parallel model")
|
158 |
-
model.to(device)
|
159 |
-
|
160 |
-
# Set up the optimizer
|
161 |
-
logging.info("Initializing optimizer with %s" % str(args.optim))
|
162 |
-
optimizer = network.optimizer(model, **args.optim, data_parallel=data_parallel)
|
163 |
-
logging.info('Learning rates initialized to:' + utils.format_lr_info(optimizer))
|
164 |
-
|
165 |
-
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
166 |
-
optimizer, **args.lr_sched)
|
167 |
-
logging.info("Initialized LR scheduler with params: fix_lr_epochs=%d %s"
|
168 |
-
% (args.fix_lr_epochs, str(args.lr_sched)))
|
169 |
-
|
170 |
-
base_metric = args.base_metric
|
171 |
-
train_metrics = {}
|
172 |
-
val_metrics = {}
|
173 |
-
|
174 |
-
# Load the model if `args.start_epoch` is greater than 0. This will load the
|
175 |
-
# model from epoch = `args.start_epoch - 1`
|
176 |
-
assert args.start_epoch >=0, "start_epoch must be greater than 0."
|
177 |
-
if args.start_epoch > 0:
|
178 |
-
checkpoint_path = os.path.join(args.exp_dir,
|
179 |
-
'%d.pt' % (args.start_epoch - 1))
|
180 |
-
_, train_metrics, val_metrics = utils.load_checkpoint(
|
181 |
-
checkpoint_path, model, optim=optimizer, lr_sched=lr_scheduler,
|
182 |
-
data_parallel=data_parallel)
|
183 |
-
logging.info("Loaded checkpoint from %s" % checkpoint_path)
|
184 |
-
logging.info("Learning rates restored to:" + utils.format_lr_info(optimizer))
|
185 |
-
|
186 |
-
# Training loop
|
187 |
-
try:
|
188 |
-
torch.autograd.set_detect_anomaly(args.detect_anomaly)
|
189 |
-
for epoch in range(args.start_epoch, args.epochs + 1):
|
190 |
-
logging.info("Epoch %d:" % epoch)
|
191 |
-
checkpoint_file = os.path.join(args.exp_dir, '%d.pt' % epoch)
|
192 |
-
assert not os.path.exists(checkpoint_file), \
|
193 |
-
"Checkpoint file %s already exists" % checkpoint_file
|
194 |
-
#print("---- begin trianivg")
|
195 |
-
curr_train_metrics = train_epoch(model, device, optimizer,
|
196 |
-
train_loader, args.n_train_items,
|
197 |
-
epoch=epoch, writer=args.writer,
|
198 |
-
data_params=args.train_data)
|
199 |
-
#raise KeyboardInterrupt
|
200 |
-
curr_test_metrics = test_epoch(model, device, val_loader,
|
201 |
-
args.n_test_items, network.loss,
|
202 |
-
network.metrics, epoch=epoch,
|
203 |
-
writer=args.writer,
|
204 |
-
data_params=args.val_data)
|
205 |
-
# LR scheduler
|
206 |
-
if epoch >= args.fix_lr_epochs:
|
207 |
-
lr_scheduler.step(curr_test_metrics[base_metric])
|
208 |
-
logging.info(
|
209 |
-
"LR after scheduling step: %s" %
|
210 |
-
[_['lr'] for _ in optimizer.param_groups])
|
211 |
-
|
212 |
-
# Write metrics to tensorboard
|
213 |
-
args.writer.add_scalars('Train', curr_train_metrics, epoch)
|
214 |
-
args.writer.add_scalars('Val', curr_test_metrics, epoch)
|
215 |
-
args.writer.flush()
|
216 |
-
|
217 |
-
for k in curr_train_metrics.keys():
|
218 |
-
if not k in train_metrics:
|
219 |
-
train_metrics[k] = [curr_train_metrics[k]]
|
220 |
-
else:
|
221 |
-
train_metrics[k].append(curr_train_metrics[k])
|
222 |
-
|
223 |
-
for k in curr_test_metrics.keys():
|
224 |
-
if not k in val_metrics:
|
225 |
-
val_metrics[k] = [curr_test_metrics[k]]
|
226 |
-
else:
|
227 |
-
val_metrics[k].append(curr_test_metrics[k])
|
228 |
-
|
229 |
-
if max(val_metrics[base_metric]) == val_metrics[base_metric][-1]:
|
230 |
-
logging.info("Found best validation %s!" % base_metric)
|
231 |
-
|
232 |
-
utils.save_checkpoint(
|
233 |
-
checkpoint_file, epoch, model, optimizer, lr_scheduler,
|
234 |
-
train_metrics, val_metrics, data_parallel)
|
235 |
-
logging.info("Saved checkpoint at %s" % checkpoint_file)
|
236 |
-
|
237 |
-
utils.save_graph(train_metrics, val_metrics, args.exp_dir)
|
238 |
-
|
239 |
-
return train_metrics, val_metrics
|
240 |
-
|
241 |
-
|
242 |
-
except KeyboardInterrupt:
|
243 |
-
print("Interrupted")
|
244 |
-
except Exception as _: # pylint: disable=broad-except
|
245 |
-
import traceback # pylint: disable=import-outside-toplevel
|
246 |
-
traceback.print_exc()
|
247 |
-
|
248 |
-
|
249 |
-
if __name__ == '__main__':
|
250 |
-
parser = argparse.ArgumentParser()
|
251 |
-
# Data Params
|
252 |
-
parser.add_argument('exp_dir', type=str,
|
253 |
-
default='./experiments/fsd_mask_label_mult',
|
254 |
-
help="Path to save checkpoints and logs.")
|
255 |
-
|
256 |
-
parser.add_argument('--n_train_items', type=int, default=None,
|
257 |
-
help="Number of items to train on in each epoch")
|
258 |
-
parser.add_argument('--n_test_items', type=int, default=None,
|
259 |
-
help="Number of items to test.")
|
260 |
-
parser.add_argument('--start_epoch', type=int, default=0,
|
261 |
-
help="Start epoch")
|
262 |
-
parser.add_argument('--pretrain_path', type=str,
|
263 |
-
help="Path to pretrained weights")
|
264 |
-
parser.add_argument('--use_cuda', dest='use_cuda', action='store_true',
|
265 |
-
help="Whether to use cuda")
|
266 |
-
parser.add_argument('--gpu_ids', nargs='+', type=int, default=None,
|
267 |
-
help="List of GPU ids used for training. "
|
268 |
-
"Eg., --gpu_ids 2 4. All GPUs are used by default.")
|
269 |
-
parser.add_argument('--detect_anomaly', dest='detect_anomaly',
|
270 |
-
action='store_true',
|
271 |
-
help="Whether to use cuda")
|
272 |
-
parser.add_argument('--wandb', dest='wandb', action='store_true',
|
273 |
-
help="Whether to sync tensorboard to wandb")
|
274 |
-
|
275 |
-
args = parser.parse_args()
|
276 |
-
|
277 |
-
# Set the random seed for reproducible experiments
|
278 |
-
torch.manual_seed(230)
|
279 |
-
random.seed(230)
|
280 |
-
np.random.seed(230)
|
281 |
-
if args.use_cuda:
|
282 |
-
torch.cuda.manual_seed(230)
|
283 |
-
|
284 |
-
# Set up checkpoints
|
285 |
-
if not os.path.exists(args.exp_dir):
|
286 |
-
os.makedirs(args.exp_dir)
|
287 |
-
|
288 |
-
utils.set_logger(os.path.join(args.exp_dir, 'train.log'))
|
289 |
-
|
290 |
-
# Load model and training params
|
291 |
-
params = utils.Params(os.path.join(args.exp_dir, 'config.json'))
|
292 |
-
for k, v in params.__dict__.items():
|
293 |
-
vars(args)[k] = v
|
294 |
-
|
295 |
-
# Initialize tensorboard writer
|
296 |
-
tensorboard_dir = os.path.join(args.exp_dir, 'tensorboard')
|
297 |
-
args.writer = SummaryWriter(tensorboard_dir, purge_step=args.start_epoch)
|
298 |
-
if args.wandb:
|
299 |
-
import wandb
|
300 |
-
wandb.init(
|
301 |
-
project='Semaudio', sync_tensorboard=True,
|
302 |
-
dir=tensorboard_dir, name=os.path.basename(args.exp_dir))
|
303 |
-
|
304 |
-
exec("import %s as network" % args.model)
|
305 |
-
logging.info("Imported the model from '%s'." % args.model)
|
306 |
-
|
307 |
-
train(args)
|
308 |
-
|
309 |
-
args.writer.close()
|
310 |
-
if args.wandb:
|
311 |
-
wandb.finish()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|