import torch def feature_loss(fmap_r, fmap_g): """ Compute the feature loss between reference and generated feature maps. Args: fmap_r (list of torch.Tensor): List of reference feature maps. fmap_g (list of torch.Tensor): List of generated feature maps. """ loss = 0 for dr, dg in zip(fmap_r, fmap_g): for rl, gl in zip(dr, dg): rl = rl.float().detach() gl = gl.float() loss += torch.mean(torch.abs(rl - gl)) return loss * 2 def discriminator_loss(disc_real_outputs, disc_generated_outputs): """ Compute the discriminator loss for real and generated outputs. Args: disc_real_outputs (list of torch.Tensor): List of discriminator outputs for real samples. disc_generated_outputs (list of torch.Tensor): List of discriminator outputs for generated samples. """ loss = 0 r_losses = [] g_losses = [] for dr, dg in zip(disc_real_outputs, disc_generated_outputs): dr = dr.float() dg = dg.float() r_loss = torch.mean((1 - dr) ** 2) g_loss = torch.mean(dg**2) loss += r_loss + g_loss r_losses.append(r_loss.item()) g_losses.append(g_loss.item()) return loss, r_losses, g_losses def generator_loss(disc_outputs): """ Compute the generator loss based on discriminator outputs. Args: disc_outputs (list of torch.Tensor): List of discriminator outputs for generated samples. """ loss = 0 gen_losses = [] for dg in disc_outputs: dg = dg.float() l = torch.mean((1 - dg) ** 2) gen_losses.append(l) loss += l return loss, gen_losses def kl_loss(z_p, logs_q, m_p, logs_p, z_mask): """ Compute the Kullback-Leibler divergence loss. Args: z_p (torch.Tensor): Latent variable z_p [b, h, t_t]. logs_q (torch.Tensor): Log variance of q [b, h, t_t]. m_p (torch.Tensor): Mean of p [b, h, t_t]. logs_p (torch.Tensor): Log variance of p [b, h, t_t]. z_mask (torch.Tensor): Mask for the latent variables [b, h, t_t]. """ z_p = z_p.float() logs_q = logs_q.float() m_p = m_p.float() logs_p = logs_p.float() z_mask = z_mask.float() kl = logs_p - logs_q - 0.5 kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p) kl = torch.sum(kl * z_mask) l = kl / torch.sum(z_mask) return l