# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import logging import math from typing import List, Tuple import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from funasr_detach.models.data2vec import utils from funasr_detach.models.data2vec.multihead_attention import MultiheadAttention class ConvFeatureExtractionModel(nn.Module): def __init__( self, conv_layers: List[Tuple[int, int, int]], dropout: float = 0.0, mode: str = "default", conv_bias: bool = False, in_d: int = 1, ): super().__init__() assert mode in {"default", "layer_norm"} def block( n_in, n_out, k, stride, is_layer_norm=False, is_group_norm=False, conv_bias=False, ): def make_conv(): conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias) nn.init.kaiming_normal_(conv.weight) return conv assert ( is_layer_norm and is_group_norm ) == False, "layer norm and group norm are exclusive" if is_layer_norm: return nn.Sequential( make_conv(), nn.Dropout(p=dropout), nn.Sequential( utils.TransposeLast(), utils.Fp32LayerNorm(dim, elementwise_affine=True), utils.TransposeLast(), ), nn.GELU(), ) elif is_group_norm: return nn.Sequential( make_conv(), nn.Dropout(p=dropout), utils.Fp32GroupNorm(dim, dim, affine=True), nn.GELU(), ) else: return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU()) self.conv_layers = nn.ModuleList() for i, cl in enumerate(conv_layers): assert len(cl) == 3, "invalid conv definition: " + str(cl) (dim, k, stride) = cl self.conv_layers.append( block( in_d, dim, k, stride, is_layer_norm=mode == "layer_norm", is_group_norm=mode == "default" and i == 0, conv_bias=conv_bias, ) ) in_d = dim def forward(self, x): if len(x.shape) == 2: x = x.unsqueeze(1) else: x = x.transpose(1, 2) for conv in self.conv_layers: x = conv(x) return x def make_conv_pos(e, k, g): pos_conv = nn.Conv1d( e, e, kernel_size=k, padding=k // 2, groups=g, ) dropout = 0 std = math.sqrt((4 * (1.0 - dropout)) / (k * e)) nn.init.normal_(pos_conv.weight, mean=0, std=std) nn.init.constant_(pos_conv.bias, 0) pos_conv = nn.utils.weight_norm(pos_conv, name="weight", dim=2) pos_conv = nn.Sequential(pos_conv, utils.SamePad(k), nn.GELU()) return pos_conv class TransformerEncoder(nn.Module): def build_encoder_layer(self): if self.layer_type == "transformer": layer = TransformerSentenceEncoderLayer( embedding_dim=self.embedding_dim, ffn_embedding_dim=self.encoder_ffn_embed_dim, num_attention_heads=self.encoder_attention_heads, dropout=self.dropout, attention_dropout=self.attention_dropout, activation_dropout=self.activation_dropout, activation_fn=self.activation_fn, layer_norm_first=self.layer_norm_first, ) else: logging.error("Only transformer is supported for data2vec now") return layer def __init__( self, # position dropout, encoder_embed_dim, required_seq_len_multiple, pos_conv_depth, conv_pos, conv_pos_groups, # transformer layers layer_type, encoder_layers, encoder_ffn_embed_dim, encoder_attention_heads, attention_dropout, activation_dropout, activation_fn, layer_norm_first, encoder_layerdrop, max_positions, ): super().__init__() # position self.dropout = dropout self.embedding_dim = encoder_embed_dim self.required_seq_len_multiple = required_seq_len_multiple if pos_conv_depth > 1: num_layers = pos_conv_depth k = max(3, conv_pos // num_layers) def make_conv_block(e, k, g, l): return nn.Sequential( *[ nn.Sequential( nn.Conv1d( e, e, kernel_size=k, padding=k // 2, groups=g, ), utils.SamePad(k), utils.TransposeLast(), torch.nn.LayerNorm(e, elementwise_affine=False), utils.TransposeLast(), nn.GELU(), ) for _ in range(l) ] ) self.pos_conv = make_conv_block( self.embedding_dim, k, conv_pos_groups, num_layers ) else: self.pos_conv = make_conv_pos( self.embedding_dim, conv_pos, conv_pos_groups, ) # transformer layers self.layer_type = layer_type self.encoder_ffn_embed_dim = encoder_ffn_embed_dim self.encoder_attention_heads = encoder_attention_heads self.attention_dropout = attention_dropout self.activation_dropout = activation_dropout self.activation_fn = activation_fn self.layer_norm_first = layer_norm_first self.layerdrop = encoder_layerdrop self.max_positions = max_positions self.layers = nn.ModuleList( [self.build_encoder_layer() for _ in range(encoder_layers)] ) self.layer_norm = torch.nn.LayerNorm(self.embedding_dim) self.apply(utils.init_bert_params) def forward(self, x, padding_mask=None, layer=None): x, layer_results = self.extract_features(x, padding_mask, layer) if self.layer_norm_first and layer is None: x = self.layer_norm(x) return x, layer_results def extract_features( self, x, padding_mask=None, tgt_layer=None, min_layer=0, ): if padding_mask is not None: x[padding_mask] = 0 x_conv = self.pos_conv(x.transpose(1, 2)) x_conv = x_conv.transpose(1, 2) x = x + x_conv if not self.layer_norm_first: x = self.layer_norm(x) # pad to the sequence length dimension x, pad_length = utils.pad_to_multiple( x, self.required_seq_len_multiple, dim=-2, value=0 ) if pad_length > 0 and padding_mask is None: padding_mask = x.new_zeros((x.size(0), x.size(1)), dtype=torch.bool) padding_mask[:, -pad_length:] = True else: padding_mask, _ = utils.pad_to_multiple( padding_mask, self.required_seq_len_multiple, dim=-1, value=True ) x = F.dropout(x, p=self.dropout, training=self.training) # B x T x C -> T x B x C x = x.transpose(0, 1) layer_results = [] r = None for i, layer in enumerate(self.layers): dropout_probability = np.random.random() if self.layerdrop > 0 else 1 if not self.training or (dropout_probability > self.layerdrop): x, (z, lr) = layer(x, self_attn_padding_mask=padding_mask) if i >= min_layer: layer_results.append((x, z, lr)) if i == tgt_layer: r = x break if r is not None: x = r # T x B x C -> B x T x C x = x.transpose(0, 1) # undo paddding if pad_length > 0: x = x[:, :-pad_length] def undo_pad(a, b, c): return ( a[:-pad_length], b[:-pad_length] if b is not None else b, c[:-pad_length], ) layer_results = [undo_pad(*u) for u in layer_results] return x, layer_results def max_positions(self): """Maximum output length supported by the encoder.""" return self.max_positions def upgrade_state_dict_named(self, state_dict, name): """Upgrade a (possibly old) state dict for new versions of fairseq.""" return state_dict class TransformerSentenceEncoderLayer(nn.Module): """ Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained models. """ def __init__( self, embedding_dim: int = 768, ffn_embedding_dim: int = 3072, num_attention_heads: int = 8, dropout: float = 0.1, attention_dropout: float = 0.1, activation_dropout: float = 0.1, activation_fn: str = "relu", layer_norm_first: bool = False, ) -> None: super().__init__() # Initialize parameters self.embedding_dim = embedding_dim self.dropout = dropout self.activation_dropout = activation_dropout # Initialize blocks self.activation_fn = utils.get_activation_fn(activation_fn) self.self_attn = MultiheadAttention( self.embedding_dim, num_attention_heads, dropout=attention_dropout, self_attention=True, ) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(self.activation_dropout) self.dropout3 = nn.Dropout(dropout) self.layer_norm_first = layer_norm_first # layer norm associated with the self attention layer self.self_attn_layer_norm = torch.nn.LayerNorm(self.embedding_dim) self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim) self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim) # layer norm associated with the position wise feed-forward NN self.final_layer_norm = torch.nn.LayerNorm(self.embedding_dim) def forward( self, x: torch.Tensor, # (T, B, C) self_attn_mask: torch.Tensor = None, self_attn_padding_mask: torch.Tensor = None, ): """ LayerNorm is applied either before or after the self-attention/ffn modules similar to the original Transformer imlementation. """ residual = x if self.layer_norm_first: x = self.self_attn_layer_norm(x) x, attn = self.self_attn( query=x, key=x, value=x, key_padding_mask=self_attn_padding_mask, attn_mask=self_attn_mask, need_weights=False, ) x = self.dropout1(x) x = residual + x residual = x x = self.final_layer_norm(x) x = self.activation_fn(self.fc1(x)) x = self.dropout2(x) x = self.fc2(x) layer_result = x x = self.dropout3(x) x = residual + x else: x, attn = self.self_attn( query=x, key=x, value=x, key_padding_mask=self_attn_padding_mask, need_weights=False, ) x = self.dropout1(x) x = residual + x x = self.self_attn_layer_norm(x) residual = x x = self.activation_fn(self.fc1(x)) x = self.dropout2(x) x = self.fc2(x) layer_result = x x = self.dropout3(x) x = residual + x x = self.final_layer_norm(x) return x, (attn, layer_result)