Spaces:
Running
Running
coledie
commited on
Commit
·
ea698d3
1
Parent(s):
0b750fc
Add model.
Browse files
model.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""MNIST digit classificatin."""
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torchvision.datasets
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from torchvision import transforms
|
9 |
+
|
10 |
+
|
11 |
+
class Encoder(nn.Module):
|
12 |
+
def __init__(self, image_dim, latent_dim):
|
13 |
+
super().__init__()
|
14 |
+
self.image_dim = image_dim
|
15 |
+
self.latent_dim = latent_dim
|
16 |
+
self.cnn = nn.Sequential(
|
17 |
+
nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, padding=2),
|
18 |
+
nn.MaxPool2d(kernel_size=2),
|
19 |
+
nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=2),
|
20 |
+
nn.MaxPool2d(kernel_size=2),
|
21 |
+
nn.Flatten(1, -1),
|
22 |
+
)
|
23 |
+
self.l_mu = nn.Linear(1568, np.product(self.latent_dim))
|
24 |
+
self.l_sigma = nn.Linear(1568, np.product(self.latent_dim))
|
25 |
+
|
26 |
+
def forward(self, x):
|
27 |
+
x = x.reshape((-1, 1, *self.image_dim))
|
28 |
+
x = self.cnn(x)
|
29 |
+
mu = self.l_mu(x)
|
30 |
+
sigma = self.l_sigma(x)
|
31 |
+
return mu, sigma
|
32 |
+
|
33 |
+
|
34 |
+
class Decoder(nn.Module):
|
35 |
+
def __init__(self, image_dim, latent_dim):
|
36 |
+
super().__init__()
|
37 |
+
self.image_dim = image_dim
|
38 |
+
self.latent_dim = latent_dim
|
39 |
+
self.cnn = nn.Sequential(
|
40 |
+
nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, padding=2),
|
41 |
+
nn.MaxPool2d(kernel_size=2),
|
42 |
+
nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=2),
|
43 |
+
nn.MaxPool2d(kernel_size=2),
|
44 |
+
nn.Flatten(1, -1),
|
45 |
+
nn.Linear(288, np.product(self.image_dim)),
|
46 |
+
nn.Sigmoid(),
|
47 |
+
)
|
48 |
+
|
49 |
+
def forward(self, c):
|
50 |
+
c = c.reshape((-1, 1, *self.latent_dim))
|
51 |
+
x = self.cnn(c)
|
52 |
+
return x
|
53 |
+
|
54 |
+
|
55 |
+
class VAE(nn.Module):
|
56 |
+
def __init__(self, image_dim=(28, 28), latent_dim=(14, 14)):
|
57 |
+
super().__init__()
|
58 |
+
self.image_dim = image_dim
|
59 |
+
self.encoder = Encoder(image_dim, latent_dim)
|
60 |
+
self.decoder = Decoder(image_dim, latent_dim)
|
61 |
+
|
62 |
+
def forward(self, x):
|
63 |
+
x = x.reshape((-1, 1, *self.image_dim))
|
64 |
+
mu, sigma = self.encoder(x)
|
65 |
+
c = mu + sigma * torch.randn_like(sigma)
|
66 |
+
xhat = self.decoder(c)
|
67 |
+
return xhat, mu, sigma
|
68 |
+
|
69 |
+
|
70 |
+
if __name__ == '__main__':
|
71 |
+
N_EPOCHS = 100
|
72 |
+
LEARNING_RATE = .001
|
73 |
+
|
74 |
+
model = VAE().cuda()
|
75 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
|
76 |
+
loss_fn = torch.nn.MSELoss()
|
77 |
+
|
78 |
+
dataset_base = torchvision.datasets.FashionMNIST("MNIST", download=True, transform=transforms.ToTensor())
|
79 |
+
|
80 |
+
dataset_base_2 = torchvision.datasets.MNIST("MNIST", download=True, transform=transforms.ToTensor())
|
81 |
+
dataset_base = torch.utils.data.ConcatDataset([dataset_base, dataset_base_2])
|
82 |
+
|
83 |
+
dataset_train, dataset_test = torch.utils.data.random_split(
|
84 |
+
dataset_base, (int(.8 * len(dataset_base)), int(.2 * len(dataset_base)))
|
85 |
+
)
|
86 |
+
|
87 |
+
model.train()
|
88 |
+
dataloader = torch.utils.data.DataLoader(dataset_train,
|
89 |
+
batch_size=512,
|
90 |
+
shuffle=True,
|
91 |
+
num_workers=0)
|
92 |
+
i = 0
|
93 |
+
for epoch in range(N_EPOCHS):
|
94 |
+
total_loss = 0
|
95 |
+
for x, label in dataloader:
|
96 |
+
#for j in range(512):
|
97 |
+
# plt.imsave(f"{i}-{label[j]}.jpg", np.stack([x[j].reshape((28, 28)).detach().numpy()] * 3, -1), cmap='gray')
|
98 |
+
# i += 1
|
99 |
+
#exit()
|
100 |
+
x = x.cuda()
|
101 |
+
label = label.cuda()
|
102 |
+
optimizer.zero_grad()
|
103 |
+
xhat, mu, logvar = model(x)
|
104 |
+
|
105 |
+
BCE = F.binary_cross_entropy(xhat, x.reshape(xhat.shape), reduction='mean')
|
106 |
+
# https://arxiv.org/abs/1312.6114
|
107 |
+
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
|
108 |
+
KLD = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
|
109 |
+
loss = BCE + KLD
|
110 |
+
loss.backward()
|
111 |
+
optimizer.step()
|
112 |
+
total_loss += loss.item()
|
113 |
+
print(f"{epoch}: {total_loss:.4f}")
|
114 |
+
|
115 |
+
model.cpu()
|
116 |
+
with open("vae.pt", "wb") as file:
|
117 |
+
torch.save(model, file)
|
118 |
+
model.eval()
|
119 |
+
dataloader = torch.utils.data.DataLoader(dataset_test,
|
120 |
+
batch_size=512,
|
121 |
+
shuffle=True,
|
122 |
+
num_workers=0)
|
123 |
+
n_correct = 0
|
124 |
+
|
125 |
+
COLS = 4
|
126 |
+
ROWS = 4
|
127 |
+
fig, axes = plt.subplots(ncols=COLS, nrows=ROWS, figsize=(5.5, 3.5),
|
128 |
+
constrained_layout=True)
|
129 |
+
|
130 |
+
dataloader_gen = iter(dataloader)
|
131 |
+
x, label = next(dataloader_gen)
|
132 |
+
xhat, mu, logvar = model(x)
|
133 |
+
xhat = xhat.reshape((-1, 28, 28))
|
134 |
+
for row in range(ROWS):
|
135 |
+
for col in range(COLS):
|
136 |
+
axes[row, col].imshow(xhat[row * COLS + col].detach().numpy(), cmap="gray")
|
137 |
+
plt.show()
|