|
import random |
|
|
|
import torch.nn as nn |
|
from models.vq.encdec import Encoder, Decoder |
|
from models.vq.residual_vq import ResidualVQ |
|
|
|
class RVQVAE(nn.Module): |
|
def __init__(self, |
|
args, |
|
input_width=263, |
|
nb_code=1024, |
|
code_dim=512, |
|
output_emb_width=512, |
|
down_t=3, |
|
stride_t=2, |
|
width=512, |
|
depth=3, |
|
dilation_growth_rate=3, |
|
activation='relu', |
|
norm=None): |
|
|
|
super().__init__() |
|
assert output_emb_width == code_dim |
|
self.code_dim = code_dim |
|
self.num_code = nb_code |
|
|
|
self.encoder = Encoder(input_width, output_emb_width, down_t, stride_t, width, depth, |
|
dilation_growth_rate, activation=activation, norm=norm) |
|
self.decoder = Decoder(input_width, output_emb_width, down_t, stride_t, width, depth, |
|
dilation_growth_rate, activation=activation, norm=norm) |
|
rvqvae_config = { |
|
'num_quantizers': args.num_quantizers, |
|
'shared_codebook': args.shared_codebook, |
|
'quantize_dropout_prob': args.quantize_dropout_prob, |
|
'quantize_dropout_cutoff_index': 0, |
|
'nb_code': nb_code, |
|
'code_dim':code_dim, |
|
'args': args, |
|
} |
|
self.quantizer = ResidualVQ(**rvqvae_config) |
|
|
|
def preprocess(self, x): |
|
|
|
x = x.permute(0, 2, 1).float() |
|
return x |
|
|
|
def postprocess(self, x): |
|
|
|
x = x.permute(0, 2, 1) |
|
return x |
|
|
|
def encode(self, x): |
|
N, T, _ = x.shape |
|
x_in = self.preprocess(x) |
|
x_encoder = self.encoder(x_in) |
|
|
|
code_idx, all_codes = self.quantizer.quantize(x_encoder, return_latent=True) |
|
|
|
|
|
|
|
|
|
return code_idx, all_codes |
|
|
|
def forward(self, x): |
|
x_in = self.preprocess(x) |
|
|
|
x_encoder = self.encoder(x_in) |
|
|
|
|
|
|
|
|
|
x_quantized, code_idx, commit_loss, perplexity = self.quantizer(x_encoder, sample_codebook_temp=0.5) |
|
|
|
|
|
|
|
x_out = self.decoder(x_quantized) |
|
|
|
return x_out, commit_loss, perplexity |
|
|
|
def forward_decoder(self, x): |
|
x_d = self.quantizer.get_codes_from_indices(x) |
|
|
|
x = x_d.sum(dim=0).permute(0, 2, 1) |
|
|
|
|
|
x_out = self.decoder(x) |
|
|
|
return x_out |
|
|
|
class LengthEstimator(nn.Module): |
|
def __init__(self, input_size, output_size): |
|
super(LengthEstimator, self).__init__() |
|
nd = 512 |
|
self.output = nn.Sequential( |
|
nn.Linear(input_size, nd), |
|
nn.LayerNorm(nd), |
|
nn.LeakyReLU(0.2, inplace=True), |
|
|
|
nn.Dropout(0.2), |
|
nn.Linear(nd, nd // 2), |
|
nn.LayerNorm(nd // 2), |
|
nn.LeakyReLU(0.2, inplace=True), |
|
|
|
nn.Dropout(0.2), |
|
nn.Linear(nd // 2, nd // 4), |
|
nn.LayerNorm(nd // 4), |
|
nn.LeakyReLU(0.2, inplace=True), |
|
|
|
nn.Linear(nd // 4, output_size) |
|
) |
|
|
|
self.output.apply(self.__init_weights) |
|
|
|
def __init_weights(self, module): |
|
if isinstance(module, (nn.Linear, nn.Embedding)): |
|
module.weight.data.normal_(mean=0.0, std=0.02) |
|
if isinstance(module, nn.Linear) and module.bias is not None: |
|
module.bias.data.zero_() |
|
elif isinstance(module, nn.LayerNorm): |
|
module.bias.data.zero_() |
|
module.weight.data.fill_(1.0) |
|
|
|
def forward(self, text_emb): |
|
return self.output(text_emb) |