Arnaudding001 commited on
Commit
15d7fed
·
1 Parent(s): f13deb4

Create dualstylegan.py

Browse files
Files changed (1) hide show
  1. dualstylegan.py +203 -0
dualstylegan.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+ from torch import nn
4
+ from model.stylegan.model import ConvLayer, PixelNorm, EqualLinear, Generator
5
+
6
+ class AdaptiveInstanceNorm(nn.Module):
7
+ def __init__(self, fin, style_dim=512):
8
+ super().__init__()
9
+
10
+ self.norm = nn.InstanceNorm2d(fin, affine=False)
11
+ self.style = nn.Linear(style_dim, fin * 2)
12
+
13
+ self.style.bias.data[:fin] = 1
14
+ self.style.bias.data[fin:] = 0
15
+
16
+ def forward(self, input, style):
17
+ style = self.style(style).unsqueeze(2).unsqueeze(3)
18
+ gamma, beta = style.chunk(2, 1)
19
+ out = self.norm(input)
20
+ out = gamma * out + beta
21
+ return out
22
+
23
+ # modulative residual blocks (ModRes)
24
+ class AdaResBlock(nn.Module):
25
+ def __init__(self, fin, style_dim=512, dilation=1): # modified
26
+ super().__init__()
27
+
28
+ self.conv = ConvLayer(fin, fin, 3, dilation=dilation) # modified
29
+ self.conv2 = ConvLayer(fin, fin, 3, dilation=dilation) # modified
30
+ self.norm = AdaptiveInstanceNorm(fin, style_dim)
31
+ self.norm2 = AdaptiveInstanceNorm(fin, style_dim)
32
+
33
+ # model initialization
34
+ # the convolution filters are set to values close to 0 to produce negligible residual features
35
+ self.conv[0].weight.data *= 0.01
36
+ self.conv2[0].weight.data *= 0.01
37
+
38
+ def forward(self, x, s, w=1):
39
+ skip = x
40
+ if w == 0:
41
+ return skip
42
+ out = self.conv(self.norm(x, s))
43
+ out = self.conv2(self.norm2(out, s))
44
+ out = out * w + skip
45
+ return out
46
+
47
+ class DualStyleGAN(nn.Module):
48
+ def __init__(self, size, style_dim, n_mlp, channel_multiplier=2, twoRes=True, res_index=6):
49
+ super().__init__()
50
+
51
+ layers = [PixelNorm()]
52
+ for i in range(n_mlp-6):
53
+ layers.append(EqualLinear(512, 512, lr_mul=0.01, activation="fused_lrelu"))
54
+ # color transform blocks T_c
55
+ self.style = nn.Sequential(*layers)
56
+ # StyleGAN2
57
+ self.generator = Generator(size, style_dim, n_mlp, channel_multiplier)
58
+ # The extrinsic style path
59
+ self.res = nn.ModuleList()
60
+ self.res_index = res_index//2 * 2
61
+ self.res.append(AdaResBlock(self.generator.channels[2 ** 2])) # for conv1
62
+ for i in range(3, self.generator.log_size + 1):
63
+ out_channel = self.generator.channels[2 ** i]
64
+ if i < 3 + self.res_index//2:
65
+ # ModRes
66
+ self.res.append(AdaResBlock(out_channel))
67
+ self.res.append(AdaResBlock(out_channel))
68
+ else:
69
+ # structure transform block T_s
70
+ self.res.append(EqualLinear(512, 512))
71
+ # FC layer is initialized with identity matrices, meaning no changes to the input latent code
72
+ self.res[-1].weight.data = torch.eye(512) * 512.0**0.5 + torch.randn(512, 512) * 0.01
73
+ self.res.append(EqualLinear(512, 512))
74
+ self.res[-1].weight.data = torch.eye(512) * 512.0**0.5 + torch.randn(512, 512) * 0.01
75
+ self.res.append(EqualLinear(512, 512)) # for to_rgb7
76
+ self.res[-1].weight.data = torch.eye(512) * 512.0**0.5 + torch.randn(512, 512) * 0.01
77
+ self.size = self.generator.size
78
+ self.style_dim = self.generator.style_dim
79
+ self.log_size = self.generator.log_size
80
+ self.num_layers = self.generator.num_layers
81
+ self.n_latent = self.generator.n_latent
82
+ self.channels = self.generator.channels
83
+
84
+ def forward(
85
+ self,
86
+ styles, # intrinsic style code
87
+ exstyles, # extrinsic style code
88
+ return_latents=False,
89
+ return_feat=False,
90
+ inject_index=None,
91
+ truncation=1,
92
+ truncation_latent=None,
93
+ input_is_latent=False,
94
+ noise=None,
95
+ randomize_noise=True,
96
+ z_plus_latent=False, # intrinsic style code is z+ or z
97
+ use_res=True, # whether to use the extrinsic style path
98
+ fuse_index=18, # layers > fuse_index do not use the extrinsic style path
99
+ interp_weights=[1]*18, # weight vector for style combination of two paths
100
+ ):
101
+
102
+ if not input_is_latent:
103
+ if not z_plus_latent:
104
+ styles = [self.generator.style(s) for s in styles]
105
+ else:
106
+ styles = [self.generator.style(s.reshape(s.shape[0]*s.shape[1], s.shape[2])).reshape(s.shape) for s in styles]
107
+
108
+ if noise is None:
109
+ if randomize_noise:
110
+ noise = [None] * self.generator.num_layers
111
+ else:
112
+ noise = [
113
+ getattr(self.generator.noises, f"noise_{i}") for i in range(self.generator.num_layers)
114
+ ]
115
+
116
+ if truncation < 1:
117
+ style_t = []
118
+
119
+ for style in styles:
120
+ style_t.append(
121
+ truncation_latent + truncation * (style - truncation_latent)
122
+ )
123
+
124
+ styles = style_t
125
+
126
+ if len(styles) < 2:
127
+ inject_index = self.generator.n_latent
128
+
129
+ if styles[0].ndim < 3:
130
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
131
+
132
+ else:
133
+ latent = styles[0]
134
+
135
+ else:
136
+ if inject_index is None:
137
+ inject_index = random.randint(1, self.generator.n_latent - 1)
138
+
139
+ if styles[0].ndim < 3:
140
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
141
+ latent2 = styles[1].unsqueeze(1).repeat(1, self.generator.n_latent - inject_index, 1)
142
+
143
+ latent = torch.cat([latent, latent2], 1)
144
+ else:
145
+ latent = torch.cat([styles[0][:,0:inject_index], styles[1][:,inject_index:]], 1)
146
+
147
+ if use_res:
148
+ if exstyles.ndim < 3:
149
+ resstyles = self.style(exstyles).unsqueeze(1).repeat(1, self.generator.n_latent, 1)
150
+ adastyles = exstyles.unsqueeze(1).repeat(1, self.generator.n_latent, 1)
151
+ else:
152
+ nB, nL, nD = exstyles.shape
153
+ resstyles = self.style(exstyles.reshape(nB*nL, nD)).reshape(nB, nL, nD)
154
+ adastyles = exstyles
155
+
156
+ out = self.generator.input(latent)
157
+ out = self.generator.conv1(out, latent[:, 0], noise=noise[0])
158
+ if use_res and fuse_index > 0:
159
+ out = self.res[0](out, resstyles[:, 0], interp_weights[0])
160
+
161
+ skip = self.generator.to_rgb1(out, latent[:, 1])
162
+ i = 1
163
+ for conv1, conv2, noise1, noise2, to_rgb in zip(
164
+ self.generator.convs[::2], self.generator.convs[1::2], noise[1::2], noise[2::2], self.generator.to_rgbs):
165
+ if use_res and fuse_index >= i and i > self.res_index:
166
+ out = conv1(out, interp_weights[i] * self.res[i](adastyles[:, i]) +
167
+ (1-interp_weights[i]) * latent[:, i], noise=noise1)
168
+ else:
169
+ out = conv1(out, latent[:, i], noise=noise1)
170
+ if use_res and fuse_index >= i and i <= self.res_index:
171
+ out = self.res[i](out, resstyles[:, i], interp_weights[i])
172
+ if use_res and fuse_index >= (i+1) and i > self.res_index:
173
+ out = conv2(out, interp_weights[i+1] * self.res[i+1](adastyles[:, i+1]) +
174
+ (1-interp_weights[i+1]) * latent[:, i+1], noise=noise2)
175
+ else:
176
+ out = conv2(out, latent[:, i + 1], noise=noise2)
177
+ if use_res and fuse_index >= (i+1) and i <= self.res_index:
178
+ out = self.res[i+1](out, resstyles[:, i+1], interp_weights[i+1])
179
+ if use_res and fuse_index >= (i+2) and i >= self.res_index-1:
180
+ skip = to_rgb(out, interp_weights[i+2] * self.res[i+2](adastyles[:, i+2]) +
181
+ (1-interp_weights[i+2]) * latent[:, i + 2], skip)
182
+ else:
183
+ skip = to_rgb(out, latent[:, i + 2], skip)
184
+ i += 2
185
+ if i > self.res_index and return_feat:
186
+ return out, skip
187
+
188
+ image = skip
189
+
190
+ if return_latents:
191
+ return image, latent
192
+
193
+ else:
194
+ return image, None
195
+
196
+ def make_noise(self):
197
+ return self.generator.make_noise()
198
+
199
+ def mean_latent(self, n_latent):
200
+ return self.generator.mean_latent(n_latent)
201
+
202
+ def get_latent(self, input):
203
+ return self.generator.style(input)