import os import argparse import json from tqdm import tqdm from copy import deepcopy import numpy as np import torch import random random.seed(0) torch.manual_seed(0) np.random.seed(0) from scipy.io.wavfile import write as wavwrite from dataset import load_CleanNoisyPairDataset from util import find_max_epoch, print_size, sampling from network import CleanUNet def denoise(output_directory, ckpt_iter, subset, dump=False): """ Denoise audio Parameters: output_directory (str): save generated speeches to this path ckpt_iter (int or 'max'): the pretrained checkpoint to be loaded; automitically selects the maximum iteration if 'max' is selected subset (str): training, testing, validation dump (bool): whether save enhanced (denoised) audio """ # setup local experiment path exp_path = train_config["exp_path"] print('exp_path:', exp_path) # load data loader_config = deepcopy(trainset_config) loader_config["crop_length_sec"] = 0 dataloader = load_CleanNoisyPairDataset( **loader_config, subset=subset, batch_size=1, num_gpus=1 ) # predefine model net = CleanUNet(**network_config).cuda() print_size(net) # load checkpoint ckpt_directory = os.path.join(train_config["log"]["directory"], exp_path, 'checkpoint') if ckpt_iter == 'max': ckpt_iter = find_max_epoch(ckpt_directory) if ckpt_iter != 'pretrained': ckpt_iter = int(ckpt_iter) model_path = os.path.join(ckpt_directory, '{}.pkl'.format(ckpt_iter)) checkpoint = torch.load(model_path, map_location='cpu') net.load_state_dict(checkpoint['model_state_dict']) net.eval() # get output directory ready if ckpt_iter == "pretrained": speech_directory = os.path.join(output_directory, exp_path, 'speech', ckpt_iter) else: speech_directory = os.path.join(output_directory, exp_path, 'speech', '{}k'.format(ckpt_iter//1000)) if dump and not os.path.isdir(speech_directory): os.makedirs(speech_directory) os.chmod(speech_directory, 0o775) print("speech_directory: ", speech_directory, flush=True) # inference all_generated_audio = [] all_clean_audio = [] sortkey = lambda name: '_'.join(name.split('/')[-1].split('_')[1:]) for clean_audio, noisy_audio, fileid in tqdm(dataloader): filename = sortkey(fileid[0][0]) noisy_audio = noisy_audio.cuda() LENGTH = len(noisy_audio[0].squeeze()) generated_audio = sampling(net, noisy_audio) if dump: wavwrite(os.path.join(speech_directory, 'enhanced_{}'.format(filename)), trainset_config["sample_rate"], generated_audio[0].squeeze().cpu().numpy()) else: all_clean_audio.append(clean_audio[0].squeeze().cpu().numpy()) all_generated_audio.append(generated_audio[0].squeeze().cpu().numpy()) return all_clean_audio, all_generated_audio if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('-c', '--config', type=str, default='config.json', help='JSON file for configuration') parser.add_argument('-ckpt_iter', '--ckpt_iter', default='max', help='Which checkpoint to use; assign a number or "max" or "pretrained"') parser.add_argument('-subset', '--subset', type=str, choices=['training', 'testing', 'validation'], default='testing', help='subset for denoising') args = parser.parse_args() # Parse configs. Globals nicer in this case with open(args.config) as f: data = f.read() config = json.loads(data) gen_config = config["gen_config"] global network_config network_config = config["network_config"] # to define wavenet global train_config train_config = config["train_config"] # train config global trainset_config trainset_config = config["trainset_config"] # to read trainset configurations torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = True if args.subset == "testing": denoise(gen_config["output_directory"], subset=args.subset, ckpt_iter=args.ckpt_iter, dump=True)