Arnaudding001 commited on
Commit
cf81973
·
1 Parent(s): c671d25

Update stylegan_model.py

Browse files
Files changed (1) hide show
  1. stylegan_model.py +674 -81
stylegan_model.py CHANGED
@@ -1,126 +1,719 @@
1
  import math
2
- import pickle
 
 
3
 
4
  import torch
5
- from torch import distributed as dist
6
- from torch.utils.data.sampler import Sampler
 
7
 
 
8
 
9
- def get_rank():
10
- if not dist.is_available():
11
- return 0
12
 
13
- if not dist.is_initialized():
14
- return 0
15
 
16
- return dist.get_rank()
17
 
 
 
18
 
19
- def synchronize():
20
- if not dist.is_available():
21
- return
22
 
23
- if not dist.is_initialized():
24
- return
25
 
26
- world_size = dist.get_world_size()
27
 
28
- if world_size == 1:
29
- return
30
 
31
- dist.barrier()
 
 
32
 
 
 
 
33
 
34
- def get_world_size():
35
- if not dist.is_available():
36
- return 1
37
 
38
- if not dist.is_initialized():
39
- return 1
40
 
41
- return dist.get_world_size()
42
 
 
 
43
 
44
- def reduce_sum(tensor):
45
- if not dist.is_available():
46
- return tensor
47
 
48
- if not dist.is_initialized():
49
- return tensor
50
 
51
- tensor = tensor.clone()
52
- dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
 
53
 
54
- return tensor
 
 
55
 
 
56
 
57
- def gather_grad(params):
58
- world_size = get_world_size()
59
-
60
- if world_size == 1:
61
- return
62
 
63
- for param in params:
64
- if param.grad is not None:
65
- dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
66
- param.grad.data.div_(world_size)
67
 
 
 
68
 
69
- def all_gather(data):
70
- world_size = get_world_size()
71
 
72
- if world_size == 1:
73
- return [data]
74
 
75
- buffer = pickle.dumps(data)
76
- storage = torch.ByteStorage.from_buffer(buffer)
77
- tensor = torch.ByteTensor(storage).to('cuda')
78
 
79
- local_size = torch.IntTensor([tensor.numel()]).to('cuda')
80
- size_list = [torch.IntTensor([0]).to('cuda') for _ in range(world_size)]
81
- dist.all_gather(size_list, local_size)
82
- size_list = [int(size.item()) for size in size_list]
83
- max_size = max(size_list)
84
 
85
- tensor_list = []
86
- for _ in size_list:
87
- tensor_list.append(torch.ByteTensor(size=(max_size,)).to('cuda'))
88
 
89
- if local_size != max_size:
90
- padding = torch.ByteTensor(size=(max_size - local_size,)).to('cuda')
91
- tensor = torch.cat((tensor, padding), 0)
92
 
93
- dist.all_gather(tensor_list, tensor)
94
 
95
- data_list = []
 
96
 
97
- for size, tensor in zip(size_list, tensor_list):
98
- buffer = tensor.cpu().numpy().tobytes()[:size]
99
- data_list.append(pickle.loads(buffer))
100
 
101
- return data_list
102
 
 
 
 
 
 
103
 
104
- def reduce_loss_dict(loss_dict):
105
- world_size = get_world_size()
 
 
106
 
107
- if world_size < 2:
108
- return loss_dict
 
109
 
110
- with torch.no_grad():
111
- keys = []
112
- losses = []
113
 
114
- for k in sorted(loss_dict.keys()):
115
- keys.append(k)
116
- losses.append(loss_dict[k])
117
 
118
- losses = torch.stack(losses, 0)
119
- dist.reduce(losses, dst=0)
 
 
 
 
 
 
 
120
 
121
- if dist.get_rank() == 0:
122
- losses /= world_size
123
 
124
- reduced_losses = {k: v for k, v in zip(keys, losses)}
 
 
 
 
125
 
126
- return reduced_losses
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import math
2
+ import random
3
+ import functools
4
+ import operator
5
 
6
  import torch
7
+ from torch import nn
8
+ from torch.nn import functional as F
9
+ from torch.autograd import Function
10
 
11
+ from model.stylegan.op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d, conv2d_gradfix
12
 
13
+ class PixelNorm(nn.Module):
14
+ def __init__(self):
15
+ super().__init__()
16
 
