""" Taken from ESPNet """ import torch import torch.nn.functional as F from .Attention import RelPositionMultiHeadedAttention from .Convolution import ConvolutionModule from .EncoderLayer import EncoderLayer from .LayerNorm import LayerNorm from .MultiLayeredConv1d import MultiLayeredConv1d from .MultiSequential import repeat from .PositionalEncoding import RelPositionalEncoding from .Swish import Swish class Conformer(torch.nn.Module): """ Conformer encoder module. Args: idim (int): Input dimension. attention_dim (int): Dimension of attention. attention_heads (int): The number of heads of multi head attention. linear_units (int): The number of units of position-wise feed forward. num_blocks (int): The number of decoder blocks. dropout_rate (float): Dropout rate. positional_dropout_rate (float): Dropout rate after adding positional encoding. attention_dropout_rate (float): Dropout rate in attention. input_layer (Union[str, torch.nn.Module]): Input layer type. normalize_before (bool): Whether to use layer_norm before the first block. concat_after (bool): Whether to concat attention layer's input and output. if True, additional linear will be applied. i.e. x -> x + linear(concat(x, att(x))) if False, no additional linear will be applied. i.e. x -> x + att(x) positionwise_layer_type (str): "linear", "conv1d", or "conv1d-linear". positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer. macaron_style (bool): Whether to use macaron style for positionwise layer. pos_enc_layer_type (str): Conformer positional encoding layer type. selfattention_layer_type (str): Conformer attention layer type. activation_type (str): Conformer activation function type. use_cnn_module (bool): Whether to use convolution module. cnn_module_kernel (int): Kernerl size of convolution module. padding_idx (int): Padding idx for input_layer=embed. """ def __init__(self, idim, attention_dim=256, attention_heads=4, linear_units=2048, num_blocks=6, dropout_rate=0.1, positional_dropout_rate=0.1, attention_dropout_rate=0.0, input_layer="conv2d", normalize_before=True, concat_after=False, positionwise_conv_kernel_size=1, macaron_style=False, use_cnn_module=False, cnn_module_kernel=31, zero_triu=False, utt_embed=None, connect_utt_emb_at_encoder_out=True, spk_emb_bottleneck_size=128, lang_embs=None): super(Conformer, self).__init__() activation = Swish() self.conv_subsampling_factor = 1 if isinstance(input_layer, torch.nn.Module): self.embed = input_layer self.pos_enc = RelPositionalEncoding(attention_dim, positional_dropout_rate) elif input_layer is None: self.embed = None self.pos_enc = torch.nn.Sequential(RelPositionalEncoding(attention_dim, positional_dropout_rate)) else: raise ValueError("unknown input_layer: " + input_layer) self.normalize_before = normalize_before self.connect_utt_emb_at_encoder_out = connect_utt_emb_at_encoder_out if utt_embed is not None: self.hs_emb_projection = torch.nn.Linear(attention_dim + spk_emb_bottleneck_size, attention_dim) # embedding projection derived from https://arxiv.org/pdf/1705.08947.pdf self.embedding_projection = torch.nn.Sequential(torch.nn.Linear(utt_embed, spk_emb_bottleneck_size), torch.nn.Softsign()) if lang_embs is not None: self.language_embedding = torch.nn.Embedding(num_embeddings=lang_embs, embedding_dim=attention_dim) # self-attention module definition encoder_selfattn_layer = RelPositionMultiHeadedAttention encoder_selfattn_layer_args = (attention_heads, attention_dim, attention_dropout_rate, zero_triu) # feed-forward module definition positionwise_layer = MultiLayeredConv1d positionwise_layer_args = (attention_dim, linear_units, positionwise_conv_kernel_size, dropout_rate,) # convolution module definition convolution_layer = ConvolutionModule convolution_layer_args = (attention_dim, cnn_module_kernel, activation) self.encoders = repeat(num_blocks, lambda lnum: EncoderLayer(attention_dim, encoder_selfattn_layer(*encoder_selfattn_layer_args), positionwise_layer(*positionwise_layer_args), positionwise_layer(*positionwise_layer_args) if macaron_style else None, convolution_layer(*convolution_layer_args) if use_cnn_module else None, dropout_rate, normalize_before, concat_after)) if self.normalize_before: self.after_norm = LayerNorm(attention_dim) def forward(self, xs, masks, utterance_embedding=None, lang_ids=None): """ Encode input sequence. Args: utterance_embedding: embedding containing lots of conditioning signals step: indicator for when to start updating the embedding function xs (torch.Tensor): Input tensor (#batch, time, idim). masks (torch.Tensor): Mask tensor (#batch, time). Returns: torch.Tensor: Output tensor (#batch, time, attention_dim). torch.Tensor: Mask tensor (#batch, time). """ if self.embed is not None: xs = self.embed(xs) if lang_ids is not None: lang_embs = self.language_embedding(lang_ids) xs = xs + lang_embs # offset the phoneme distribution of a language if utterance_embedding is not None and not self.connect_utt_emb_at_encoder_out: xs = self._integrate_with_utt_embed(xs, utterance_embedding) xs = self.pos_enc(xs) xs, masks = self.encoders(xs, masks) if isinstance(xs, tuple): xs = xs[0] if self.normalize_before: xs = self.after_norm(xs) if utterance_embedding is not None and self.connect_utt_emb_at_encoder_out: xs = self._integrate_with_utt_embed(xs, utterance_embedding) return xs, masks def _integrate_with_utt_embed(self, hs, utt_embeddings): # project embedding into smaller space speaker_embeddings_projected = self.embedding_projection(utt_embeddings) # concat hidden states with spk embeds and then apply projection speaker_embeddings_expanded = F.normalize(speaker_embeddings_projected).unsqueeze(1).expand(-1, hs.size(1), -1) hs = self.hs_emb_projection(torch.cat([hs, speaker_embeddings_expanded], dim=-1)) return hs