antoniomae1234's picture
changes in flenema
2493d72 verified
import math
import torch
from torch import nn
from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer
from TTS.tts.layers.generic.res_conv_bn import ResidualConv1dBNBlock
class PositionalEncoding(nn.Module):
"""Sinusoidal positional encoding for non-recurrent neural networks.
Implementation based on "Attention Is All You Need"
Args:
channels (int): embedding size
dropout (float): dropout parameter
"""
def __init__(self, channels, dropout=0.0, max_len=5000):
super().__init__()
if channels % 2 != 0:
raise ValueError(
"Cannot use sin/cos positional encoding with "
"odd channels (got channels={:d})".format(channels))
pe = torch.zeros(max_len, channels)
position = torch.arange(0, max_len).unsqueeze(1)
div_term = torch.exp((torch.arange(0, channels, 2, dtype=torch.float) *
-(math.log(10000.0) / channels)))
pe[:, 0::2] = torch.sin(position.float() * div_term)
pe[:, 1::2] = torch.cos(position.float() * div_term)
pe = pe.unsqueeze(0).transpose(1, 2)
self.register_buffer('pe', pe)
if dropout > 0:
self.dropout = nn.Dropout(p=dropout)
self.channels = channels
def forward(self, x, mask=None, first_idx=None, last_idx=None):
"""
Shapes:
x: [B, C, T]
mask: [B, 1, T]
first_idx: int
last_idx: int
"""
x = x * math.sqrt(self.channels)
if first_idx is None:
if self.pe.size(2) < x.size(2):
raise RuntimeError(
f"Sequence is {x.size(2)} but PositionalEncoding is"
f" limited to {self.pe.size(2)}. See max_len argument.")
if mask is not None:
pos_enc = (self.pe[:, :, :x.size(2)] * mask)
else:
pos_enc = self.pe[:, :, :x.size(2)]
x = x + pos_enc
else:
x = x + self.pe[:, :, first_idx:last_idx]
if hasattr(self, 'dropout'):
x = self.dropout(x)
return x
class RelativePositionTransformerEncoder(nn.Module):
"""Speedy speech encoder built on Transformer with Relative Position encoding.
TODO: Integrate speaker conditioning vector.
Args:
in_channels (int): number of input channels.
out_channels (int): number of output channels.
hidden_channels (int): number of hidden channels
params (dict): dictionary for residual convolutional blocks.
"""
def __init__(self, in_channels, out_channels, hidden_channels, params):
super().__init__()
self.prenet = ResidualConv1dBNBlock(in_channels,
hidden_channels,
hidden_channels,
kernel_size=5,
num_res_blocks=3,
num_conv_blocks=1,
dilations=[1, 1, 1]
)
self.rel_pos_transformer = RelativePositionTransformer(
hidden_channels, out_channels, hidden_channels, **params)
def forward(self, x, x_mask=None, g=None): # pylint: disable=unused-argument
if x_mask is None:
x_mask = 1
o = self.prenet(x) * x_mask
o = self.rel_pos_transformer(o, x_mask)
return o
class ResidualConv1dBNEncoder(nn.Module):
"""Residual Convolutional Encoder as in the original Speedy Speech paper
TODO: Integrate speaker conditioning vector.
Args:
in_channels (int): number of input channels.
out_channels (int): number of output channels.
hidden_channels (int): number of hidden channels
params (dict): dictionary for residual convolutional blocks.
"""
def __init__(self, in_channels, out_channels, hidden_channels, params):
super().__init__()
self.prenet = nn.Sequential(
nn.Conv1d(in_channels, hidden_channels, 1),
nn.ReLU())
self.res_conv_block = ResidualConv1dBNBlock(hidden_channels,
hidden_channels,
hidden_channels, **params)
self.postnet = nn.Sequential(*[
nn.Conv1d(hidden_channels, hidden_channels, 1),
nn.ReLU(),
nn.BatchNorm1d(hidden_channels),
nn.Conv1d(hidden_channels, out_channels, 1)
])
def forward(self, x, x_mask=None, g=None): # pylint: disable=unused-argument
if x_mask is None:
x_mask = 1
o = self.prenet(x) * x_mask
o = self.res_conv_block(o, x_mask)
o = self.postnet(o + x) * x_mask
return o * x_mask
class Encoder(nn.Module):
# pylint: disable=dangerous-default-value
"""Factory class for Speedy Speech encoder enables different encoder types internally.
Args:
num_chars (int): number of characters.
out_channels (int): number of output channels.
in_hidden_channels (int): input and hidden channels. Model keeps the input channels for the intermediate layers.
encoder_type (str): encoder layer types. 'transformers' or 'residual_conv_bn'. Default 'residual_conv_bn'.
encoder_params (dict): model parameters for specified encoder type.
c_in_channels (int): number of channels for conditional input.
Note:
Default encoder_params...
for 'transformer'
encoder_params={
'hidden_channels_ffn': 128,
'num_heads': 2,
"kernel_size": 3,
"dropout_p": 0.1,
"num_layers": 6,
"rel_attn_window_size": 4,
"input_length": None
},
for 'residual_conv_bn'
encoder_params = {
"kernel_size": 4,
"dilations": 4 * [1, 2, 4] + [1],
"num_conv_blocks": 2,
"num_res_blocks": 13
}
"""
def __init__(
self,
in_hidden_channels,
out_channels,
encoder_type='residual_conv_bn',
encoder_params={
"kernel_size": 4,
"dilations": 4 * [1, 2, 4] + [1],
"num_conv_blocks": 2,
"num_res_blocks": 13
},
c_in_channels=0):
super().__init__()
self.out_channels = out_channels
self.in_channels = in_hidden_channels
self.hidden_channels = in_hidden_channels
self.encoder_type = encoder_type
self.c_in_channels = c_in_channels
# init encoder
if encoder_type.lower() == "transformer":
# text encoder
self.encoder = RelativePositionTransformerEncoder(in_hidden_channels,
out_channels,
in_hidden_channels,
encoder_params) # pylint: disable=unexpected-keyword-arg
elif encoder_type.lower() == 'residual_conv_bn':
self.encoder = ResidualConv1dBNEncoder(in_hidden_channels,
out_channels,
in_hidden_channels,
encoder_params)
else:
raise NotImplementedError(' [!] unknown encoder type.')
# final projection layers
def forward(self, x, x_mask, g=None): # pylint: disable=unused-argument
"""
Shapes:
x: [B, C, T]
x_mask: [B, 1, T]
g: [B, C, 1]
"""
o = self.encoder(x, x_mask)
return o * x_mask