GGmorello commited on
Commit
793ec18
1 Parent(s): ebcb0f8

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +66 -0
model.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ class Generator(nn.Module):
4
+ def __init__(self, z_dim=100, img_channels=3):
5
+ super(Generator, self).__init__()
6
+ self.gen = nn.Sequential(
7
+ # input is Z, going into a convolution
8
+ nn.ConvTranspose2d(z_dim, 512, 4, 1, 0, bias=False),
9
+ nn.BatchNorm2d(512),
10
+ nn.ReLU(True),
11
+ # state size. 512 x 4 x 4
12
+ nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
13
+ nn.BatchNorm2d(256),
14
+ nn.ReLU(True),
15
+ # state size. 256 x 8 x 8
16
+ nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
17
+ nn.BatchNorm2d(128),
18
+ nn.ReLU(True),
19
+ # state size. 128 x 16 x 16
20
+ nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
21
+ nn.BatchNorm2d(64),
22
+ nn.ReLU(True),
23
+ # state size. 64 x 32 x 32
24
+ nn.ConvTranspose2d(64, img_channels, 4, 2, 1, bias=False),
25
+ nn.Tanh()
26
+ # state size. img_channels x 64 x 64
27
+ )
28
+
29
+ def forward(self, input):
30
+ return self.gen(input)
31
+
32
+ class Discriminator(nn.Module):
33
+ def __init__(self, img_channels=3):
34
+ super(Discriminator, self).__init__()
35
+ self.disc = nn.Sequential(
36
+ # input is img_channels x 64 x 64
37
+ nn.Conv2d(img_channels, 64, 4, 2, 1, bias=False),
38
+ nn.LeakyReLU(0.2, inplace=True),
39
+ # state size. 64 x 32 x 32
40
+ nn.Conv2d(64, 128, 4, 2, 1, bias=False),
41
+ nn.BatchNorm2d(128),
42
+ nn.LeakyReLU(0.2, inplace=True),
43
+ # state size. 128 x 16 x 16
44
+ nn.Conv2d(128, 256, 4, 2, 1, bias=False),
45
+ nn.BatchNorm2d(256),
46
+ nn.LeakyReLU(0.2, inplace=True),
47
+ # state size. 256 x 8 x 8
48
+ nn.Conv2d(256, 512, 4, 2, 1, bias=False),
49
+ nn.BatchNorm2d(512),
50
+ nn.LeakyReLU(0.2, inplace=True),
51
+ # state size. 512 x 4 x 4
52
+ nn.Conv2d(512, 1, 4, 1, 0, bias=False),
53
+ nn.Sigmoid()
54
+ )
55
+
56
+ def forward(self, input):
57
+ return self.disc(input).view(-1, 1).squeeze(1)
58
+
59
+ batch_size = 32
60
+ latent_vector_size = 100
61
+
62
+ generator = Generator()
63
+ discriminator = Discriminator()
64
+
65
+ generator.load_state_dict(torch.load('netG.pth', map_location=torch.device('cpu') ))
66
+ discriminator.load_state_dict(torch.load('netD.pth', map_location=torch.device('cpu') ))