import math import torch from torch import nn from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer from TTS.tts.layers.generic.res_conv_bn import ResidualConv1dBNBlock class PositionalEncoding(nn.Module): """Sinusoidal positional encoding for non-recurrent neural networks. Implementation based on "Attention Is All You Need" Args: channels (int): embedding size dropout (float): dropout parameter """ def __init__(self, channels, dropout=0.0, max_len=5000): super().__init__() if channels % 2 != 0: raise ValueError( "Cannot use sin/cos positional encoding with " "odd channels (got channels={:d})".format(channels)) pe = torch.zeros(max_len, channels) position = torch.arange(0, max_len).unsqueeze(1) div_term = torch.exp((torch.arange(0, channels, 2, dtype=torch.float) * -(math.log(10000.0) / channels))) pe[:, 0::2] = torch.sin(position.float() * div_term) pe[:, 1::2] = torch.cos(position.float() * div_term) pe = pe.unsqueeze(0).transpose(1, 2) self.register_buffer('pe', pe) if dropout > 0: self.dropout = nn.Dropout(p=dropout) self.channels = channels def forward(self, x, mask=None, first_idx=None, last_idx=None): """ Shapes: x: [B, C, T] mask: [B, 1, T] first_idx: int last_idx: int """ x = x * math.sqrt(self.channels) if first_idx is None: if self.pe.size(2) < x.size(2): raise RuntimeError( f"Sequence is {x.size(2)} but PositionalEncoding is" f" limited to {self.pe.size(2)}. See max_len argument.") if mask is not None: pos_enc = (self.pe[:, :, :x.size(2)] * mask) else: pos_enc = self.pe[:, :, :x.size(2)] x = x + pos_enc else: x = x + self.pe[:, :, first_idx:last_idx] if hasattr(self, 'dropout'): x = self.dropout(x) return x class RelativePositionTransformerEncoder(nn.Module): """Speedy speech encoder built on Transformer with Relative Position encoding. TODO: Integrate speaker conditioning vector. Args: in_channels (int): number of input channels. out_channels (int): number of output channels. hidden_channels (int): number of hidden channels params (dict): dictionary for residual convolutional blocks. """ def __init__(self, in_channels, out_channels, hidden_channels, params): super().__init__() self.prenet = ResidualConv1dBNBlock(in_channels, hidden_channels, hidden_channels, kernel_size=5, num_res_blocks=3, num_conv_blocks=1, dilations=[1, 1, 1] ) self.rel_pos_transformer = RelativePositionTransformer( hidden_channels, out_channels, hidden_channels, **params) def forward(self, x, x_mask=None, g=None): # pylint: disable=unused-argument if x_mask is None: x_mask = 1 o = self.prenet(x) * x_mask o = self.rel_pos_transformer(o, x_mask) return o class ResidualConv1dBNEncoder(nn.Module): """Residual Convolutional Encoder as in the original Speedy Speech paper TODO: Integrate speaker conditioning vector. Args: in_channels (int): number of input channels. out_channels (int): number of output channels. hidden_channels (int): number of hidden channels params (dict): dictionary for residual convolutional blocks. """ def __init__(self, in_channels, out_channels, hidden_channels, params): super().__init__() self.prenet = nn.Sequential( nn.Conv1d(in_channels, hidden_channels, 1), nn.ReLU()) self.res_conv_block = ResidualConv1dBNBlock(hidden_channels, hidden_channels, hidden_channels, **params) self.postnet = nn.Sequential(*[ nn.Conv1d(hidden_channels, hidden_channels, 1), nn.ReLU(), nn.BatchNorm1d(hidden_channels), nn.Conv1d(hidden_channels, out_channels, 1) ]) def forward(self, x, x_mask=None, g=None): # pylint: disable=unused-argument if x_mask is None: x_mask = 1 o = self.prenet(x) * x_mask o = self.res_conv_block(o, x_mask) o = self.postnet(o + x) * x_mask return o * x_mask class Encoder(nn.Module): # pylint: disable=dangerous-default-value """Factory class for Speedy Speech encoder enables different encoder types internally. Args: num_chars (int): number of characters. out_channels (int): number of output channels. in_hidden_channels (int): input and hidden channels. Model keeps the input channels for the intermediate layers. encoder_type (str): encoder layer types. 'transformers' or 'residual_conv_bn'. Default 'residual_conv_bn'. encoder_params (dict): model parameters for specified encoder type. c_in_channels (int): number of channels for conditional input. Note: Default encoder_params... for 'transformer' encoder_params={ 'hidden_channels_ffn': 128, 'num_heads': 2, "kernel_size": 3, "dropout_p": 0.1, "num_layers": 6, "rel_attn_window_size": 4, "input_length": None }, for 'residual_conv_bn' encoder_params = { "kernel_size": 4, "dilations": 4 * [1, 2, 4] + [1], "num_conv_blocks": 2, "num_res_blocks": 13 } """ def __init__( self, in_hidden_channels, out_channels, encoder_type='residual_conv_bn', encoder_params={ "kernel_size": 4, "dilations": 4 * [1, 2, 4] + [1], "num_conv_blocks": 2, "num_res_blocks": 13 }, c_in_channels=0): super().__init__() self.out_channels = out_channels self.in_channels = in_hidden_channels self.hidden_channels = in_hidden_channels self.encoder_type = encoder_type self.c_in_channels = c_in_channels # init encoder if encoder_type.lower() == "transformer": # text encoder self.encoder = RelativePositionTransformerEncoder(in_hidden_channels, out_channels, in_hidden_channels, encoder_params) # pylint: disable=unexpected-keyword-arg elif encoder_type.lower() == 'residual_conv_bn': self.encoder = ResidualConv1dBNEncoder(in_hidden_channels, out_channels, in_hidden_channels, encoder_params) else: raise NotImplementedError(' [!] unknown encoder type.') # final projection layers def forward(self, x, x_mask, g=None): # pylint: disable=unused-argument """ Shapes: x: [B, C, T] x_mask: [B, 1, T] g: [B, C, 1] """ o = self.encoder(x, x_mask) return o * x_mask