Spaces:
Runtime error
Runtime error
# Copyright (c) 2024 Amphion. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from einops import rearrange | |
from torch.nn.utils import weight_norm | |
def WNConv1d(*args, **kwargs): | |
return weight_norm(nn.Conv1d(*args, **kwargs)) | |
def WNConvTranspose1d(*args, **kwargs): | |
return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) | |
class LookupFreeQuantize(nn.Module): | |
def __init__( | |
self, | |
input_dim, | |
codebook_size, | |
codebook_dim, | |
): | |
super().__init__() | |
self.input_dim = input_dim | |
self.codebook_size = codebook_size | |
self.codebook_dim = codebook_dim | |
assert 2**codebook_dim == codebook_size | |
if self.input_dim != self.codebook_dim: | |
self.in_project = WNConv1d(self.input_dim, self.codebook_dim, kernel_size=1) | |
self.out_project = WNConv1d( | |
self.codebook_dim, self.input_dim, kernel_size=1 | |
) | |
else: | |
self.in_project = nn.Identity() | |
self.out_project = nn.Identity() | |
def forward(self, z): | |
z_e = self.in_project(z) | |
z_e = F.sigmoid(z_e) | |
z_q = z_e + (torch.round(z_e) - z_e).detach() | |
z_q = self.out_project(z_q) | |
commit_loss = torch.zeros(z.shape[0], device=z.device) | |
codebook_loss = torch.zeros(z.shape[0], device=z.device) | |
bits = ( | |
2 | |
** torch.arange(self.codebook_dim, device=z.device) | |
.unsqueeze(0) | |
.unsqueeze(-1) | |
.long() | |
) # (1, d, 1) | |
indices = (torch.round(z_e.clone().detach()).long() * bits).sum(1).long() | |
return z_q, commit_loss, codebook_loss, indices, z_e | |
def vq2emb(self, vq, out_proj=True): | |
emb = torch.zeros( | |
vq.shape[0], self.codebook_dim, vq.shape[-1], device=vq.device | |
) # (B, d, T) | |
for i in range(self.codebook_dim): | |
emb[:, i, :] = (vq % 2).float() | |
vq = vq // 2 | |
if out_proj: | |
emb = self.out_project(emb) | |
return emb | |