#!/usr/bin/python3 # -*- coding: utf-8 -*- """ https://github.com/yxlu-0102/MP-SENet/blob/main/train.py """ import argparse import json import logging from logging.handlers import TimedRotatingFileHandler import os import platform from pathlib import Path import random import sys import shutil from typing import List pwd = os.path.abspath(os.path.dirname(__file__)) sys.path.append(os.path.join(pwd, "../../")) import numpy as np import torch from torch.nn import functional as F from torch.utils.data.dataloader import DataLoader from tqdm import tqdm from toolbox.torch.utils.data.dataset.denoise_excel_dataset import DenoiseExcelDataset from toolbox.torchaudio.models.mpnet.configuration_mpnet import MPNetConfig from toolbox.torchaudio.models.mpnet.discriminator import MetricDiscriminatorPretrainedModel from toolbox.torchaudio.models.mpnet.modeling_mpnet import MPNet, MPNetPretrainedModel, phase_losses from toolbox.torchaudio.models.mpnet.utils import mag_pha_stft, mag_pha_istft from toolbox.torchaudio.models.mpnet.metrics import run_batch_pesq, run_pesq_score def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--train_dataset", default="train.xlsx", type=str) parser.add_argument("--valid_dataset", default="valid.xlsx", type=str) parser.add_argument("--max_epochs", default=100, type=int) parser.add_argument("--num_serialized_models_to_keep", default=10, type=int) parser.add_argument("--patience", default=5, type=int) parser.add_argument("--serialization_dir", default="serialization_dir", type=str) parser.add_argument("--config_file", default="config.yaml", type=str) args = parser.parse_args() return args def logging_config(file_dir: str): fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s" logging.basicConfig(format=fmt, datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO) file_handler = TimedRotatingFileHandler( filename=os.path.join(file_dir, "main.log"), encoding="utf-8", when="D", interval=1, backupCount=7 ) file_handler.setLevel(logging.INFO) file_handler.setFormatter(logging.Formatter(fmt)) logger = logging.getLogger(__name__) logger.addHandler(file_handler) return logger class CollateFunction(object): def __init__(self): pass def __call__(self, batch: List[dict]): clean_audios = list() noisy_audios = list() for sample in batch: # noise_wave: torch.Tensor = sample["noise_wave"] clean_audio: torch.Tensor = sample["speech_wave"] noisy_audio: torch.Tensor = sample["mix_wave"] # snr_db: float = sample["snr_db"] clean_audios.append(clean_audio) noisy_audios.append(noisy_audio) clean_audios = torch.stack(clean_audios) noisy_audios = torch.stack(noisy_audios) # assert if torch.any(torch.isnan(clean_audios)) or torch.any(torch.isinf(clean_audios)): raise AssertionError("nan or inf in clean_audios") if torch.any(torch.isnan(noisy_audios)) or torch.any(torch.isinf(noisy_audios)): raise AssertionError("nan or inf in noisy_audios") return clean_audios, noisy_audios collate_fn = CollateFunction() def main(): args = get_args() config = MPNetConfig.from_pretrained( pretrained_model_name_or_path=args.config_file, ) serialization_dir = Path(args.serialization_dir) serialization_dir.mkdir(parents=True, exist_ok=True) logger = logging_config(serialization_dir) random.seed(config.seed) np.random.seed(config.seed) torch.manual_seed(config.seed) logger.info(f"set seed: {config.seed}") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") n_gpu = torch.cuda.device_count() logger.info(f"GPU available count: {n_gpu}; device: {device}") # datasets train_dataset = DenoiseExcelDataset( excel_file=args.train_dataset, expected_sample_rate=8000, max_wave_value=32768.0, ) valid_dataset = DenoiseExcelDataset( excel_file=args.valid_dataset, expected_sample_rate=8000, max_wave_value=32768.0, ) train_data_loader = DataLoader( dataset=train_dataset, batch_size=config.batch_size, shuffle=True, sampler=None, # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能. num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2, collate_fn=collate_fn, pin_memory=False, # prefetch_factor=64, ) valid_data_loader = DataLoader( dataset=valid_dataset, batch_size=config.batch_size, shuffle=True, sampler=None, # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能. num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2, collate_fn=collate_fn, pin_memory=False, # prefetch_factor=64, ) # models logger.info(f"prepare models. config_file: {args.config_file}") generator = MPNetPretrainedModel(config).to(device) discriminator = MetricDiscriminatorPretrainedModel(config).to(device) # optimizer logger.info("prepare optimizer, lr_scheduler") num_params = 0 for p in generator.parameters(): num_params += p.numel() logger.info("total parameters (generator): {:.3f}M".format(num_params/1e6)) optim_g = torch.optim.AdamW(generator.parameters(), config.learning_rate, betas=[config.adam_b1, config.adam_b2]) optim_d = torch.optim.AdamW(discriminator.parameters(), config.learning_rate, betas=[config.adam_b1, config.adam_b2]) # resume training last_epoch = -1 for epoch_i in serialization_dir.glob("epoch-*"): epoch_i = Path(epoch_i) epoch_idx = epoch_i.stem.split("-")[1] epoch_idx = int(epoch_idx) if epoch_idx > last_epoch: last_epoch = epoch_idx if last_epoch != -1: logger.info(f"resume from epoch-{last_epoch}.") generator_pt = serialization_dir / f"epoch-{last_epoch}/generator.pt" discriminator_pt = serialization_dir / f"epoch-{last_epoch}/discriminator.pt" optim_g_pth = serialization_dir / f"epoch-{last_epoch}/optim_g.pth" optim_d_pth = serialization_dir / f"epoch-{last_epoch}/optim_d.pth" logger.info(f"load state dict for generator.") with open(generator_pt.as_posix(), "rb") as f: state_dict = torch.load(f, map_location="cpu", weights_only=True) generator.load_state_dict(state_dict, strict=True) logger.info(f"load state dict for discriminator.") with open(discriminator_pt.as_posix(), "rb") as f: state_dict = torch.load(f, map_location="cpu", weights_only=True) discriminator.load_state_dict(state_dict, strict=True) logger.info(f"load state dict for optim_g.") with open(optim_g_pth.as_posix(), "rb") as f: state_dict = torch.load(f, map_location="cpu", weights_only=True) optim_g.load_state_dict(state_dict) logger.info(f"load state dict for optim_d.") with open(optim_d_pth.as_posix(), "rb") as f: state_dict = torch.load(f, map_location="cpu", weights_only=True) optim_d.load_state_dict(state_dict) scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=config.lr_decay, last_epoch=last_epoch) scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=config.lr_decay, last_epoch=last_epoch) # training loop # state loss_d = 10000000000 loss_g = 10000000000 pesq_metric = 10000000000 mag_err = 10000000000 pha_err = 10000000000 com_err = 10000000000 stft_err = 10000000000 model_list = list() best_idx_epoch = None best_metric = None patience_count = 0 logger.info("training") for idx_epoch in range(max(0, last_epoch+1), args.max_epochs): # train generator.train() discriminator.train() total_loss_d = 0. total_loss_g = 0. total_batches = 0. progress_bar = tqdm( total=len(train_data_loader), desc="Training; epoch: {}".format(idx_epoch), ) for batch in train_data_loader: clean_audio, noisy_audio = batch clean_audio = clean_audio.to(device) noisy_audio = noisy_audio.to(device) one_labels = torch.ones(clean_audio.shape[0]).to(device) clean_mag, clean_pha, clean_com = mag_pha_stft(clean_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor) noisy_mag, noisy_pha, noisy_com = mag_pha_stft(noisy_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor) mag_g, pha_g, com_g = generator.forward(noisy_mag, noisy_pha) audio_g = mag_pha_istft(mag_g, pha_g, config.n_fft, config.hop_size, config.win_size, config.compress_factor) mag_g_hat, pha_g_hat, com_g_hat = mag_pha_stft(audio_g, config.n_fft, config.hop_size, config.win_size, config.compress_factor) audio_list_r, audio_list_g = list(clean_audio.cpu().numpy()), list(audio_g.detach().cpu().numpy()) pesq_score_list: List[float] = run_batch_pesq(audio_list_r, audio_list_g, sample_rate=config.sample_rate, mode="nb") # Discriminator optim_d.zero_grad() metric_r = discriminator.forward(clean_mag, clean_mag) metric_g = discriminator.forward(clean_mag, mag_g_hat.detach()) loss_disc_r = F.mse_loss(one_labels, metric_r.flatten()) if -1 in pesq_score_list: # print("-1 in batch_pesq_score!") loss_disc_g = 0 else: pesq_score_list: torch.FloatTensor = torch.tensor([(score - 1) / 3.5 for score in pesq_score_list], dtype=torch.float32) loss_disc_g = F.mse_loss(pesq_score_list.to(device), metric_g.flatten()) loss_disc_all = loss_disc_r + loss_disc_g loss_disc_all.backward() optim_d.step() # Generator optim_g.zero_grad() # L2 Magnitude Loss loss_mag = F.mse_loss(clean_mag, mag_g) # Anti-wrapping Phase Loss loss_ip, loss_gd, loss_iaf = phase_losses(clean_pha, pha_g) loss_pha = loss_ip + loss_gd + loss_iaf # L2 Complex Loss loss_com = F.mse_loss(clean_com, com_g) * 2 # L2 Consistency Loss loss_stft = F.mse_loss(com_g, com_g_hat) * 2 # Time Loss loss_time = F.l1_loss(clean_audio, audio_g) # Metric Loss metric_g = discriminator.forward(clean_mag, mag_g_hat) loss_metric = F.mse_loss(metric_g.flatten(), one_labels) loss_gen_all = loss_mag * 0.9 + loss_pha * 0.3 + loss_com * 0.1 + loss_stft * 0.1 + loss_metric * 0.05 + loss_time * 0.2 loss_gen_all.backward() optim_g.step() total_loss_d += loss_disc_all.item() total_loss_g += loss_gen_all.item() total_batches += 1 loss_d = round(total_loss_d / total_batches, 4) loss_g = round(total_loss_g / total_batches, 4) progress_bar.update(1) progress_bar.set_postfix({ "loss_d": loss_d, "loss_g": loss_g, }) # evaluation generator.eval() discriminator.eval() torch.cuda.empty_cache() total_pesq_score = 0. total_mag_err = 0. total_pha_err = 0. total_com_err = 0. total_stft_err = 0. total_batches = 0. progress_bar = tqdm( total=len(valid_data_loader), desc="Evaluation; epoch: {}".format(idx_epoch), ) with torch.no_grad(): for batch in valid_data_loader: clean_audio, noisy_audio = batch clean_audio = clean_audio.to(device) noisy_audio = noisy_audio.to(device) clean_mag, clean_pha, clean_com = mag_pha_stft(clean_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor) noisy_mag, noisy_pha, noisy_com = mag_pha_stft(noisy_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor) mag_g, pha_g, com_g = generator.forward(noisy_mag, noisy_pha) audio_g = mag_pha_istft(mag_g, pha_g, config.n_fft, config.hop_size, config.win_size, config.compress_factor) mag_g_hat, pha_g_hat, com_g_hat = mag_pha_stft(audio_g, config.n_fft, config.hop_size, config.win_size, config.compress_factor) clean_audio_list = torch.split(clean_audio, 1, dim=0) enhanced_audio_list = torch.split(audio_g, 1, dim=0) clean_audio_list = [t.squeeze().cpu().numpy() for t in clean_audio_list] enhanced_audio_list = [t.squeeze().cpu().numpy() for t in enhanced_audio_list] pesq_score = run_pesq_score( clean_audio_list, enhanced_audio_list, sample_rate = config.sample_rate, mode = "nb", ) total_pesq_score += pesq_score total_mag_err += F.mse_loss(clean_mag, mag_g).item() val_ip_err, val_gd_err, val_iaf_err = phase_losses(clean_pha, pha_g) total_pha_err += (val_ip_err + val_gd_err + val_iaf_err).item() total_com_err += F.mse_loss(clean_com, com_g).item() total_stft_err += F.mse_loss(com_g, com_g_hat).item() total_batches += 1 pesq_metric = round(total_pesq_score / total_batches, 4) mag_err = round(total_mag_err / total_batches, 4) pha_err = round(total_pha_err / total_batches, 4) com_err = round(total_com_err / total_batches, 4) stft_err = round(total_stft_err / total_batches, 4) progress_bar.update(1) progress_bar.set_postfix({ "pesq_metric": pesq_metric, "mag_err": mag_err, "pha_err": pha_err, "com_err": com_err, "stft_err": stft_err, }) # scheduler scheduler_g.step() scheduler_d.step() # save path epoch_dir = serialization_dir / "epoch-{}".format(idx_epoch) epoch_dir.mkdir(parents=True, exist_ok=False) # save models generator.save_pretrained(epoch_dir.as_posix()) discriminator.save_pretrained(epoch_dir.as_posix()) # save optim torch.save(optim_d.state_dict(), (epoch_dir / "optim_d.pth").as_posix()) torch.save(optim_g.state_dict(), (epoch_dir / "optim_g.pth").as_posix()) model_list.append(epoch_dir) if len(model_list) >= args.num_serialized_models_to_keep: model_to_delete: Path = model_list.pop(0) shutil.rmtree(model_to_delete.as_posix()) # save metric if best_metric is None: best_idx_epoch = idx_epoch best_metric = pesq_metric elif pesq_metric > best_metric: # great is better. best_idx_epoch = idx_epoch best_metric = pesq_metric else: pass metrics = { "idx_epoch": idx_epoch, "best_idx_epoch": best_idx_epoch, "loss_d": loss_d, "loss_g": loss_g, "pesq_metric": pesq_metric, "mag_err": mag_err, "pha_err": pha_err, "com_err": com_err, "stft_err": stft_err, } metrics_filename = epoch_dir / "metrics_epoch.json" with open(metrics_filename, "w", encoding="utf-8") as f: json.dump(metrics, f, indent=4, ensure_ascii=False) # save best best_dir = serialization_dir / "best" if best_idx_epoch == idx_epoch: if best_dir.exists(): shutil.rmtree(best_dir) shutil.copytree(epoch_dir, best_dir) # early stop early_stop_flag = False if best_idx_epoch == idx_epoch: patience_count = 0 else: patience_count += 1 if patience_count >= args.patience: early_stop_flag = True # early stop if early_stop_flag: break return if __name__ == "__main__": main()