Spaces:
Paused
Paused
import torch | |
from torch import nn | |
from TTS.encoder.models.base_encoder import BaseEncoder | |
class LSTMWithProjection(nn.Module): | |
def __init__(self, input_size, hidden_size, proj_size): | |
super().__init__() | |
self.input_size = input_size | |
self.hidden_size = hidden_size | |
self.proj_size = proj_size | |
self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True) | |
self.linear = nn.Linear(hidden_size, proj_size, bias=False) | |
def forward(self, x): | |
self.lstm.flatten_parameters() | |
o, (_, _) = self.lstm(x) | |
return self.linear(o) | |
class LSTMWithoutProjection(nn.Module): | |
def __init__(self, input_dim, lstm_dim, proj_dim, num_lstm_layers): | |
super().__init__() | |
self.lstm = nn.LSTM(input_size=input_dim, hidden_size=lstm_dim, num_layers=num_lstm_layers, batch_first=True) | |
self.linear = nn.Linear(lstm_dim, proj_dim, bias=True) | |
self.relu = nn.ReLU() | |
def forward(self, x): | |
_, (hidden, _) = self.lstm(x) | |
return self.relu(self.linear(hidden[-1])) | |
class LSTMSpeakerEncoder(BaseEncoder): | |
def __init__( | |
self, | |
input_dim, | |
proj_dim=256, | |
lstm_dim=768, | |
num_lstm_layers=3, | |
use_lstm_with_projection=True, | |
use_torch_spec=False, | |
audio_config=None, | |
): | |
super().__init__() | |
self.use_lstm_with_projection = use_lstm_with_projection | |
self.use_torch_spec = use_torch_spec | |
self.audio_config = audio_config | |
self.proj_dim = proj_dim | |
layers = [] | |
# choise LSTM layer | |
if use_lstm_with_projection: | |
layers.append(LSTMWithProjection(input_dim, lstm_dim, proj_dim)) | |
for _ in range(num_lstm_layers - 1): | |
layers.append(LSTMWithProjection(proj_dim, lstm_dim, proj_dim)) | |
self.layers = nn.Sequential(*layers) | |
else: | |
self.layers = LSTMWithoutProjection(input_dim, lstm_dim, proj_dim, num_lstm_layers) | |
self.instancenorm = nn.InstanceNorm1d(input_dim) | |
if self.use_torch_spec: | |
self.torch_spec = self.get_torch_mel_spectrogram_class(audio_config) | |
else: | |
self.torch_spec = None | |
self._init_layers() | |
def _init_layers(self): | |
for name, param in self.layers.named_parameters(): | |
if "bias" in name: | |
nn.init.constant_(param, 0.0) | |
elif "weight" in name: | |
nn.init.xavier_normal_(param) | |
def forward(self, x, l2_norm=True): | |
"""Forward pass of the model. | |
Args: | |
x (Tensor): Raw waveform signal or spectrogram frames. If input is a waveform, `torch_spec` must be `True` | |
to compute the spectrogram on-the-fly. | |
l2_norm (bool): Whether to L2-normalize the outputs. | |
Shapes: | |
- x: :math:`(N, 1, T_{in})` or :math:`(N, D_{spec}, T_{in})` | |
""" | |
with torch.no_grad(): | |
with torch.cuda.amp.autocast(enabled=False): | |
if self.use_torch_spec: | |
x.squeeze_(1) | |
x = self.torch_spec(x) | |
x = self.instancenorm(x).transpose(1, 2) | |
d = self.layers(x) | |
if self.use_lstm_with_projection: | |
d = d[:, -1] | |
if l2_norm: | |
d = torch.nn.functional.normalize(d, p=2, dim=1) | |
return d | |