Fashion_VAE / model.py
coledie
Add model.
ea698d3
raw
history blame
4.64 kB
"""MNIST digit classificatin."""
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torchvision.datasets
import torch.nn.functional as F
from torchvision import transforms
class Encoder(nn.Module):
def __init__(self, image_dim, latent_dim):
super().__init__()
self.image_dim = image_dim
self.latent_dim = latent_dim
self.cnn = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, padding=2),
nn.MaxPool2d(kernel_size=2),
nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=2),
nn.MaxPool2d(kernel_size=2),
nn.Flatten(1, -1),
)
self.l_mu = nn.Linear(1568, np.product(self.latent_dim))
self.l_sigma = nn.Linear(1568, np.product(self.latent_dim))
def forward(self, x):
x = x.reshape((-1, 1, *self.image_dim))
x = self.cnn(x)
mu = self.l_mu(x)
sigma = self.l_sigma(x)
return mu, sigma
class Decoder(nn.Module):
def __init__(self, image_dim, latent_dim):
super().__init__()
self.image_dim = image_dim
self.latent_dim = latent_dim
self.cnn = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, padding=2),
nn.MaxPool2d(kernel_size=2),
nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=2),
nn.MaxPool2d(kernel_size=2),
nn.Flatten(1, -1),
nn.Linear(288, np.product(self.image_dim)),
nn.Sigmoid(),
)
def forward(self, c):
c = c.reshape((-1, 1, *self.latent_dim))
x = self.cnn(c)
return x
class VAE(nn.Module):
def __init__(self, image_dim=(28, 28), latent_dim=(14, 14)):
super().__init__()
self.image_dim = image_dim
self.encoder = Encoder(image_dim, latent_dim)
self.decoder = Decoder(image_dim, latent_dim)
def forward(self, x):
x = x.reshape((-1, 1, *self.image_dim))
mu, sigma = self.encoder(x)
c = mu + sigma * torch.randn_like(sigma)
xhat = self.decoder(c)
return xhat, mu, sigma
if __name__ == '__main__':
N_EPOCHS = 100
LEARNING_RATE = .001
model = VAE().cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
loss_fn = torch.nn.MSELoss()
dataset_base = torchvision.datasets.FashionMNIST("MNIST", download=True, transform=transforms.ToTensor())
dataset_base_2 = torchvision.datasets.MNIST("MNIST", download=True, transform=transforms.ToTensor())
dataset_base = torch.utils.data.ConcatDataset([dataset_base, dataset_base_2])
dataset_train, dataset_test = torch.utils.data.random_split(
dataset_base, (int(.8 * len(dataset_base)), int(.2 * len(dataset_base)))
)
model.train()
dataloader = torch.utils.data.DataLoader(dataset_train,
batch_size=512,
shuffle=True,
num_workers=0)
i = 0
for epoch in range(N_EPOCHS):
total_loss = 0
for x, label in dataloader:
#for j in range(512):
# plt.imsave(f"{i}-{label[j]}.jpg", np.stack([x[j].reshape((28, 28)).detach().numpy()] * 3, -1), cmap='gray')
# i += 1
#exit()
x = x.cuda()
label = label.cuda()
optimizer.zero_grad()
xhat, mu, logvar = model(x)
BCE = F.binary_cross_entropy(xhat, x.reshape(xhat.shape), reduction='mean')
# https://arxiv.org/abs/1312.6114
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
KLD = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
loss = BCE + KLD
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"{epoch}: {total_loss:.4f}")
model.cpu()
with open("vae.pt", "wb") as file:
torch.save(model, file)
model.eval()
dataloader = torch.utils.data.DataLoader(dataset_test,
batch_size=512,
shuffle=True,
num_workers=0)
n_correct = 0
COLS = 4
ROWS = 4
fig, axes = plt.subplots(ncols=COLS, nrows=ROWS, figsize=(5.5, 3.5),
constrained_layout=True)
dataloader_gen = iter(dataloader)
x, label = next(dataloader_gen)
xhat, mu, logvar = model(x)
xhat = xhat.reshape((-1, 28, 28))
for row in range(ROWS):
for col in range(COLS):
axes[row, col].imshow(xhat[row * COLS + col].detach().numpy(), cmap="gray")
plt.show()