NorHsangPha's picture
Initial commit
de6e35f verified
import json
import torch
import torch.nn as nn
from Preprocessing.Codec.env import AttrDict
from Preprocessing.Codec.models import Encoder
from Preprocessing.Codec.models import Generator
from Preprocessing.Codec.models import Quantizer
class VQVAE(nn.Module):
def __init__(self,
config_path,
ckpt_path,
with_encoder=False):
super(VQVAE, self).__init__()
ckpt = torch.load(ckpt_path, map_location=torch.device('cpu'))
with open(config_path) as f:
data = f.read()
json_config = json.loads(data)
self.h = AttrDict(json_config)
self.quantizer = Quantizer(self.h)
self.generator = Generator(self.h)
self.generator.load_state_dict(ckpt['generator'])
self.quantizer.load_state_dict(ckpt['quantizer'])
if with_encoder:
self.encoder = Encoder(self.h)
self.encoder.load_state_dict(ckpt['encoder'])
def forward(self, x):
# x is the codebook
# x.shape (B, T, Nq)
quant_emb = self.quantizer.embed(x)
return self.generator(quant_emb)
def encode(self, x):
batch_size = x.size(0)
if len(x.shape) == 3 and x.shape[-1] == 1:
x = x.squeeze(-1)
c = self.encoder(x.unsqueeze(1))
q, loss_q, c = self.quantizer(c)
c = [code.reshape(batch_size, -1) for code in c]
# shape: [N, T, 4]
return torch.stack(c, -1)