"""MNIST digit classificatin.""" import numpy as np import torch import torch.nn as nn import torch.nn.functional as F class Encoder(nn.Module): def __init__(self, image_dim, latent_dim): super().__init__() self.image_dim = image_dim self.latent_dim = latent_dim self.cnn = nn.Sequential( nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, padding=2), nn.MaxPool2d(kernel_size=2), nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=2), nn.MaxPool2d(kernel_size=2), nn.Flatten(1, -1), ) self.l_mu = nn.Linear(1568, np.product(self.latent_dim)) self.l_sigma = nn.Linear(1568, np.product(self.latent_dim)) def forward(self, x): x = x.reshape((-1, 1, *self.image_dim)) x = self.cnn(x) mu = self.l_mu(x) sigma = self.l_sigma(x) return mu, sigma class Decoder(nn.Module): def __init__(self, image_dim, latent_dim): super().__init__() self.image_dim = image_dim self.latent_dim = latent_dim self.cnn = nn.Sequential( nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, padding=2), nn.MaxPool2d(kernel_size=2), nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=2), nn.MaxPool2d(kernel_size=2), nn.Flatten(1, -1), nn.Linear(288, np.product(self.image_dim)), nn.Sigmoid(), ) def forward(self, c): c = c.reshape((-1, 1, *self.latent_dim)) x = self.cnn(c) return x class VAE(nn.Module): def __init__(self, image_dim=(28, 28), latent_dim=(14, 14)): super().__init__() self.image_dim = image_dim self.encoder = Encoder(image_dim, latent_dim) self.decoder = Decoder(image_dim, latent_dim) def forward(self, x): x = x.reshape((-1, 1, *self.image_dim)) mu, sigma = self.encoder(x) c = mu + sigma * torch.randn_like(sigma) xhat = self.decoder(c) return xhat, mu, sigma