denoiser / dataset.py
azamat's picture
Init
33e3a91
raw
history blame contribute delete
No virus
4.64 kB
# Copyright (c) 2022 NVIDIA CORPORATION.
# Licensed under the MIT license.
import os
import numpy as np
from scipy.io.wavfile import read as wavread
import warnings
warnings.filterwarnings("ignore")
import torch
from torch.utils.data import Dataset
from torch.utils.data.distributed import DistributedSampler
import random
random.seed(0)
torch.manual_seed(0)
np.random.seed(0)
from torchvision import datasets, models, transforms
import torchaudio
class CleanNoisyPairDataset(Dataset):
"""
Create a Dataset of clean and noisy audio pairs.
Each element is a tuple of the form (clean waveform, noisy waveform, file_id)
"""
def __init__(self, root='./', subset='training', crop_length_sec=0):
super(CleanNoisyPairDataset).__init__()
assert subset is None or subset in ["training", "testing"]
self.crop_length_sec = crop_length_sec
self.subset = subset
N_clean = len(os.listdir(os.path.join(root, 'training_set/clean')))
N_noisy = len(os.listdir(os.path.join(root, 'training_set/noisy')))
assert N_clean == N_noisy
if subset == "training":
self.files = [(os.path.join(root, 'training_set/clean', 'fileid_{}.wav'.format(i)),
os.path.join(root, 'training_set/noisy', 'fileid_{}.wav'.format(i))) for i in range(N_clean)]
elif subset == "testing":
sortkey = lambda name: '_'.join(name.split('_')[-2:]) # specific for dns due to test sample names
_p = os.path.join(root, 'datasets/test_set/synthetic/no_reverb') # path for DNS
clean_files = os.listdir(os.path.join(_p, 'clean'))
noisy_files = os.listdir(os.path.join(_p, 'noisy'))
clean_files.sort(key=sortkey)
noisy_files.sort(key=sortkey)
self.files = []
for _c, _n in zip(clean_files, noisy_files):
assert sortkey(_c) == sortkey(_n)
self.files.append((os.path.join(_p, 'clean', _c),
os.path.join(_p, 'noisy', _n)))
self.crop_length_sec = 0
else:
raise NotImplementedError
def __getitem__(self, n):
fileid = self.files[n]
clean_audio, sample_rate = torchaudio.load(fileid[0])
noisy_audio, sample_rate = torchaudio.load(fileid[1])
clean_audio, noisy_audio = clean_audio.squeeze(0), noisy_audio.squeeze(0)
assert len(clean_audio) == len(noisy_audio)
crop_length = int(self.crop_length_sec * sample_rate)
assert crop_length < len(clean_audio)
# random crop
if self.subset != 'testing' and crop_length > 0:
start = np.random.randint(low=0, high=len(clean_audio) - crop_length + 1)
clean_audio = clean_audio[start:(start + crop_length)]
noisy_audio = noisy_audio[start:(start + crop_length)]
clean_audio, noisy_audio = clean_audio.unsqueeze(0), noisy_audio.unsqueeze(0)
return (clean_audio, noisy_audio, fileid)
def __len__(self):
return len(self.files)
def load_CleanNoisyPairDataset(root, subset, crop_length_sec, batch_size, sample_rate, num_gpus=1):
"""
Get dataloader with distributed sampling
"""
dataset = CleanNoisyPairDataset(root=root, subset=subset, crop_length_sec=crop_length_sec)
kwargs = {"batch_size": batch_size, "num_workers": 4, "pin_memory": False, "drop_last": False}
if num_gpus > 1:
train_sampler = DistributedSampler(dataset)
dataloader = torch.utils.data.DataLoader(dataset, sampler=train_sampler, **kwargs)
else:
dataloader = torch.utils.data.DataLoader(dataset, sampler=None, shuffle=True, **kwargs)
return dataloader
if __name__ == '__main__':
import json
with open('./configs/DNS-large-full.json') as f:
data = f.read()
config = json.loads(data)
trainset_config = config["trainset_config"]
trainloader = load_CleanNoisyPairDataset(**trainset_config, subset='training', batch_size=2, num_gpus=1)
testloader = load_CleanNoisyPairDataset(**trainset_config, subset='testing', batch_size=2, num_gpus=1)
print(len(trainloader), len(testloader))
for clean_audio, noisy_audio, fileid in trainloader:
clean_audio = clean_audio.cuda()
noisy_audio = noisy_audio.cuda()
print(clean_audio.shape, noisy_audio.shape, fileid)
break