Fashion_VAE / vae.py
coledie
Update.
8ce78b0
raw
history blame
2.13 kB
"""MNIST digit classificatin."""
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class Encoder(nn.Module):
def __init__(self, image_dim, latent_dim):
super().__init__()
self.image_dim = image_dim
self.latent_dim = latent_dim
self.cnn = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, padding=2),
nn.MaxPool2d(kernel_size=2),
nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=2),
nn.MaxPool2d(kernel_size=2),
nn.Flatten(1, -1),
)
self.l_mu = nn.Linear(1568, np.product(self.latent_dim))
self.l_sigma = nn.Linear(1568, np.product(self.latent_dim))
def forward(self, x):
x = x.reshape((-1, 1, *self.image_dim))
x = self.cnn(x)
mu = self.l_mu(x)
sigma = self.l_sigma(x)
return mu, sigma
class Decoder(nn.Module):
def __init__(self, image_dim, latent_dim):
super().__init__()
self.image_dim = image_dim
self.latent_dim = latent_dim
self.cnn = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, padding=2),
nn.MaxPool2d(kernel_size=2),
nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=2),
nn.MaxPool2d(kernel_size=2),
nn.Flatten(1, -1),
nn.Linear(288, np.product(self.image_dim)),
nn.Sigmoid(),
)
def forward(self, c):
c = c.reshape((-1, 1, *self.latent_dim))
x = self.cnn(c)
return x
class VAE(nn.Module):
def __init__(self, image_dim=(28, 28), latent_dim=(14, 14)):
super().__init__()
self.image_dim = image_dim
self.encoder = Encoder(image_dim, latent_dim)
self.decoder = Decoder(image_dim, latent_dim)
def forward(self, x):
x = x.reshape((-1, 1, *self.image_dim))
mu, sigma = self.encoder(x)
c = mu + sigma * torch.randn_like(sigma)
xhat = self.decoder(c)
return xhat, mu, sigma