Spaces:
Runtime error
Runtime error
"""MNIST digit classificatin.""" | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class Encoder(nn.Module): | |
def __init__(self, image_dim, latent_dim): | |
super().__init__() | |
self.image_dim = image_dim | |
self.latent_dim = latent_dim | |
self.cnn = nn.Sequential( | |
nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, padding=2), | |
nn.MaxPool2d(kernel_size=2), | |
nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=2), | |
nn.MaxPool2d(kernel_size=2), | |
nn.Flatten(1, -1), | |
) | |
self.l_mu = nn.Linear(1568, np.product(self.latent_dim)) | |
self.l_sigma = nn.Linear(1568, np.product(self.latent_dim)) | |
def forward(self, x): | |
x = x.reshape((-1, 1, *self.image_dim)) | |
x = self.cnn(x) | |
mu = self.l_mu(x) | |
sigma = self.l_sigma(x) | |
return mu, sigma | |
class Decoder(nn.Module): | |
def __init__(self, image_dim, latent_dim): | |
super().__init__() | |
self.image_dim = image_dim | |
self.latent_dim = latent_dim | |
self.cnn = nn.Sequential( | |
nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, padding=2), | |
nn.MaxPool2d(kernel_size=2), | |
nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=2), | |
nn.MaxPool2d(kernel_size=2), | |
nn.Flatten(1, -1), | |
nn.Linear(288, np.product(self.image_dim)), | |
nn.Sigmoid(), | |
) | |
def forward(self, c): | |
c = c.reshape((-1, 1, *self.latent_dim)) | |
x = self.cnn(c) | |
return x | |
class VAE(nn.Module): | |
def __init__(self, image_dim=(28, 28), latent_dim=(14, 14)): | |
super().__init__() | |
self.image_dim = image_dim | |
self.encoder = Encoder(image_dim, latent_dim) | |
self.decoder = Decoder(image_dim, latent_dim) | |
def forward(self, x): | |
x = x.reshape((-1, 1, *self.image_dim)) | |
mu, sigma = self.encoder(x) | |
c = mu + sigma * torch.randn_like(sigma) | |
xhat = self.decoder(c) | |
return xhat, mu, sigma | |