marigold334 commited on
Commit
373cacd
โ€ข
1 Parent(s): 8346e07

Update Tmodel.py

Browse files
Files changed (1) hide show
  1. Tmodel.py +3 -3
Tmodel.py CHANGED
@@ -14,7 +14,7 @@ class GlowTTS(nn.Module):
14
  self.encoder = Encoder()
15
  self.decoder = Decoder()
16
 
17
- def forward(self, text, text_len, mel=None, mel_len=None, inference=False):
18
  """
19
  =====inputs=====
20
  text: (B, T)
@@ -45,7 +45,7 @@ class GlowTTS(nn.Module):
45
  if not inference: # training
46
  y_max_len = y.size(2)
47
  else: # inference
48
- dur = torch.exp(x_log_dur) * x_mask # (B, 1, T)
49
  ceil_dur = torch.ceil(dur) # (B, 1, T)
50
  y_len = torch.clamp_min(torch.sum(ceil_dur, [1, 2]), 1).long() # (B)
51
  # ceil_dur์„ [1, 2] ์ถ•์— ๋Œ€ํ•ด sumํ•œ ๋’ค ์ตœ์†Ÿ๊ฐ’์ด 1์ด์ƒ์ด ๋˜๋„๋ก ์„ค์ •. ์ •์ˆ˜ long ํƒ€์ž…์œผ๋กœ ๋ฐ˜ํ™˜ํ•œ๋‹ค.
@@ -99,7 +99,7 @@ class GlowTTS(nn.Module):
99
  z_log_std = z_log_std.transpose(1, 2) # (B, 80, F)
100
  log_d = torch.log(1e-8 + torch.sum(attention_alignment, -1)).unsqueeze(1) * x_mask # (B, 1, T) | alignment์—์„œ ํ˜•์„ฑ๋œ duration์˜ log scale
101
 
102
- z = (z_mean + torch.exp(z_log_std) * torch.randn_like(z_mean)) * z_mask # z(latent representation) ์ƒ์„ฑ
103
  y, log_det = self.decoder(z, z_mask, reverse=True) # mel-spectrogram ์ƒ์„ฑ
104
  return (y, z_mean, z_log_std, log_det, z_mask), (x_mean, x_log_std, x_mask), (attention_alignment, x_log_dur, log_d)
105
 
 
14
  self.encoder = Encoder()
15
  self.decoder = Decoder()
16
 
17
+ def forward(self, text, text_len, mel=None, mel_len=None, inference=False, noise_scale=1., length_scale=1.):
18
  """
19
  =====inputs=====
20
  text: (B, T)
 
45
  if not inference: # training
46
  y_max_len = y.size(2)
47
  else: # inference
48
+ dur = torch.exp(x_log_dur) * x_mask * length_scale # (B, 1, T)
49
  ceil_dur = torch.ceil(dur) # (B, 1, T)
50
  y_len = torch.clamp_min(torch.sum(ceil_dur, [1, 2]), 1).long() # (B)
51
  # ceil_dur์„ [1, 2] ์ถ•์— ๋Œ€ํ•ด sumํ•œ ๋’ค ์ตœ์†Ÿ๊ฐ’์ด 1์ด์ƒ์ด ๋˜๋„๋ก ์„ค์ •. ์ •์ˆ˜ long ํƒ€์ž…์œผ๋กœ ๋ฐ˜ํ™˜ํ•œ๋‹ค.
 
99
  z_log_std = z_log_std.transpose(1, 2) # (B, 80, F)
100
  log_d = torch.log(1e-8 + torch.sum(attention_alignment, -1)).unsqueeze(1) * x_mask # (B, 1, T) | alignment์—์„œ ํ˜•์„ฑ๋œ duration์˜ log scale
101
 
102
+ z = (z_mean + torch.exp(z_log_std) * torch.randn_like(z_mean) * noise_scale) * z_mask # z(latent representation) ์ƒ์„ฑ
103
  y, log_det = self.decoder(z, z_mask, reverse=True) # mel-spectrogram ์ƒ์„ฑ
104
  return (y, z_mean, z_log_std, log_det, z_mask), (x_mean, x_log_std, x_mask), (attention_alignment, x_log_dur, log_d)
105