Spaces:
Runtime error
Runtime error
File size: 2,431 Bytes
4efe6b5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
import torch
def feature_loss(fmap_r, fmap_g):
"""
Compute the feature loss between reference and generated feature maps.
Args:
fmap_r (list of torch.Tensor): List of reference feature maps.
fmap_g (list of torch.Tensor): List of generated feature maps.
"""
loss = 0
for dr, dg in zip(fmap_r, fmap_g):
for rl, gl in zip(dr, dg):
rl = rl.float().detach()
gl = gl.float()
loss += torch.mean(torch.abs(rl - gl))
return loss * 2
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
"""
Compute the discriminator loss for real and generated outputs.
Args:
disc_real_outputs (list of torch.Tensor): List of discriminator outputs for real samples.
disc_generated_outputs (list of torch.Tensor): List of discriminator outputs for generated samples.
"""
loss = 0
r_losses = []
g_losses = []
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
dr = dr.float()
dg = dg.float()
r_loss = torch.mean((1 - dr) ** 2)
g_loss = torch.mean(dg**2)
loss += r_loss + g_loss
r_losses.append(r_loss.item())
g_losses.append(g_loss.item())
return loss, r_losses, g_losses
def generator_loss(disc_outputs):
"""
Compute the generator loss based on discriminator outputs.
Args:
disc_outputs (list of torch.Tensor): List of discriminator outputs for generated samples.
"""
loss = 0
gen_losses = []
for dg in disc_outputs:
dg = dg.float()
l = torch.mean((1 - dg) ** 2)
gen_losses.append(l)
loss += l
return loss, gen_losses
def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
"""
Compute the Kullback-Leibler divergence loss.
Args:
z_p (torch.Tensor): Latent variable z_p [b, h, t_t].
logs_q (torch.Tensor): Log variance of q [b, h, t_t].
m_p (torch.Tensor): Mean of p [b, h, t_t].
logs_p (torch.Tensor): Log variance of p [b, h, t_t].
z_mask (torch.Tensor): Mask for the latent variables [b, h, t_t].
"""
z_p = z_p.float()
logs_q = logs_q.float()
m_p = m_p.float()
logs_p = logs_p.float()
z_mask = z_mask.float()
kl = logs_p - logs_q - 0.5
kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
kl = torch.sum(kl * z_mask)
l = kl / torch.sum(z_mask)
return l
|