Spaces:
Runtime error
Runtime error
File size: 6,894 Bytes
41989ff |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 |
import torch
import torch.nn.functional as F
from module import *
import numpy as np
import math
def sequence_mask(length, max_length=None):
if max_length is None:
max_length = max(length)
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
return x.unsqueeze(0) < length.unsqueeze(1)
def maximum_path(value, mask, max_neg_val=-np.inf):
""" Numpy-friendly version. It's about 4 times faster than torch version.
value: [b, t_x, t_y]
mask: [b, t_x, t_y]
"""
value = value * mask
device = value.device
dtype = value.dtype
value = value.cpu().detach().numpy()
mask = mask.cpu().detach().numpy().astype(np.bool)
b, t_x, t_y = value.shape
direction = np.zeros(value.shape, dtype=np.int64)
v = np.zeros((b, t_x), dtype=np.float32)
x_range = np.arange(t_x, dtype=np.float32).reshape(1, -1)
for j in range(t_y):
v0 = np.pad(v, [[0, 0], [1, 0]], mode="constant", constant_values=max_neg_val)[:, :-1]
v1 = v
max_mask = (v1 >= v0)
v_max = np.where(max_mask, v1, v0)
direction[:, :, j] = max_mask
index_mask = (x_range <= j)
v = np.where(index_mask, v_max + value[:, :, j], max_neg_val)
direction = np.where(mask, direction, 1)
path = np.zeros(value.shape, dtype=np.float32)
index = mask[:, :, 0].sum(1).astype(np.int64) - 1
index_range = np.arange(b)
for j in reversed(range(t_y)):
path[index_range, index, j] = 1
index = index + direction[index_range, index, j] - 1
path = path * mask.astype(np.float32)
path = torch.from_numpy(path).to(device=device, dtype=dtype)
return path
def generate_path(duration, mask):
"""
duration: [b, t_x]
mask: [b, t_x, t_y]
"""
device = duration.device
b, t_x, t_y = mask.shape
cum_duration = torch.cumsum(duration, 1)
path = torch.zeros(b, t_x, t_y, dtype=mask.dtype).to(device=device)
cum_duration_flat = cum_duration.view(b * t_x)
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
path = path.view(b, t_x, t_y)
path = path - F.pad(path, (0, 0, 1, 0, 0, 0))[:, :-1]
path = path * mask
return path
def mle_loss(z, m, logs, logdet, mask):
# neg normal likelihood w/o the constant term
l = torch.sum(logs) + 0.5 * torch.sum(torch.exp(-2 * logs) * ((z - m)**2))
l = l - torch.sum(logdet) # log jacobian determinant
# averaging across batch, channel and time axes
l = l / torch.sum(torch.ones_like(z) * mask)
l = l + 0.5 * math.log(2 * math.pi) # add the remaining constant term
return l
def duration_loss(logw, logw_, lengths):
l = torch.sum((logw - logw_)**2) / torch.sum(lengths)
return l
class AttrDict(dict):
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
self.__dict__ = self
def GAN_Loss_Generator(gen_outputs):
"""
gen_outputs: (B, ?) list # MPD(len=5) 또는 MSD(len=3)의 출력
"""
loss = 0
for DG in gen_outputs:
loss += torch.mean((DG-1)**2)
return loss
def GAN_Loss_Discriminator(real_outputs, gen_outputs):
"""
real_outputs: (B, ?) list # MPD(len=5) 또는 MSD(len=3)의 출력
gen_outputs: (B, ?) list # MPD(len=5) 또는 MSD(len=3)의 출력
"""
loss = 0
for D, DG in zip(real_outputs, gen_outputs):
loss += torch.mean((D-1)**2 + DG**2)
return loss
def Mel_Spectrogram_Loss(real_mel, gen_mel):
"""
real_mel: (B, F, 80) # Dataloader로부터 가져온 mel-spectrogram
gen_mel: (B, F, 80) # Generator가 생성한 waveform의 mel-spectrogram
"""
loss = F.l1_loss(real_mel, gen_mel)
return 45*loss
def Feature_Matching_Loss(real_features, gen_features):
"""
real_features: (?, ..., ?) list of list # MPD(len=[5, 6]) 또는 MSD(len=[3, 7])의 출력
gen_features: (?, ..., ?) list of list # MPD(len=[5, 6]) 또는 MSD(len=[3, 7])의 출력
"""
loss = 0
for Ds, DGs in zip(real_features, gen_features):
for D, DG in zip(Ds, DGs):
loss += torch.mean(torch.abs(D - DG))
return 2*loss
def Final_Loss_Generator(mpd_gen_outputs, mpd_real_features, mpd_gen_features,
msd_gen_outputs, msd_real_features, msd_gen_features,
real_mel, gen_mel):
"""
=====inputs=====
[:3]: MPD outputs 뒤쪽 3개
[3:6]: MSD outputs 뒤쪽 3개
[7:8]: real_mel and gen_mel
=====outputs=====
Generator_Loss
Mel_Loss
"""
Gen_Adv1 = GAN_Loss_Generator(mpd_gen_outputs)
Gen_Adv2 = GAN_Loss_Generator(msd_gen_outputs)
Adv = Gen_Adv1 + Gen_Adv2
FM1 = Feature_Matching_Loss(mpd_real_features, mpd_gen_features)
FM2 = Feature_Matching_Loss(msd_real_features, msd_gen_features)
FM = FM1 + FM2
Mel_Loss = Mel_Spectrogram_Loss(real_mel, gen_mel)
Generator_Loss = Adv + FM + Mel_Loss
return Generator_Loss, Mel_Loss , Adv, FM
def Final_Loss_Discriminator(mpd_real_outputs, mpd_gen_outputs,
msd_real_outputs, msd_gen_outputs):
"""
=====inputs=====
[:2]: MPD outputs 앞쪽 2개
[2:4]: MSD outputs 앞쪽 2개
=====outputs=====
Discriminator_Loss
"""
Disc_Adv1 = GAN_Loss_Discriminator(mpd_real_outputs, mpd_gen_outputs)
Disc_Adv2 = GAN_Loss_Discriminator(msd_real_outputs, msd_gen_outputs)
Discriminator_Loss = Disc_Adv1 + Disc_Adv2
return Discriminator_Loss
class Adam():
def __init__(self, params, scheduler, dim_model, warmup_steps=4000, lr=1e0, betas=(0.9, 0.98), eps=1e-9):
self.params = params
self.scheduler = scheduler
self.dim_model = dim_model
self.warmup_steps = warmup_steps
self.lr = lr
self.betas = betas
self.eps = eps
self.step_num = 1
self.cur_lr = lr * self._get_lr_scale()
self._optim = torch.optim.Adam(params, lr=self.cur_lr, betas=betas, eps=eps)
self.param_groups = self._optim.param_groups
def _get_lr_scale(self):
if self.scheduler == "noam":
return np.power(self.dim_model, -0.5) * np.min([np.power(self.step_num, -0.5), self.step_num * np.power(self.warmup_steps, -1.5)])
else:
return 1
def _update_learning_rate(self):
self.step_num += 1
if self.scheduler == "noam":
self.cur_lr = self.lr * self._get_lr_scale()
for param_group in self._optim.param_groups:
param_group['lr'] = self.cur_lr
def get_lr(self):
return self.cur_lr
def step(self):
self._optim.step()
self._update_learning_rate()
def zero_grad(self):
self._optim.zero_grad()
def load_state_dict(self, d):
self._optim.load_state_dict(d)
def state_dict(self):
return self._optim.state_dict() |