Spaces:
Runtime error
Runtime error
coledie
commited on
Commit
•
102339f
1
Parent(s):
9545d2c
Update.
Browse files
vae.py
CHANGED
@@ -5,7 +5,6 @@ import torch
|
|
5 |
import torch.nn as nn
|
6 |
import torchvision.datasets
|
7 |
import torch.nn.functional as F
|
8 |
-
from torchvision import transforms
|
9 |
|
10 |
|
11 |
class Encoder(nn.Module):
|
@@ -65,65 +64,3 @@ class VAE(nn.Module):
|
|
65 |
c = mu + sigma * torch.randn_like(sigma)
|
66 |
xhat = self.decoder(c)
|
67 |
return xhat, mu, sigma
|
68 |
-
|
69 |
-
|
70 |
-
if __name__ == '__main__':
|
71 |
-
N_EPOCHS = 50
|
72 |
-
LEARNING_RATE = .001
|
73 |
-
|
74 |
-
model = VAE().cuda()
|
75 |
-
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
|
76 |
-
loss_fn = torch.nn.MSELoss()
|
77 |
-
|
78 |
-
dataset_base = torchvision.datasets.FashionMNIST("MNIST", download=True, transform=transforms.ToTensor())
|
79 |
-
dataset_train, dataset_test = torch.utils.data.random_split(
|
80 |
-
dataset_base, (int(.8 * len(dataset_base)), int(.2 * len(dataset_base)))
|
81 |
-
)
|
82 |
-
|
83 |
-
model.train()
|
84 |
-
dataloader = torch.utils.data.DataLoader(dataset_train,
|
85 |
-
batch_size=512,
|
86 |
-
shuffle=True,
|
87 |
-
num_workers=0)
|
88 |
-
i = 0
|
89 |
-
for epoch in range(N_EPOCHS):
|
90 |
-
total_loss = 0
|
91 |
-
for x, label in dataloader:
|
92 |
-
x = x.cuda()
|
93 |
-
label = label.cuda()
|
94 |
-
optimizer.zero_grad()
|
95 |
-
xhat, mu, logvar = model(x)
|
96 |
-
|
97 |
-
BCE = F.binary_cross_entropy(xhat, x.reshape(xhat.shape), reduction='mean')
|
98 |
-
# https://arxiv.org/abs/1312.6114
|
99 |
-
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
|
100 |
-
KLD = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
|
101 |
-
loss = BCE + KLD
|
102 |
-
loss.backward()
|
103 |
-
optimizer.step()
|
104 |
-
total_loss += loss.item()
|
105 |
-
print(f"{epoch}: {total_loss:.4f}")
|
106 |
-
|
107 |
-
model.cpu()
|
108 |
-
with open("vae.pt", "wb") as file:
|
109 |
-
torch.save(model, file)
|
110 |
-
model.eval()
|
111 |
-
dataloader = torch.utils.data.DataLoader(dataset_test,
|
112 |
-
batch_size=512,
|
113 |
-
shuffle=True,
|
114 |
-
num_workers=0)
|
115 |
-
n_correct = 0
|
116 |
-
|
117 |
-
COLS = 4
|
118 |
-
ROWS = 4
|
119 |
-
fig, axes = plt.subplots(ncols=COLS, nrows=ROWS, figsize=(5.5, 3.5),
|
120 |
-
constrained_layout=True)
|
121 |
-
|
122 |
-
dataloader_gen = iter(dataloader)
|
123 |
-
x, label = next(dataloader_gen)
|
124 |
-
xhat, mu, logvar = model(x)
|
125 |
-
xhat = xhat.reshape((-1, 28, 28))
|
126 |
-
for row in range(ROWS):
|
127 |
-
for col in range(COLS):
|
128 |
-
axes[row, col].imshow(xhat[row * COLS + col].detach().numpy(), cmap="gray")
|
129 |
-
plt.show()
|
|
|
5 |
import torch.nn as nn
|
6 |
import torchvision.datasets
|
7 |
import torch.nn.functional as F
|
|
|
8 |
|
9 |
|
10 |
class Encoder(nn.Module):
|
|
|
64 |
c = mu + sigma * torch.randn_like(sigma)
|
65 |
xhat = self.decoder(c)
|
66 |
return xhat, mu, sigma
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|