Arnaudding001 commited on
Commit
e628a3b
·
1 Parent(s): cd6cde5

Create vtoonify.py

Browse files
Files changed (1) hide show
  1. vtoonify.py +286 -0
vtoonify.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import math
4
+ from torch import nn
5
+ from model.stylegan.model import ConvLayer, EqualLinear, Generator, ResBlock
6
+ from model.dualstylegan import AdaptiveInstanceNorm, AdaResBlock, DualStyleGAN
7
+ import torch.nn.functional as F
8
+
9
+ # IC-GAN: stylegan discriminator
10
+ class ConditionalDiscriminator(nn.Module):
11
+ def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1], use_condition=False, style_num=None):
12
+ super().__init__()
13
+
14
+ channels = {
15
+ 4: 512,
16
+ 8: 512,
17
+ 16: 512,
18
+ 32: 512,
19
+ 64: 256 * channel_multiplier,
20
+ 128: 128 * channel_multiplier,
21
+ 256: 64 * channel_multiplier,
22
+ 512: 32 * channel_multiplier,
23
+ 1024: 16 * channel_multiplier,
24
+ }
25
+
26
+ convs = [ConvLayer(3, channels[size], 1)]
27
+
28
+ log_size = int(math.log(size, 2))
29
+
30
+ in_channel = channels[size]
31
+
32
+ for i in range(log_size, 2, -1):
33
+ out_channel = channels[2 ** (i - 1)]
34
+
35
+ convs.append(ResBlock(in_channel, out_channel, blur_kernel))
36
+
37
+ in_channel = out_channel
38
+
39
+ self.convs = nn.Sequential(*convs)
40
+
41
+ self.stddev_group = 4
42
+ self.stddev_feat = 1
43
+ self.use_condition = use_condition
44
+
45
+ if self.use_condition:
46
+ self.condition_dim = 128
47
+ # map style degree to 64-dimensional vector
48
+ self.label_mapper = nn.Sequential(
49
+ nn.Linear(1, 64),
50
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
51
+ nn.Linear(64, 64),
52
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
53
+ nn.Linear(64, self.condition_dim//2),
54
+ )
55
+ # map style code index to 64-dimensional vector
56
+ self.style_mapper = nn.Embedding(style_num, self.condition_dim-self.condition_dim//2)
57
+ else:
58
+ self.condition_dim = 1
59
+
60
+ self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
61
+ self.final_linear = nn.Sequential(
62
+ EqualLinear(channels[4] * 4 * 4, channels[4], activation="fused_lrelu"),
63
+ EqualLinear(channels[4], self.condition_dim),
64
+ )
65
+
66
+ def forward(self, input, degree_label=None, style_ind=None):
67
+ out = self.convs(input)
68
+
69
+ batch, channel, height, width = out.shape
70
+ group = min(batch, self.stddev_group)
71
+ stddev = out.view(
72
+ group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
73
+ )
74
+ stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
75
+ stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
76
+ stddev = stddev.repeat(group, 1, height, width)
77
+ out = torch.cat([out, stddev], 1)
78
+
79
+ out = self.final_conv(out)
80
+ out = out.view(batch, -1)
81
+
82
+ if self.use_condition:
83
+ h = self.final_linear(out)
84
+ condition = torch.cat((self.label_mapper(degree_label), self.style_mapper(style_ind)), dim=1)
85
+ out = (h * condition).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.condition_dim))
86
+ else:
87
+ out = self.final_linear(out)
88
+
89
+ return out
90
+
91
+
92
+ class VToonifyResBlock(nn.Module):
93
+ def __init__(self, fin):
94
+ super().__init__()
95
+
96
+ self.conv = nn.Conv2d(fin, fin, 3, 1, 1)
97
+ self.conv2 = nn.Conv2d(fin, fin, 3, 1, 1)
98
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
99
+
100
+ def forward(self, x):
101
+ out = self.lrelu(self.conv(x))
102
+ out = self.lrelu(self.conv2(out))
103
+ out = (out + x) / math.sqrt(2)
104
+ return out
105
+
106
+ class Fusion(nn.Module):
107
+ def __init__(self, in_channels, skip_channels, out_channels):
108
+ super().__init__()
109
+
110
+ # create conv layers
111
+ self.conv = nn.Conv2d(in_channels + skip_channels, out_channels, 3, 1, 1, bias=True)
112
+ self.norm = AdaptiveInstanceNorm(in_channels + skip_channels, 128)
113
+ self.conv2 = nn.Conv2d(in_channels + skip_channels, 1, 3, 1, 1, bias=True)
114
+ #'''
115
+ self.linear = nn.Sequential(
116
+ nn.Linear(1, 64),
117
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
118
+ nn.Linear(64, 128),
119
+ nn.LeakyReLU(negative_slope=0.2, inplace=True)
120
+ )
121
+
122
+ def forward(self, f_G, f_E, d_s=1):
123
+ # label of style degree
124
+ label = self.linear(torch.zeros(f_G.size(0),1).to(f_G.device) + d_s)
125
+ out = torch.cat([f_G, abs(f_G-f_E)], dim=1)
126
+ m_E = (F.relu(self.conv2(self.norm(out, label)))).tanh()
127
+ f_out = self.conv(torch.cat([f_G, f_E * m_E], dim=1))
128
+ return f_out, m_E
129
+
130
+ class VToonify(nn.Module):
131
+ def __init__(self,
132
+ in_size=256,
133
+ out_size=1024,
134
+ img_channels=3,
135
+ style_channels=512,
136
+ num_mlps=8,
137
+ channel_multiplier=2,
138
+ num_res_layers=6,
139
+ backbone = 'dualstylegan',
140
+ ):
141
+
142
+ super().__init__()
143
+
144
+ self.backbone = backbone
145
+ if self.backbone == 'dualstylegan':
146
+ # DualStyleGAN, with weights being fixed
147
+ self.generator = DualStyleGAN(out_size, style_channels, num_mlps, channel_multiplier)
148
+ else:
149
+ # StyleGANv2, with weights being fixed
150
+ self.generator = Generator(out_size, style_channels, num_mlps, channel_multiplier)
151
+
152
+ self.in_size = in_size
153
+ self.style_channels = style_channels
154
+ channels = self.generator.channels
155
+
156
+ # encoder
157
+ num_styles = int(np.log2(out_size)) * 2 - 2
158
+ encoder_res = [2**i for i in range(int(np.log2(in_size)), 4, -1)]
159
+ self.encoder = nn.ModuleList()
160
+ self.encoder.append(
161
+ nn.Sequential(
162
+ nn.Conv2d(img_channels+19, 32, 3, 1, 1, bias=True),
163
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
164
+ nn.Conv2d(32, channels[in_size], 3, 1, 1, bias=True),
165
+ nn.LeakyReLU(negative_slope=0.2, inplace=True)))
166
+
167
+ for res in encoder_res:
168
+ in_channels = channels[res]
169
+ if res > 32:
170
+ out_channels = channels[res // 2]
171
+ block = nn.Sequential(
172
+ nn.Conv2d(in_channels, out_channels, 3, 2, 1, bias=True),
173
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
174
+ nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=True),
175
+ nn.LeakyReLU(negative_slope=0.2, inplace=True))
176
+ self.encoder.append(block)
177
+ else:
178
+ layers = []
179
+ for _ in range(num_res_layers):
180
+ layers.append(VToonifyResBlock(in_channels))
181
+ self.encoder.append(nn.Sequential(*layers))
182
+ block = nn.Conv2d(in_channels, img_channels, 1, 1, 0, bias=True)
183
+ self.encoder.append(block)
184
+
185
+ # trainable fusion module
186
+ self.fusion_out = nn.ModuleList()
187
+ self.fusion_skip = nn.ModuleList()
188
+ for res in encoder_res[::-1]:
189
+ num_channels = channels[res]
190
+ if self.backbone == 'dualstylegan':
191
+ self.fusion_out.append(
192
+ Fusion(num_channels, num_channels, num_channels))
193
+ else:
194
+ self.fusion_out.append(
195
+ nn.Conv2d(num_channels * 2, num_channels, 3, 1, 1, bias=True))
196
+
197
+ self.fusion_skip.append(
198
+ nn.Conv2d(num_channels + 3, 3, 3, 1, 1, bias=True))
199
+
200
+ # Modified ModRes blocks in DualStyleGAN, with weights being fixed
201
+ if self.backbone == 'dualstylegan':
202
+ self.res = nn.ModuleList()
203
+ self.res.append(AdaResBlock(self.generator.channels[2 ** 2])) # for conv1, no use in this model
204
+ for i in range(3, 6):
205
+ out_channel = self.generator.channels[2 ** i]
206
+ self.res.append(AdaResBlock(out_channel, dilation=2**(5-i)))
207
+ self.res.append(AdaResBlock(out_channel, dilation=2**(5-i)))
208
+
209
+
210
+ def forward(self, x, style, d_s=None, return_mask=False, return_feat=False):
211
+ # map style to W+ space
212
+ if style is not None and style.ndim < 3:
213
+ if self.backbone == 'dualstylegan':
214
+ resstyles = self.generator.style(style).unsqueeze(1).repeat(1, self.generator.n_latent, 1)
215
+ adastyles = style.unsqueeze(1).repeat(1, self.generator.n_latent, 1)
216
+ elif style is not None:
217
+ nB, nL, nD = style.shape
218
+ if self.backbone == 'dualstylegan':
219
+ resstyles = self.generator.style(style.reshape(nB*nL, nD)).reshape(nB, nL, nD)
220
+ adastyles = style
221
+ if self.backbone == 'dualstylegan':
222
+ adastyles = adastyles.clone()
223
+ for i in range(7, self.generator.n_latent):
224
+ adastyles[:, i] = self.generator.res[i](adastyles[:, i])
225
+
226
+ # obtain multi-scale content features
227
+ feat = x
228
+ encoder_features = []
229
+ # downsampling conv parts of E
230
+ for block in self.encoder[:-2]:
231
+ feat = block(feat)
232
+ encoder_features.append(feat)
233
+ encoder_features = encoder_features[::-1]
234
+ # Resblocks in E
235
+ for ii, block in enumerate(self.encoder[-2]):
236
+ feat = block(feat)
237
+ # adjust Resblocks with ModRes blocks
238
+ if self.backbone == 'dualstylegan':
239
+ feat = self.res[ii+1](feat, resstyles[:, ii+1], d_s)
240
+ # the last-layer feature of E (inputs of backbone)
241
+ out = feat
242
+ skip = self.encoder[-1](feat)
243
+ if return_feat:
244
+ return out, skip
245
+
246
+ # 32x32 ---> higher res
247
+ _index = 1
248
+ m_Es = []
249
+ for conv1, conv2, to_rgb in zip(
250
+ self.stylegan().convs[6::2], self.stylegan().convs[7::2], self.stylegan().to_rgbs[3:]):
251
+
252
+ # pass the mid-layer features of E to the corresponding resolution layers of G
253
+ if 2 ** (5+((_index-1)//2)) <= self.in_size:
254
+ fusion_index = (_index - 1) // 2
255
+ f_E = encoder_features[fusion_index]
256
+
257
+ if self.backbone == 'dualstylegan':
258
+ out, m_E = self.fusion_out[fusion_index](out, f_E, d_s)
259
+ skip = self.fusion_skip[fusion_index](torch.cat([skip, f_E*m_E], dim=1))
260
+ m_Es += [m_E]
261
+ else:
262
+ out = self.fusion_out[fusion_index](torch.cat([out, f_E], dim=1))
263
+ skip = self.fusion_skip[fusion_index](torch.cat([skip, f_E], dim=1))
264
+
265
+ # remove the noise input
266
+ batch, _, height, width = out.shape
267
+ noise = x.new_empty(batch, 1, height * 2, width * 2).normal_().detach() * 0.0
268
+
269
+ out = conv1(out, adastyles[:, _index+6], noise=noise)
270
+ out = conv2(out, adastyles[:, _index+7], noise=noise)
271
+ skip = to_rgb(out, adastyles[:, _index+8], skip)
272
+ _index += 2
273
+
274
+ image = skip
275
+ if return_mask and self.backbone == 'dualstylegan':
276
+ return image, m_Es
277
+ return image
278
+
279
+ def stylegan(self):
280
+ if self.backbone == 'dualstylegan':
281
+ return self.generator.generator
282
+ else:
283
+ return self.generator
284
+
285
+ def zplus2wplus(self, zplus):
286
+ return self.stylegan().style(zplus.reshape(zplus.shape[0]*zplus.shape[1], zplus.shape[2])).reshape(zplus.shape)