coledie commited on
Commit
ea698d3
·
1 Parent(s): 0b750fc

Add model.

Browse files
Files changed (1) hide show
  1. model.py +137 -0
model.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MNIST digit classificatin."""
2
+ import matplotlib.pyplot as plt
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torchvision.datasets
7
+ import torch.nn.functional as F
8
+ from torchvision import transforms
9
+
10
+
11
+ class Encoder(nn.Module):
12
+ def __init__(self, image_dim, latent_dim):
13
+ super().__init__()
14
+ self.image_dim = image_dim
15
+ self.latent_dim = latent_dim
16
+ self.cnn = nn.Sequential(
17
+ nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, padding=2),
18
+ nn.MaxPool2d(kernel_size=2),
19
+ nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=2),
20
+ nn.MaxPool2d(kernel_size=2),
21
+ nn.Flatten(1, -1),
22
+ )
23
+ self.l_mu = nn.Linear(1568, np.product(self.latent_dim))
24
+ self.l_sigma = nn.Linear(1568, np.product(self.latent_dim))
25
+
26
+ def forward(self, x):
27
+ x = x.reshape((-1, 1, *self.image_dim))
28
+ x = self.cnn(x)
29
+ mu = self.l_mu(x)
30
+ sigma = self.l_sigma(x)
31
+ return mu, sigma
32
+
33
+
34
+ class Decoder(nn.Module):
35
+ def __init__(self, image_dim, latent_dim):
36
+ super().__init__()
37
+ self.image_dim = image_dim
38
+ self.latent_dim = latent_dim
39
+ self.cnn = nn.Sequential(
40
+ nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, padding=2),
41
+ nn.MaxPool2d(kernel_size=2),
42
+ nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=2),
43
+ nn.MaxPool2d(kernel_size=2),
44
+ nn.Flatten(1, -1),
45
+ nn.Linear(288, np.product(self.image_dim)),
46
+ nn.Sigmoid(),
47
+ )
48
+
49
+ def forward(self, c):
50
+ c = c.reshape((-1, 1, *self.latent_dim))
51
+ x = self.cnn(c)
52
+ return x
53
+
54
+
55
+ class VAE(nn.Module):
56
+ def __init__(self, image_dim=(28, 28), latent_dim=(14, 14)):
57
+ super().__init__()
58
+ self.image_dim = image_dim
59
+ self.encoder = Encoder(image_dim, latent_dim)
60
+ self.decoder = Decoder(image_dim, latent_dim)
61
+
62
+ def forward(self, x):
63
+ x = x.reshape((-1, 1, *self.image_dim))
64
+ mu, sigma = self.encoder(x)
65
+ c = mu + sigma * torch.randn_like(sigma)
66
+ xhat = self.decoder(c)
67
+ return xhat, mu, sigma
68
+
69
+
70
+ if __name__ == '__main__':
71
+ N_EPOCHS = 100
72
+ LEARNING_RATE = .001
73
+
74
+ model = VAE().cuda()
75
+ optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
76
+ loss_fn = torch.nn.MSELoss()
77
+
78
+ dataset_base = torchvision.datasets.FashionMNIST("MNIST", download=True, transform=transforms.ToTensor())
79
+
80
+ dataset_base_2 = torchvision.datasets.MNIST("MNIST", download=True, transform=transforms.ToTensor())
81
+ dataset_base = torch.utils.data.ConcatDataset([dataset_base, dataset_base_2])
82
+
83
+ dataset_train, dataset_test = torch.utils.data.random_split(
84
+ dataset_base, (int(.8 * len(dataset_base)), int(.2 * len(dataset_base)))
85
+ )
86
+
87
+ model.train()
88
+ dataloader = torch.utils.data.DataLoader(dataset_train,
89
+ batch_size=512,
90
+ shuffle=True,
91
+ num_workers=0)
92
+ i = 0
93
+ for epoch in range(N_EPOCHS):
94
+ total_loss = 0
95
+ for x, label in dataloader:
96
+ #for j in range(512):
97
+ # plt.imsave(f"{i}-{label[j]}.jpg", np.stack([x[j].reshape((28, 28)).detach().numpy()] * 3, -1), cmap='gray')
98
+ # i += 1
99
+ #exit()
100
+ x = x.cuda()
101
+ label = label.cuda()
102
+ optimizer.zero_grad()
103
+ xhat, mu, logvar = model(x)
104
+
105
+ BCE = F.binary_cross_entropy(xhat, x.reshape(xhat.shape), reduction='mean')
106
+ # https://arxiv.org/abs/1312.6114
107
+ # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
108
+ KLD = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
109
+ loss = BCE + KLD
110
+ loss.backward()
111
+ optimizer.step()
112
+ total_loss += loss.item()
113
+ print(f"{epoch}: {total_loss:.4f}")
114
+
115
+ model.cpu()
116
+ with open("vae.pt", "wb") as file:
117
+ torch.save(model, file)
118
+ model.eval()
119
+ dataloader = torch.utils.data.DataLoader(dataset_test,
120
+ batch_size=512,
121
+ shuffle=True,
122
+ num_workers=0)
123
+ n_correct = 0
124
+
125
+ COLS = 4
126
+ ROWS = 4
127
+ fig, axes = plt.subplots(ncols=COLS, nrows=ROWS, figsize=(5.5, 3.5),
128
+ constrained_layout=True)
129
+
130
+ dataloader_gen = iter(dataloader)
131
+ x, label = next(dataloader_gen)
132
+ xhat, mu, logvar = model(x)
133
+ xhat = xhat.reshape((-1, 28, 28))
134
+ for row in range(ROWS):
135
+ for col in range(COLS):
136
+ axes[row, col].imshow(xhat[row * COLS + col].detach().numpy(), cmap="gray")
137
+ plt.show()