coledie commited on
Commit
102339f
1 Parent(s): 9545d2c
Files changed (1) hide show
  1. vae.py +0 -63
vae.py CHANGED
@@ -5,7 +5,6 @@ import torch
5
  import torch.nn as nn
6
  import torchvision.datasets
7
  import torch.nn.functional as F
8
- from torchvision import transforms
9
 
10
 
11
  class Encoder(nn.Module):
@@ -65,65 +64,3 @@ class VAE(nn.Module):
65
  c = mu + sigma * torch.randn_like(sigma)
66
  xhat = self.decoder(c)
67
  return xhat, mu, sigma
68
-
69
-
70
- if __name__ == '__main__':
71
- N_EPOCHS = 50
72
- LEARNING_RATE = .001
73
-
74
- model = VAE().cuda()
75
- optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
76
- loss_fn = torch.nn.MSELoss()
77
-
78
- dataset_base = torchvision.datasets.FashionMNIST("MNIST", download=True, transform=transforms.ToTensor())
79
- dataset_train, dataset_test = torch.utils.data.random_split(
80
- dataset_base, (int(.8 * len(dataset_base)), int(.2 * len(dataset_base)))
81
- )
82
-
83
- model.train()
84
- dataloader = torch.utils.data.DataLoader(dataset_train,
85
- batch_size=512,
86
- shuffle=True,
87
- num_workers=0)
88
- i = 0
89
- for epoch in range(N_EPOCHS):
90
- total_loss = 0
91
- for x, label in dataloader:
92
- x = x.cuda()
93
- label = label.cuda()
94
- optimizer.zero_grad()
95
- xhat, mu, logvar = model(x)
96
-
97
- BCE = F.binary_cross_entropy(xhat, x.reshape(xhat.shape), reduction='mean')
98
- # https://arxiv.org/abs/1312.6114
99
- # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
100
- KLD = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
101
- loss = BCE + KLD
102
- loss.backward()
103
- optimizer.step()
104
- total_loss += loss.item()
105
- print(f"{epoch}: {total_loss:.4f}")
106
-
107
- model.cpu()
108
- with open("vae.pt", "wb") as file:
109
- torch.save(model, file)
110
- model.eval()
111
- dataloader = torch.utils.data.DataLoader(dataset_test,
112
- batch_size=512,
113
- shuffle=True,
114
- num_workers=0)
115
- n_correct = 0
116
-
117
- COLS = 4
118
- ROWS = 4
119
- fig, axes = plt.subplots(ncols=COLS, nrows=ROWS, figsize=(5.5, 3.5),
120
- constrained_layout=True)
121
-
122
- dataloader_gen = iter(dataloader)
123
- x, label = next(dataloader_gen)
124
- xhat, mu, logvar = model(x)
125
- xhat = xhat.reshape((-1, 28, 28))
126
- for row in range(ROWS):
127
- for col in range(COLS):
128
- axes[row, col].imshow(xhat[row * COLS + col].detach().numpy(), cmap="gray")
129
- plt.show()
 
5
  import torch.nn as nn
6
  import torchvision.datasets
7
  import torch.nn.functional as F
 
8
 
9
 
10
  class Encoder(nn.Module):
 
64
  c = mu + sigma * torch.randn_like(sigma)
65
  xhat = self.decoder(c)
66
  return xhat, mu, sigma