voice_conversion_demo / decoder_base.py
uzdzn's picture
Upload 7 files
2c7b92a verified
raw
history blame
No virus
8.95 kB
import math
import torch
import torch.nn as nn
from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
URLS = {
"hubert-discrete": "https://github.com/bshall/acoustic-model/releases/download/v0.1/hubert-discrete-d49e1c77.pt",
"hubert-soft": "https://github.com/bshall/acoustic-model/releases/download/v0.1/hubert-soft-0321fd7e.pt",
}
class CustomLSTM(nn.Module):
def __init__(self, input_sz, hidden_sz):
super().__init__()
self.input_sz = input_sz
self.hidden_size = hidden_sz
self.W = nn.Parameter(torch.Tensor(input_sz, hidden_sz * 4))
self.U = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz * 4))
self.bias = nn.Parameter(torch.Tensor(hidden_sz * 4))
self.init_weights()
def init_weights(self):
stdv = 1.0 / math.sqrt(self.hidden_size)
for weight in self.parameters():
weight.data.uniform_(-stdv, stdv)
def forward(self, x,
init_states=None):
"""Assumes x is of shape (batch, sequence, feature)"""
#print(type(x))
#print(x.shape)
bs, seq_sz, _ = x.size()
hidden_seq = []
if init_states is None:
h_t, c_t = (torch.zeros(bs, self.hidden_size).to(x.device),
torch.zeros(bs, self.hidden_size).to(x.device))
else:
h_t, c_t = init_states
HS = self.hidden_size
for t in range(seq_sz):
x_t = x[:, t, :]
# batch the computations into a single matrix multiplication
gates = x_t @ self.W + h_t @ self.U + self.bias
i_t, f_t, g_t, o_t = (
torch.sigmoid(gates[:, :HS]), # input
torch.sigmoid(gates[:, HS:HS*2]), # forget
torch.tanh(gates[:, HS*2:HS*3]),
torch.sigmoid(gates[:, HS*3:]), # output
)
c_t = f_t * c_t + i_t * g_t
h_t = o_t * torch.tanh(c_t)
hidden_seq.append(h_t.unsqueeze(0))
hidden_seq = torch.cat(hidden_seq, dim=0)
# reshape from shape (sequence, batch, feature) to (batch, sequence, feature)
hidden_seq = hidden_seq.transpose(0, 1).contiguous()
return hidden_seq, (h_t, c_t)
class AcousticModel(nn.Module):
def __init__(self, discrete: bool = False, upsample: bool = True, use_custom_lstm=False):
super().__init__()
# self.spk_projection = nn.Linear(512+512, 512)
self.encoder = Encoder(discrete, upsample)
self.decoder = Decoder(use_custom_lstm=use_custom_lstm)
def forward(self, x: torch.Tensor, spk_embs, mels: torch.Tensor) -> torch.Tensor:
x = self.encoder(x)
exp_spk_embs = spk_embs.unsqueeze(1).expand(-1, x.size(1), -1)
concat_x = torch.cat([x, exp_spk_embs], dim=-1)
# x = self.spk_projection(concat_x)
return self.decoder(concat_x, mels)
#def forward(self, x: torch.Tensor, mels: torch.Tensor) -> torch.Tensor:
# x = self.encoder(x)
# return self.decoder(x, mels)
def forward_test(self, x, spk_embs, mels):
print('x shape', x.shape)
print('se shape', spk_embs.shape)
print('mels shape', mels.shape)
x = self.encoder(x)
print('x_enc shape', x.shape)
return
@torch.inference_mode()
def generate(self, x: torch.Tensor, spk_embs) -> torch.Tensor:
x = self.encoder(x)
exp_spk_embs = spk_embs.unsqueeze(1).expand(-1, x.size(1), -1)
concat_x = torch.cat([x, exp_spk_embs], dim=-1)
# x = self.spk_projection(concat_x)
return self.decoder.generate(concat_x)
class Encoder(nn.Module):
def __init__(self, discrete: bool = False, upsample: bool = True):
super().__init__()
self.embedding = nn.Embedding(100 + 1, 256) if discrete else None
self.prenet = PreNet(256, 256, 256)
self.convs = nn.Sequential(
nn.Conv1d(256, 512, 5, 1, 2),
nn.ReLU(),
nn.InstanceNorm1d(512),
nn.ConvTranspose1d(512, 512, 4, 2, 1) if upsample else nn.Identity(),
nn.Conv1d(512, 512, 5, 1, 2),
nn.ReLU(),
nn.InstanceNorm1d(512),
nn.Conv1d(512, 512, 5, 1, 2),
nn.ReLU(),
nn.InstanceNorm1d(512),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.embedding is not None:
x = self.embedding(x)
x = self.prenet(x)
x = self.convs(x.transpose(1, 2))
return x.transpose(1, 2)
class Decoder(nn.Module):
def __init__(self, use_custom_lstm=False):
super().__init__()
self.use_custom_lstm = use_custom_lstm
self.prenet = PreNet(128, 256, 256)
self.prenet = PreNet(128, 256, 256)
if use_custom_lstm:
self.lstm1 = CustomLSTM(1024 + 256, 768)
self.lstm2 = CustomLSTM(768, 768)
self.lstm3 = CustomLSTM(768, 768)
else:
self.lstm1 = nn.LSTM(1024 + 256, 768)
self.lstm2 = nn.LSTM(768, 768)
self.lstm3 = nn.LSTM(768, 768)
self.proj = nn.Linear(768, 128, bias=False)
def forward(self, x: torch.Tensor, mels: torch.Tensor) -> torch.Tensor:
mels = self.prenet(mels)
x, _ = self.lstm1(torch.cat((x, mels), dim=-1))
res = x
x, _ = self.lstm2(x)
x = res + x
res = x
x, _ = self.lstm3(x)
x = res + x
return self.proj(x)
@torch.inference_mode()
def generate(self, xs: torch.Tensor) -> torch.Tensor:
m = torch.zeros(xs.size(0), 128, device=xs.device)
if not self.use_custom_lstm:
h1 = torch.zeros(1, xs.size(0), 768, device=xs.device)
c1 = torch.zeros(1, xs.size(0), 768, device=xs.device)
h2 = torch.zeros(1, xs.size(0), 768, device=xs.device)
c2 = torch.zeros(1, xs.size(0), 768, device=xs.device)
h3 = torch.zeros(1, xs.size(0), 768, device=xs.device)
c3 = torch.zeros(1, xs.size(0), 768, device=xs.device)
else:
h1 = torch.zeros(xs.size(0), 768, device=xs.device)
c1 = torch.zeros(xs.size(0), 768, device=xs.device)
h2 = torch.zeros(xs.size(0), 768, device=xs.device)
c2 = torch.zeros(xs.size(0), 768, device=xs.device)
h3 = torch.zeros(xs.size(0), 768, device=xs.device)
c3 = torch.zeros(xs.size(0), 768, device=xs.device)
mel = []
for x in torch.unbind(xs, dim=1):
m = self.prenet(m)
x = torch.cat((x, m), dim=1).unsqueeze(1)
x1, (h1, c1) = self.lstm1(x, (h1, c1))
x2, (h2, c2) = self.lstm2(x1, (h2, c2))
x = x1 + x2
x3, (h3, c3) = self.lstm3(x, (h3, c3))
x = x + x3
m = self.proj(x).squeeze(1)
mel.append(m)
return torch.stack(mel, dim=1)
class PreNet(nn.Module):
def __init__(
self,
input_size: int,
hidden_size: int,
output_size: int,
dropout: float = 0.5,
):
super().__init__()
self.net = nn.Sequential(
nn.Linear(input_size, hidden_size),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_size, output_size),
nn.ReLU(),
nn.Dropout(dropout),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)
def _acoustic(
name: str,
discrete: bool,
upsample: bool,
pretrained: bool = True,
progress: bool = True,
) -> AcousticModel:
acoustic = AcousticModel(discrete, upsample)
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(URLS[name], progress=progress)
consume_prefix_in_state_dict_if_present(checkpoint["acoustic-model"], "module.")
acoustic.load_state_dict(checkpoint["acoustic-model"])
acoustic.eval()
return acoustic
def hubert_discrete(
pretrained: bool = True,
progress: bool = True,
) -> AcousticModel:
r"""HuBERT-Discrete acoustic model from `"A Comparison of Discrete and Soft Speech Units for Improved Voice Conversion"`.
Args:
pretrained (bool): load pretrained weights into the model
progress (bool): show progress bar when downloading model
"""
return _acoustic(
"hubert-discrete",
discrete=True,
upsample=True,
pretrained=pretrained,
progress=progress,
)
def hubert_soft(
pretrained: bool = True,
progress: bool = True,
) -> AcousticModel:
r"""HuBERT-Soft acoustic model from `"A Comparison of Discrete and Soft Speech Units for Improved Voice Conversion"`.
Args:
pretrained (bool): load pretrained weights into the model
progress (bool): show progress bar when downloading model
"""
return _acoustic(
"hubert-soft",
discrete=False,
upsample=True,
pretrained=pretrained,
progress=progress,
)