|
import os
|
|
import argparse
|
|
import json
|
|
from tqdm import tqdm
|
|
from copy import deepcopy
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
import random
|
|
random.seed(0)
|
|
torch.manual_seed(0)
|
|
np.random.seed(0)
|
|
|
|
from scipy.io.wavfile import write as wavwrite
|
|
|
|
from dataset import load_CleanNoisyPairDataset
|
|
from util import find_max_epoch, print_size, sampling
|
|
from network import CleanUNet
|
|
|
|
|
|
def denoise(output_directory, ckpt_iter, subset, dump=False):
|
|
"""
|
|
Denoise audio
|
|
|
|
Parameters:
|
|
output_directory (str): save generated speeches to this path
|
|
ckpt_iter (int or 'max'): the pretrained checkpoint to be loaded;
|
|
automitically selects the maximum iteration if 'max' is selected
|
|
subset (str): training, testing, validation
|
|
dump (bool): whether save enhanced (denoised) audio
|
|
"""
|
|
|
|
|
|
exp_path = train_config["exp_path"]
|
|
print('exp_path:', exp_path)
|
|
|
|
|
|
loader_config = deepcopy(trainset_config)
|
|
loader_config["crop_length_sec"] = 0
|
|
dataloader = load_CleanNoisyPairDataset(
|
|
**loader_config,
|
|
subset=subset,
|
|
batch_size=1,
|
|
num_gpus=1
|
|
)
|
|
|
|
|
|
net = CleanUNet(**network_config).cuda()
|
|
print_size(net)
|
|
|
|
|
|
ckpt_directory = os.path.join(train_config["log"]["directory"], exp_path, 'checkpoint')
|
|
if ckpt_iter == 'max':
|
|
ckpt_iter = find_max_epoch(ckpt_directory)
|
|
if ckpt_iter != 'pretrained':
|
|
ckpt_iter = int(ckpt_iter)
|
|
model_path = os.path.join(ckpt_directory, '{}.pkl'.format(ckpt_iter))
|
|
checkpoint = torch.load(model_path, map_location='cpu')
|
|
net.load_state_dict(checkpoint['model_state_dict'])
|
|
net.eval()
|
|
|
|
|
|
if ckpt_iter == "pretrained":
|
|
speech_directory = os.path.join(output_directory, exp_path, 'speech', ckpt_iter)
|
|
else:
|
|
speech_directory = os.path.join(output_directory, exp_path, 'speech', '{}k'.format(ckpt_iter//1000))
|
|
if dump and not os.path.isdir(speech_directory):
|
|
os.makedirs(speech_directory)
|
|
os.chmod(speech_directory, 0o775)
|
|
print("speech_directory: ", speech_directory, flush=True)
|
|
|
|
|
|
all_generated_audio = []
|
|
all_clean_audio = []
|
|
sortkey = lambda name: '_'.join(name.split('/')[-1].split('_')[1:])
|
|
for clean_audio, noisy_audio, fileid in tqdm(dataloader):
|
|
filename = sortkey(fileid[0][0])
|
|
|
|
noisy_audio = noisy_audio.cuda()
|
|
LENGTH = len(noisy_audio[0].squeeze())
|
|
generated_audio = sampling(net, noisy_audio)
|
|
|
|
if dump:
|
|
wavwrite(os.path.join(speech_directory, 'enhanced_{}'.format(filename)),
|
|
trainset_config["sample_rate"],
|
|
generated_audio[0].squeeze().cpu().numpy())
|
|
else:
|
|
all_clean_audio.append(clean_audio[0].squeeze().cpu().numpy())
|
|
all_generated_audio.append(generated_audio[0].squeeze().cpu().numpy())
|
|
|
|
return all_clean_audio, all_generated_audio
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('-c', '--config', type=str, default='config.json',
|
|
help='JSON file for configuration')
|
|
parser.add_argument('-ckpt_iter', '--ckpt_iter', default='max',
|
|
help='Which checkpoint to use; assign a number or "max" or "pretrained"')
|
|
parser.add_argument('-subset', '--subset', type=str, choices=['training', 'testing', 'validation'],
|
|
default='testing', help='subset for denoising')
|
|
args = parser.parse_args()
|
|
|
|
|
|
with open(args.config) as f:
|
|
data = f.read()
|
|
config = json.loads(data)
|
|
gen_config = config["gen_config"]
|
|
global network_config
|
|
network_config = config["network_config"]
|
|
global train_config
|
|
train_config = config["train_config"]
|
|
global trainset_config
|
|
trainset_config = config["trainset_config"]
|
|
|
|
torch.backends.cudnn.enabled = True
|
|
torch.backends.cudnn.benchmark = True
|
|
|
|
if args.subset == "testing":
|
|
denoise(gen_config["output_directory"],
|
|
subset=args.subset,
|
|
ckpt_iter=args.ckpt_iter,
|
|
dump=True) |