Spaces:
Runtime error
Runtime error
""" | |
Torch dataset object for synthetically rendered spatial data. | |
""" | |
import os | |
import json | |
import random | |
from pathlib import Path | |
import logging | |
import numpy as np | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
import scaper | |
import torch | |
import torchaudio | |
import torchaudio.transforms as AT | |
from random import randrange | |
class FSDSoundScapesDataset(torch.utils.data.Dataset): # type: ignore | |
""" | |
Base class for FSD Sound Scapes dataset | |
""" | |
_labels = [ | |
"Acoustic_guitar", "Applause", "Bark", "Bass_drum", | |
"Burping_or_eructation", "Bus", "Cello", "Chime", "Clarinet", | |
"Computer_keyboard", "Cough", "Cowbell", "Double_bass", | |
"Drawer_open_or_close", "Electric_piano", "Fart", "Finger_snapping", | |
"Fireworks", "Flute", "Glockenspiel", "Gong", "Gunshot_or_gunfire", | |
"Harmonica", "Hi-hat", "Keys_jangling", "Knock", "Laughter", "Meow", | |
"Microwave_oven", "Oboe", "Saxophone", "Scissors", "Shatter", | |
"Snare_drum", "Squeak", "Tambourine", "Tearing", "Telephone", | |
"Trumpet", "Violin_or_fiddle", "Writing"] | |
def __init__(self, input_dir, dset='', sr=None, | |
resample_rate=None, max_num_targets=1): | |
assert dset in ['train', 'val', 'test'], \ | |
"`dset` must be one of ['train', 'val', 'test']" | |
self.dset = dset | |
self.max_num_targets = max_num_targets | |
self.fg_dir = os.path.join(input_dir, 'FSDKaggle2018/%s' % dset) | |
if dset in ['train', 'val']: | |
self.bg_dir = os.path.join( | |
input_dir, | |
'TAU-acoustic-sounds/' | |
'TAU-urban-acoustic-scenes-2019-development') | |
else: | |
self.bg_dir = os.path.join( | |
input_dir, | |
'TAU-acoustic-sounds/' | |
'TAU-urban-acoustic-scenes-2019-evaluation') | |
logging.info("Loading %s dataset: fg_dir=%s bg_dir=%s" % | |
(dset, self.fg_dir, self.bg_dir)) | |
self.samples = sorted(list( | |
Path(os.path.join(input_dir, 'jams', dset)).glob('[0-9]*'))) | |
jamsfile = os.path.join(self.samples[0], 'mixture.jams') | |
_, jams, _, _ = scaper.generate_from_jams( | |
jamsfile, fg_path=self.fg_dir, bg_path=self.bg_dir) | |
_sr = jams['annotations'][0]['sandbox']['scaper']['sr'] | |
assert _sr == sr, "Sampling rate provided does not match the data" | |
if resample_rate is not None: | |
self.resampler = AT.Resample(sr, resample_rate) | |
self.sr = resample_rate | |
else: | |
self.resampler = lambda a: a | |
self.sr = sr | |
def _get_label_vector(self, labels): | |
""" | |
Generates a multi-hot vector corresponding to `labels`. | |
""" | |
vector = torch.zeros(len(FSDSoundScapesDataset._labels)) | |
for label in labels: | |
idx = FSDSoundScapesDataset._labels.index(label) | |
assert vector[idx] == 0, "Repeated labels" | |
vector[idx] = 1 | |
return vector | |
def __len__(self): | |
return len(self.samples) | |
def __getitem__(self, idx): | |
sample_path = self.samples[idx] | |
jamsfile = os.path.join(sample_path, 'mixture.jams') | |
mixture, jams, ann_list, event_audio_list = scaper.generate_from_jams( | |
jamsfile, fg_path=self.fg_dir, bg_path=self.bg_dir) | |
isolated_events = {} | |
for e, a in zip(ann_list, event_audio_list[1:]): | |
# 0th event is background | |
isolated_events[e[2]] = a | |
gt_events = list(pd.read_csv( | |
os.path.join(sample_path, 'gt_events.csv'), sep='\t')['label']) | |
mixture = torch.from_numpy(mixture).permute(1, 0) | |
mixture = self.resampler(mixture.to(torch.float)) | |
if self.dset == 'train': | |
labels = random.sample(gt_events, randrange(1,self.max_num_targets+1)) | |
elif self.dset == 'val': | |
labels = gt_events[:idx%self.max_num_targets+1] | |
elif self.dset == 'test': | |
labels = gt_events[:self.max_num_targets] | |
label_vector = self._get_label_vector(labels) | |
gt = torch.zeros_like( | |
torch.from_numpy(event_audio_list[1]).permute(1, 0)) | |
for l in labels: | |
gt = gt + torch.from_numpy(isolated_events[l]).permute(1, 0) | |
gt = self.resampler(gt.to(torch.float)) | |
return mixture, label_vector, gt #, jams | |
def tensorboard_add_sample(writer, tag, sample, step, params): | |
""" | |
Adds a sample of FSDSynthDataset to tensorboard. | |
""" | |
if params['resample_rate'] is not None: | |
sr = params['resample_rate'] | |
else: | |
sr = params['sr'] | |
resample_rate = 16000 if sr > 16000 else sr | |
m, l, gt, o = sample | |
m, gt, o = ( | |
torchaudio.functional.resample(_, sr, resample_rate).cpu() | |
for _ in (m, gt, o)) | |
def _add_audio(a, audio_tag, axis, plt_title): | |
for i, ch in enumerate(a): | |
axis.plot(ch, label='mic %d' % i) | |
writer.add_audio( | |
'%s/mic %d' % (audio_tag, i), ch.unsqueeze(0), step, resample_rate) | |
axis.set_title(plt_title) | |
axis.legend() | |
for b in range(m.shape[0]): | |
label = [] | |
for i in range(len(l[b, :])): | |
if l[b, i] == 1: | |
label.append(FSDSoundScapesDataset._labels[i]) | |
# Add waveforms | |
rows = 3 # input, output, gt | |
fig = plt.figure(figsize=(10, 2 * rows)) | |
axes = fig.subplots(rows, 1, sharex=True) | |
_add_audio(m[b], '%s/sample_%d/0_input' % (tag, b), axes[0], "Mixed") | |
_add_audio(o[b], '%s/sample_%d/1_output' % (tag, b), axes[1], "Output (%s)" % label) | |
_add_audio(gt[b], '%s/sample_%d/2_gt' % (tag, b), axes[2], "GT (%s)" % label) | |
writer.add_figure('%s/sample_%d/waveform' % (tag, b), fig, step) | |
def tensorboard_add_metrics(writer, tag, metrics, label, step): | |
""" | |
Add metrics to tensorboard. | |
""" | |
vals = np.asarray(metrics['scale_invariant_signal_noise_ratio']) | |
writer.add_histogram('%s/%s' % (tag, 'SI-SNRi'), vals, step) | |
label_names = [FSDSoundScapesDataset._labels[torch.argmax(_)] for _ in label] | |
for l, v in zip(label_names, vals): | |
writer.add_histogram('%s/%s' % (tag, l), v, step) | |