denoiser / denoise.py
azamat's picture
Init
33e3a91
raw
history blame
4.58 kB
import os
import argparse
import json
from tqdm import tqdm
from copy import deepcopy
import numpy as np
import torch
import random
random.seed(0)
torch.manual_seed(0)
np.random.seed(0)
from scipy.io.wavfile import write as wavwrite
from dataset import load_CleanNoisyPairDataset
from util import find_max_epoch, print_size, sampling
from network import CleanUNet
def denoise(output_directory, ckpt_iter, subset, dump=False):
"""
Denoise audio
Parameters:
output_directory (str): save generated speeches to this path
ckpt_iter (int or 'max'): the pretrained checkpoint to be loaded;
automitically selects the maximum iteration if 'max' is selected
subset (str): training, testing, validation
dump (bool): whether save enhanced (denoised) audio
"""
# setup local experiment path
exp_path = train_config["exp_path"]
print('exp_path:', exp_path)
# load data
loader_config = deepcopy(trainset_config)
loader_config["crop_length_sec"] = 0
dataloader = load_CleanNoisyPairDataset(
**loader_config,
subset=subset,
batch_size=1,
num_gpus=1
)
# predefine model
net = CleanUNet(**network_config).cuda()
print_size(net)
# load checkpoint
ckpt_directory = os.path.join(train_config["log"]["directory"], exp_path, 'checkpoint')
if ckpt_iter == 'max':
ckpt_iter = find_max_epoch(ckpt_directory)
if ckpt_iter != 'pretrained':
ckpt_iter = int(ckpt_iter)
model_path = os.path.join(ckpt_directory, '{}.pkl'.format(ckpt_iter))
checkpoint = torch.load(model_path, map_location='cpu')
net.load_state_dict(checkpoint['model_state_dict'])
net.eval()
# get output directory ready
if ckpt_iter == "pretrained":
speech_directory = os.path.join(output_directory, exp_path, 'speech', ckpt_iter)
else:
speech_directory = os.path.join(output_directory, exp_path, 'speech', '{}k'.format(ckpt_iter//1000))
if dump and not os.path.isdir(speech_directory):
os.makedirs(speech_directory)
os.chmod(speech_directory, 0o775)
print("speech_directory: ", speech_directory, flush=True)
# inference
all_generated_audio = []
all_clean_audio = []
sortkey = lambda name: '_'.join(name.split('/')[-1].split('_')[1:])
for clean_audio, noisy_audio, fileid in tqdm(dataloader):
filename = sortkey(fileid[0][0])
noisy_audio = noisy_audio.cuda()
LENGTH = len(noisy_audio[0].squeeze())
generated_audio = sampling(net, noisy_audio)
if dump:
wavwrite(os.path.join(speech_directory, 'enhanced_{}'.format(filename)),
trainset_config["sample_rate"],
generated_audio[0].squeeze().cpu().numpy())
else:
all_clean_audio.append(clean_audio[0].squeeze().cpu().numpy())
all_generated_audio.append(generated_audio[0].squeeze().cpu().numpy())
return all_clean_audio, all_generated_audio
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config', type=str, default='config.json',
help='JSON file for configuration')
parser.add_argument('-ckpt_iter', '--ckpt_iter', default='max',
help='Which checkpoint to use; assign a number or "max" or "pretrained"')
parser.add_argument('-subset', '--subset', type=str, choices=['training', 'testing', 'validation'],
default='testing', help='subset for denoising')
args = parser.parse_args()
# Parse configs. Globals nicer in this case
with open(args.config) as f:
data = f.read()
config = json.loads(data)
gen_config = config["gen_config"]
global network_config
network_config = config["network_config"] # to define wavenet
global train_config
train_config = config["train_config"] # train config
global trainset_config
trainset_config = config["trainset_config"] # to read trainset configurations
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
if args.subset == "testing":
denoise(gen_config["output_directory"],
subset=args.subset,
ckpt_iter=args.ckpt_iter,
dump=True)