17
+ def forward(self, input):
18
+ return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
19
 
 
20
 
21
+ def make_kernel(k):
22
+ k = torch.tensor(k, dtype=torch.float32)
23
 
24
+ if k.ndim == 1:
25
+ k = k[None, :] * k[:, None]
 
26
 
27
+ k /= k.sum()
 
28
 
29
+ return k
30
 
 
 
31
 
32
+ class Upsample(nn.Module):
33
+ def __init__(self, kernel, factor=2):
34
+ super().__init__()
35
 
36
+ self.factor = factor
37
+ kernel = make_kernel(kernel) * (factor ** 2)
38
+ self.register_buffer("kernel", kernel)
39
 
40
+ p = kernel.shape[0] - factor
 
 
41
 
42
+ pad0 = (p + 1) // 2 + factor - 1
43
+ pad1 = p // 2
44
 
45
+ self.pad = (pad0, pad1)
46
 
47
+ def forward(self, input):
48
+ out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
49
 
50
+ return out
 
 
51
 
 
 
52
 
53
+ class Downsample(nn.Module):
54
+ def __init__(self, kernel, factor=2):
55
+ super().__init__()
56
 
57
+ self.factor = factor
58
+ kernel = make_kernel(kernel)
59
+ self.register_buffer("kernel", kernel)
60
 
61
+ p = kernel.shape[0] - factor
62
 
63
+ pad0 = (p + 1) // 2
64
+ pad1 = p // 2
 
 
 
65
 
66
+ self.pad = (pad0, pad1)
 
 
 
67
 
68
+ def forward(self, input):
69
+ out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
70
 
71
+ return out
 
72
 
 
 
73
 
74
+ class Blur(nn.Module):
75
+ def __init__(self, kernel, pad, upsample_factor=1):
76
+ super().__init__()
77
 
78
+ kernel = make_kernel(kernel)
 
 
 
 
79
 
80
+ if upsample_factor > 1:
81
+ kernel = kernel * (upsample_factor ** 2)
 
82
 
83
+ self.register_buffer("kernel", kernel)
 
 
84
 
85
+ self.pad = pad
86
 
87
+ def forward(self, input):
88
+ out = upfirdn2d(input, self.kernel, pad=self.pad)
89
 
90
+ return out
 
 
91
 
 
92
 
93
+ class EqualConv2d(nn.Module):
94
+ def __init__(
95
+ self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True, dilation=1 ## modified
96
+ ):
97
+ super().__init__()
98
 
99
+ self.weight = nn.Parameter(
100
+ torch.randn(out_channel, in_channel, kernel_size, kernel_size)
101
+ )
102
+ self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
103
 
104
+ self.stride = stride
105
+ self.padding = padding
106
+ self.dilation = dilation ## modified
107
 
108
+ if bias:
109
+ self.bias = nn.Parameter(torch.zeros(out_channel))
 
110
 
111
+ else:
112
+ self.bias = None
 
113
 
114
+ def forward(self, input):
115
+ out = conv2d_gradfix.conv2d(
116
+ input,
117
+ self.weight * self.scale,
118
+ bias=self.bias,
119
+ stride=self.stride,
120
+ padding=self.padding,
121
+ dilation=self.dilation, ## modified
122
+ )
123
 
124
+ return out
 
125
 
126
+ def __repr__(self):
127
+ return (
128
+ f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},"
129
+ f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding}, dilation={self.dilation})" ## modified
130
+ )
131
 
