maskgct-audio-lab / models /tts /jets /jets_trainer.py
Hecheng0625's picture
Upload 409 files
c968fc3 verified
raw
history blame
12.2 kB
# Copyright (c) 2024 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
import os
from utils.io import save_audio
from tqdm import tqdm
from models.tts.base import TTSTrainer
from models.tts.jets.jets import Jets
from models.tts.jets.jets_loss import GeneratorLoss, DiscriminatorLoss
from models.tts.jets.jets_dataset import JetsDataset, JetsCollator
from optimizer.optimizers import NoamLR
from torch.optim.lr_scheduler import ExponentialLR
from models.vocoders.gan.discriminator.mpd import MultiScaleMultiPeriodDiscriminator
def get_segments(
x: torch.Tensor,
start_idxs: torch.Tensor,
segment_size: int,
) -> torch.Tensor:
"""Get segments.
Args:
x (Tensor): Input tensor (B, C, T).
start_idxs (Tensor): Start index tensor (B,).
segment_size (int): Segment size.
Returns:
Tensor: Segmented tensor (B, C, segment_size).
"""
b, c, t = x.size()
segments = x.new_zeros(b, c, segment_size)
for i, start_idx in enumerate(start_idxs):
segments[i] = x[i, :, start_idx : start_idx + segment_size]
return segments
class JetsTrainer(TTSTrainer):
def __init__(self, args, cfg):
TTSTrainer.__init__(self, args, cfg)
self.cfg = cfg
def _build_dataset(self):
return JetsDataset, JetsCollator
def __build_scheduler(self):
return NoamLR(self.optimizer, **self.cfg.train.lr_scheduler)
def _write_summary(
self,
losses,
stats,
images={},
audios={},
audio_sampling_rate=24000,
tag="train",
):
for key, value in losses.items():
self.sw.add_scalar(tag + "/" + key, value, self.step)
self.sw.add_scalar(
"learning_rate",
self.optimizer["optimizer_g"].param_groups[0]["lr"],
self.step,
)
if len(images) != 0:
for key, value in images.items():
self.sw.add_image(key, value, self.global_step, batchformats="HWC")
if len(audios) != 0:
for key, value in audios.items():
self.sw.add_audio(key, value, self.global_step, audio_sampling_rate)
for key, value in losses.items():
self.sw.add_scalar("train/" + key, value, self.step)
lr = self.optimizer.state_dict()["param_groups"][0]["lr"]
self.sw.add_scalar("learning_rate", lr, self.step)
def _write_valid_summary(
self, losses, stats, images={}, audios={}, audio_sampling_rate=24000, tag="val"
):
for key, value in losses.items():
self.sw.add_scalar(tag + "/" + key, value, self.step)
if len(images) != 0:
for key, value in images.items():
self.sw.add_image(key, value, self.global_step, batchformats="HWC")
if len(audios) != 0:
for key, value in audios.items():
self.sw.add_audio(key, value, self.global_step, audio_sampling_rate)
def _build_criterion(self):
criterion = {
"generator": GeneratorLoss(self.cfg),
"discriminator": DiscriminatorLoss(self.cfg),
}
return criterion
def get_state_dict(self):
state_dict = {
"generator": self.model["generator"].state_dict(),
"discriminator": self.model["discriminator"].state_dict(),
"optimizer_g": self.optimizer["optimizer_g"].state_dict(),
"optimizer_d": self.optimizer["optimizer_d"].state_dict(),
"scheduler_g": self.scheduler["scheduler_g"].state_dict(),
"scheduler_d": self.scheduler["scheduler_d"].state_dict(),
"step": self.step,
"epoch": self.epoch,
"batch_size": self.cfg.train.batch_size,
}
return state_dict
def _build_optimizer(self):
optimizer_g = torch.optim.AdamW(
self.model["generator"].parameters(),
self.cfg.train.learning_rate,
betas=self.cfg.train.AdamW.betas,
eps=self.cfg.train.AdamW.eps,
)
optimizer_d = torch.optim.AdamW(
self.model["discriminator"].parameters(),
self.cfg.train.learning_rate,
betas=self.cfg.train.AdamW.betas,
eps=self.cfg.train.AdamW.eps,
)
optimizer = {"optimizer_g": optimizer_g, "optimizer_d": optimizer_d}
return optimizer
def _build_scheduler(self):
scheduler_g = ExponentialLR(
self.optimizer["optimizer_g"],
gamma=self.cfg.train.lr_decay,
last_epoch=self.epoch - 1,
)
scheduler_d = ExponentialLR(
self.optimizer["optimizer_d"],
gamma=self.cfg.train.lr_decay,
last_epoch=self.epoch - 1,
)
scheduler = {"scheduler_g": scheduler_g, "scheduler_d": scheduler_d}
return scheduler
def _build_model(self):
net_g = Jets(self.cfg)
net_d = MultiScaleMultiPeriodDiscriminator()
self.model = {"generator": net_g, "discriminator": net_d}
return self.model
def _train_epoch(self):
r"""Training epoch. Should return average loss of a batch (sample) over
one epoch. See ``train_loop`` for usage.
"""
self.model["generator"].train()
self.model["discriminator"].train()
epoch_sum_loss: float = 0.0
epoch_losses: dict = {}
epoch_step: int = 0
for batch in tqdm(
self.train_dataloader,
desc=f"Training Epoch {self.epoch}",
unit="batch",
colour="GREEN",
leave=False,
dynamic_ncols=True,
smoothing=0.04,
disable=not self.accelerator.is_main_process,
):
with self.accelerator.accumulate(self.model):
if batch["target_len"].min() < self.cfg.train.segment_size:
continue
total_loss, train_losses, training_stats = self._train_step(batch)
self.batch_count += 1
if self.batch_count % self.cfg.train.gradient_accumulation_step == 0:
epoch_sum_loss += total_loss
for key, value in train_losses.items():
if key not in epoch_losses.keys():
epoch_losses[key] = value
else:
epoch_losses[key] += value
self.accelerator.log(
{
"Step/Train {} Loss".format(key): value,
},
step=self.step,
)
self.step += 1
epoch_step += 1
self.accelerator.wait_for_everyone()
epoch_sum_loss = (
epoch_sum_loss
/ len(self.train_dataloader)
* self.cfg.train.gradient_accumulation_step
)
for key in epoch_losses.keys():
epoch_losses[key] = (
epoch_losses[key]
/ len(self.train_dataloader)
* self.cfg.train.gradient_accumulation_step
)
return epoch_sum_loss, epoch_losses
def _train_step(self, batch):
train_losses = {}
total_loss = 0
training_stats = {}
# Train Discriminator
# Generator output
outputs_g = self.model["generator"](batch)
speech_hat_, _, _, start_idxs, *_ = outputs_g
# Discriminator output
speech = batch["audio"].unsqueeze(1)
upsample_factor = self.cfg.train.upsample_factor
speech_ = get_segments(
x=speech,
start_idxs=start_idxs * upsample_factor,
segment_size=self.cfg.train.segment_size * upsample_factor,
)
p_hat = self.model["discriminator"](speech_hat_.detach())
p = self.model["discriminator"](speech_)
# Discriminator loss
loss_d = self.criterion["discriminator"](p, p_hat)
train_losses.update(loss_d)
# BP and Grad Updated
self.optimizer["optimizer_d"].zero_grad()
self.accelerator.backward(loss_d["loss_disc_all"])
self.optimizer["optimizer_d"].step()
# Train Generator
p_hat = self.model["discriminator"](speech_hat_)
with torch.no_grad():
p = self.model["discriminator"](speech_)
outputs_d = (p_hat, p)
loss_g = self.criterion["generator"](outputs_g, outputs_d, speech_)
train_losses.update(loss_g)
# BP and Grad Updated
self.optimizer["optimizer_g"].zero_grad()
self.accelerator.backward(loss_g["g_total_loss"])
self.optimizer["optimizer_g"].step()
for item in train_losses:
train_losses[item] = train_losses[item].item()
total_loss = loss_g["g_total_loss"] + loss_d["loss_disc_all"]
return (
total_loss.item(),
train_losses,
training_stats,
)
@torch.inference_mode()
def _valid_step(self, batch):
valid_losses = {}
total_loss = 0
valid_stats = {}
# Discriminator
# Generator output
outputs_g = self.model["generator"](batch)
speech_hat_, _, _, start_idxs, *_ = outputs_g
# Discriminator output
speech = batch["audio"].unsqueeze(1)
upsample_factor = self.cfg.train.upsample_factor
speech_ = get_segments(
x=speech,
start_idxs=start_idxs * upsample_factor,
segment_size=self.cfg.train.segment_size * upsample_factor,
)
p_hat = self.model["discriminator"](speech_hat_.detach())
p = self.model["discriminator"](speech_)
# Discriminator loss
loss_d = self.criterion["discriminator"](p, p_hat)
valid_losses.update(loss_d)
# Generator loss
p_hat = self.model["discriminator"](speech_hat_)
with torch.no_grad():
p = self.model["discriminator"](speech_)
outputs_d = (p_hat, p)
loss_g = self.criterion["generator"](outputs_g, outputs_d, speech_)
valid_losses.update(loss_g)
for item in valid_losses:
valid_losses[item] = valid_losses[item].item()
total_loss = loss_g["g_total_loss"] + loss_d["loss_disc_all"]
return (
total_loss.item(),
valid_losses,
valid_stats,
)
@torch.inference_mode()
def _valid_epoch(self):
r"""Testing epoch. Should return average loss of a batch (sample) over
one epoch. See ``train_loop`` for usage.
"""
if isinstance(self.model, dict):
for key in self.model.keys():
self.model[key].eval()
else:
self.model.eval()
epoch_sum_loss = 0.0
epoch_losses = dict()
for batch in tqdm(
self.valid_dataloader,
desc=f"Validating Epoch {self.epoch}",
unit="batch",
colour="GREEN",
leave=False,
dynamic_ncols=True,
smoothing=0.04,
disable=not self.accelerator.is_main_process,
):
total_loss, valid_losses, valid_stats = self._valid_step(batch)
epoch_sum_loss += total_loss
if isinstance(valid_losses, dict):
for key, value in valid_losses.items():
if key not in epoch_losses.keys():
epoch_losses[key] = value
else:
epoch_losses[key] += value
self.accelerator.log(
{
"Step/Valid {} Loss".format(key): value,
},
step=self.step,
)
epoch_sum_loss = epoch_sum_loss / len(self.valid_dataloader)
for key in epoch_losses.keys():
epoch_losses[key] = epoch_losses[key] / len(self.valid_dataloader)
self.accelerator.wait_for_everyone()
return epoch_sum_loss, epoch_losses