Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) 2023 Amphion. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import math | |
import torch | |
from torch import nn | |
from typing import Optional, Any | |
from torch import Tensor | |
import torch.nn.functional as F | |
import torchaudio | |
import torchaudio.functional as audio_F | |
import random | |
random.seed(0) | |
def _get_activation_fn(activ): | |
if activ == "relu": | |
return nn.ReLU() | |
elif activ == "lrelu": | |
return nn.LeakyReLU(0.2) | |
elif activ == "swish": | |
return lambda x: x * torch.sigmoid(x) | |
else: | |
raise RuntimeError( | |
"Unexpected activ type %s, expected [relu, lrelu, swish]" % activ | |
) | |
class LinearNorm(torch.nn.Module): | |
def __init__(self, in_dim, out_dim, bias=True, w_init_gain="linear"): | |
super(LinearNorm, self).__init__() | |
self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias) | |
torch.nn.init.xavier_uniform_( | |
self.linear_layer.weight, gain=torch.nn.init.calculate_gain(w_init_gain) | |
) | |
def forward(self, x): | |
return self.linear_layer(x) | |
class ConvNorm(torch.nn.Module): | |
def __init__( | |
self, | |
in_channels, | |
out_channels, | |
kernel_size=1, | |
stride=1, | |
padding=None, | |
dilation=1, | |
bias=True, | |
w_init_gain="linear", | |
param=None, | |
): | |
super(ConvNorm, self).__init__() | |
if padding is None: | |
assert kernel_size % 2 == 1 | |
padding = int(dilation * (kernel_size - 1) / 2) | |
self.conv = torch.nn.Conv1d( | |
in_channels, | |
out_channels, | |
kernel_size=kernel_size, | |
stride=stride, | |
padding=padding, | |
dilation=dilation, | |
bias=bias, | |
) | |
torch.nn.init.xavier_uniform_( | |
self.conv.weight, | |
gain=torch.nn.init.calculate_gain(w_init_gain, param=param), | |
) | |
def forward(self, signal): | |
conv_signal = self.conv(signal) | |
return conv_signal | |
class CausualConv(nn.Module): | |
def __init__( | |
self, | |
in_channels, | |
out_channels, | |
kernel_size=1, | |
stride=1, | |
padding=1, | |
dilation=1, | |
bias=True, | |
w_init_gain="linear", | |
param=None, | |
): | |
super(CausualConv, self).__init__() | |
if padding is None: | |
assert kernel_size % 2 == 1 | |
padding = int(dilation * (kernel_size - 1) / 2) * 2 | |
else: | |
self.padding = padding * 2 | |
self.conv = nn.Conv1d( | |
in_channels, | |
out_channels, | |
kernel_size=kernel_size, | |
stride=stride, | |
padding=self.padding, | |
dilation=dilation, | |
bias=bias, | |
) | |
torch.nn.init.xavier_uniform_( | |
self.conv.weight, | |
gain=torch.nn.init.calculate_gain(w_init_gain, param=param), | |
) | |
def forward(self, x): | |
x = self.conv(x) | |
x = x[:, :, : -self.padding] | |
return x | |
class CausualBlock(nn.Module): | |
def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2, activ="lrelu"): | |
super(CausualBlock, self).__init__() | |
self.blocks = nn.ModuleList( | |
[ | |
self._get_conv( | |
hidden_dim, dilation=3**i, activ=activ, dropout_p=dropout_p | |
) | |
for i in range(n_conv) | |
] | |
) | |
def forward(self, x): | |
for block in self.blocks: | |
res = x | |
x = block(x) | |
x += res | |
return x | |
def _get_conv(self, hidden_dim, dilation, activ="lrelu", dropout_p=0.2): | |
layers = [ | |
CausualConv( | |
hidden_dim, | |
hidden_dim, | |
kernel_size=3, | |
padding=dilation, | |
dilation=dilation, | |
), | |
_get_activation_fn(activ), | |
nn.BatchNorm1d(hidden_dim), | |
nn.Dropout(p=dropout_p), | |
CausualConv(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1), | |
_get_activation_fn(activ), | |
nn.Dropout(p=dropout_p), | |
] | |
return nn.Sequential(*layers) | |
class ConvBlock(nn.Module): | |
def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2, activ="relu"): | |
super().__init__() | |
self._n_groups = 8 | |
self.blocks = nn.ModuleList( | |
[ | |
self._get_conv( | |
hidden_dim, dilation=3**i, activ=activ, dropout_p=dropout_p | |
) | |
for i in range(n_conv) | |
] | |
) | |
def forward(self, x): | |
for block in self.blocks: | |
res = x | |
x = block(x) | |
x += res | |
return x | |
def _get_conv(self, hidden_dim, dilation, activ="relu", dropout_p=0.2): | |
layers = [ | |
ConvNorm( | |
hidden_dim, | |
hidden_dim, | |
kernel_size=3, | |
padding=dilation, | |
dilation=dilation, | |
), | |
_get_activation_fn(activ), | |
nn.GroupNorm(num_groups=self._n_groups, num_channels=hidden_dim), | |
nn.Dropout(p=dropout_p), | |
ConvNorm(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1), | |
_get_activation_fn(activ), | |
nn.Dropout(p=dropout_p), | |
] | |
return nn.Sequential(*layers) | |
class LocationLayer(nn.Module): | |
def __init__(self, attention_n_filters, attention_kernel_size, attention_dim): | |
super(LocationLayer, self).__init__() | |
padding = int((attention_kernel_size - 1) / 2) | |
self.location_conv = ConvNorm( | |
2, | |
attention_n_filters, | |
kernel_size=attention_kernel_size, | |
padding=padding, | |
bias=False, | |
stride=1, | |
dilation=1, | |
) | |
self.location_dense = LinearNorm( | |
attention_n_filters, attention_dim, bias=False, w_init_gain="tanh" | |
) | |
def forward(self, attention_weights_cat): | |
processed_attention = self.location_conv(attention_weights_cat) | |
processed_attention = processed_attention.transpose(1, 2) | |
processed_attention = self.location_dense(processed_attention) | |
return processed_attention | |
class Attention(nn.Module): | |
def __init__( | |
self, | |
attention_rnn_dim, | |
embedding_dim, | |
attention_dim, | |
attention_location_n_filters, | |
attention_location_kernel_size, | |
): | |
super(Attention, self).__init__() | |
self.query_layer = LinearNorm( | |
attention_rnn_dim, attention_dim, bias=False, w_init_gain="tanh" | |
) | |
self.memory_layer = LinearNorm( | |
embedding_dim, attention_dim, bias=False, w_init_gain="tanh" | |
) | |
self.v = LinearNorm(attention_dim, 1, bias=False) | |
self.location_layer = LocationLayer( | |
attention_location_n_filters, attention_location_kernel_size, attention_dim | |
) | |
self.score_mask_value = -float("inf") | |
def get_alignment_energies(self, query, processed_memory, attention_weights_cat): | |
""" | |
PARAMS | |
------ | |
query: decoder output (batch, n_mel_channels * n_frames_per_step) | |
processed_memory: processed encoder outputs (B, T_in, attention_dim) | |
attention_weights_cat: cumulative and prev. att weights (B, 2, max_time) | |
RETURNS | |
------- | |
alignment (batch, max_time) | |
""" | |
processed_query = self.query_layer(query.unsqueeze(1)) | |
processed_attention_weights = self.location_layer(attention_weights_cat) | |
energies = self.v( | |
torch.tanh(processed_query + processed_attention_weights + processed_memory) | |
) | |
energies = energies.squeeze(-1) | |
return energies | |
def forward( | |
self, | |
attention_hidden_state, | |
memory, | |
processed_memory, | |
attention_weights_cat, | |
mask, | |
): | |
""" | |
PARAMS | |
------ | |
attention_hidden_state: attention rnn last output | |
memory: encoder outputs | |
processed_memory: processed encoder outputs | |
attention_weights_cat: previous and cummulative attention weights | |
mask: binary mask for padded data | |
""" | |
alignment = self.get_alignment_energies( | |
attention_hidden_state, processed_memory, attention_weights_cat | |
) | |
if mask is not None: | |
alignment.data.masked_fill_(mask, self.score_mask_value) | |
attention_weights = F.softmax(alignment, dim=1) | |
attention_context = torch.bmm(attention_weights.unsqueeze(1), memory) | |
attention_context = attention_context.squeeze(1) | |
return attention_context, attention_weights | |
class ForwardAttentionV2(nn.Module): | |
def __init__( | |
self, | |
attention_rnn_dim, | |
embedding_dim, | |
attention_dim, | |
attention_location_n_filters, | |
attention_location_kernel_size, | |
): | |
super(ForwardAttentionV2, self).__init__() | |
self.query_layer = LinearNorm( | |
attention_rnn_dim, attention_dim, bias=False, w_init_gain="tanh" | |
) | |
self.memory_layer = LinearNorm( | |
embedding_dim, attention_dim, bias=False, w_init_gain="tanh" | |
) | |
self.v = LinearNorm(attention_dim, 1, bias=False) | |
self.location_layer = LocationLayer( | |
attention_location_n_filters, attention_location_kernel_size, attention_dim | |
) | |
self.score_mask_value = -float(1e20) | |
def get_alignment_energies(self, query, processed_memory, attention_weights_cat): | |
""" | |
PARAMS | |
------ | |
query: decoder output (batch, n_mel_channels * n_frames_per_step) | |
processed_memory: processed encoder outputs (B, T_in, attention_dim) | |
attention_weights_cat: prev. and cumulative att weights (B, 2, max_time) | |
RETURNS | |
------- | |
alignment (batch, max_time) | |
""" | |
processed_query = self.query_layer(query.unsqueeze(1)) | |
processed_attention_weights = self.location_layer(attention_weights_cat) | |
energies = self.v( | |
torch.tanh(processed_query + processed_attention_weights + processed_memory) | |
) | |
energies = energies.squeeze(-1) | |
return energies | |
def forward( | |
self, | |
attention_hidden_state, | |
memory, | |
processed_memory, | |
attention_weights_cat, | |
mask, | |
log_alpha, | |
): | |
""" | |
PARAMS | |
------ | |
attention_hidden_state: attention rnn last output | |
memory: encoder outputs | |
processed_memory: processed encoder outputs | |
attention_weights_cat: previous and cummulative attention weights | |
mask: binary mask for padded data | |
""" | |
log_energy = self.get_alignment_energies( | |
attention_hidden_state, processed_memory, attention_weights_cat | |
) | |
# log_energy = | |
if mask is not None: | |
log_energy.data.masked_fill_(mask, self.score_mask_value) | |
# attention_weights = F.softmax(alignment, dim=1) | |
# content_score = log_energy.unsqueeze(1) #[B, MAX_TIME] -> [B, 1, MAX_TIME] | |
# log_alpha = log_alpha.unsqueeze(2) #[B, MAX_TIME] -> [B, MAX_TIME, 1] | |
# log_total_score = log_alpha + content_score | |
# previous_attention_weights = attention_weights_cat[:,0,:] | |
log_alpha_shift_padded = [] | |
max_time = log_energy.size(1) | |
for sft in range(2): | |
shifted = log_alpha[:, : max_time - sft] | |
shift_padded = F.pad(shifted, (sft, 0), "constant", self.score_mask_value) | |
log_alpha_shift_padded.append(shift_padded.unsqueeze(2)) | |
biased = torch.logsumexp(torch.cat(log_alpha_shift_padded, 2), 2) | |
log_alpha_new = biased + log_energy | |
attention_weights = F.softmax(log_alpha_new, dim=1) | |
attention_context = torch.bmm(attention_weights.unsqueeze(1), memory) | |
attention_context = attention_context.squeeze(1) | |
return attention_context, attention_weights, log_alpha_new | |
class PhaseShuffle2d(nn.Module): | |
def __init__(self, n=2): | |
super(PhaseShuffle2d, self).__init__() | |
self.n = n | |
self.random = random.Random(1) | |
def forward(self, x, move=None): | |
# x.size = (B, C, M, L) | |
if move is None: | |
move = self.random.randint(-self.n, self.n) | |
if move == 0: | |
return x | |
else: | |
left = x[:, :, :, :move] | |
right = x[:, :, :, move:] | |
shuffled = torch.cat([right, left], dim=3) | |
return shuffled | |
class PhaseShuffle1d(nn.Module): | |
def __init__(self, n=2): | |
super(PhaseShuffle1d, self).__init__() | |
self.n = n | |
self.random = random.Random(1) | |
def forward(self, x, move=None): | |
# x.size = (B, C, M, L) | |
if move is None: | |
move = self.random.randint(-self.n, self.n) | |
if move == 0: | |
return x | |
else: | |
left = x[:, :, :move] | |
right = x[:, :, move:] | |
shuffled = torch.cat([right, left], dim=2) | |
return shuffled | |
class MFCC(nn.Module): | |
def __init__(self, n_mfcc=40, n_mels=80): | |
super(MFCC, self).__init__() | |
self.n_mfcc = n_mfcc | |
self.n_mels = n_mels | |
self.norm = "ortho" | |
dct_mat = audio_F.create_dct(self.n_mfcc, self.n_mels, self.norm) | |
self.register_buffer("dct_mat", dct_mat) | |
def forward(self, mel_specgram): | |
if len(mel_specgram.shape) == 2: | |
mel_specgram = mel_specgram.unsqueeze(0) | |
unsqueezed = True | |
else: | |
unsqueezed = False | |
# (channel, n_mels, time).tranpose(...) dot (n_mels, n_mfcc) | |
# -> (channel, time, n_mfcc).tranpose(...) | |
mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2) | |
# unpack batch | |
if unsqueezed: | |
mfcc = mfcc.squeeze(0) | |
return mfcc | |