132
+
133
+ class EqualLinear(nn.Module):
134
+ def __init__(
135
+ self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
136
+ ):
137
+ super().__init__()
138
+
139
+ self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
140
+
141
+ if bias:
142
+ self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
143
+
144
+ else:
145
+ self.bias = None
146
+
147
+ self.activation = activation
148
+
149
+ self.scale = (1 / math.sqrt(in_dim)) * lr_mul
150
+ self.lr_mul = lr_mul
151
+
152
+ def forward(self, input):
153
+ if self.activation:
154
+ out = F.linear(input, self.weight * self.scale)
155
+ out = fused_leaky_relu(out, self.bias * self.lr_mul)
156
+
157
+ else:
158
+ out = F.linear(
159
+ input, self.weight * self.scale, bias=self.bias * self.lr_mul
160
+ )
161
+
162
+ return out
163
+
164
+ def __repr__(self):
165
+ return (
166
+ f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})"
167
+ )
168
+
169
+
170
+ class ModulatedConv2d(nn.Module):
171
+ def __init__(
172
+ self,
173
+ in_channel,
174
+ out_channel,
175
+ kernel_size,
176
+ style_dim,
177
+ demodulate=True,
178
+ upsample=False,
179
+ downsample=False,
180
+ blur_kernel=[1, 3, 3, 1],
181
+ fused=True,
182
+ ):
183
+ super().__init__()
184
+
185
+ self.eps = 1e-8
186
+ self.kernel_size = kernel_size
187
+ self.in_channel = in_channel
188
+ self.out_channel = out_channel
189
+ self.upsample = upsample
190
+ self.downsample = downsample
191
+
192
+ if upsample:
193
+ factor = 2
194
+ p = (len(blur_kernel) - factor) - (kernel_size - 1)
195
+ pad0 = (p + 1) // 2 + factor - 1
196
+ pad1 = p // 2 + 1
197
+
198
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
199
+
200
+ if downsample:
201
+ factor = 2
202
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
203
+ pad0 = (p + 1) // 2
204
+ pad1 = p // 2
205
+
206
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1))
207
+
208
+ fan_in = in_channel * kernel_size ** 2
209
+ self.scale = 1 / math.sqrt(fan_in)
210
+ self.padding = kernel_size // 2
211
+
212
+ self.weight = nn.Parameter(
213
+ torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
214
+ )
215
+
216
+ self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
217
+
218
+ self.demodulate = demodulate
219
+ self.fused = fused
220
+
221
+ def __repr__(self):
222
+ return (
223
+ f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, "
224
+ f"upsample={self.upsample}, downsample={self.downsample})"
225
+ )
226
+
227
+ def forward(self, input, style, externalweight=None):
228
+ batch, in_channel, height, width = input.shape
229
+
230
+ if not self.fused:
231
+ weight = self.scale * self.weight.squeeze(0)
232
+ style = self.modulation(style)
233
+
234
+ if self.demodulate:
235
+ w = weight.unsqueeze(0) * style.view(batch, 1, in_channel, 1, 1)
236
+ dcoefs = (w.square().sum((2, 3, 4)) + 1e-8).rsqrt()
237
+
238
+ input = input * style.reshape(batch, in_channel, 1, 1)
239
+
240
+ if self.upsample:
241
+ weight = weight.transpose(0, 1)
242
+ out = conv2d_gradfix.conv_transpose2d(
243
+ input, weight, padding=0, stride=2
244
+ )
245
+ out = self.blur(out)
246
+
247
+ elif self.downsample:
248
+ input = self.blur(input)
249
+ out = conv2d_gradfix.conv2d(input, weight, padding=0, stride=2)
250
+
251
+ else:
252
+ out = conv2d_gradfix.conv2d(input, weight, padding=self.padding)
253
+
254
+ if self.demodulate:
255
+ out = out * dcoefs.view(batch, -1, 1, 1)
256
+
257
+ return out
258
+
259
+ style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
260
+ if externalweight is None:
261
+ weight = self.scale * self.weight * style
262
+ else:
263
+ weight = self.scale * (self.weight + externalweight) * style
264
+
265
+ if self.demodulate:
266
+ demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
267
+ weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
268
+
269
+ weight = weight.view(
270
+ batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
271
+ )
272
+
273
+ if self.upsample:
274
+ input = input.view(1, batch * in_channel, height, width)
275
+ weight = weight.view(
276
+ batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
277
+ )
278
+ weight = weight.transpose(1, 2).reshape(
279
+ batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
280
+ )
281
+ out = conv2d_gradfix.conv_transpose2d(
282
+ input, weight, padding=0, stride=2, groups=batch
283
+ )
284
+ _, _, height, width = out.shape
285
+ out = out.view(batch, self.out_channel, height, width)
286
+ out = self.blur(out)
287
+
288
+ elif self.downsample:
289
+ input = self.blur(input)
290
+ _, _, height, width = input.shape
291
+ input = input.view(1, batch * in_channel, height, width)
292
+ out = conv2d_gradfix.conv2d(
293
+ input, weight, padding=0, stride=2, groups=batch
294
+ )
295
+ _, _, height, width = out.shape
296
+ out = out.view(batch, self.out_channel, height, width)
297
+
298
+ else:
299
+ input = input.view(1, batch * in_channel, height, width)
300
+ out = conv2d_gradfix.conv2d(
301
+ input, weight, padding=self.padding, groups=batch
302
+ )
303
+ _, _, height, width = out.shape
304
+ out = out.view(batch, self.out_channel, height, width)
305
+
306
+ return out
307
+
308
+
309
+ class NoiseInjection(nn.Module):
310
+ def __init__(self):
311
+ super().__init__()
312
+
313
+ self.weight = nn.Parameter(torch.zeros(1))
314
+
315
+ def forward(self, image, noise=None):
316
+ if noise is None:
317
+ batch, _, height, width = image.shape
318
+ noise = image.new_empty(batch, 1, height, width).normal_()
319
+
320
+ return image + self.weight * noise
321
+
322
+
323
+ class ConstantInput(nn.Module):
324
+ def __init__(self, channel, size=4):
325
+ super().__init__()
326
+
327
+ self.input = nn.Parameter(torch.randn(1, channel, size, size))
328
+
329
+ def forward(self, input):
330
+ batch = input.shape[0]
331
+ out = self.input.repeat(batch, 1, 1, 1)
332
+
333
+ return out
334
+
335
+
336
+ class StyledConv(nn.Module):
337
+ def __init__(
338
+ self,
339
+ in_channel,
340
+ out_channel,
341
+ kernel_size,
342
+ style_dim,
343
+ upsample=False,
344
+ blur_kernel=[1, 3, 3, 1],
345
+ demodulate=True,
346
+ ):
347
+ super().__init__()
348
+
349
+ self.conv = ModulatedConv2d(
350
+ in_channel,
351
+ out_channel,
352
+ kernel_size,
353
+ style_dim,
354
+ upsample=upsample,
355
+ blur_kernel=blur_kernel,
356
+ demodulate=demodulate,
357
+ )
358
+
359
+ self.noise = NoiseInjection()
360
+ # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
361
+ # self.activate = ScaledLeakyReLU(0.2)
362
+ self.activate = FusedLeakyReLU(out_channel)
363
+
364
+ def forward(self, input, style, noise=None, externalweight=None):
365
+ out = self.conv(input, style, externalweight)
366
+ out = self.noise(out, noise=noise)
367
+ # out = out + self.bias
368
+ out = self.activate(out)
369
+
370
+ return out
371
+
372
+
373
+ class ToRGB(nn.Module):
374
+ def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
375
+ super().__init__()
376
+
377
+ if upsample:
378
+ self.upsample = Upsample(blur_kernel)
379
+
380
+ self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
381
+ self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
382
+
383
+ def forward(self, input, style, skip=None, externalweight=None):
384
+ out = self.conv(input, style, externalweight)
385
+ out = out + self.bias
386
+
387
+ if skip is not None:
388
+ skip = self.upsample(skip)
389
+
390
+ out = out + skip
391
+
392
+ return out
393
+
394
+
395
+ class Generator(nn.Module):
396
+ def __init__(
397
+ self,
398
+ size,
399
+ style_dim,
400
+ n_mlp,
401
+ channel_multiplier=2,
402
+ blur_kernel=[1, 3, 3, 1],
403
+ lr_mlp=0.01,
404
+ ):
405
+ super().__init__()
406
+
407
+ self.size = size
408
+
409
+ self.style_dim = style_dim
410
+
411
+ layers = [PixelNorm()]
412
+
413
+ for i in range(n_mlp):
414
+ layers.append(
415
+ EqualLinear(
416
+ style_dim, style_dim, lr_mul=lr_mlp, activation="fused_lrelu"
417
+ )
418
+ )
419
+
420
+ self.style = nn.Sequential(*layers)
421
+
422
+ self.channels = {
423
+ 4: 512,
424
+ 8: 512,
425
+ 16: 512,
426
+ 32: 512,
427
+ 64: 256 * channel_multiplier,
428
+ 128: 128 * channel_multiplier,
429
+ 256: 64 * channel_multiplier,
430
+ 512: 32 * channel_multiplier,
431
+ 1024: 16 * channel_multiplier,
432
+ }
433
+
434
+ self.input = ConstantInput(self.channels[4])
435
+ self.conv1 = StyledConv(
436
+ self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
437
+ )
438
+ self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
439
+
440
+ self.log_size = int(math.log(size, 2))
441
+ self.num_layers = (self.log_size - 2) * 2 + 1
442
+
443
+ self.convs = nn.ModuleList()
444
+ self.upsamples = nn.ModuleList()
445
+ self.to_rgbs = nn.ModuleList()
446
+ self.noises = nn.Module()
447
+
448
+ in_channel = self.channels[4]
449
+
450
+ for layer_idx in range(self.num_layers):
451
+ res = (layer_idx + 5) // 2
452
+ shape = [1, 1, 2 ** res, 2 ** res]
453
+ self.noises.register_buffer(f"noise_{layer_idx}", torch.randn(*shape))
454
+
455
+ for i in range(3, self.log_size + 1):
456
+ out_channel = self.channels[2 ** i]
457
+
458
+ self.convs.append(
459
+ StyledConv(
460
+ in_channel,
461
+ out_channel,
462
+ 3,
463
+ style_dim,
464
+ upsample=True,
465
+ blur_kernel=blur_kernel,
466
+ )
467
+ )
468
+
469
+ self.convs.append(
470
+ StyledConv(
471
+ out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
472
+ )
473
+ )
474
+
475
+ self.to_rgbs.append(ToRGB(out_channel, style_dim))
476
+
477
+ in_channel = out_channel
478
+
479
+ self.n_latent = self.log_size * 2 - 2
480
+
481
+ def make_noise(self):
482
+ device = self.input.input.device
483
+
484
+ noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
485
+
486
+ for i in range(3, self.log_size + 1):
487
+ for _ in range(2):
488
+ noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))
489
+
490
+ return noises
491
+
492
+ def mean_latent(self, n_latent):
493
+ latent_in = torch.randn(
494
+ n_latent, self.style_dim, device=self.input.input.device
495
+ )
496
+ latent = self.style(latent_in).mean(0, keepdim=True)
497
+
498
+ return latent
499
+
500
+ def get_latent(self, input):
501
+ return self.style(input)
502
+
503
+ def forward(
504
+ self,
505
+ styles,
506
+ return_latents=False,
507
+ inject_index=None,
508
+ truncation=1,
509
+ truncation_latent=None,
510
+ input_is_latent=False,
511
+ noise=None,
512
+ randomize_noise=True,
513
+ z_plus_latent=False,
514
+ return_feature_ind=999,
515
+ ):
516
+ if not input_is_latent:
517
+ if not z_plus_latent:
518
+ styles = [self.style(s) for s in styles]
519
+ else:
520
+ styles_ = []
521
+ for s in styles:
522
+ style_ = []
523
+ for i in range(s.shape[1]):
524
+ style_.append(self.style(s[:,i]).unsqueeze(1))
525
+ styles_.append(torch.cat(style_,dim=1))
526
+ styles = styles_
527
+
528
+ if noise is None:
529
+ if randomize_noise:
530
+ noise = [None] * self.num_layers
531
+ else:
532
+ noise = [
533
+ getattr(self.noises, f"noise_{i}") for i in range(self.num_layers)
534
+ ]
535
+
536
+ if truncation < 1:
537
+ style_t = []
538
+
539
+ for style in styles:
540
+ style_t.append(
541
+ truncation_latent + truncation * (style - truncation_latent)
542
+ )
543
+
544
+ styles = style_t
545
+
546
+ if len(styles) < 2:
547
+ inject_index = self.n_latent
548
+
549
+ if styles[0].ndim < 3:
550
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
551
+
552
+ else:
553
+ latent = styles[0]
554
+
555
+ else:
556
+ if inject_index is None:
557
+ inject_index = random.randint(1, self.n_latent - 1)
558
+
559
+ if styles[0].ndim < 3:
560
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
561
+ latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
562
+
563
+ latent = torch.cat([latent, latent2], 1)
564
+ else:
565
+ latent = torch.cat([styles[0][:,0:inject_index], styles[1][:,inject_index:]], 1)
566
+
567
+ out = self.input(latent)
568
+ out = self.conv1(out, latent[:, 0], noise=noise[0])
569
+
570
+ skip = self.to_rgb1(out, latent[:, 1])
571
+
572
+ i = 1
573
+ for conv1, conv2, noise1, noise2, to_rgb in zip(
574
+ self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
575
+ ):
576
+ out = conv1(out, latent[:, i], noise=noise1)
577
+ out = conv2(out, latent[:, i + 1], noise=noise2)
578
+ skip = to_rgb(out, latent[:, i + 2], skip)
579
+
580
+ i += 2
581
+ if i > return_feature_ind:
582
+ return out, skip
583
+
584
+ image = skip
585
+
586
+ if return_latents:
587
+ return image, latent
588
+
589
+ else:
590
+ return image, None
591
+
592
+
593
+ class ConvLayer(nn.Sequential):
594
+ def __init__(
595
+ self,
596
+ in_channel,
597
+ out_channel,
598
+ kernel_size,
599
+ downsample=False,
600
+ blur_kernel=[1, 3, 3, 1],
601
+ bias=True,
602
+ activate=True,
603
+ dilation=1, ## modified
604
+ ):
605
+ layers = []
606
+
607
+ if downsample:
608
+ factor = 2
609
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
610
+ pad0 = (p + 1) // 2
611
+ pad1 = p // 2
612
+
613
+ layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
614
+
615
+ stride = 2
616
+ self.padding = 0
617
+
618
+ else:
619
+ stride = 1
620
+ self.padding = kernel_size // 2 + dilation-1 ## modified
621
+
622
+ layers.append(
623
+ EqualConv2d(
624
+ in_channel,
625
+ out_channel,
626
+ kernel_size,
627
+ padding=self.padding,
628
+ stride=stride,
629
+ bias=bias and not activate,
630
+ dilation=dilation, ## modified
631
+ )
632
+ )
633
+
634
+ if activate:
635
+ layers.append(FusedLeakyReLU(out_channel, bias=bias))
636
+
637
+ super().__init__(*layers)
638
+
639
+
640
+ class ResBlock(nn.Module):
641
+ def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
642
+ super().__init__()
643
+
644
+ self.conv1 = ConvLayer(in_channel, in_channel, 3)
645
+ self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
646
+
647
+ self.skip = ConvLayer(
648
+ in_channel, out_channel, 1, downsample=True, activate=False, bias=False
649
+ )
650
+
651
+ def forward(self, input):
652
+ out = self.conv1(input)
653
+ out = self.conv2(out)
654
+
655
+ skip = self.skip(input)
656
+ out = (out + skip) / math.sqrt(2)
657
+
658
+ return out
659
+
660
+
661
+ class Discriminator(nn.Module):
662
+ def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
663
+ super().__init__()
664
+
665
+ channels = {
666
+ 4: 512,
667
+ 8: 512,
668
+ 16: 512,
669
+ 32: 512,
670
+ 64: 256 * channel_multiplier,
671
+ 128: 128 * channel_multiplier,
672
+ 256: 64 * channel_multiplier,
673
+ 512: 32 * channel_multiplier,
674
+ 1024: 16 * channel_multiplier,
675
+ }
676
+
677
+ convs = [ConvLayer(3, channels[size], 1)]
678
+
679
+ log_size = int(math.log(size, 2))
680
+
681
+ in_channel = channels[size]
682
+
683
+ for i in range(log_size, 2, -1):
684
+ out_channel = channels[2 ** (i - 1)]
685
+
686
+ convs.append(ResBlock(in_channel, out_channel, blur_kernel))
687
+
688
+ in_channel = out_channel
689
+
690
+ self.convs = nn.Sequential(*convs)
691
+
692
+ self.stddev_group = 4
693
+ self.stddev_feat = 1
694
+
695
+ self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
696
+ self.final_linear = nn.Sequential(
697
+ EqualLinear(channels[4] * 4 * 4, channels[4], activation="fused_lrelu"),
698
+ EqualLinear(channels[4], 1),
699
+ )
700
+
701
+ def forward(self, input):
702
+ out = self.convs(input)
703
+
704
+ batch, channel, height, width = out.shape
705
+ group = min(batch, self.stddev_group)
706
+ stddev = out.view(
707
+ group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
708
+ )
709
+ stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
710
+ stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
711
+ stddev = stddev.repeat(group, 1, height, width)
712
+ out = torch.cat([out, stddev], 1)
713
+
714
+ out = self.final_conv(out)
715
+
716
+ out = out.view(batch, -1)
717
+ out = self.final_linear(out)
718
+
719
+ return out