Arnaudding001 commited on
Commit
5914e7c
·
1 Parent(s): fbe25e3

Create stylegan_non_leaking.py

Browse files
Files changed (1) hide show
  1. stylegan_non_leaking.py +469 -0
stylegan_non_leaking.py ADDED
@@ -0,0 +1,469 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ from torch import autograd
5
+ from torch.nn import functional as F
6
+ import numpy as np
7
+
8
+ from model.stylegan.distributed import reduce_sum
9
+ from model.stylegan.op import upfirdn2d
10
+
11
+
12
+ class AdaptiveAugment:
13
+ def __init__(self, ada_aug_target, ada_aug_len, update_every, device):
14
+ self.ada_aug_target = ada_aug_target
15
+ self.ada_aug_len = ada_aug_len
16
+ self.update_every = update_every
17
+
18
+ self.ada_update = 0
19
+ self.ada_aug_buf = torch.tensor([0.0, 0.0], device=device)
20
+ self.r_t_stat = 0
21
+ self.ada_aug_p = 0
22
+
23
+ @torch.no_grad()
24
+ def tune(self, real_pred):
25
+ self.ada_aug_buf += torch.tensor(
26
+ (torch.sign(real_pred).sum().item(), real_pred.shape[0]),
27
+ device=real_pred.device,
28
+ )
29
+ self.ada_update += 1
30
+
31
+ if self.ada_update % self.update_every == 0:
32
+ self.ada_aug_buf = reduce_sum(self.ada_aug_buf)
33
+ pred_signs, n_pred = self.ada_aug_buf.tolist()
34
+
35
+ self.r_t_stat = pred_signs / n_pred
36
+
37
+ if self.r_t_stat > self.ada_aug_target:
38
+ sign = 1
39
+
40
+ else:
41
+ sign = -1
42
+
43
+ self.ada_aug_p += sign * n_pred / self.ada_aug_len
44
+ self.ada_aug_p = min(1, max(0, self.ada_aug_p))
45
+ self.ada_aug_buf.mul_(0)
46
+ self.ada_update = 0
47
+
48
+ return self.ada_aug_p
49
+
50
+
51
+ SYM6 = (
52
+ 0.015404109327027373,
53
+ 0.0034907120842174702,
54
+ -0.11799011114819057,
55
+ -0.048311742585633,
56
+ 0.4910559419267466,
57
+ 0.787641141030194,
58
+ 0.3379294217276218,
59
+ -0.07263752278646252,
60
+ -0.021060292512300564,
61
+ 0.04472490177066578,
62
+ 0.0017677118642428036,
63
+ -0.007800708325034148,
64
+ )
65
+
66
+
67
+ def translate_mat(t_x, t_y, device="cpu"):
68
+ batch = t_x.shape[0]
69
+
70
+ mat = torch.eye(3, device=device).unsqueeze(0).repeat(batch, 1, 1)
71
+ translate = torch.stack((t_x, t_y), 1)
72
+ mat[:, :2, 2] = translate
73
+
74
+ return mat
75
+
76
+
77
+ def rotate_mat(theta, device="cpu"):
78
+ batch = theta.shape[0]
79
+
80
+ mat = torch.eye(3, device=device).unsqueeze(0).repeat(batch, 1, 1)
81
+ sin_t = torch.sin(theta)
82
+ cos_t = torch.cos(theta)
83
+ rot = torch.stack((cos_t, -sin_t, sin_t, cos_t), 1).view(batch, 2, 2)
84
+ mat[:, :2, :2] = rot
85
+
86
+ return mat
87
+
88
+
89
+ def scale_mat(s_x, s_y, device="cpu"):
90
+ batch = s_x.shape[0]
91
+
92
+ mat = torch.eye(3, device=device).unsqueeze(0).repeat(batch, 1, 1)
93
+ mat[:, 0, 0] = s_x
94
+ mat[:, 1, 1] = s_y
95
+
96
+ return mat
97
+
98
+
99
+ def translate3d_mat(t_x, t_y, t_z):
100
+ batch = t_x.shape[0]
101
+
102
+ mat = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)
103
+ translate = torch.stack((t_x, t_y, t_z), 1)
104
+ mat[:, :3, 3] = translate
105
+
106
+ return mat
107
+
108
+
109
+ def rotate3d_mat(axis, theta):
110
+ batch = theta.shape[0]
111
+
112
+ u_x, u_y, u_z = axis
113
+
114
+ eye = torch.eye(3).unsqueeze(0)
115
+ cross = torch.tensor([(0, -u_z, u_y), (u_z, 0, -u_x), (-u_y, u_x, 0)]).unsqueeze(0)
116
+ outer = torch.tensor(axis)
117
+ outer = (outer.unsqueeze(1) * outer).unsqueeze(0)
118
+
119
+ sin_t = torch.sin(theta).view(-1, 1, 1)
120
+ cos_t = torch.cos(theta).view(-1, 1, 1)
121
+
122
+ rot = cos_t * eye + sin_t * cross + (1 - cos_t) * outer
123
+
124
+ eye_4 = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)
125
+ eye_4[:, :3, :3] = rot
126
+
127
+ return eye_4
128
+
129
+
130
+ def scale3d_mat(s_x, s_y, s_z):
131
+ batch = s_x.shape[0]
132
+
133
+ mat = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)
134
+ mat[:, 0, 0] = s_x
135
+ mat[:, 1, 1] = s_y
136
+ mat[:, 2, 2] = s_z
137
+
138
+ return mat
139
+
140
+
141
+ def luma_flip_mat(axis, i):
142
+ batch = i.shape[0]
143
+
144
+ eye = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)
145
+ axis = torch.tensor(axis + (0,))
146
+ flip = 2 * torch.ger(axis, axis) * i.view(-1, 1, 1)
147
+
148
+ return eye - flip
149
+
150
+
151
+ def saturation_mat(axis, i):
152
+ batch = i.shape[0]
153
+
154
+ eye = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)
155
+ axis = torch.tensor(axis + (0,))
156
+ axis = torch.ger(axis, axis)
157
+ saturate = axis + (eye - axis) * i.view(-1, 1, 1)
158
+
159
+ return saturate
160
+
161
+
162
+ def lognormal_sample(size, mean=0, std=1, device="cpu"):
163
+ return torch.empty(size, device=device).log_normal_(mean=mean, std=std)
164
+
165
+
166
+ def category_sample(size, categories, device="cpu"):
167
+ category = torch.tensor(categories, device=device)
168
+ sample = torch.randint(high=len(categories), size=(size,), device=device)
169
+
170
+ return category[sample]
171
+
172
+
173
+ def uniform_sample(size, low, high, device="cpu"):
174
+ return torch.empty(size, device=device).uniform_(low, high)
175
+
176
+
177
+ def normal_sample(size, mean=0, std=1, device="cpu"):
178
+ return torch.empty(size, device=device).normal_(mean, std)
179
+
180
+
181
+ def bernoulli_sample(size, p, device="cpu"):
182
+ return torch.empty(size, device=device).bernoulli_(p)
183
+
184
+
185
+ def random_mat_apply(p, transform, prev, eye, device="cpu"):
186
+ size = transform.shape[0]
187
+ select = bernoulli_sample(size, p, device=device).view(size, 1, 1)
188
+ select_transform = select * transform + (1 - select) * eye
189
+
190
+ return select_transform @ prev
191
+
192
+
193
+ def sample_affine(p, size, height, width, device="cpu"):
194
+ G = torch.eye(3, device=device).unsqueeze(0).repeat(size, 1, 1)
195
+ eye = G
196
+
197
+ # flip
198
+ param = category_sample(size, (0, 1))
199
+ Gc = scale_mat(1 - 2.0 * param, torch.ones(size), device=device)
200
+ G = random_mat_apply(p, Gc, G, eye, device=device)
201
+ # print('flip', G, scale_mat(1 - 2.0 * param, torch.ones(size)), sep='\n')
202
+
203
+ # 90 rotate
204
+ #param = category_sample(size, (0, 3))
205
+ #Gc = rotate_mat(-math.pi / 2 * param, device=device)
206
+ #G = random_mat_apply(p, Gc, G, eye, device=device)
207
+ # print('90 rotate', G, rotate_mat(-math.pi / 2 * param), sep='\n')
208
+
209
+ # integer translate
210
+ param = uniform_sample(size, -0.125, 0.125)
211
+ param_height = torch.round(param * height) / height
212
+ param_width = torch.round(param * width) / width
213
+ Gc = translate_mat(param_width, param_height, device=device)
214
+ G = random_mat_apply(p, Gc, G, eye, device=device)
215
+ # print('integer translate', G, translate_mat(param_width, param_height), sep='\n')
216
+
217
+ # isotropic scale
218
+ param = lognormal_sample(size, std=0.2 * math.log(2))
219
+ Gc = scale_mat(param, param, device=device)
220
+ G = random_mat_apply(p, Gc, G, eye, device=device)
221
+ # print('isotropic scale', G, scale_mat(param, param), sep='\n')
222
+
223
+ p_rot = 1 - math.sqrt(1 - p)
224
+
225
+ # pre-rotate
226
+ param = uniform_sample(size, -math.pi, math.pi)
227
+ Gc = rotate_mat(-param, device=device)
228
+ G = random_mat_apply(p_rot, Gc, G, eye, device=device)
229
+ # print('pre-rotate', G, rotate_mat(-param), sep='\n')
230
+
231
+ # anisotropic scale
232
+ param = lognormal_sample(size, std=0.2 * math.log(2))
233
+ Gc = scale_mat(param, 1 / param, device=device)
234
+ G = random_mat_apply(p, Gc, G, eye, device=device)
235
+ # print('anisotropic scale', G, scale_mat(param, 1 / param), sep='\n')
236
+
237
+ # post-rotate
238
+ param = uniform_sample(size, -math.pi, math.pi)
239
+ Gc = rotate_mat(-param, device=device)
240
+ G = random_mat_apply(p_rot, Gc, G, eye, device=device)
241
+ # print('post-rotate', G, rotate_mat(-param), sep='\n')
242
+
243
+ # fractional translate
244
+ param = normal_sample(size, std=0.125)
245
+ Gc = translate_mat(param, param, device=device)
246
+ G = random_mat_apply(p, Gc, G, eye, device=device)
247
+ # print('fractional translate', G, translate_mat(param, param), sep='\n')
248
+
249
+ return G
250
+
251
+
252
+ def sample_color(p, size):
253
+ C = torch.eye(4).unsqueeze(0).repeat(size, 1, 1)
254
+ eye = C
255
+ axis_val = 1 / math.sqrt(3)
256
+ axis = (axis_val, axis_val, axis_val)
257
+
258
+ # brightness
259
+ param = normal_sample(size, std=0.2)
260
+ Cc = translate3d_mat(param, param, param)
261
+ C = random_mat_apply(p, Cc, C, eye)
262
+
263
+ # contrast
264
+ param = lognormal_sample(size, std=0.5 * math.log(2))
265
+ Cc = scale3d_mat(param, param, param)
266
+ C = random_mat_apply(p, Cc, C, eye)
267
+
268
+ # luma flip
269
+ param = category_sample(size, (0, 1))
270
+ Cc = luma_flip_mat(axis, param)
271
+ C = random_mat_apply(p, Cc, C, eye)
272
+
273
+ # hue rotation
274
+ param = uniform_sample(size, -math.pi, math.pi)
275
+ Cc = rotate3d_mat(axis, param)
276
+ C = random_mat_apply(p, Cc, C, eye)
277
+
278
+ # saturation
279
+ param = lognormal_sample(size, std=1 * math.log(2))
280
+ Cc = saturation_mat(axis, param)
281
+ C = random_mat_apply(p, Cc, C, eye)
282
+
283
+ return C
284
+
285
+
286
+ def make_grid(shape, x0, x1, y0, y1, device):
287
+ n, c, h, w = shape
288
+ grid = torch.empty(n, h, w, 3, device=device)
289
+ grid[:, :, :, 0] = torch.linspace(x0, x1, w, device=device)
290
+ grid[:, :, :, 1] = torch.linspace(y0, y1, h, device=device).unsqueeze(-1)
291
+ grid[:, :, :, 2] = 1
292
+
293
+ return grid
294
+
295
+
296
+ def affine_grid(grid, mat):
297
+ n, h, w, _ = grid.shape
298
+ return (grid.view(n, h * w, 3) @ mat.transpose(1, 2)).view(n, h, w, 2)
299
+
300
+
301
+ def get_padding(G, height, width, kernel_size):
302
+ device = G.device
303
+
304
+ cx = (width - 1) / 2
305
+ cy = (height - 1) / 2
306
+ cp = torch.tensor(
307
+ [(-cx, -cy, 1), (cx, -cy, 1), (cx, cy, 1), (-cx, cy, 1)], device=device
308
+ )
309
+ cp = G @ cp.T
310
+
311
+ pad_k = kernel_size // 4
312
+
313
+ pad = cp[:, :2, :].permute(1, 0, 2).flatten(1)
314
+ pad = torch.cat((-pad, pad)).max(1).values
315
+ pad = pad + torch.tensor([pad_k * 2 - cx, pad_k * 2 - cy] * 2, device=device)
316
+ pad = pad.max(torch.tensor([0, 0] * 2, device=device))
317
+ pad = pad.min(torch.tensor([width - 1, height - 1] * 2, device=device))
318
+
319
+ pad_x1, pad_y1, pad_x2, pad_y2 = pad.ceil().to(torch.int32)
320
+
321
+ return pad_x1, pad_x2, pad_y1, pad_y2
322
+
323
+
324
+ def try_sample_affine_and_pad(img, p, kernel_size, G=None):
325
+ batch, _, height, width = img.shape
326
+
327
+ G_try = G
328
+
329
+ if G is None:
330
+ G_try = torch.inverse(sample_affine(p, batch, height, width))
331
+
332
+ pad_x1, pad_x2, pad_y1, pad_y2 = get_padding(G_try, height, width, kernel_size)
333
+
334
+ img_pad = F.pad(img, (pad_x1, pad_x2, pad_y1, pad_y2), mode="reflect")
335
+
336
+ return img_pad, G_try, (pad_x1, pad_x2, pad_y1, pad_y2)
337
+
338
+
339
+ class GridSampleForward(autograd.Function):
340
+ @staticmethod
341
+ def forward(ctx, input, grid):
342
+ out = F.grid_sample(
343
+ input, grid, mode="bilinear", padding_mode="zeros", align_corners=False
344
+ )
345
+ ctx.save_for_backward(input, grid)
346
+
347
+ return out
348
+
349
+ @staticmethod
350
+ def backward(ctx, grad_output):
351
+ input, grid = ctx.saved_tensors
352
+ grad_input, grad_grid = GridSampleBackward.apply(grad_output, input, grid)
353
+
354
+ return grad_input, grad_grid
355
+
356
+
357
+ class GridSampleBackward(autograd.Function):
358
+ @staticmethod
359
+ def forward(ctx, grad_output, input, grid):
360
+ op = torch._C._jit_get_operation("aten::grid_sampler_2d_backward")
361
+ grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)
362
+ ctx.save_for_backward(grid)
363
+
364
+ return grad_input, grad_grid
365
+
366
+ @staticmethod
367
+ def backward(ctx, grad_grad_input, grad_grad_grid):
368
+ grid, = ctx.saved_tensors
369
+ grad_grad_output = None
370
+
371
+ if ctx.needs_input_grad[0]:
372
+ grad_grad_output = GridSampleForward.apply(grad_grad_input, grid)
373
+
374
+ return grad_grad_output, None, None
375
+
376
+
377
+ grid_sample = GridSampleForward.apply
378
+
379
+
380
+ def scale_mat_single(s_x, s_y):
381
+ return torch.tensor(((s_x, 0, 0), (0, s_y, 0), (0, 0, 1)), dtype=torch.float32)
382
+
383
+
384
+ def translate_mat_single(t_x, t_y):
385
+ return torch.tensor(((1, 0, t_x), (0, 1, t_y), (0, 0, 1)), dtype=torch.float32)
386
+
387
+
388
+ def random_apply_affine(img, p, G=None, antialiasing_kernel=SYM6):
389
+ kernel = antialiasing_kernel
390
+ len_k = len(kernel)
391
+
392
+ kernel = torch.as_tensor(kernel).to(img)
393
+ # kernel = torch.ger(kernel, kernel).to(img)
394
+ kernel_flip = torch.flip(kernel, (0,))
395
+
396
+ img_pad, G, (pad_x1, pad_x2, pad_y1, pad_y2) = try_sample_affine_and_pad(
397
+ img, p, len_k, G
398
+ )
399
+
400
+ G_inv = (
401
+ translate_mat_single((pad_x1 - pad_x2).item() / 2, (pad_y1 - pad_y2).item() / 2)
402
+ @ G
403
+ )
404
+ up_pad = (
405
+ (len_k + 2 - 1) // 2,
406
+ (len_k - 2) // 2,
407
+ (len_k + 2 - 1) // 2,
408
+ (len_k - 2) // 2,
409
+ )
410
+ img_2x = upfirdn2d(img_pad, kernel.unsqueeze(0), up=(2, 1), pad=(*up_pad[:2], 0, 0))
411
+ img_2x = upfirdn2d(img_2x, kernel.unsqueeze(1), up=(1, 2), pad=(0, 0, *up_pad[2:]))
412
+ G_inv = scale_mat_single(2, 2) @ G_inv @ scale_mat_single(1 / 2, 1 / 2)
413
+ G_inv = translate_mat_single(-0.5, -0.5) @ G_inv @ translate_mat_single(0.5, 0.5)
414
+ batch_size, channel, height, width = img.shape
415
+ pad_k = len_k // 4
416
+ shape = (batch_size, channel, (height + pad_k * 2) * 2, (width + pad_k * 2) * 2)
417
+ G_inv = (
418
+ scale_mat_single(2 / img_2x.shape[3], 2 / img_2x.shape[2])
419
+ @ G_inv
420
+ @ scale_mat_single(1 / (2 / shape[3]), 1 / (2 / shape[2]))
421
+ )
422
+ grid = F.affine_grid(G_inv[:, :2, :].to(img_2x), shape, align_corners=False)
423
+ img_affine = grid_sample(img_2x, grid)
424
+ d_p = -pad_k * 2
425
+ down_pad = (
426
+ d_p + (len_k - 2 + 1) // 2,
427
+ d_p + (len_k - 2) // 2,
428
+ d_p + (len_k - 2 + 1) // 2,
429
+ d_p + (len_k - 2) // 2,
430
+ )
431
+ img_down = upfirdn2d(
432
+ img_affine, kernel_flip.unsqueeze(0), down=(2, 1), pad=(*down_pad[:2], 0, 0)
433
+ )
434
+ img_down = upfirdn2d(
435
+ img_down, kernel_flip.unsqueeze(1), down=(1, 2), pad=(0, 0, *down_pad[2:])
436
+ )
437
+
438
+ return img_down, G
439
+
440
+
441
+ def apply_color(img, mat):
442
+ batch = img.shape[0]
443
+ img = img.permute(0, 2, 3, 1)
444
+ mat_mul = mat[:, :3, :3].transpose(1, 2).view(batch, 1, 3, 3)
445
+ mat_add = mat[:, :3, 3].view(batch, 1, 1, 3)
446
+ img = img @ mat_mul + mat_add
447
+ img = img.permute(0, 3, 1, 2)
448
+
449
+ return img
450
+
451
+
452
+ def random_apply_color(img, p, C=None):
453
+ if C is None:
454
+ C = sample_color(p, img.shape[0])
455
+
456
+ img = apply_color(img, C.to(img))
457
+
458
+ return img, C
459
+
460
+
461
+ def augment(img, p, transform_matrix=(None, None)):
462
+ img, G = random_apply_affine(img, p, transform_matrix[0])
463
+ if img.shape[1] == 3:
464
+ img, C = random_apply_color(img, p, transform_matrix[1])
465
+ else:
466
+ tmp, C = random_apply_color(img[:,0:3], p, transform_matrix[1])
467
+ img = torch.cat((tmp, img[:,3:]), dim=1)
468
+
469
+ return img, (G, C)