Spaces:
Runtime error
Runtime error
Base code
Browse files- app.py +62 -0
- default_config.json +60 -0
- requirements.txt +9 -0
- src/__init__.py +0 -0
- src/helpers/__init__.py +0 -0
- src/helpers/utils.py +205 -0
- src/training/__init__.py +0 -0
- src/training/dcc_tf.py +486 -0
- src/training/eval.py +214 -0
- src/training/synthetic_dataset.py +168 -0
- src/training/train.py +311 -0
app.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
|
4 |
+
import wget
|
5 |
+
import torch
|
6 |
+
import torchaudio
|
7 |
+
import gradio as gr
|
8 |
+
|
9 |
+
from src.helpers import utils
|
10 |
+
from src.training.dcc_tf import Net as Waveformer
|
11 |
+
|
12 |
+
TARGETS = [
|
13 |
+
"Acoustic_guitar", "Applause", "Bark", "Bass_drum",
|
14 |
+
"Burping_or_eructation", "Bus", "Cello", "Chime", "Clarinet",
|
15 |
+
"Computer_keyboard", "Cough", "Cowbell", "Double_bass",
|
16 |
+
"Drawer_open_or_close", "Electric_piano", "Fart", "Finger_snapping",
|
17 |
+
"Fireworks", "Flute", "Glockenspiel", "Gong", "Gunshot_or_gunfire",
|
18 |
+
"Harmonica", "Hi-hat", "Keys_jangling", "Knock", "Laughter", "Meow",
|
19 |
+
"Microwave_oven", "Oboe", "Saxophone", "Scissors", "Shatter",
|
20 |
+
"Snare_drum", "Squeak", "Tambourine", "Tearing", "Telephone",
|
21 |
+
"Trumpet", "Violin_or_fiddle", "Writing"
|
22 |
+
]
|
23 |
+
|
24 |
+
if not os.path.exists('default_config.json'):
|
25 |
+
config_url = 'https://targetsound.cs.washington.edu/files/default_config.json'
|
26 |
+
print("Downloading model configuration from %s:" % config_url)
|
27 |
+
wget.download(config_url)
|
28 |
+
|
29 |
+
if not os.path.exists('default_ckpt.pt'):
|
30 |
+
ckpt_url = 'https://targetsound.cs.washington.edu/files/default_ckpt.pt'
|
31 |
+
print("\nDownloading the checkpoint from %s:" % ckpt_url)
|
32 |
+
wget.download(ckpt_url)
|
33 |
+
|
34 |
+
# Instantiate model
|
35 |
+
params = utils.Params('default_config.json')
|
36 |
+
model = Waveformer(**params.model_params)
|
37 |
+
utils.load_checkpoint('default_ckpt.pt', model)
|
38 |
+
model.eval()
|
39 |
+
|
40 |
+
def waveformer(audio, label_choices):
|
41 |
+
# Read input audio
|
42 |
+
fs, mixture = audio
|
43 |
+
if fs != 44100:
|
44 |
+
raise ValueError(fs)
|
45 |
+
mixture = torch.from_numpy(mixture).unsqueeze(0)
|
46 |
+
|
47 |
+
# Construct the query vector
|
48 |
+
if len(label_choices) == 0:
|
49 |
+
raise ValueError(label_choices)
|
50 |
+
query = torch.zeros(1, len(TARGETS))
|
51 |
+
for t in label_choices:
|
52 |
+
query[0, TARGETS.index(t)] = 1.
|
53 |
+
|
54 |
+
with torch.no_grad():
|
55 |
+
output = model(mixture, query)
|
56 |
+
|
57 |
+
return fs, output.squeeze(0).numpy()
|
58 |
+
|
59 |
+
|
60 |
+
label_checkbox = gr.CheckboxGroup(choices=TARGETS)
|
61 |
+
demo = gr.Interface(fn=waveformer, inputs=['audio', label_checkbox], outputs="audio")
|
62 |
+
demo.launch()
|
default_config.json
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"model": "src.training.dcc_tf",
|
3 |
+
"model_params":
|
4 |
+
{
|
5 |
+
"label_len": 41,
|
6 |
+
"L": 32,
|
7 |
+
"enc_dim": 512,
|
8 |
+
"num_enc_layers": 10,
|
9 |
+
"dec_dim": 256,
|
10 |
+
"num_dec_layers": 1,
|
11 |
+
"dec_buf_len": 13,
|
12 |
+
"dec_chunk_size": 13,
|
13 |
+
"out_buf_len": 4,
|
14 |
+
"use_pos_enc": "true"
|
15 |
+
},
|
16 |
+
"train_data":
|
17 |
+
{
|
18 |
+
"input_dir": "data/FSDSoundScapes",
|
19 |
+
"dset": "train",
|
20 |
+
"sr": 44100,
|
21 |
+
"resample_rate": null,
|
22 |
+
"max_num_targets":3
|
23 |
+
},
|
24 |
+
"val_data":
|
25 |
+
{
|
26 |
+
"input_dir": "data/FSDSoundScapes",
|
27 |
+
"dset": "val",
|
28 |
+
"sr": 44100,
|
29 |
+
"resample_rate": null,
|
30 |
+
"max_num_targets":3
|
31 |
+
},
|
32 |
+
"test_data":
|
33 |
+
{
|
34 |
+
"input_dir": "data/FSDSoundScapes",
|
35 |
+
"dset": "test",
|
36 |
+
"sr": 44100,
|
37 |
+
"resample_rate": null,
|
38 |
+
"max_num_targets":3
|
39 |
+
},
|
40 |
+
"optim":
|
41 |
+
{
|
42 |
+
"lr": 5e-4,
|
43 |
+
"weight_decay": 0.0
|
44 |
+
},
|
45 |
+
"lr_sched":
|
46 |
+
{
|
47 |
+
"mode": "max",
|
48 |
+
"factor": 0.1,
|
49 |
+
"patience": 5,
|
50 |
+
"min_lr": 5e-6,
|
51 |
+
"threshold": 0.1,
|
52 |
+
"threshold_mode": "abs"
|
53 |
+
},
|
54 |
+
"base_metric": "scale_invariant_signal_noise_ratio",
|
55 |
+
"fix_lr_epochs": 50,
|
56 |
+
"epochs": 150,
|
57 |
+
"batch_size": 16,
|
58 |
+
"eval_batch_size": 64,
|
59 |
+
"n_workers": 16
|
60 |
+
}
|
requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### Requirements
|
2 |
+
librosa
|
3 |
+
torch
|
4 |
+
torchaudio
|
5 |
+
soundfile
|
6 |
+
numpy
|
7 |
+
speechbrain
|
8 |
+
wget
|
9 |
+
|
src/__init__.py
ADDED
File without changes
|
src/helpers/__init__.py
ADDED
File without changes
|
src/helpers/utils.py
ADDED
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""A collection of useful helper functions"""
|
2 |
+
|
3 |
+
import os
|
4 |
+
import logging
|
5 |
+
import json
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from torch.profiler import profile, record_function, ProfilerActivity
|
9 |
+
import pandas as pd
|
10 |
+
from torchmetrics.functional import(
|
11 |
+
scale_invariant_signal_noise_ratio as si_snr,
|
12 |
+
signal_noise_ratio as snr,
|
13 |
+
signal_distortion_ratio as sdr,
|
14 |
+
scale_invariant_signal_distortion_ratio as si_sdr)
|
15 |
+
import matplotlib.pyplot as plt
|
16 |
+
|
17 |
+
class Params():
|
18 |
+
"""Class that loads hyperparameters from a json file.
|
19 |
+
Example:
|
20 |
+
```
|
21 |
+
params = Params(json_path)
|
22 |
+
print(params.learning_rate)
|
23 |
+
params.learning_rate = 0.5 # change the value of learning_rate in params
|
24 |
+
```
|
25 |
+
"""
|
26 |
+
|
27 |
+
def __init__(self, json_path):
|
28 |
+
with open(json_path) as f:
|
29 |
+
params = json.load(f)
|
30 |
+
self.__dict__.update(params)
|
31 |
+
|
32 |
+
def save(self, json_path):
|
33 |
+
with open(json_path, 'w') as f:
|
34 |
+
json.dump(self.__dict__, f, indent=4)
|
35 |
+
|
36 |
+
def update(self, json_path):
|
37 |
+
"""Loads parameters from json file"""
|
38 |
+
with open(json_path) as f:
|
39 |
+
params = json.load(f)
|
40 |
+
self.__dict__.update(params)
|
41 |
+
|
42 |
+
@property
|
43 |
+
def dict(self):
|
44 |
+
"""Gives dict-like access to Params instance by `params.dict['learning_rate']"""
|
45 |
+
return self.__dict__
|
46 |
+
|
47 |
+
def save_graph(train_metrics, test_metrics, save_dir):
|
48 |
+
metrics = [snr, si_snr]
|
49 |
+
results = {'train_loss': train_metrics['loss'],
|
50 |
+
'test_loss' : test_metrics['loss']}
|
51 |
+
|
52 |
+
for m_fn in metrics:
|
53 |
+
results["train_"+m_fn.__name__] = train_metrics[m_fn.__name__]
|
54 |
+
results["test_"+m_fn.__name__] = test_metrics[m_fn.__name__]
|
55 |
+
|
56 |
+
results_pd = pd.DataFrame(results)
|
57 |
+
|
58 |
+
results_pd.to_csv(os.path.join(save_dir, 'results.csv'))
|
59 |
+
|
60 |
+
fig, temp_ax = plt.subplots(2, 3, figsize=(15,10))
|
61 |
+
axs=[]
|
62 |
+
for i in temp_ax:
|
63 |
+
for j in i:
|
64 |
+
axs.append(j)
|
65 |
+
|
66 |
+
x = range(len(train_metrics['loss']))
|
67 |
+
axs[0].plot(x, train_metrics['loss'], label='train')
|
68 |
+
axs[0].plot(x, test_metrics['loss'], label='test')
|
69 |
+
axs[0].set(ylabel='Loss')
|
70 |
+
axs[0].set(xlabel='Epoch')
|
71 |
+
axs[0].set_title('loss',fontweight='bold')
|
72 |
+
axs[0].legend()
|
73 |
+
|
74 |
+
for i in range(len(metrics)):
|
75 |
+
axs[i+1].plot(x, train_metrics[metrics[i].__name__], label='train')
|
76 |
+
axs[i+1].plot(x, test_metrics[metrics[i].__name__], label='test')
|
77 |
+
axs[i+1].set(xlabel='Epoch')
|
78 |
+
axs[i+1].set_title(metrics[i].__name__,fontweight='bold')
|
79 |
+
axs[i+1].legend()
|
80 |
+
|
81 |
+
plt.tight_layout()
|
82 |
+
plt.savefig(os.path.join(save_dir, 'results.png'))
|
83 |
+
plt.close(fig)
|
84 |
+
|
85 |
+
def set_logger(log_path):
|
86 |
+
"""Set the logger to log info in terminal and file `log_path`.
|
87 |
+
In general, it is useful to have a logger so that every output to the terminal is saved
|
88 |
+
in a permanent file. Here we save it to `model_dir/train.log`.
|
89 |
+
Example:
|
90 |
+
```
|
91 |
+
logging.info("Starting training...")
|
92 |
+
```
|
93 |
+
Args:
|
94 |
+
log_path: (string) where to log
|
95 |
+
"""
|
96 |
+
logger = logging.getLogger()
|
97 |
+
logger.setLevel(logging.INFO)
|
98 |
+
logger.handlers.clear()
|
99 |
+
|
100 |
+
# Logging to a file
|
101 |
+
file_handler = logging.FileHandler(log_path)
|
102 |
+
file_handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s'))
|
103 |
+
logger.addHandler(file_handler)
|
104 |
+
|
105 |
+
# Logging to console
|
106 |
+
stream_handler = logging.StreamHandler()
|
107 |
+
stream_handler.setFormatter(logging.Formatter('%(message)s'))
|
108 |
+
logger.addHandler(stream_handler)
|
109 |
+
|
110 |
+
def load_checkpoint(checkpoint, model, optim=None, lr_sched=None, data_parallel=False):
|
111 |
+
"""Loads model parameters (state_dict) from file_path.
|
112 |
+
|
113 |
+
Args:
|
114 |
+
checkpoint: (string) filename which needs to be loaded
|
115 |
+
model: (torch.nn.Module) model for which the parameters are loaded
|
116 |
+
data_parallel: (bool) if the model is a data parallel model
|
117 |
+
"""
|
118 |
+
if not os.path.exists(checkpoint):
|
119 |
+
raise("File doesn't exist {}".format(checkpoint))
|
120 |
+
|
121 |
+
state_dict = torch.load(checkpoint)
|
122 |
+
|
123 |
+
if data_parallel:
|
124 |
+
state_dict['model_state_dict'] = {
|
125 |
+
'module.' + k: state_dict['model_state_dict'][k]
|
126 |
+
for k in state_dict['model_state_dict'].keys()}
|
127 |
+
model.load_state_dict(state_dict['model_state_dict'])
|
128 |
+
|
129 |
+
if optim is not None:
|
130 |
+
optim.load_state_dict(state_dict['optim_state_dict'])
|
131 |
+
|
132 |
+
if lr_sched is not None:
|
133 |
+
lr_sched.load_state_dict(state_dict['lr_sched_state_dict'])
|
134 |
+
|
135 |
+
return state_dict['epoch'], state_dict['train_metrics'], \
|
136 |
+
state_dict['val_metrics']
|
137 |
+
|
138 |
+
def save_checkpoint(checkpoint, epoch, model, optim=None, lr_sched=None,
|
139 |
+
train_metrics=None, val_metrics=None, data_parallel=False):
|
140 |
+
"""Saves model parameters (state_dict) to file_path.
|
141 |
+
|
142 |
+
Args:
|
143 |
+
checkpoint: (string) filename which needs to be loaded
|
144 |
+
model: (torch.nn.Module) model for which the parameters are loaded
|
145 |
+
data_parallel: (bool) if the model is a data parallel model
|
146 |
+
"""
|
147 |
+
if os.path.exists(checkpoint):
|
148 |
+
raise("File already exists {}".format(checkpoint))
|
149 |
+
|
150 |
+
model_state_dict = model.state_dict()
|
151 |
+
if data_parallel:
|
152 |
+
model_state_dict = {
|
153 |
+
k.partition('module.')[2]:
|
154 |
+
model_state_dict[k] for k in model_state_dict.keys()}
|
155 |
+
|
156 |
+
optim_state_dict = None if not optim else optim.state_dict()
|
157 |
+
lr_sched_state_dict = None if not lr_sched else lr_sched.state_dict()
|
158 |
+
|
159 |
+
state_dict = {
|
160 |
+
'epoch': epoch,
|
161 |
+
'model_state_dict': model_state_dict,
|
162 |
+
'optim_state_dict': optim_state_dict,
|
163 |
+
'lr_sched_state_dict': lr_sched_state_dict,
|
164 |
+
'train_metrics': train_metrics,
|
165 |
+
'val_metrics': val_metrics
|
166 |
+
}
|
167 |
+
|
168 |
+
torch.save(state_dict, checkpoint)
|
169 |
+
|
170 |
+
def model_size(model):
|
171 |
+
"""
|
172 |
+
Returns size of the `model` in millions of parameters.
|
173 |
+
"""
|
174 |
+
num_train_params = sum(
|
175 |
+
p.numel() for p in model.parameters() if p.requires_grad)
|
176 |
+
return num_train_params / 1e6
|
177 |
+
|
178 |
+
def run_time(model, inputs, profiling=False):
|
179 |
+
"""
|
180 |
+
Returns runtime of a model in ms.
|
181 |
+
"""
|
182 |
+
# Warmup
|
183 |
+
for _ in range(100):
|
184 |
+
output = model(*inputs)
|
185 |
+
|
186 |
+
with profile(activities=[ProfilerActivity.CPU],
|
187 |
+
record_shapes=True) as prof:
|
188 |
+
with record_function("model_inference"):
|
189 |
+
output = model(*inputs)
|
190 |
+
|
191 |
+
# Print profiling results
|
192 |
+
if profiling:
|
193 |
+
print(prof.key_averages().table(sort_by="self_cpu_time_total",
|
194 |
+
row_limit=20))
|
195 |
+
|
196 |
+
# Return runtime in ms
|
197 |
+
return prof.profiler.self_cpu_time_total / 1000
|
198 |
+
|
199 |
+
def format_lr_info(optimizer):
|
200 |
+
lr_info = ""
|
201 |
+
for i, pg in enumerate(optimizer.param_groups):
|
202 |
+
lr_info += " {group %d: params=%.5fM lr=%.1E}" % (
|
203 |
+
i, sum([p.numel() for p in pg['params']]) / (1024 ** 2), pg['lr'])
|
204 |
+
return lr_info
|
205 |
+
|
src/training/__init__.py
ADDED
File without changes
|
src/training/dcc_tf.py
ADDED
@@ -0,0 +1,486 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from collections import OrderedDict
|
3 |
+
from typing import Optional
|
4 |
+
|
5 |
+
from torch import Tensor
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import torch.optim as optim
|
10 |
+
from torchmetrics.functional import(
|
11 |
+
scale_invariant_signal_noise_ratio as si_snr,
|
12 |
+
signal_noise_ratio as snr,
|
13 |
+
signal_distortion_ratio as sdr,
|
14 |
+
scale_invariant_signal_distortion_ratio as si_sdr)
|
15 |
+
|
16 |
+
from speechbrain.lobes.models.transformer.Transformer import PositionalEncoding
|
17 |
+
|
18 |
+
def mod_pad(x, chunk_size, pad):
|
19 |
+
# Mod pad the input to perform integer number of
|
20 |
+
# inferences
|
21 |
+
mod = 0
|
22 |
+
if (x.shape[-1] % chunk_size) != 0:
|
23 |
+
mod = chunk_size - (x.shape[-1] % chunk_size)
|
24 |
+
|
25 |
+
x = F.pad(x, (0, mod))
|
26 |
+
x = F.pad(x, pad)
|
27 |
+
|
28 |
+
return x, mod
|
29 |
+
|
30 |
+
class LayerNormPermuted(nn.LayerNorm):
|
31 |
+
def __init__(self, *args, **kwargs):
|
32 |
+
super(LayerNormPermuted, self).__init__(*args, **kwargs)
|
33 |
+
|
34 |
+
def forward(self, x):
|
35 |
+
"""
|
36 |
+
Args:
|
37 |
+
x: [B, C, T]
|
38 |
+
"""
|
39 |
+
x = x.permute(0, 2, 1) # [B, T, C]
|
40 |
+
x = super().forward(x)
|
41 |
+
x = x.permute(0, 2, 1) # [B, C, T]
|
42 |
+
return x
|
43 |
+
|
44 |
+
class DepthwiseSeparableConv(nn.Module):
|
45 |
+
"""
|
46 |
+
Depthwise separable convolutions
|
47 |
+
"""
|
48 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride,
|
49 |
+
padding, dilation):
|
50 |
+
super(DepthwiseSeparableConv, self).__init__()
|
51 |
+
|
52 |
+
self.layers = nn.Sequential(
|
53 |
+
nn.Conv1d(in_channels, in_channels, kernel_size, stride,
|
54 |
+
padding, groups=in_channels, dilation=dilation),
|
55 |
+
LayerNormPermuted(in_channels),
|
56 |
+
nn.ReLU(),
|
57 |
+
nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=1,
|
58 |
+
padding=0),
|
59 |
+
LayerNormPermuted(out_channels),
|
60 |
+
nn.ReLU(),
|
61 |
+
)
|
62 |
+
|
63 |
+
def forward(self, x):
|
64 |
+
return self.layers(x)
|
65 |
+
|
66 |
+
class DilatedCausalConvEncoder(nn.Module):
|
67 |
+
"""
|
68 |
+
A dilated causal convolution based encoder for encoding
|
69 |
+
time domain audio input into latent space.
|
70 |
+
"""
|
71 |
+
def __init__(self, channels, num_layers, kernel_size=3):
|
72 |
+
super(DilatedCausalConvEncoder, self).__init__()
|
73 |
+
self.channels = channels
|
74 |
+
self.num_layers = num_layers
|
75 |
+
self.kernel_size = kernel_size
|
76 |
+
|
77 |
+
# Compute buffer lengths for each layer
|
78 |
+
# buf_length[i] = (kernel_size - 1) * dilation[i]
|
79 |
+
self.buf_lengths = [(kernel_size - 1) * 2**i
|
80 |
+
for i in range(num_layers)]
|
81 |
+
|
82 |
+
# Compute buffer start indices for each layer
|
83 |
+
self.buf_indices = [0]
|
84 |
+
for i in range(num_layers - 1):
|
85 |
+
self.buf_indices.append(
|
86 |
+
self.buf_indices[-1] + self.buf_lengths[i])
|
87 |
+
|
88 |
+
# Dilated causal conv layers aggregate previous context to obtain
|
89 |
+
# contexful encoded input.
|
90 |
+
_dcc_layers = OrderedDict()
|
91 |
+
for i in range(num_layers):
|
92 |
+
dcc_layer = DepthwiseSeparableConv(
|
93 |
+
channels, channels, kernel_size=3, stride=1,
|
94 |
+
padding=0, dilation=2**i)
|
95 |
+
_dcc_layers.update({'dcc_%d' % i: dcc_layer})
|
96 |
+
self.dcc_layers = nn.Sequential(_dcc_layers)
|
97 |
+
|
98 |
+
def init_ctx_buf(self, batch_size, device):
|
99 |
+
"""
|
100 |
+
Returns an initialized context buffer for a given batch size.
|
101 |
+
"""
|
102 |
+
return torch.zeros(
|
103 |
+
(batch_size, self.channels,
|
104 |
+
(self.kernel_size - 1) * (2**self.num_layers - 1)),
|
105 |
+
device=device)
|
106 |
+
|
107 |
+
def forward(self, x, ctx_buf):
|
108 |
+
"""
|
109 |
+
Encodes input audio `x` into latent space, and aggregates
|
110 |
+
contextual information in `ctx_buf`. Also generates new context
|
111 |
+
buffer with updated context.
|
112 |
+
Args:
|
113 |
+
x: [B, in_channels, T]
|
114 |
+
Input multi-channel audio.
|
115 |
+
ctx_buf: {[B, channels, self.buf_length[0]], ...}
|
116 |
+
A list of tensors holding context for each dilation
|
117 |
+
causal conv layer. (len(ctx_buf) == self.num_layers)
|
118 |
+
Returns:
|
119 |
+
ctx_buf: {[B, channels, self.buf_length[0]], ...}
|
120 |
+
Updated context buffer with output as the
|
121 |
+
last element.
|
122 |
+
"""
|
123 |
+
T = x.shape[-1] # Sequence length
|
124 |
+
|
125 |
+
for i in range(self.num_layers):
|
126 |
+
buf_start_idx = self.buf_indices[i]
|
127 |
+
buf_end_idx = self.buf_indices[i] + self.buf_lengths[i]
|
128 |
+
|
129 |
+
# DCC input: concatenation of current output and context
|
130 |
+
dcc_in = torch.cat(
|
131 |
+
(ctx_buf[..., buf_start_idx:buf_end_idx], x), dim=-1)
|
132 |
+
|
133 |
+
# Push current output to the context buffer
|
134 |
+
ctx_buf[..., buf_start_idx:buf_end_idx] = \
|
135 |
+
dcc_in[..., -self.buf_lengths[i]:]
|
136 |
+
|
137 |
+
# Residual connection
|
138 |
+
x = x + self.dcc_layers[i](dcc_in)
|
139 |
+
|
140 |
+
return x, ctx_buf
|
141 |
+
|
142 |
+
class CausalTransformerDecoderLayer(torch.nn.TransformerDecoderLayer):
|
143 |
+
"""
|
144 |
+
Adapted from:
|
145 |
+
"https://github.com/alexmt-scale/causal-transformer-decoder/blob/"
|
146 |
+
"0caf6ad71c46488f76d89845b0123d2550ef792f/"
|
147 |
+
"causal_transformer_decoder/model.py#L77"
|
148 |
+
"""
|
149 |
+
def forward(
|
150 |
+
self,
|
151 |
+
tgt: Tensor,
|
152 |
+
memory: Optional[Tensor] = None,
|
153 |
+
chunk_size: int = 1
|
154 |
+
) -> Tensor:
|
155 |
+
tgt_last_tok = tgt[:, -chunk_size:, :]
|
156 |
+
|
157 |
+
# self attention part
|
158 |
+
tmp_tgt, sa_map = self.self_attn(
|
159 |
+
tgt_last_tok,
|
160 |
+
tgt,
|
161 |
+
tgt,
|
162 |
+
attn_mask=None, # not needed because we only care about the last token
|
163 |
+
key_padding_mask=None,
|
164 |
+
)
|
165 |
+
tgt_last_tok = tgt_last_tok + self.dropout1(tmp_tgt)
|
166 |
+
tgt_last_tok = self.norm1(tgt_last_tok)
|
167 |
+
|
168 |
+
# encoder-decoder attention
|
169 |
+
if memory is not None:
|
170 |
+
tmp_tgt, ca_map = self.multihead_attn(
|
171 |
+
tgt_last_tok,
|
172 |
+
memory,
|
173 |
+
memory,
|
174 |
+
attn_mask=None, # Attend to the entire chunk
|
175 |
+
key_padding_mask=None,
|
176 |
+
)
|
177 |
+
tgt_last_tok = tgt_last_tok + self.dropout2(tmp_tgt)
|
178 |
+
tgt_last_tok = self.norm2(tgt_last_tok)
|
179 |
+
|
180 |
+
# final feed-forward network
|
181 |
+
tmp_tgt = self.linear2(
|
182 |
+
self.dropout(self.activation(self.linear1(tgt_last_tok)))
|
183 |
+
)
|
184 |
+
tgt_last_tok = tgt_last_tok + self.dropout3(tmp_tgt)
|
185 |
+
tgt_last_tok = self.norm3(tgt_last_tok)
|
186 |
+
return tgt_last_tok, sa_map, ca_map
|
187 |
+
|
188 |
+
class CausalTransformerDecoder(nn.Module):
|
189 |
+
"""
|
190 |
+
A casual transformer decoder which decodes input vectors using
|
191 |
+
precisely `ctx_len` past vectors in the sequence, and using no future
|
192 |
+
vectors at all.
|
193 |
+
"""
|
194 |
+
def __init__(self, model_dim, ctx_len, chunk_size, num_layers,
|
195 |
+
nhead, use_pos_enc, ff_dim):
|
196 |
+
super(CausalTransformerDecoder, self).__init__()
|
197 |
+
self.num_layers = num_layers
|
198 |
+
self.model_dim = model_dim
|
199 |
+
self.ctx_len = ctx_len
|
200 |
+
self.chunk_size = chunk_size
|
201 |
+
self.nhead = nhead
|
202 |
+
self.use_pos_enc = use_pos_enc
|
203 |
+
self.unfold = nn.Unfold(kernel_size=(ctx_len + chunk_size, 1), stride=chunk_size)
|
204 |
+
self.pos_enc = PositionalEncoding(model_dim, max_len=200)
|
205 |
+
self.tf_dec_layers = nn.ModuleList([CausalTransformerDecoderLayer(
|
206 |
+
d_model=model_dim, nhead=nhead, dim_feedforward=ff_dim,
|
207 |
+
batch_first=True) for _ in range(num_layers)])
|
208 |
+
|
209 |
+
def init_ctx_buf(self, batch_size, device):
|
210 |
+
return torch.zeros(
|
211 |
+
(batch_size, self.num_layers + 1, self.ctx_len, self.model_dim),
|
212 |
+
device=device)
|
213 |
+
|
214 |
+
def _causal_unfold(self, x):
|
215 |
+
"""
|
216 |
+
Unfolds the sequence into a batch of sequences
|
217 |
+
prepended with `ctx_len` previous values.
|
218 |
+
|
219 |
+
Args:
|
220 |
+
x: [B, ctx_len + L, C]
|
221 |
+
ctx_len: int
|
222 |
+
Returns:
|
223 |
+
[B * L, ctx_len + 1, C]
|
224 |
+
"""
|
225 |
+
B, T, C = x.shape
|
226 |
+
x = x.permute(0, 2, 1) # [B, C, ctx_len + L]
|
227 |
+
x = self.unfold(x.unsqueeze(-1)) # [B, C * (ctx_len + chunk_size), -1]
|
228 |
+
x = x.permute(0, 2, 1)
|
229 |
+
x = x.reshape(B, -1, C, self.ctx_len + self.chunk_size)
|
230 |
+
x = x.reshape(-1, C, self.ctx_len + self.chunk_size)
|
231 |
+
x = x.permute(0, 2, 1)
|
232 |
+
return x
|
233 |
+
|
234 |
+
def forward(self, tgt, mem, ctx_buf, probe=False):
|
235 |
+
"""
|
236 |
+
Args:
|
237 |
+
x: [B, model_dim, T]
|
238 |
+
ctx_buf: [B, num_layers, model_dim, ctx_len]
|
239 |
+
"""
|
240 |
+
mem, _ = mod_pad(mem, self.chunk_size, (0, 0))
|
241 |
+
tgt, mod = mod_pad(tgt, self.chunk_size, (0, 0))
|
242 |
+
|
243 |
+
# Input sequence length
|
244 |
+
B, C, T = tgt.shape
|
245 |
+
|
246 |
+
tgt = tgt.permute(0, 2, 1)
|
247 |
+
mem = mem.permute(0, 2, 1)
|
248 |
+
|
249 |
+
# Prepend mem with the context
|
250 |
+
mem = torch.cat((ctx_buf[:, 0, :, :], mem), dim=1)
|
251 |
+
ctx_buf[:, 0, :, :] = mem[:, -self.ctx_len:, :]
|
252 |
+
mem_ctx = self._causal_unfold(mem)
|
253 |
+
if self.use_pos_enc:
|
254 |
+
mem_ctx = mem_ctx + self.pos_enc(mem_ctx)
|
255 |
+
|
256 |
+
# Attention chunk size: required to ensure the model
|
257 |
+
# wouldn't trigger an out-of-memory error when working
|
258 |
+
# on long sequences.
|
259 |
+
K = 1000
|
260 |
+
|
261 |
+
for i, tf_dec_layer in enumerate(self.tf_dec_layers):
|
262 |
+
# Update the tgt with context
|
263 |
+
tgt = torch.cat((ctx_buf[:, i + 1, :, :], tgt), dim=1)
|
264 |
+
ctx_buf[:, i + 1, :, :] = tgt[:, -self.ctx_len:, :]
|
265 |
+
|
266 |
+
# Compute encoded output
|
267 |
+
tgt_ctx = self._causal_unfold(tgt)
|
268 |
+
if self.use_pos_enc and i == 0:
|
269 |
+
tgt_ctx = tgt_ctx + self.pos_enc(tgt_ctx)
|
270 |
+
tgt = torch.zeros_like(tgt_ctx)[:, -self.chunk_size:, :]
|
271 |
+
for i in range(int(math.ceil(tgt.shape[0] / K))):
|
272 |
+
tgt[i*K:(i+1)*K], _sa_map, _ca_map = tf_dec_layer(
|
273 |
+
tgt_ctx[i*K:(i+1)*K], mem_ctx[i*K:(i+1)*K],
|
274 |
+
self.chunk_size)
|
275 |
+
tgt = tgt.reshape(B, T, C)
|
276 |
+
|
277 |
+
tgt = tgt.permute(0, 2, 1)
|
278 |
+
if mod != 0:
|
279 |
+
tgt = tgt[..., :-mod]
|
280 |
+
|
281 |
+
return tgt, ctx_buf
|
282 |
+
|
283 |
+
class MaskNet(nn.Module):
|
284 |
+
def __init__(self, enc_dim, num_enc_layers, dec_dim, dec_buf_len,
|
285 |
+
dec_chunk_size, num_dec_layers, use_pos_enc, skip_connection, proj):
|
286 |
+
super(MaskNet, self).__init__()
|
287 |
+
self.skip_connection = skip_connection
|
288 |
+
self.proj = proj
|
289 |
+
|
290 |
+
# Encoder based on dilated causal convolutions.
|
291 |
+
self.encoder = DilatedCausalConvEncoder(channels=enc_dim,
|
292 |
+
num_layers=num_enc_layers)
|
293 |
+
|
294 |
+
# Project between encoder and decoder dimensions
|
295 |
+
self.proj_e2d_e = nn.Sequential(
|
296 |
+
nn.Conv1d(enc_dim, dec_dim, kernel_size=1, stride=1, padding=0,
|
297 |
+
groups=dec_dim),
|
298 |
+
nn.ReLU())
|
299 |
+
self.proj_e2d_l = nn.Sequential(
|
300 |
+
nn.Conv1d(enc_dim, dec_dim, kernel_size=1, stride=1, padding=0,
|
301 |
+
groups=dec_dim),
|
302 |
+
nn.ReLU())
|
303 |
+
self.proj_d2e = nn.Sequential(
|
304 |
+
nn.Conv1d(dec_dim, enc_dim, kernel_size=1, stride=1, padding=0,
|
305 |
+
groups=dec_dim),
|
306 |
+
nn.ReLU())
|
307 |
+
|
308 |
+
# Transformer decoder that operates on chunks of size
|
309 |
+
# buffer size.
|
310 |
+
self.decoder = CausalTransformerDecoder(
|
311 |
+
model_dim=dec_dim, ctx_len=dec_buf_len, chunk_size=dec_chunk_size,
|
312 |
+
num_layers=num_dec_layers, nhead=8, use_pos_enc=use_pos_enc,
|
313 |
+
ff_dim=2 * dec_dim)
|
314 |
+
|
315 |
+
def forward(self, x, l, enc_buf, dec_buf):
|
316 |
+
"""
|
317 |
+
Generates a mask based on encoded input `e` and the one-hot
|
318 |
+
label `label`.
|
319 |
+
|
320 |
+
Args:
|
321 |
+
x: [B, C, T]
|
322 |
+
Input audio sequence
|
323 |
+
l: [B, C]
|
324 |
+
Label embedding
|
325 |
+
ctx_buf: {[B, C, <receptive field of the layer>], ...}
|
326 |
+
List of context buffers maintained by DCC encoder
|
327 |
+
"""
|
328 |
+
# Enocder the label integrated input
|
329 |
+
e, enc_buf = self.encoder(x, enc_buf)
|
330 |
+
|
331 |
+
# Label integration
|
332 |
+
l = l.unsqueeze(2) * e
|
333 |
+
|
334 |
+
# Project to `dec_dim` dimensions
|
335 |
+
if self.proj:
|
336 |
+
e = self.proj_e2d_e(e)
|
337 |
+
m = self.proj_e2d_l(l)
|
338 |
+
# Cross-attention to predict the mask
|
339 |
+
m, dec_buf = self.decoder(m, e, dec_buf)
|
340 |
+
else:
|
341 |
+
# Cross-attention to predict the mask
|
342 |
+
m, dec_buf = self.decoder(l, e, dec_buf)
|
343 |
+
|
344 |
+
# Project mask to encoder dimensions
|
345 |
+
if self.proj:
|
346 |
+
m = self.proj_d2e(m)
|
347 |
+
|
348 |
+
# Final mask after residual connection
|
349 |
+
if self.skip_connection:
|
350 |
+
m = l + m
|
351 |
+
|
352 |
+
return m, enc_buf, dec_buf
|
353 |
+
|
354 |
+
class Net(nn.Module):
|
355 |
+
def __init__(self, label_len, L=8,
|
356 |
+
enc_dim=512, num_enc_layers=10,
|
357 |
+
dec_dim=256, dec_buf_len=100, num_dec_layers=2,
|
358 |
+
dec_chunk_size=72, out_buf_len=2,
|
359 |
+
use_pos_enc=True, skip_connection=True, proj=True, lookahead=True):
|
360 |
+
super(Net, self).__init__()
|
361 |
+
self.L = L
|
362 |
+
self.out_buf_len = out_buf_len
|
363 |
+
self.enc_dim = enc_dim
|
364 |
+
self.lookahead = lookahead
|
365 |
+
|
366 |
+
# Input conv to convert input audio to a latent representation
|
367 |
+
kernel_size = 3 * L if lookahead else L
|
368 |
+
self.in_conv = nn.Sequential(
|
369 |
+
nn.Conv1d(in_channels=1,
|
370 |
+
out_channels=enc_dim, kernel_size=kernel_size, stride=L,
|
371 |
+
padding=0, bias=False),
|
372 |
+
nn.ReLU())
|
373 |
+
|
374 |
+
# Label embedding layer
|
375 |
+
self.label_embedding = nn.Sequential(
|
376 |
+
nn.Linear(label_len, 512),
|
377 |
+
nn.LayerNorm(512),
|
378 |
+
nn.ReLU(),
|
379 |
+
nn.Linear(512, enc_dim),
|
380 |
+
nn.LayerNorm(enc_dim),
|
381 |
+
nn.ReLU())
|
382 |
+
|
383 |
+
# Mask generator
|
384 |
+
self.mask_gen = MaskNet(
|
385 |
+
enc_dim=enc_dim, num_enc_layers=num_enc_layers,
|
386 |
+
dec_dim=dec_dim, dec_buf_len=dec_buf_len,
|
387 |
+
dec_chunk_size=dec_chunk_size, num_dec_layers=num_dec_layers,
|
388 |
+
use_pos_enc=use_pos_enc, skip_connection=skip_connection, proj=proj)
|
389 |
+
|
390 |
+
# Output conv layer
|
391 |
+
self.out_conv = nn.Sequential(
|
392 |
+
nn.ConvTranspose1d(
|
393 |
+
in_channels=enc_dim, out_channels=1,
|
394 |
+
kernel_size=(out_buf_len + 1) * L,
|
395 |
+
stride=L,
|
396 |
+
padding=out_buf_len * L, bias=False),
|
397 |
+
nn.Tanh())
|
398 |
+
|
399 |
+
def init_buffers(self, batch_size, device):
|
400 |
+
enc_buf = self.mask_gen.encoder.init_ctx_buf(batch_size, device)
|
401 |
+
dec_buf = self.mask_gen.decoder.init_ctx_buf(batch_size, device)
|
402 |
+
out_buf = torch.zeros(batch_size, self.enc_dim, self.out_buf_len,
|
403 |
+
device=device)
|
404 |
+
return enc_buf, dec_buf, out_buf
|
405 |
+
|
406 |
+
def forward(self, x, label, init_enc_buf=None, init_dec_buf=None,
|
407 |
+
init_out_buf=None, pad=True):
|
408 |
+
"""
|
409 |
+
Extracts the audio corresponding to the `label` in the given
|
410 |
+
`mixture`. Generates `chunk_size` samples per iteration.
|
411 |
+
|
412 |
+
Args:
|
413 |
+
mixed: [B, n_mics, T]
|
414 |
+
input audio mixture
|
415 |
+
label: [B, num_labels]
|
416 |
+
one hot label
|
417 |
+
Returns:
|
418 |
+
out: [B, n_spk, T]
|
419 |
+
extracted audio with sounds corresponding to the `label`
|
420 |
+
"""
|
421 |
+
mod = 0
|
422 |
+
if pad:
|
423 |
+
pad_size = (self.L, self.L) if self.lookahead else (0, 0)
|
424 |
+
x, mod = mod_pad(x, chunk_size=self.L, pad=pad_size)
|
425 |
+
|
426 |
+
if init_enc_buf is None or init_dec_buf is None or init_out_buf is None:
|
427 |
+
assert init_enc_buf is None and \
|
428 |
+
init_dec_buf is None and \
|
429 |
+
init_out_buf is None, \
|
430 |
+
"Both buffers have to initialized, or " \
|
431 |
+
"both of them have to be None."
|
432 |
+
enc_buf, dec_buf, out_buf = self.init_buffers(
|
433 |
+
x.shape[0], x.device)
|
434 |
+
else:
|
435 |
+
enc_buf, dec_buf, out_buf = \
|
436 |
+
init_enc_buf, init_dec_buf, init_out_buf
|
437 |
+
|
438 |
+
# Generate latent space representation of the input
|
439 |
+
x = self.in_conv(x)
|
440 |
+
|
441 |
+
# Generate label embedding
|
442 |
+
l = self.label_embedding(label) # [B, label_len] --> [B, channels]
|
443 |
+
|
444 |
+
# Generate mask corresponding to the label
|
445 |
+
m, enc_buf, dec_buf = self.mask_gen(x, l, enc_buf, dec_buf)
|
446 |
+
|
447 |
+
# Apply mask and decode
|
448 |
+
x = x * m
|
449 |
+
x = torch.cat((out_buf, x), dim=-1)
|
450 |
+
out_buf = x[..., -self.out_buf_len:]
|
451 |
+
x = self.out_conv(x)
|
452 |
+
|
453 |
+
# Remove mod padding, if present.
|
454 |
+
if mod != 0:
|
455 |
+
x = x[:, :, :-mod]
|
456 |
+
|
457 |
+
if init_enc_buf is None:
|
458 |
+
return x
|
459 |
+
else:
|
460 |
+
return x, enc_buf, dec_buf, out_buf
|
461 |
+
|
462 |
+
# Define optimizer, loss and metrics
|
463 |
+
|
464 |
+
def optimizer(model, data_parallel=False, **kwargs):
|
465 |
+
return optim.Adam(model.parameters(), **kwargs)
|
466 |
+
|
467 |
+
def loss(pred, tgt):
|
468 |
+
return -0.9 * snr(pred, tgt).mean() - 0.1 * si_snr(pred, tgt).mean()
|
469 |
+
|
470 |
+
def metrics(mixed, output, gt):
|
471 |
+
""" Function to compute metrics """
|
472 |
+
metrics = {}
|
473 |
+
|
474 |
+
def metric_i(metric, src, pred, tgt):
|
475 |
+
_vals = []
|
476 |
+
for s, t, p in zip(src, tgt, pred):
|
477 |
+
_vals.append((metric(p, t) - metric(s, t)).cpu().item())
|
478 |
+
return _vals
|
479 |
+
|
480 |
+
for m_fn in [snr, si_snr]:
|
481 |
+
metrics[m_fn.__name__] = metric_i(m_fn,
|
482 |
+
mixed[:, :gt.shape[1], :],
|
483 |
+
output,
|
484 |
+
gt)
|
485 |
+
|
486 |
+
return metrics
|
src/training/eval.py
ADDED
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Test script to evaluate the model.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import argparse
|
6 |
+
import importlib
|
7 |
+
import multiprocessing
|
8 |
+
import os, glob
|
9 |
+
import logging
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
import pandas as pd
|
14 |
+
import torch.nn as nn
|
15 |
+
from torch.utils.tensorboard import SummaryWriter
|
16 |
+
from torch.profiler import profile, record_function, ProfilerActivity
|
17 |
+
from tqdm import tqdm # pylint: disable=unused-import
|
18 |
+
from torchmetrics.functional import(
|
19 |
+
scale_invariant_signal_noise_ratio as si_snr,
|
20 |
+
signal_noise_ratio as snr,
|
21 |
+
signal_distortion_ratio as sdr,
|
22 |
+
scale_invariant_signal_distortion_ratio as si_sdr)
|
23 |
+
|
24 |
+
from src.helpers import utils
|
25 |
+
from src.training.synthetic_dataset import FSDSoundScapesDataset, tensorboard_add_metrics
|
26 |
+
from src.training.synthetic_dataset import tensorboard_add_sample
|
27 |
+
|
28 |
+
def test_epoch(model: nn.Module, device: torch.device,
|
29 |
+
test_loader: torch.utils.data.dataloader.DataLoader,
|
30 |
+
n_items: int, loss_fn, metrics_fn,
|
31 |
+
profiling: bool = False, epoch: int = 0,
|
32 |
+
writer: SummaryWriter = None, data_params = None) -> float:
|
33 |
+
"""
|
34 |
+
Evaluate the network.
|
35 |
+
"""
|
36 |
+
model.eval()
|
37 |
+
metrics = {}
|
38 |
+
|
39 |
+
with torch.no_grad():
|
40 |
+
for batch_idx, (mixed, label, gt) in \
|
41 |
+
enumerate(tqdm(test_loader, desc='Test', ncols=100)):
|
42 |
+
mixed = mixed.to(device)
|
43 |
+
label = label.to(device)
|
44 |
+
gt = gt.to(device)
|
45 |
+
|
46 |
+
# Run through the model
|
47 |
+
with profile(activities=[ProfilerActivity.CPU],
|
48 |
+
record_shapes=True) as prof:
|
49 |
+
with record_function("model_inference"):
|
50 |
+
output = model(mixed, label)
|
51 |
+
if profiling:
|
52 |
+
logging.info(
|
53 |
+
prof.key_averages().table(sort_by="self_cpu_time_total",
|
54 |
+
row_limit=20))
|
55 |
+
|
56 |
+
# Compute loss
|
57 |
+
loss = loss_fn(output, gt)
|
58 |
+
|
59 |
+
# Compute metrics
|
60 |
+
metrics_batch = metrics_fn(mixed, output, gt)
|
61 |
+
metrics_batch['loss'] = [loss.item()]
|
62 |
+
metrics_batch['runtime'] = [prof.profiler.self_cpu_time_total/1000]
|
63 |
+
for k in metrics_batch.keys():
|
64 |
+
if not k in metrics:
|
65 |
+
metrics[k] = metrics_batch[k]
|
66 |
+
else:
|
67 |
+
metrics[k] += metrics_batch[k]
|
68 |
+
|
69 |
+
if writer is not None:
|
70 |
+
if batch_idx == 0:
|
71 |
+
tensorboard_add_sample(
|
72 |
+
writer, tag='Test',
|
73 |
+
sample=(mixed[:8], label[:8], gt[:8], output[:8]),
|
74 |
+
step=epoch, params=data_params)
|
75 |
+
tensorboard_add_metrics(
|
76 |
+
writer, tag='Test', metrics=metrics_batch, label=label,
|
77 |
+
step=epoch)
|
78 |
+
|
79 |
+
if n_items is not None and batch_idx == (n_items - 1):
|
80 |
+
break
|
81 |
+
|
82 |
+
avg_metrics = {k: np.mean(metrics[k]) for k in metrics.keys()}
|
83 |
+
avg_metrics_str = "Test:"
|
84 |
+
for m in avg_metrics.keys():
|
85 |
+
avg_metrics_str += ' %s=%.04f' % (m, avg_metrics[m])
|
86 |
+
logging.info(avg_metrics_str)
|
87 |
+
|
88 |
+
return avg_metrics
|
89 |
+
|
90 |
+
def evaluate(network, args: argparse.Namespace):
|
91 |
+
"""
|
92 |
+
Evaluate the model on a given dataset.
|
93 |
+
"""
|
94 |
+
|
95 |
+
# Load dataset
|
96 |
+
data_test = FSDSoundScapesDataset(**args.test_data)
|
97 |
+
logging.info("Loaded test dataset at %s containing %d elements" %
|
98 |
+
(args.test_data['input_dir'], len(data_test)))
|
99 |
+
|
100 |
+
# Set up the device and workers.
|
101 |
+
use_cuda = args.use_cuda and torch.cuda.is_available()
|
102 |
+
if use_cuda:
|
103 |
+
gpu_ids = args.gpu_ids if args.gpu_ids is not None\
|
104 |
+
else range(torch.cuda.device_count())
|
105 |
+
device_ids = [_ for _ in gpu_ids]
|
106 |
+
data_parallel = len(device_ids) > 1
|
107 |
+
device = 'cuda:%d' % device_ids[0]
|
108 |
+
torch.cuda.set_device(device_ids[0])
|
109 |
+
logging.info("Using CUDA devices: %s" % str(device_ids))
|
110 |
+
else:
|
111 |
+
data_parallel = False
|
112 |
+
device = torch.device('cpu')
|
113 |
+
logging.info("Using device: CPU")
|
114 |
+
|
115 |
+
# Set multiprocessing params
|
116 |
+
num_workers = min(multiprocessing.cpu_count(), args.n_workers)
|
117 |
+
kwargs = {
|
118 |
+
'num_workers': num_workers,
|
119 |
+
'pin_memory': True
|
120 |
+
} if use_cuda else {}
|
121 |
+
|
122 |
+
# Set up data loader
|
123 |
+
test_loader = torch.utils.data.DataLoader(data_test,
|
124 |
+
batch_size=args.eval_batch_size,
|
125 |
+
**kwargs)
|
126 |
+
|
127 |
+
# Set up model
|
128 |
+
model = network.Net(**args.model_params)
|
129 |
+
if use_cuda and data_parallel:
|
130 |
+
model = nn.DataParallel(model, device_ids=device_ids)
|
131 |
+
logging.info("Using data parallel model")
|
132 |
+
model.to(device)
|
133 |
+
|
134 |
+
# Load weights
|
135 |
+
if args.pretrain_path == "best":
|
136 |
+
ckpts = glob.glob(os.path.join(args.exp_dir, '*.pt'))
|
137 |
+
ckpts.sort(
|
138 |
+
key=lambda _: int(os.path.splitext(os.path.basename(_))[0]))
|
139 |
+
val_metrics = torch.load(ckpts[-1])['val_metrics'][args.base_metric]
|
140 |
+
best_epoch = max(range(len(val_metrics)), key=val_metrics.__getitem__)
|
141 |
+
args.pretrain_path = os.path.join(args.exp_dir, '%d.pt' % best_epoch)
|
142 |
+
logging.info(
|
143 |
+
"Found 'best' validation %s=%.02f at %s" %
|
144 |
+
(args.base_metric, val_metrics[best_epoch], args.pretrain_path))
|
145 |
+
if args.pretrain_path != "":
|
146 |
+
utils.load_checkpoint(
|
147 |
+
args.pretrain_path, model, data_parallel=data_parallel)
|
148 |
+
logging.info("Loaded pretrain weights from %s" % args.pretrain_path)
|
149 |
+
|
150 |
+
# Evaluate
|
151 |
+
try:
|
152 |
+
return test_epoch(
|
153 |
+
model, device, test_loader, args.n_items, network.loss,
|
154 |
+
network.metrics, args.profiling)
|
155 |
+
except KeyboardInterrupt:
|
156 |
+
print("Interrupted")
|
157 |
+
except Exception as _: # pylint: disable=broad-except
|
158 |
+
import traceback # pylint: disable=import-outside-toplevel
|
159 |
+
traceback.print_exc()
|
160 |
+
|
161 |
+
|
162 |
+
if __name__ == '__main__':
|
163 |
+
parser = argparse.ArgumentParser()
|
164 |
+
# Data Params
|
165 |
+
parser.add_argument('experiments', nargs='+', type=str,
|
166 |
+
default=None,
|
167 |
+
help="List of experiments to evaluate. "
|
168 |
+
"Provide only one experiment when providing "
|
169 |
+
"pretrained path. If pretrianed path is not "
|
170 |
+
"provided, epoch with best validation metric "
|
171 |
+
"is used for evaluation.")
|
172 |
+
parser.add_argument('--results', type=str, default="",
|
173 |
+
help="Path to the CSV file to store results.")
|
174 |
+
|
175 |
+
# System params
|
176 |
+
parser.add_argument('--n_items', type=int, default=None,
|
177 |
+
help="Number of items to test.")
|
178 |
+
parser.add_argument('--pretrain_path', type=str, default="best",
|
179 |
+
help="Path to pretrained weights")
|
180 |
+
parser.add_argument('--profiling', dest='profiling', action='store_true',
|
181 |
+
help="Enable or disable profiling.")
|
182 |
+
parser.add_argument('--use_cuda', dest='use_cuda', action='store_true',
|
183 |
+
help="Whether to use cuda")
|
184 |
+
parser.add_argument('--gpu_ids', nargs='+', type=int, default=None,
|
185 |
+
help="List of GPU ids used for training. "
|
186 |
+
"Eg., --gpu_ids 2 4. All GPUs are used by default.")
|
187 |
+
args = parser.parse_args()
|
188 |
+
|
189 |
+
results = []
|
190 |
+
|
191 |
+
for exp_dir in args.experiments:
|
192 |
+
eval_args = argparse.Namespace(**vars(args))
|
193 |
+
eval_args.exp_dir = exp_dir
|
194 |
+
|
195 |
+
utils.set_logger(os.path.join(exp_dir, 'eval.log'))
|
196 |
+
logging.info("Evaluating %s ..." % exp_dir)
|
197 |
+
|
198 |
+
# Load model and training params
|
199 |
+
params = utils.Params(os.path.join(exp_dir, 'config.json'))
|
200 |
+
for k, v in params.__dict__.items():
|
201 |
+
vars(eval_args)[k] = v
|
202 |
+
|
203 |
+
network = importlib.import_module(eval_args.model)
|
204 |
+
logging.info("Imported the model from '%s'." % eval_args.model)
|
205 |
+
|
206 |
+
curr_res = evaluate(network, eval_args)
|
207 |
+
curr_res['experiment'] = os.path.basename(exp_dir)
|
208 |
+
results.append(curr_res)
|
209 |
+
|
210 |
+
del eval_args
|
211 |
+
|
212 |
+
if args.results != "":
|
213 |
+
print("Writing results to %s" % args.results)
|
214 |
+
pd.DataFrame(results).to_csv(args.results, index=False)
|
src/training/synthetic_dataset.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Torch dataset object for synthetically rendered spatial data.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import os
|
6 |
+
import json
|
7 |
+
import random
|
8 |
+
from pathlib import Path
|
9 |
+
import logging
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
import pandas as pd
|
13 |
+
import matplotlib.pyplot as plt
|
14 |
+
import scaper
|
15 |
+
import torch
|
16 |
+
import torchaudio
|
17 |
+
import torchaudio.transforms as AT
|
18 |
+
from random import randrange
|
19 |
+
|
20 |
+
class FSDSoundScapesDataset(torch.utils.data.Dataset): # type: ignore
|
21 |
+
"""
|
22 |
+
Base class for FSD Sound Scapes dataset
|
23 |
+
"""
|
24 |
+
|
25 |
+
_labels = [
|
26 |
+
"Acoustic_guitar", "Applause", "Bark", "Bass_drum",
|
27 |
+
"Burping_or_eructation", "Bus", "Cello", "Chime", "Clarinet",
|
28 |
+
"Computer_keyboard", "Cough", "Cowbell", "Double_bass",
|
29 |
+
"Drawer_open_or_close", "Electric_piano", "Fart", "Finger_snapping",
|
30 |
+
"Fireworks", "Flute", "Glockenspiel", "Gong", "Gunshot_or_gunfire",
|
31 |
+
"Harmonica", "Hi-hat", "Keys_jangling", "Knock", "Laughter", "Meow",
|
32 |
+
"Microwave_oven", "Oboe", "Saxophone", "Scissors", "Shatter",
|
33 |
+
"Snare_drum", "Squeak", "Tambourine", "Tearing", "Telephone",
|
34 |
+
"Trumpet", "Violin_or_fiddle", "Writing"]
|
35 |
+
|
36 |
+
def __init__(self, input_dir, dset='', sr=None,
|
37 |
+
resample_rate=None, max_num_targets=1):
|
38 |
+
assert dset in ['train', 'val', 'test'], \
|
39 |
+
"`dset` must be one of ['train', 'val', 'test']"
|
40 |
+
self.dset = dset
|
41 |
+
self.max_num_targets = max_num_targets
|
42 |
+
self.fg_dir = os.path.join(input_dir, 'FSDKaggle2018/%s' % dset)
|
43 |
+
if dset in ['train', 'val']:
|
44 |
+
self.bg_dir = os.path.join(
|
45 |
+
input_dir,
|
46 |
+
'TAU-acoustic-sounds/'
|
47 |
+
'TAU-urban-acoustic-scenes-2019-development')
|
48 |
+
else:
|
49 |
+
self.bg_dir = os.path.join(
|
50 |
+
input_dir,
|
51 |
+
'TAU-acoustic-sounds/'
|
52 |
+
'TAU-urban-acoustic-scenes-2019-evaluation')
|
53 |
+
logging.info("Loading %s dataset: fg_dir=%s bg_dir=%s" %
|
54 |
+
(dset, self.fg_dir, self.bg_dir))
|
55 |
+
|
56 |
+
self.samples = sorted(list(
|
57 |
+
Path(os.path.join(input_dir, 'jams', dset)).glob('[0-9]*')))
|
58 |
+
|
59 |
+
jamsfile = os.path.join(self.samples[0], 'mixture.jams')
|
60 |
+
_, jams, _, _ = scaper.generate_from_jams(
|
61 |
+
jamsfile, fg_path=self.fg_dir, bg_path=self.bg_dir)
|
62 |
+
_sr = jams['annotations'][0]['sandbox']['scaper']['sr']
|
63 |
+
assert _sr == sr, "Sampling rate provided does not match the data"
|
64 |
+
|
65 |
+
if resample_rate is not None:
|
66 |
+
self.resampler = AT.Resample(sr, resample_rate)
|
67 |
+
self.sr = resample_rate
|
68 |
+
else:
|
69 |
+
self.resampler = lambda a: a
|
70 |
+
self.sr = sr
|
71 |
+
|
72 |
+
def _get_label_vector(self, labels):
|
73 |
+
"""
|
74 |
+
Generates a multi-hot vector corresponding to `labels`.
|
75 |
+
"""
|
76 |
+
vector = torch.zeros(len(FSDSoundScapesDataset._labels))
|
77 |
+
|
78 |
+
for label in labels:
|
79 |
+
idx = FSDSoundScapesDataset._labels.index(label)
|
80 |
+
assert vector[idx] == 0, "Repeated labels"
|
81 |
+
vector[idx] = 1
|
82 |
+
|
83 |
+
return vector
|
84 |
+
|
85 |
+
def __len__(self):
|
86 |
+
return len(self.samples)
|
87 |
+
|
88 |
+
def __getitem__(self, idx):
|
89 |
+
sample_path = self.samples[idx]
|
90 |
+
jamsfile = os.path.join(sample_path, 'mixture.jams')
|
91 |
+
|
92 |
+
mixture, jams, ann_list, event_audio_list = scaper.generate_from_jams(
|
93 |
+
jamsfile, fg_path=self.fg_dir, bg_path=self.bg_dir)
|
94 |
+
isolated_events = {}
|
95 |
+
for e, a in zip(ann_list, event_audio_list[1:]):
|
96 |
+
# 0th event is background
|
97 |
+
isolated_events[e[2]] = a
|
98 |
+
gt_events = list(pd.read_csv(
|
99 |
+
os.path.join(sample_path, 'gt_events.csv'), sep='\t')['label'])
|
100 |
+
|
101 |
+
mixture = torch.from_numpy(mixture).permute(1, 0)
|
102 |
+
mixture = self.resampler(mixture.to(torch.float))
|
103 |
+
|
104 |
+
if self.dset == 'train':
|
105 |
+
labels = random.sample(gt_events, randrange(1,self.max_num_targets+1))
|
106 |
+
elif self.dset == 'val':
|
107 |
+
labels = gt_events[:idx%self.max_num_targets+1]
|
108 |
+
elif self.dset == 'test':
|
109 |
+
labels = gt_events[:self.max_num_targets]
|
110 |
+
label_vector = self._get_label_vector(labels)
|
111 |
+
|
112 |
+
gt = torch.zeros_like(
|
113 |
+
torch.from_numpy(event_audio_list[1]).permute(1, 0))
|
114 |
+
for l in labels:
|
115 |
+
gt = gt + torch.from_numpy(isolated_events[l]).permute(1, 0)
|
116 |
+
gt = self.resampler(gt.to(torch.float))
|
117 |
+
|
118 |
+
return mixture, label_vector, gt #, jams
|
119 |
+
|
120 |
+
def tensorboard_add_sample(writer, tag, sample, step, params):
|
121 |
+
"""
|
122 |
+
Adds a sample of FSDSynthDataset to tensorboard.
|
123 |
+
"""
|
124 |
+
if params['resample_rate'] is not None:
|
125 |
+
sr = params['resample_rate']
|
126 |
+
else:
|
127 |
+
sr = params['sr']
|
128 |
+
resample_rate = 16000 if sr > 16000 else sr
|
129 |
+
|
130 |
+
m, l, gt, o = sample
|
131 |
+
m, gt, o = (
|
132 |
+
torchaudio.functional.resample(_, sr, resample_rate).cpu()
|
133 |
+
for _ in (m, gt, o))
|
134 |
+
|
135 |
+
def _add_audio(a, audio_tag, axis, plt_title):
|
136 |
+
for i, ch in enumerate(a):
|
137 |
+
axis.plot(ch, label='mic %d' % i)
|
138 |
+
writer.add_audio(
|
139 |
+
'%s/mic %d' % (audio_tag, i), ch.unsqueeze(0), step, resample_rate)
|
140 |
+
axis.set_title(plt_title)
|
141 |
+
axis.legend()
|
142 |
+
|
143 |
+
for b in range(m.shape[0]):
|
144 |
+
label = []
|
145 |
+
for i in range(len(l[b, :])):
|
146 |
+
if l[b, i] == 1:
|
147 |
+
label.append(FSDSoundScapesDataset._labels[i])
|
148 |
+
|
149 |
+
# Add waveforms
|
150 |
+
rows = 3 # input, output, gt
|
151 |
+
fig = plt.figure(figsize=(10, 2 * rows))
|
152 |
+
axes = fig.subplots(rows, 1, sharex=True)
|
153 |
+
_add_audio(m[b], '%s/sample_%d/0_input' % (tag, b), axes[0], "Mixed")
|
154 |
+
_add_audio(o[b], '%s/sample_%d/1_output' % (tag, b), axes[1], "Output (%s)" % label)
|
155 |
+
_add_audio(gt[b], '%s/sample_%d/2_gt' % (tag, b), axes[2], "GT (%s)" % label)
|
156 |
+
writer.add_figure('%s/sample_%d/waveform' % (tag, b), fig, step)
|
157 |
+
|
158 |
+
def tensorboard_add_metrics(writer, tag, metrics, label, step):
|
159 |
+
"""
|
160 |
+
Add metrics to tensorboard.
|
161 |
+
"""
|
162 |
+
vals = np.asarray(metrics['scale_invariant_signal_noise_ratio'])
|
163 |
+
|
164 |
+
writer.add_histogram('%s/%s' % (tag, 'SI-SNRi'), vals, step)
|
165 |
+
|
166 |
+
label_names = [FSDSoundScapesDataset._labels[torch.argmax(_)] for _ in label]
|
167 |
+
for l, v in zip(label_names, vals):
|
168 |
+
writer.add_histogram('%s/%s' % (tag, l), v, step)
|
src/training/train.py
ADDED
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
The main training script for training on synthetic data
|
3 |
+
"""
|
4 |
+
|
5 |
+
import argparse
|
6 |
+
import multiprocessing
|
7 |
+
import os
|
8 |
+
import logging
|
9 |
+
from pathlib import Path
|
10 |
+
import random
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
import torch.nn as nn
|
15 |
+
import torch.nn.functional as F
|
16 |
+
import torch.optim as optim
|
17 |
+
from torch.utils.tensorboard import SummaryWriter
|
18 |
+
from tqdm import tqdm # pylint: disable=unused-import
|
19 |
+
from torchmetrics.functional import(
|
20 |
+
scale_invariant_signal_noise_ratio as si_snr,
|
21 |
+
signal_noise_ratio as snr,
|
22 |
+
signal_distortion_ratio as sdr,
|
23 |
+
scale_invariant_signal_distortion_ratio as si_sdr)
|
24 |
+
|
25 |
+
from src.helpers import utils
|
26 |
+
from src.training.eval import test_epoch
|
27 |
+
from src.training.synthetic_dataset import FSDSoundScapesDataset as Dataset
|
28 |
+
from src.training.synthetic_dataset import tensorboard_add_sample
|
29 |
+
|
30 |
+
def train_epoch(model: nn.Module, device: torch.device,
|
31 |
+
optimizer: optim.Optimizer,
|
32 |
+
train_loader: torch.utils.data.dataloader.DataLoader,
|
33 |
+
n_items: int, epoch: int = 0,
|
34 |
+
writer: SummaryWriter = None, data_params = None) -> float:
|
35 |
+
|
36 |
+
"""
|
37 |
+
Train a single epoch.
|
38 |
+
"""
|
39 |
+
# Set the model to training.
|
40 |
+
model.train()
|
41 |
+
|
42 |
+
# Training loop
|
43 |
+
losses = []
|
44 |
+
metrics = {}
|
45 |
+
|
46 |
+
with tqdm(total=len(train_loader), desc='Train', ncols=100) as t:
|
47 |
+
for batch_idx, (mixed, label, gt) in enumerate(train_loader):
|
48 |
+
mixed = mixed.to(device)
|
49 |
+
label = label.to(device)
|
50 |
+
gt = gt.to(device)
|
51 |
+
|
52 |
+
# Reset grad
|
53 |
+
optimizer.zero_grad()
|
54 |
+
|
55 |
+
# Run through the model
|
56 |
+
output = model(mixed, label)
|
57 |
+
|
58 |
+
# Compute loss
|
59 |
+
loss = network.loss(output, gt)
|
60 |
+
|
61 |
+
losses.append(loss.item())
|
62 |
+
|
63 |
+
# Backpropagation
|
64 |
+
loss.backward()
|
65 |
+
|
66 |
+
# Gradient clipping
|
67 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
|
68 |
+
|
69 |
+
# Update the weights
|
70 |
+
optimizer.step()
|
71 |
+
|
72 |
+
metrics_batch = network.metrics(mixed.detach(), output.detach(),
|
73 |
+
gt.detach())
|
74 |
+
for k in metrics_batch.keys():
|
75 |
+
if not k in metrics:
|
76 |
+
metrics[k] = metrics_batch[k]
|
77 |
+
else:
|
78 |
+
metrics[k] += metrics_batch[k]
|
79 |
+
|
80 |
+
if writer is not None and batch_idx == 0:
|
81 |
+
tensorboard_add_sample(
|
82 |
+
writer, tag='Train',
|
83 |
+
sample=(mixed.detach()[:8], label.detach()[:8],
|
84 |
+
gt.detach()[:8], output.detach()[:8]),
|
85 |
+
step=epoch, params=data_params)
|
86 |
+
|
87 |
+
# Show current loss in the progress meter
|
88 |
+
t.set_postfix(loss='%.05f'%loss.item())
|
89 |
+
t.update()
|
90 |
+
|
91 |
+
if n_items is not None and batch_idx == n_items:
|
92 |
+
break
|
93 |
+
|
94 |
+
avg_metrics = {k: np.mean(metrics[k]) for k in metrics.keys()}
|
95 |
+
avg_metrics['loss'] = np.mean(losses)
|
96 |
+
avg_metrics_str = "Train:"
|
97 |
+
for m in avg_metrics.keys():
|
98 |
+
avg_metrics_str += ' %s=%.04f' % (m, avg_metrics[m])
|
99 |
+
logging.info(avg_metrics_str)
|
100 |
+
|
101 |
+
return avg_metrics
|
102 |
+
|
103 |
+
|
104 |
+
def train(args: argparse.Namespace):
|
105 |
+
"""
|
106 |
+
Train the network.
|
107 |
+
"""
|
108 |
+
|
109 |
+
# Load dataset
|
110 |
+
data_train = Dataset(**args.train_data)
|
111 |
+
logging.info("Loaded train dataset at %s containing %d elements" %
|
112 |
+
(args.train_data['input_dir'], len(data_train)))
|
113 |
+
data_val = Dataset(**args.val_data)
|
114 |
+
logging.info("Loaded test dataset at %s containing %d elements" %
|
115 |
+
(args.val_data['input_dir'], len(data_val)))
|
116 |
+
|
117 |
+
# Set up the device and workers.
|
118 |
+
use_cuda = args.use_cuda and torch.cuda.is_available()
|
119 |
+
if use_cuda:
|
120 |
+
gpu_ids = args.gpu_ids if args.gpu_ids is not None\
|
121 |
+
else range(torch.cuda.device_count())
|
122 |
+
device_ids = [_ for _ in gpu_ids]
|
123 |
+
data_parallel = len(device_ids) > 1
|
124 |
+
device = 'cuda:%d' % device_ids[0]
|
125 |
+
torch.cuda.set_device(device_ids[0])
|
126 |
+
logging.info("Using CUDA devices: %s" % str(device_ids))
|
127 |
+
else:
|
128 |
+
data_parallel = False
|
129 |
+
device = torch.device('cpu')
|
130 |
+
logging.info("Using device: CPU")
|
131 |
+
|
132 |
+
# Set multiprocessing params
|
133 |
+
num_workers = min(multiprocessing.cpu_count(), args.n_workers)
|
134 |
+
kwargs = {
|
135 |
+
'num_workers': num_workers,
|
136 |
+
'pin_memory': True
|
137 |
+
} if use_cuda else {}
|
138 |
+
|
139 |
+
# Set up data loaders
|
140 |
+
#print(args.batch_size, args.eval_batch_size)
|
141 |
+
train_loader = torch.utils.data.DataLoader(data_train,
|
142 |
+
batch_size=args.batch_size,
|
143 |
+
shuffle=True, **kwargs)
|
144 |
+
val_loader = torch.utils.data.DataLoader(data_val,
|
145 |
+
batch_size=args.eval_batch_size,
|
146 |
+
**kwargs)
|
147 |
+
|
148 |
+
# Set up model
|
149 |
+
model = network.Net(**args.model_params)
|
150 |
+
|
151 |
+
# Add graph to tensorboard with example train samples
|
152 |
+
# _mixed, _label, _ = next(iter(val_loader))
|
153 |
+
# args.writer.add_graph(model, (_mixed, _label))
|
154 |
+
|
155 |
+
if use_cuda and data_parallel:
|
156 |
+
model = nn.DataParallel(model, device_ids=device_ids)
|
157 |
+
logging.info("Using data parallel model")
|
158 |
+
model.to(device)
|
159 |
+
|
160 |
+
# Set up the optimizer
|
161 |
+
logging.info("Initializing optimizer with %s" % str(args.optim))
|
162 |
+
optimizer = network.optimizer(model, **args.optim, data_parallel=data_parallel)
|
163 |
+
logging.info('Learning rates initialized to:' + utils.format_lr_info(optimizer))
|
164 |
+
|
165 |
+
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
166 |
+
optimizer, **args.lr_sched)
|
167 |
+
logging.info("Initialized LR scheduler with params: fix_lr_epochs=%d %s"
|
168 |
+
% (args.fix_lr_epochs, str(args.lr_sched)))
|
169 |
+
|
170 |
+
base_metric = args.base_metric
|
171 |
+
train_metrics = {}
|
172 |
+
val_metrics = {}
|
173 |
+
|
174 |
+
# Load the model if `args.start_epoch` is greater than 0. This will load the
|
175 |
+
# model from epoch = `args.start_epoch - 1`
|
176 |
+
assert args.start_epoch >=0, "start_epoch must be greater than 0."
|
177 |
+
if args.start_epoch > 0:
|
178 |
+
checkpoint_path = os.path.join(args.exp_dir,
|
179 |
+
'%d.pt' % (args.start_epoch - 1))
|
180 |
+
_, train_metrics, val_metrics = utils.load_checkpoint(
|
181 |
+
checkpoint_path, model, optim=optimizer, lr_sched=lr_scheduler,
|
182 |
+
data_parallel=data_parallel)
|
183 |
+
logging.info("Loaded checkpoint from %s" % checkpoint_path)
|
184 |
+
logging.info("Learning rates restored to:" + utils.format_lr_info(optimizer))
|
185 |
+
|
186 |
+
# Training loop
|
187 |
+
try:
|
188 |
+
torch.autograd.set_detect_anomaly(args.detect_anomaly)
|
189 |
+
for epoch in range(args.start_epoch, args.epochs + 1):
|
190 |
+
logging.info("Epoch %d:" % epoch)
|
191 |
+
checkpoint_file = os.path.join(args.exp_dir, '%d.pt' % epoch)
|
192 |
+
assert not os.path.exists(checkpoint_file), \
|
193 |
+
"Checkpoint file %s already exists" % checkpoint_file
|
194 |
+
#print("---- begin trianivg")
|
195 |
+
curr_train_metrics = train_epoch(model, device, optimizer,
|
196 |
+
train_loader, args.n_train_items,
|
197 |
+
epoch=epoch, writer=args.writer,
|
198 |
+
data_params=args.train_data)
|
199 |
+
#raise KeyboardInterrupt
|
200 |
+
curr_test_metrics = test_epoch(model, device, val_loader,
|
201 |
+
args.n_test_items, network.loss,
|
202 |
+
network.metrics, epoch=epoch,
|
203 |
+
writer=args.writer,
|
204 |
+
data_params=args.val_data)
|
205 |
+
# LR scheduler
|
206 |
+
if epoch >= args.fix_lr_epochs:
|
207 |
+
lr_scheduler.step(curr_test_metrics[base_metric])
|
208 |
+
logging.info(
|
209 |
+
"LR after scheduling step: %s" %
|
210 |
+
[_['lr'] for _ in optimizer.param_groups])
|
211 |
+
|
212 |
+
# Write metrics to tensorboard
|
213 |
+
args.writer.add_scalars('Train', curr_train_metrics, epoch)
|
214 |
+
args.writer.add_scalars('Val', curr_test_metrics, epoch)
|
215 |
+
args.writer.flush()
|
216 |
+
|
217 |
+
for k in curr_train_metrics.keys():
|
218 |
+
if not k in train_metrics:
|
219 |
+
train_metrics[k] = [curr_train_metrics[k]]
|
220 |
+
else:
|
221 |
+
train_metrics[k].append(curr_train_metrics[k])
|
222 |
+
|
223 |
+
for k in curr_test_metrics.keys():
|
224 |
+
if not k in val_metrics:
|
225 |
+
val_metrics[k] = [curr_test_metrics[k]]
|
226 |
+
else:
|
227 |
+
val_metrics[k].append(curr_test_metrics[k])
|
228 |
+
|
229 |
+
if max(val_metrics[base_metric]) == val_metrics[base_metric][-1]:
|
230 |
+
logging.info("Found best validation %s!" % base_metric)
|
231 |
+
|
232 |
+
utils.save_checkpoint(
|
233 |
+
checkpoint_file, epoch, model, optimizer, lr_scheduler,
|
234 |
+
train_metrics, val_metrics, data_parallel)
|
235 |
+
logging.info("Saved checkpoint at %s" % checkpoint_file)
|
236 |
+
|
237 |
+
utils.save_graph(train_metrics, val_metrics, args.exp_dir)
|
238 |
+
|
239 |
+
return train_metrics, val_metrics
|
240 |
+
|
241 |
+
|
242 |
+
except KeyboardInterrupt:
|
243 |
+
print("Interrupted")
|
244 |
+
except Exception as _: # pylint: disable=broad-except
|
245 |
+
import traceback # pylint: disable=import-outside-toplevel
|
246 |
+
traceback.print_exc()
|
247 |
+
|
248 |
+
|
249 |
+
if __name__ == '__main__':
|
250 |
+
parser = argparse.ArgumentParser()
|
251 |
+
# Data Params
|
252 |
+
parser.add_argument('exp_dir', type=str,
|
253 |
+
default='./experiments/fsd_mask_label_mult',
|
254 |
+
help="Path to save checkpoints and logs.")
|
255 |
+
|
256 |
+
parser.add_argument('--n_train_items', type=int, default=None,
|
257 |
+
help="Number of items to train on in each epoch")
|
258 |
+
parser.add_argument('--n_test_items', type=int, default=None,
|
259 |
+
help="Number of items to test.")
|
260 |
+
parser.add_argument('--start_epoch', type=int, default=0,
|
261 |
+
help="Start epoch")
|
262 |
+
parser.add_argument('--pretrain_path', type=str,
|
263 |
+
help="Path to pretrained weights")
|
264 |
+
parser.add_argument('--use_cuda', dest='use_cuda', action='store_true',
|
265 |
+
help="Whether to use cuda")
|
266 |
+
parser.add_argument('--gpu_ids', nargs='+', type=int, default=None,
|
267 |
+
help="List of GPU ids used for training. "
|
268 |
+
"Eg., --gpu_ids 2 4. All GPUs are used by default.")
|
269 |
+
parser.add_argument('--detect_anomaly', dest='detect_anomaly',
|
270 |
+
action='store_true',
|
271 |
+
help="Whether to use cuda")
|
272 |
+
parser.add_argument('--wandb', dest='wandb', action='store_true',
|
273 |
+
help="Whether to sync tensorboard to wandb")
|
274 |
+
|
275 |
+
args = parser.parse_args()
|
276 |
+
|
277 |
+
# Set the random seed for reproducible experiments
|
278 |
+
torch.manual_seed(230)
|
279 |
+
random.seed(230)
|
280 |
+
np.random.seed(230)
|
281 |
+
if args.use_cuda:
|
282 |
+
torch.cuda.manual_seed(230)
|
283 |
+
|
284 |
+
# Set up checkpoints
|
285 |
+
if not os.path.exists(args.exp_dir):
|
286 |
+
os.makedirs(args.exp_dir)
|
287 |
+
|
288 |
+
utils.set_logger(os.path.join(args.exp_dir, 'train.log'))
|
289 |
+
|
290 |
+
# Load model and training params
|
291 |
+
params = utils.Params(os.path.join(args.exp_dir, 'config.json'))
|
292 |
+
for k, v in params.__dict__.items():
|
293 |
+
vars(args)[k] = v
|
294 |
+
|
295 |
+
# Initialize tensorboard writer
|
296 |
+
tensorboard_dir = os.path.join(args.exp_dir, 'tensorboard')
|
297 |
+
args.writer = SummaryWriter(tensorboard_dir, purge_step=args.start_epoch)
|
298 |
+
if args.wandb:
|
299 |
+
import wandb
|
300 |
+
wandb.init(
|
301 |
+
project='Semaudio', sync_tensorboard=True,
|
302 |
+
dir=tensorboard_dir, name=os.path.basename(args.exp_dir))
|
303 |
+
|
304 |
+
exec("import %s as network" % args.model)
|
305 |
+
logging.info("Imported the model from '%s'." % args.model)
|
306 |
+
|
307 |
+
train(args)
|
308 |
+
|
309 |
+
args.writer.close()
|
310 |
+
if args.wandb:
|
311 |
+
wandb.finish()
|