Spaces:
Running
Running
File size: 4,629 Bytes
cea6632 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
# Copyright 2019 Tomoki Hayashi
# MIT License (https://opensource.org/licenses/MIT)
# Adapted by Florian Lux 2021
import torch
from Layers.LayerNorm import LayerNorm
class DurationPredictor(torch.nn.Module):
"""
Duration predictor module.
This is a module of duration predictor described
in `FastSpeech: Fast, Robust and Controllable Text to Speech`_.
The duration predictor predicts a duration of each frame in log domain
from the hidden embeddings of encoder.
.. _`FastSpeech: Fast, Robust and Controllable Text to Speech`:
https://arxiv.org/pdf/1905.09263.pdf
Note:
The calculation domain of outputs is different
between in `forward` and in `inference`. In `forward`,
the outputs are calculated in log domain but in `inference`,
those are calculated in linear domain.
"""
def __init__(self, idim, n_layers=2, n_chans=384, kernel_size=3, dropout_rate=0.1, offset=1.0):
"""
Initialize duration predictor module.
Args:
idim (int): Input dimension.
n_layers (int, optional): Number of convolutional layers.
n_chans (int, optional): Number of channels of convolutional layers.
kernel_size (int, optional): Kernel size of convolutional layers.
dropout_rate (float, optional): Dropout rate.
offset (float, optional): Offset value to avoid nan in log domain.
"""
super(DurationPredictor, self).__init__()
self.offset = offset
self.conv = torch.nn.ModuleList()
for idx in range(n_layers):
in_chans = idim if idx == 0 else n_chans
self.conv += [torch.nn.Sequential(torch.nn.Conv1d(in_chans, n_chans, kernel_size, stride=1, padding=(kernel_size - 1) // 2, ), torch.nn.ReLU(),
LayerNorm(n_chans, dim=1), torch.nn.Dropout(dropout_rate), )]
self.linear = torch.nn.Linear(n_chans, 1)
def _forward(self, xs, x_masks=None, is_inference=False):
xs = xs.transpose(1, -1) # (B, idim, Tmax)
for f in self.conv:
xs = f(xs) # (B, C, Tmax)
# NOTE: calculate in log domain
xs = self.linear(xs.transpose(1, -1)).squeeze(-1) # (B, Tmax)
if is_inference:
# NOTE: calculate in linear domain
xs = torch.clamp(torch.round(xs.exp() - self.offset), min=0).long() # avoid negative value
if x_masks is not None:
xs = xs.masked_fill(x_masks, 0.0)
return xs
def forward(self, xs, x_masks=None):
"""
Calculate forward propagation.
Args:
xs (Tensor): Batch of input sequences (B, Tmax, idim).
x_masks (ByteTensor, optional):
Batch of masks indicating padded part (B, Tmax).
Returns:
Tensor: Batch of predicted durations in log domain (B, Tmax).
"""
return self._forward(xs, x_masks, False)
def inference(self, xs, x_masks=None):
"""
Inference duration.
Args:
xs (Tensor): Batch of input sequences (B, Tmax, idim).
x_masks (ByteTensor, optional):
Batch of masks indicating padded part (B, Tmax).
Returns:
LongTensor: Batch of predicted durations in linear domain (B, Tmax).
"""
return self._forward(xs, x_masks, True)
class DurationPredictorLoss(torch.nn.Module):
"""
Loss function module for duration predictor.
The loss value is Calculated in log domain to make it Gaussian.
"""
def __init__(self, offset=1.0, reduction="mean"):
"""
Args:
offset (float, optional): Offset value to avoid nan in log domain.
reduction (str): Reduction type in loss calculation.
"""
super(DurationPredictorLoss, self).__init__()
self.criterion = torch.nn.MSELoss(reduction=reduction)
self.offset = offset
def forward(self, outputs, targets):
"""
Calculate forward propagation.
Args:
outputs (Tensor): Batch of prediction durations in log domain (B, T)
targets (LongTensor): Batch of groundtruth durations in linear domain (B, T)
Returns:
Tensor: Mean squared error loss value.
Note:
`outputs` is in log domain but `targets` is in linear domain.
"""
# NOTE: outputs is in log domain while targets in linear
targets = torch.log(targets.float() + self.offset)
loss = self.criterion(outputs, targets)
return loss
|