mrfakename's picture
Super-squash branch 'main' using huggingface_hub
0102e16 verified
raw
history blame
12.5 kB
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import math
from typing import List, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from funasr_detach.models.data2vec import utils
from funasr_detach.models.data2vec.multihead_attention import MultiheadAttention
class ConvFeatureExtractionModel(nn.Module):
def __init__(
self,
conv_layers: List[Tuple[int, int, int]],
dropout: float = 0.0,
mode: str = "default",
conv_bias: bool = False,
in_d: int = 1,
):
super().__init__()
assert mode in {"default", "layer_norm"}
def block(
n_in,
n_out,
k,
stride,
is_layer_norm=False,
is_group_norm=False,
conv_bias=False,
):
def make_conv():
conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias)
nn.init.kaiming_normal_(conv.weight)
return conv
assert (
is_layer_norm and is_group_norm
) == False, "layer norm and group norm are exclusive"
if is_layer_norm:
return nn.Sequential(
make_conv(),
nn.Dropout(p=dropout),
nn.Sequential(
utils.TransposeLast(),
utils.Fp32LayerNorm(dim, elementwise_affine=True),
utils.TransposeLast(),
),
nn.GELU(),
)
elif is_group_norm:
return nn.Sequential(
make_conv(),
nn.Dropout(p=dropout),
utils.Fp32GroupNorm(dim, dim, affine=True),
nn.GELU(),
)
else:
return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU())
self.conv_layers = nn.ModuleList()
for i, cl in enumerate(conv_layers):
assert len(cl) == 3, "invalid conv definition: " + str(cl)
(dim, k, stride) = cl
self.conv_layers.append(
block(
in_d,
dim,
k,
stride,
is_layer_norm=mode == "layer_norm",
is_group_norm=mode == "default" and i == 0,
conv_bias=conv_bias,
)
)
in_d = dim
def forward(self, x):
if len(x.shape) == 2:
x = x.unsqueeze(1)
else:
x = x.transpose(1, 2)
for conv in self.conv_layers:
x = conv(x)
return x
def make_conv_pos(e, k, g):
pos_conv = nn.Conv1d(
e,
e,
kernel_size=k,
padding=k // 2,
groups=g,
)
dropout = 0
std = math.sqrt((4 * (1.0 - dropout)) / (k * e))
nn.init.normal_(pos_conv.weight, mean=0, std=std)
nn.init.constant_(pos_conv.bias, 0)
pos_conv = nn.utils.weight_norm(pos_conv, name="weight", dim=2)
pos_conv = nn.Sequential(pos_conv, utils.SamePad(k), nn.GELU())
return pos_conv
class TransformerEncoder(nn.Module):
def build_encoder_layer(self):
if self.layer_type == "transformer":
layer = TransformerSentenceEncoderLayer(
embedding_dim=self.embedding_dim,
ffn_embedding_dim=self.encoder_ffn_embed_dim,
num_attention_heads=self.encoder_attention_heads,
dropout=self.dropout,
attention_dropout=self.attention_dropout,
activation_dropout=self.activation_dropout,
activation_fn=self.activation_fn,
layer_norm_first=self.layer_norm_first,
)
else:
logging.error("Only transformer is supported for data2vec now")
return layer
def __init__(
self,
# position
dropout,
encoder_embed_dim,
required_seq_len_multiple,
pos_conv_depth,
conv_pos,
conv_pos_groups,
# transformer layers
layer_type,
encoder_layers,
encoder_ffn_embed_dim,
encoder_attention_heads,
attention_dropout,
activation_dropout,
activation_fn,
layer_norm_first,
encoder_layerdrop,
max_positions,
):
super().__init__()
# position
self.dropout = dropout
self.embedding_dim = encoder_embed_dim
self.required_seq_len_multiple = required_seq_len_multiple
if pos_conv_depth > 1:
num_layers = pos_conv_depth
k = max(3, conv_pos // num_layers)
def make_conv_block(e, k, g, l):
return nn.Sequential(
*[
nn.Sequential(
nn.Conv1d(
e,
e,
kernel_size=k,
padding=k // 2,
groups=g,
),
utils.SamePad(k),
utils.TransposeLast(),
torch.nn.LayerNorm(e, elementwise_affine=False),
utils.TransposeLast(),
nn.GELU(),
)
for _ in range(l)
]
)
self.pos_conv = make_conv_block(
self.embedding_dim, k, conv_pos_groups, num_layers
)
else:
self.pos_conv = make_conv_pos(
self.embedding_dim,
conv_pos,
conv_pos_groups,
)
# transformer layers
self.layer_type = layer_type
self.encoder_ffn_embed_dim = encoder_ffn_embed_dim
self.encoder_attention_heads = encoder_attention_heads
self.attention_dropout = attention_dropout
self.activation_dropout = activation_dropout
self.activation_fn = activation_fn
self.layer_norm_first = layer_norm_first
self.layerdrop = encoder_layerdrop
self.max_positions = max_positions
self.layers = nn.ModuleList(
[self.build_encoder_layer() for _ in range(encoder_layers)]
)
self.layer_norm = torch.nn.LayerNorm(self.embedding_dim)
self.apply(utils.init_bert_params)
def forward(self, x, padding_mask=None, layer=None):
x, layer_results = self.extract_features(x, padding_mask, layer)
if self.layer_norm_first and layer is None:
x = self.layer_norm(x)
return x, layer_results
def extract_features(
self,
x,
padding_mask=None,
tgt_layer=None,
min_layer=0,
):
if padding_mask is not None:
x[padding_mask] = 0
x_conv = self.pos_conv(x.transpose(1, 2))
x_conv = x_conv.transpose(1, 2)
x = x + x_conv
if not self.layer_norm_first:
x = self.layer_norm(x)
# pad to the sequence length dimension
x, pad_length = utils.pad_to_multiple(
x, self.required_seq_len_multiple, dim=-2, value=0
)
if pad_length > 0 and padding_mask is None:
padding_mask = x.new_zeros((x.size(0), x.size(1)), dtype=torch.bool)
padding_mask[:, -pad_length:] = True
else:
padding_mask, _ = utils.pad_to_multiple(
padding_mask, self.required_seq_len_multiple, dim=-1, value=True
)
x = F.dropout(x, p=self.dropout, training=self.training)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
layer_results = []
r = None
for i, layer in enumerate(self.layers):
dropout_probability = np.random.random() if self.layerdrop > 0 else 1
if not self.training or (dropout_probability > self.layerdrop):
x, (z, lr) = layer(x, self_attn_padding_mask=padding_mask)
if i >= min_layer:
layer_results.append((x, z, lr))
if i == tgt_layer:
r = x
break
if r is not None:
x = r
# T x B x C -> B x T x C
x = x.transpose(0, 1)
# undo paddding
if pad_length > 0:
x = x[:, :-pad_length]
def undo_pad(a, b, c):
return (
a[:-pad_length],
b[:-pad_length] if b is not None else b,
c[:-pad_length],
)
layer_results = [undo_pad(*u) for u in layer_results]
return x, layer_results
def max_positions(self):
"""Maximum output length supported by the encoder."""
return self.max_positions
def upgrade_state_dict_named(self, state_dict, name):
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
return state_dict
class TransformerSentenceEncoderLayer(nn.Module):
"""
Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained
models.
"""
def __init__(
self,
embedding_dim: int = 768,
ffn_embedding_dim: int = 3072,
num_attention_heads: int = 8,
dropout: float = 0.1,
attention_dropout: float = 0.1,
activation_dropout: float = 0.1,
activation_fn: str = "relu",
layer_norm_first: bool = False,
) -> None:
super().__init__()
# Initialize parameters
self.embedding_dim = embedding_dim
self.dropout = dropout
self.activation_dropout = activation_dropout
# Initialize blocks
self.activation_fn = utils.get_activation_fn(activation_fn)
self.self_attn = MultiheadAttention(
self.embedding_dim,
num_attention_heads,
dropout=attention_dropout,
self_attention=True,
)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(self.activation_dropout)
self.dropout3 = nn.Dropout(dropout)
self.layer_norm_first = layer_norm_first
# layer norm associated with the self attention layer
self.self_attn_layer_norm = torch.nn.LayerNorm(self.embedding_dim)
self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
# layer norm associated with the position wise feed-forward NN
self.final_layer_norm = torch.nn.LayerNorm(self.embedding_dim)
def forward(
self,
x: torch.Tensor, # (T, B, C)
self_attn_mask: torch.Tensor = None,
self_attn_padding_mask: torch.Tensor = None,
):
"""
LayerNorm is applied either before or after the self-attention/ffn
modules similar to the original Transformer imlementation.
"""
residual = x
if self.layer_norm_first:
x = self.self_attn_layer_norm(x)
x, attn = self.self_attn(
query=x,
key=x,
value=x,
key_padding_mask=self_attn_padding_mask,
attn_mask=self_attn_mask,
need_weights=False,
)
x = self.dropout1(x)
x = residual + x
residual = x
x = self.final_layer_norm(x)
x = self.activation_fn(self.fc1(x))
x = self.dropout2(x)
x = self.fc2(x)
layer_result = x
x = self.dropout3(x)
x = residual + x
else:
x, attn = self.self_attn(
query=x,
key=x,
value=x,
key_padding_mask=self_attn_padding_mask,
need_weights=False,
)
x = self.dropout1(x)
x = residual + x
x = self.self_attn_layer_norm(x)
residual = x
x = self.activation_fn(self.fc1(x))
x = self.dropout2(x)
x = self.fc2(x)
layer_result = x
x = self.dropout3(x)
x = residual + x
x = self.final_layer_norm(x)
return x, (attn, layer_result)