DmitrMakeev commited on
Commit
add457a
1 Parent(s): 208a7f5

Upload 8 files

Browse files
models/layers/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2020 NVIDIA Corporation. All rights reserved.
2
+ #
3
+ # This work is made available under the Nvidia Source Code License-NC.
4
+ # To view a copy of this license, check out LICENSE.md
5
+ from .conv import LinearBlock, Conv1dBlock, Conv2dBlock, Conv3dBlock, \
6
+ HyperConv2dBlock, MultiOutConv2dBlock, \
7
+ PartialConv2dBlock, PartialConv3dBlock
8
+ from .residual import ResLinearBlock, Res1dBlock, Res2dBlock, Res3dBlock, \
9
+ HyperRes2dBlock, MultiOutRes2dBlock, UpRes2dBlock, DownRes2dBlock, \
10
+ PartialRes2dBlock, PartialRes3dBlock
11
+ # from .non_local import NonLocal2dBlock
12
+
13
+ __all__ = ['Conv1dBlock', 'Conv2dBlock', 'Conv3dBlock', 'LinearBlock',
14
+ 'HyperConv2dBlock', 'MultiOutConv2dBlock',
15
+ 'PartialConv2dBlock', 'PartialConv3dBlock',
16
+ 'Res1dBlock', 'Res2dBlock', 'Res3dBlock',
17
+ 'UpRes2dBlock', 'DownRes2dBlock',
18
+ 'ResLinearBlock', 'HyperRes2dBlock', 'MultiOutRes2dBlock',
19
+ 'PartialRes2dBlock', 'PartialRes3dBlock',]
models/layers/activation_norm.py ADDED
@@ -0,0 +1,420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from types import SimpleNamespace
2
+
3
+ import torch
4
+
5
+
6
+ try:
7
+ # from torch.nn import BatchNorm2d as SyncBatchNorm
8
+ from torch.nn import SyncBatchNorm
9
+ except ImportError:
10
+ from torch.nn import BatchNorm2d as SyncBatchNorm
11
+ from torch import nn
12
+ from torch.nn import functional as F
13
+ from .conv import LinearBlock, Conv2dBlock, HyperConv2d, PartialConv2dBlock
14
+ from .misc import PartialSequential
15
+ import sync_batchnorm
16
+
17
+
18
+ class AdaptiveNorm(nn.Module):
19
+ r"""Adaptive normalization layer. The layer first normalizes the input, then
20
+ performs an affine transformation using parameters computed from the
21
+ conditional inputs.
22
+ Args:
23
+ num_features (int): Number of channels in the input tensor.
24
+ cond_dims (int): Number of channels in the conditional inputs.
25
+ weight_norm_type (str): Type of weight normalization.
26
+ ``'none'``, ``'spectral'``, ``'weight'``, or ``'weight_demod'``.
27
+ projection (bool): If ``True``, project the conditional input to gamma
28
+ and beta using a fully connected layer, otherwise directly use
29
+ the conditional input as gamma and beta.
30
+ separate_projection (bool): If ``True``, we will use two different
31
+ layers for gamma and beta. Otherwise, we will use one layer. It
32
+ matters only if you apply any weight norms to this layer.
33
+ input_dim (int): Number of dimensions of the input tensor.
34
+ activation_norm_type (str):
35
+ Type of activation normalization.
36
+ ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
37
+ ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
38
+ ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
39
+ activation_norm_params (obj, optional, default=None):
40
+ Parameters of activation normalization.
41
+ If not ``None``, ``activation_norm_params.__dict__`` will be used as
42
+ keyword arguments when initializing activation normalization.
43
+ """
44
+
45
+ def __init__(self, num_features, cond_dims, weight_norm_type='',
46
+ projection=True,
47
+ separate_projection=False,
48
+ input_dim=2,
49
+ activation_norm_type='instance',
50
+ activation_norm_params=None):
51
+ super().__init__()
52
+ self.projection = projection
53
+ self.separate_projection = separate_projection
54
+ if activation_norm_params is None:
55
+ activation_norm_params = SimpleNamespace(affine=False)
56
+ self.norm = get_activation_norm_layer(num_features,
57
+ activation_norm_type,
58
+ input_dim,
59
+ **vars(activation_norm_params))
60
+ if self.projection:
61
+ if self.separate_projection:
62
+ self.fc_gamma = \
63
+ LinearBlock(cond_dims, num_features,
64
+ weight_norm_type=weight_norm_type)
65
+ self.fc_beta = \
66
+ LinearBlock(cond_dims, num_features,
67
+ weight_norm_type=weight_norm_type)
68
+ else:
69
+ self.fc = LinearBlock(cond_dims, num_features * 2,
70
+ weight_norm_type=weight_norm_type)
71
+
72
+ self.conditional = True
73
+
74
+ def forward(self, x, y, **kwargs):
75
+ r"""Adaptive Normalization forward.
76
+ Args:
77
+ x (N x C1 x * tensor): Input tensor.
78
+ y (N x C2 tensor): Conditional information.
79
+ Returns:
80
+ out (N x C1 x * tensor): Output tensor.
81
+ """
82
+ if self.projection:
83
+ if self.separate_projection:
84
+ gamma = self.fc_gamma(y)
85
+ beta = self.fc_beta(y)
86
+ for _ in range(x.dim() - gamma.dim()):
87
+ gamma = gamma.unsqueeze(-1)
88
+ beta = beta.unsqueeze(-1)
89
+ else:
90
+ y = self.fc(y)
91
+ for _ in range(x.dim() - y.dim()):
92
+ y = y.unsqueeze(-1)
93
+ gamma, beta = y.chunk(2, 1)
94
+ else:
95
+ for _ in range(x.dim() - y.dim()):
96
+ y = y.unsqueeze(-1)
97
+ gamma, beta = y.chunk(2, 1)
98
+ x = self.norm(x) if self.norm is not None else x
99
+ out = x * (1 + gamma) + beta
100
+ return out
101
+
102
+
103
+ class SpatiallyAdaptiveNorm(nn.Module):
104
+ r"""Spatially Adaptive Normalization (SPADE) initialization.
105
+ Args:
106
+ num_features (int) : Number of channels in the input tensor.
107
+ cond_dims (int or list of int) : List of numbers of channels
108
+ in the input.
109
+ num_filters (int): Number of filters in SPADE.
110
+ kernel_size (int): Kernel size of the convolutional filters in
111
+ the SPADE layer.
112
+ weight_norm_type (str): Type of weight normalization.
113
+ ``'none'``, ``'spectral'``, or ``'weight'``.
114
+ separate_projection (bool): If ``True``, we will use two different
115
+ layers for gamma and beta. Otherwise, we will use one layer. It
116
+ matters only if you apply any weight norms to this layer.
117
+ activation_norm_type (str):
118
+ Type of activation normalization.
119
+ ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
120
+ ``'layer'``, ``'layer_2d'``, ``'group'``.
121
+ activation_norm_params (obj, optional, default=None):
122
+ Parameters of activation normalization.
123
+ If not ``None``, ``activation_norm_params.__dict__`` will be used as
124
+ keyword arguments when initializing activation normalization.
125
+ """
126
+
127
+ def __init__(self,
128
+ num_features,
129
+ cond_dims,
130
+ num_filters=128,
131
+ kernel_size=3,
132
+ weight_norm_type='',
133
+ separate_projection=False,
134
+ activation_norm_type='sync_batch',
135
+ activation_norm_params=None,
136
+ partial=False):
137
+ super().__init__()
138
+ if activation_norm_params is None:
139
+ activation_norm_params = SimpleNamespace(affine=False)
140
+ padding = kernel_size // 2
141
+ self.separate_projection = separate_projection
142
+ self.mlps = nn.ModuleList()
143
+ self.gammas = nn.ModuleList()
144
+ self.betas = nn.ModuleList()
145
+
146
+ # Make cond_dims a list.
147
+ if type(cond_dims) != list:
148
+ cond_dims = [cond_dims]
149
+
150
+ # Make num_filters a list.
151
+ if not isinstance(num_filters, list):
152
+ num_filters = [num_filters] * len(cond_dims)
153
+ else:
154
+ assert len(num_filters) >= len(cond_dims)
155
+
156
+ # Make partial a list.
157
+ if not isinstance(partial, list):
158
+ partial = [partial] * len(cond_dims)
159
+ else:
160
+ assert len(partial) >= len(cond_dims)
161
+
162
+ for i, cond_dim in enumerate(cond_dims):
163
+ mlp = []
164
+ conv_block = PartialConv2dBlock if partial[i] else Conv2dBlock
165
+ sequential = PartialSequential if partial[i] else nn.Sequential
166
+
167
+ if num_filters[i] > 0:
168
+ mlp += [conv_block(cond_dim,
169
+ num_filters[i],
170
+ kernel_size,
171
+ padding=padding,
172
+ weight_norm_type=weight_norm_type,
173
+ nonlinearity='relu')]
174
+ mlp_ch = cond_dim if num_filters[i] == 0 else num_filters[i]
175
+
176
+ if self.separate_projection:
177
+ if partial[i]:
178
+ raise NotImplementedError(
179
+ 'Separate projection not yet implemented for ' +
180
+ 'partial conv')
181
+ self.mlps.append(nn.Sequential(*mlp))
182
+ self.gammas.append(
183
+ conv_block(mlp_ch, num_features,
184
+ kernel_size,
185
+ padding=padding,
186
+ weight_norm_type=weight_norm_type))
187
+ self.betas.append(
188
+ conv_block(mlp_ch, num_features,
189
+ kernel_size,
190
+ padding=padding,
191
+ weight_norm_type=weight_norm_type))
192
+ else:
193
+ mlp += [conv_block(mlp_ch, num_features * 2, kernel_size,
194
+ padding=padding,
195
+ weight_norm_type=weight_norm_type)]
196
+ self.mlps.append(sequential(*mlp))
197
+
198
+ self.norm = get_activation_norm_layer(num_features,
199
+ activation_norm_type,
200
+ 2,
201
+ **vars(activation_norm_params))
202
+ self.conditional = True
203
+
204
+ def forward(self, x, *cond_inputs, **kwargs):
205
+ r"""Spatially Adaptive Normalization (SPADE) forward.
206
+ Args:
207
+ x (N x C1 x H x W tensor) : Input tensor.
208
+ cond_inputs (list of tensors) : Conditional maps for SPADE.
209
+ Returns:
210
+ output (4D tensor) : Output tensor.
211
+ """
212
+ output = self.norm(x) if self.norm is not None else x
213
+ for i in range(len(cond_inputs)):
214
+ if cond_inputs[i] is None:
215
+ continue
216
+ label_map = F.interpolate(cond_inputs[i], size=x.size()[2:],
217
+ mode='nearest')
218
+ if self.separate_projection:
219
+ hidden = self.mlps[i](label_map)
220
+ gamma = self.gammas[i](hidden)
221
+ beta = self.betas[i](hidden)
222
+ else:
223
+ affine_params = self.mlps[i](label_map)
224
+ gamma, beta = affine_params.chunk(2, dim=1)
225
+ output = output * (1 + gamma) + beta
226
+ return output
227
+
228
+
229
+ class HyperSpatiallyAdaptiveNorm(nn.Module):
230
+ r"""Spatially Adaptive Normalization (SPADE) initialization.
231
+ Args:
232
+ num_features (int) : Number of channels in the input tensor.
233
+ cond_dims (int or list of int) : List of numbers of channels
234
+ in the conditional input.
235
+ num_filters (int): Number of filters in SPADE.
236
+ kernel_size (int): Kernel size of the convolutional filters in
237
+ the SPADE layer.
238
+ weight_norm_type (str): Type of weight normalization.
239
+ ``'none'``, ``'spectral'``, or ``'weight'``.
240
+ activation_norm_type (str):
241
+ Type of activation normalization.
242
+ ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
243
+ ``'layer'``, ``'layer_2d'``, ``'group'``.
244
+ is_hyper (bool): Whether to use hyper SPADE.
245
+ """
246
+
247
+ def __init__(self, num_features, cond_dims,
248
+ num_filters=0, kernel_size=3,
249
+ weight_norm_type='',
250
+ activation_norm_type='sync_batch', is_hyper=True):
251
+ super().__init__()
252
+ padding = kernel_size // 2
253
+ self.mlps = nn.ModuleList()
254
+ if type(cond_dims) != list:
255
+ cond_dims = [cond_dims]
256
+
257
+ for i, cond_dim in enumerate(cond_dims):
258
+ mlp = []
259
+ if not is_hyper or (i != 0):
260
+ if num_filters > 0:
261
+ mlp += [Conv2dBlock(cond_dim, num_filters, kernel_size,
262
+ padding=padding,
263
+ weight_norm_type=weight_norm_type,
264
+ nonlinearity='relu')]
265
+ mlp_ch = cond_dim if num_filters == 0 else num_filters
266
+ mlp += [Conv2dBlock(mlp_ch, num_features * 2, kernel_size,
267
+ padding=padding,
268
+ weight_norm_type=weight_norm_type)]
269
+ mlp = nn.Sequential(*mlp)
270
+ else:
271
+ if num_filters > 0:
272
+ raise ValueError('Multi hyper layer not supported yet.')
273
+ mlp = HyperConv2d(padding=padding)
274
+ self.mlps.append(mlp)
275
+
276
+ self.norm = get_activation_norm_layer(num_features,
277
+ activation_norm_type,
278
+ 2,
279
+ affine=False)
280
+
281
+ self.conditional = True
282
+
283
+ def forward(self, x, *cond_inputs,
284
+ norm_weights=(None, None), **kwargs):
285
+ r"""Spatially Adaptive Normalization (SPADE) forward.
286
+ Args:
287
+ x (4D tensor) : Input tensor.
288
+ cond_inputs (list of tensors) : Conditional maps for SPADE.
289
+ norm_weights (5D tensor or list of tensors): conv weights or
290
+ [weights, biases].
291
+ Returns:
292
+ output (4D tensor) : Output tensor.
293
+ """
294
+ output = self.norm(x)
295
+ for i in range(len(cond_inputs)):
296
+ if cond_inputs[i] is None:
297
+ continue
298
+ if type(cond_inputs[i]) == list:
299
+ cond_input, mask = cond_inputs[i]
300
+ mask = F.interpolate(mask, size=x.size()[2:], mode='bilinear',
301
+ align_corners=False)
302
+ else:
303
+ cond_input = cond_inputs[i]
304
+ mask = None
305
+ label_map = F.interpolate(cond_input, size=x.size()[2:])
306
+ if norm_weights is None or norm_weights[0] is None or i != 0:
307
+ affine_params = self.mlps[i](label_map)
308
+ else:
309
+ affine_params = self.mlps[i](label_map,
310
+ conv_weights=norm_weights)
311
+ gamma, beta = affine_params.chunk(2, dim=1)
312
+ if mask is not None:
313
+ gamma = gamma * (1 - mask)
314
+ beta = beta * (1 - mask)
315
+ output = output * (1 + gamma) + beta
316
+ return output
317
+
318
+
319
+ class LayerNorm2d(nn.Module):
320
+ r"""Layer Normalization as introduced in
321
+ https://arxiv.org/abs/1607.06450.
322
+ This is the usual way to apply layer normalization in CNNs.
323
+ Note that unlike the pytorch implementation which applies per-element
324
+ scale and bias, here it applies per-channel scale and bias, similar to
325
+ batch/instance normalization.
326
+ Args:
327
+ num_features (int): Number of channels in the input tensor.
328
+ eps (float, optional, default=1e-5): a value added to the
329
+ denominator for numerical stability.
330
+ affine (bool, optional, default=False): If ``True``, performs
331
+ affine transformation after normalization.
332
+ """
333
+
334
+ def __init__(self, num_features, eps=1e-5, affine=True):
335
+ super(LayerNorm2d, self).__init__()
336
+ self.num_features = num_features
337
+ self.affine = affine
338
+ self.eps = eps
339
+
340
+ if self.affine:
341
+ self.gamma = nn.Parameter(torch.Tensor(num_features).uniform_())
342
+ self.beta = nn.Parameter(torch.zeros(num_features))
343
+
344
+ def forward(self, x):
345
+ r"""
346
+ Args:
347
+ x (tensor): Input tensor.
348
+ """
349
+ shape = [-1] + [1] * (x.dim() - 1)
350
+ if x.size(0) == 1:
351
+ mean = x.view(-1).mean().view(*shape)
352
+ std = x.view(-1).std().view(*shape)
353
+ else:
354
+ mean = x.view(x.size(0), -1).mean(1).view(*shape)
355
+ std = x.view(x.size(0), -1).std(1).view(*shape)
356
+
357
+ x = (x - mean) / (std + self.eps)
358
+
359
+ if self.affine:
360
+ shape = [1, -1] + [1] * (x.dim() - 2)
361
+ x = x * self.gamma.view(*shape) + self.beta.view(*shape)
362
+ return x
363
+
364
+
365
+ def get_activation_norm_layer(num_features, norm_type,
366
+ input_dim, **norm_params):
367
+ r"""Return an activation normalization layer.
368
+ Args:
369
+ num_features (int): Number of feature channels.
370
+ norm_type (str):
371
+ Type of activation normalization.
372
+ ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
373
+ ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
374
+ ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
375
+ input_dim (int): Number of input dimensions.
376
+ norm_params: Arbitrary keyword arguments that will be used to
377
+ initialize the activation normalization.
378
+ """
379
+ input_dim = max(input_dim, 1) # Norm1d works with both 0d and 1d inputs
380
+
381
+ if norm_type == 'none' or norm_type == '':
382
+ norm_layer = None
383
+ elif norm_type == 'batch':
384
+ # norm = getattr(nn, 'BatchNorm%dd' % input_dim)
385
+ norm = getattr(sync_batchnorm, 'SynchronizedBatchNorm%dd' % input_dim)
386
+ norm_layer = norm(num_features, **norm_params)
387
+ elif norm_type == 'instance':
388
+ affine = norm_params.pop('affine', True) # Use affine=True by default
389
+ norm = getattr(nn, 'InstanceNorm%dd' % input_dim)
390
+ norm_layer = norm(num_features, affine=affine, **norm_params)
391
+ elif norm_type == 'sync_batch':
392
+ # There is a bug of using amp O1 with synchronize batch norm.
393
+ # The lines below fix it.
394
+ affine = norm_params.pop('affine', True)
395
+ # Always call SyncBN with affine=True
396
+ norm_layer = SyncBatchNorm(num_features, affine=True, **norm_params)
397
+ norm_layer.weight.requires_grad = affine
398
+ norm_layer.bias.requires_grad = affine
399
+ elif norm_type == 'layer':
400
+ norm_layer = nn.LayerNorm(num_features, **norm_params)
401
+ elif norm_type == 'layer_2d':
402
+ norm_layer = LayerNorm2d(num_features, **norm_params)
403
+ elif norm_type == 'group':
404
+ norm_layer = nn.GroupNorm(num_channels=num_features, **norm_params)
405
+ elif norm_type == 'adaptive':
406
+ norm_layer = AdaptiveNorm(num_features, **norm_params)
407
+ elif norm_type == 'spatially_adaptive':
408
+ if input_dim != 2:
409
+ raise ValueError('Spatially adaptive normalization layers '
410
+ 'only supports 2D input')
411
+ norm_layer = SpatiallyAdaptiveNorm(num_features, **norm_params)
412
+ elif norm_type == 'hyper_spatially_adaptive':
413
+ if input_dim != 2:
414
+ raise ValueError('Spatially adaptive normalization layers '
415
+ 'only supports 2D input')
416
+ norm_layer = HyperSpatiallyAdaptiveNorm(num_features, **norm_params)
417
+ else:
418
+ raise ValueError('Activation norm layer %s '
419
+ 'is not recognized' % norm_type)
420
+ return norm_layer
models/layers/conv.py ADDED
@@ -0,0 +1,1073 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2020 NVIDIA Corporation. All rights reserved.
2
+ #
3
+ # This work is made available under the Nvidia Source Code License-NC.
4
+ # To view a copy of this license, check out LICENSE.md
5
+ from types import SimpleNamespace
6
+
7
+ import torch
8
+ from torch import nn
9
+ from torch.nn import functional as F
10
+
11
+ from .misc import ApplyNoise
12
+
13
+
14
+ class _BaseConvBlock(nn.Module):
15
+ r"""An abstract wrapper class that wraps a torch convolution or linear layer
16
+ with normalization and nonlinearity.
17
+ """
18
+
19
+ def __init__(self, in_channels, out_channels, kernel_size, stride,
20
+ padding, dilation, groups, bias, padding_mode,
21
+ weight_norm_type, weight_norm_params,
22
+ activation_norm_type, activation_norm_params,
23
+ nonlinearity, inplace_nonlinearity,
24
+ apply_noise, order, input_dim):
25
+ super().__init__()
26
+ from .nonlinearity import get_nonlinearity_layer
27
+ from .weight_norm import get_weight_norm_layer
28
+ from .activation_norm import get_activation_norm_layer
29
+ self.weight_norm_type = weight_norm_type
30
+
31
+ # Convolutional layer.
32
+ if weight_norm_params is None:
33
+ weight_norm_params = SimpleNamespace()
34
+ weight_norm = get_weight_norm_layer(
35
+ weight_norm_type, **vars(weight_norm_params))
36
+ conv_layer = weight_norm(self._get_conv_layer(
37
+ in_channels, out_channels, kernel_size, stride, padding, dilation,
38
+ groups, bias, padding_mode, input_dim))
39
+
40
+ # Noise injection layer.
41
+ noise_layer = ApplyNoise() if apply_noise else None
42
+
43
+ # Normalization layer.
44
+ conv_before_norm = order.find('C') < order.find('N')
45
+ norm_channels = out_channels if conv_before_norm else in_channels
46
+ if activation_norm_params is None:
47
+ activation_norm_params = SimpleNamespace()
48
+ activation_norm_layer = get_activation_norm_layer(
49
+ norm_channels,
50
+ activation_norm_type,
51
+ input_dim,
52
+ **vars(activation_norm_params))
53
+
54
+ # Nonlinearity layer.
55
+ nonlinearity_layer = get_nonlinearity_layer(
56
+ nonlinearity, inplace=inplace_nonlinearity)
57
+
58
+ # Mapping from operation names to layers.
59
+ mappings = {'C': {'conv': conv_layer},
60
+ 'N': {'norm': activation_norm_layer},
61
+ 'A': {'nonlinearity': nonlinearity_layer}}
62
+
63
+ # All layers in order.
64
+ self.layers = nn.ModuleDict()
65
+ for op in order:
66
+ if list(mappings[op].values())[0] is not None:
67
+ self.layers.update(mappings[op])
68
+ if op == 'C' and noise_layer is not None:
69
+ # Inject noise after convolution.
70
+ self.layers.update({'noise': noise_layer})
71
+
72
+ # Whether this block expects conditional inputs.
73
+ self.conditional = \
74
+ getattr(conv_layer, 'conditional', False) or \
75
+ getattr(activation_norm_layer, 'conditional', False)
76
+
77
+ def forward(self, x, *cond_inputs, **kw_cond_inputs):
78
+ r"""
79
+
80
+ Args:
81
+ x (tensor): Input tensor.
82
+ cond_inputs (list of tensors) : Conditional input tensors.
83
+ kw_cond_inputs (dict) : Keyword conditional inputs.
84
+ """
85
+ for layer in self.layers.values():
86
+ if getattr(layer, 'conditional', False):
87
+ # Layers that require conditional inputs.
88
+ x = layer(x, *cond_inputs, **kw_cond_inputs)
89
+ else:
90
+ x = layer(x)
91
+ return x
92
+
93
+ def _get_conv_layer(self, in_channels, out_channels, kernel_size, stride,
94
+ padding, dilation, groups, bias, padding_mode,
95
+ input_dim):
96
+ # Returns the convolutional layer.
97
+ if input_dim == 0:
98
+ layer = nn.Linear(in_channels, out_channels, bias)
99
+ else:
100
+ layer_type = getattr(nn, 'Conv%dd' % input_dim)
101
+
102
+ layer = layer_type(
103
+ in_channels, out_channels, kernel_size, stride, padding,
104
+ dilation, groups, bias)
105
+ return layer
106
+
107
+ def __repr__(self):
108
+ main_str = self._get_name() + '('
109
+ child_lines = []
110
+ for name, layer in self.layers.items():
111
+ mod_str = repr(layer)
112
+ if name == 'conv' and self.weight_norm_type != 'none' and \
113
+ self.weight_norm_type != '':
114
+ mod_str = mod_str[:-1] + \
115
+ ', weight_norm={}'.format(self.weight_norm_type) + ')'
116
+ mod_str = self._addindent(mod_str, 2)
117
+ child_lines.append(mod_str)
118
+ if len(child_lines) == 1:
119
+ main_str += child_lines[0]
120
+ else:
121
+ main_str += '\n ' + '\n '.join(child_lines) + '\n'
122
+
123
+ main_str += ')'
124
+ return main_str
125
+
126
+ @staticmethod
127
+ def _addindent(s_, numSpaces):
128
+ s = s_.split('\n')
129
+ # don't do anything for single-line stuff
130
+ if len(s) == 1:
131
+ return s_
132
+ first = s.pop(0)
133
+ s = [(numSpaces * ' ') + line for line in s]
134
+ s = '\n'.join(s)
135
+ s = first + '\n' + s
136
+ return s
137
+
138
+
139
+ class LinearBlock(_BaseConvBlock):
140
+ r"""A Wrapper class that wraps ``torch.nn.Linear`` with normalization and
141
+ nonlinearity.
142
+
143
+ Args:
144
+ in_features (int): Number of channels in the input tensor.
145
+ out_features (int): Number of channels in the output tensor.
146
+ bias (bool, optional, default=True):
147
+ If ``True``, adds a learnable bias to the output.
148
+ weight_norm_type (str, optional, default='none'):
149
+ Type of weight normalization.
150
+ ``'none'``, ``'spectral'``, ``'weight'``
151
+ or ``'weight_demod'``.
152
+ weight_norm_params (obj, optional, default=None):
153
+ Parameters of weight normalization.
154
+ If not ``None``, ``weight_norm_params.__dict__`` will be used as
155
+ keyword arguments when initializing weight normalization.
156
+ activation_norm_type (str, optional, default='none'):
157
+ Type of activation normalization.
158
+ ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
159
+ ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
160
+ ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
161
+ activation_norm_params (obj, optional, default=None):
162
+ Parameters of activation normalization.
163
+ If not ``None``, ``activation_norm_params.__dict__`` will be used as
164
+ keyword arguments when initializing activation normalization.
165
+ nonlinearity (str, optional, default='none'):
166
+ Type of nonlinear activation function.
167
+ ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
168
+ ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
169
+ inplace_nonlinearity (bool, optional, default=False): If ``True``,
170
+ set ``inplace=True`` when initializing the nonlinearity layer.
171
+ apply_noise (bool, optional, default=False): If ``True``, add
172
+ Gaussian noise with learnable magnitude after the
173
+ fully-connected layer.
174
+ order (str, optional, default='CNA'): Order of operations.
175
+ ``'C'``: fully-connected,
176
+ ``'N'``: normalization,
177
+ ``'A'``: nonlinear activation.
178
+ For example, a block initialized with ``order='CNA'`` will
179
+ do convolution first, then normalization, then nonlinearity.
180
+ """
181
+
182
+ def __init__(self, in_features, out_features, bias=True,
183
+ weight_norm_type='none', weight_norm_params=None,
184
+ activation_norm_type='none', activation_norm_params=None,
185
+ nonlinearity='none', inplace_nonlinearity=False,
186
+ apply_noise=False, order='CNA'):
187
+ super().__init__(in_features, out_features, None, None,
188
+ None, None, None, bias,
189
+ None, weight_norm_type, weight_norm_params,
190
+ activation_norm_type, activation_norm_params,
191
+ nonlinearity, inplace_nonlinearity, apply_noise,
192
+ order, 0)
193
+
194
+
195
+ class Conv1dBlock(_BaseConvBlock):
196
+ r"""A Wrapper class that wraps ``torch.nn.Conv1d`` with normalization and
197
+ nonlinearity.
198
+
199
+ Args:
200
+ in_channels (int): Number of channels in the input tensor.
201
+ out_channels (int): Number of channels in the output tensor.
202
+ kernel_size (int or tuple): Size of the convolving kernel.
203
+ stride (int or tuple, optional, default=1):
204
+ Stride of the convolution.
205
+ padding (int or tuple, optional, default=0):
206
+ Zero-padding added to both sides of the input.
207
+ dilation (int or tuple, optional, default=1):
208
+ Spacing between kernel elements.
209
+ groups (int, optional, default=1): Number of blocked connections
210
+ from input channels to output channels.
211
+ bias (bool, optional, default=True):
212
+ If ``True``, adds a learnable bias to the output.
213
+ padding_mode (string, optional, default='zeros'): Type of padding:
214
+ ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
215
+ weight_norm_type (str, optional, default='none'):
216
+ Type of weight normalization.
217
+ ``'none'``, ``'spectral'``, ``'weight'``
218
+ or ``'weight_demod'``.
219
+ weight_norm_params (obj, optional, default=None):
220
+ Parameters of weight normalization.
221
+ If not ``None``, ``weight_norm_params.__dict__`` will be used as
222
+ keyword arguments when initializing weight normalization.
223
+ activation_norm_type (str, optional, default='none'):
224
+ Type of activation normalization.
225
+ ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
226
+ ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
227
+ ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
228
+ activation_norm_params (obj, optional, default=None):
229
+ Parameters of activation normalization.
230
+ If not ``None``, ``activation_norm_params.__dict__`` will be used as
231
+ keyword arguments when initializing activation normalization.
232
+ nonlinearity (str, optional, default='none'):
233
+ Type of nonlinear activation function.
234
+ ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
235
+ ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
236
+ inplace_nonlinearity (bool, optional, default=False): If ``True``,
237
+ set ``inplace=True`` when initializing the nonlinearity layer.
238
+ apply_noise (bool, optional, default=False): If ``True``, adds
239
+ Gaussian noise with learnable magnitude to the convolution output.
240
+ order (str, optional, default='CNA'): Order of operations.
241
+ ``'C'``: convolution,
242
+ ``'N'``: normalization,
243
+ ``'A'``: nonlinear activation.
244
+ For example, a block initialized with ``order='CNA'`` will
245
+ do convolution first, then normalization, then nonlinearity.
246
+ """
247
+
248
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
249
+ padding=0, dilation=1, groups=1, bias=True,
250
+ padding_mode='zeros',
251
+ weight_norm_type='none', weight_norm_params=None,
252
+ activation_norm_type='none', activation_norm_params=None,
253
+ nonlinearity='none', inplace_nonlinearity=False,
254
+ apply_noise=False, order='CNA'):
255
+ super().__init__(in_channels, out_channels, kernel_size, stride,
256
+ padding, dilation, groups, bias, padding_mode,
257
+ weight_norm_type, weight_norm_params,
258
+ activation_norm_type, activation_norm_params,
259
+ nonlinearity, inplace_nonlinearity, apply_noise,
260
+ order, 1)
261
+
262
+
263
+ class Conv2dBlock(_BaseConvBlock):
264
+ r"""A Wrapper class that wraps ``torch.nn.Conv2d`` with normalization and
265
+ nonlinearity.
266
+
267
+ Args:
268
+ in_channels (int): Number of channels in the input tensor.
269
+ out_channels (int): Number of channels in the output tensor.
270
+ kernel_size (int or tuple): Size of the convolving kernel.
271
+ stride (int or tuple, optional, default=1):
272
+ Stride of the convolution.
273
+ padding (int or tuple, optional, default=0):
274
+ Zero-padding added to both sides of the input.
275
+ dilation (int or tuple, optional, default=1):
276
+ Spacing between kernel elements.
277
+ groups (int, optional, default=1): Number of blocked connections
278
+ from input channels to output channels.
279
+ bias (bool, optional, default=True):
280
+ If ``True``, adds a learnable bias to the output.
281
+ padding_mode (string, optional, default='zeros'): Type of padding:
282
+ ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
283
+ weight_norm_type (str, optional, default='none'):
284
+ Type of weight normalization.
285
+ ``'none'``, ``'spectral'``, ``'weight'``
286
+ or ``'weight_demod'``.
287
+ weight_norm_params (obj, optional, default=None):
288
+ Parameters of weight normalization.
289
+ If not ``None``, ``weight_norm_params.__dict__`` will be used as
290
+ keyword arguments when initializing weight normalization.
291
+ activation_norm_type (str, optional, default='none'):
292
+ Type of activation normalization.
293
+ ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
294
+ ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
295
+ ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
296
+ activation_norm_params (obj, optional, default=None):
297
+ Parameters of activation normalization.
298
+ If not ``None``, ``activation_norm_params.__dict__`` will be used as
299
+ keyword arguments when initializing activation normalization.
300
+ nonlinearity (str, optional, default='none'):
301
+ Type of nonlinear activation function.
302
+ ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
303
+ ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
304
+ inplace_nonlinearity (bool, optional, default=False): If ``True``,
305
+ set ``inplace=True`` when initializing the nonlinearity layer.
306
+ apply_noise (bool, optional, default=False): If ``True``, adds
307
+ Gaussian noise with learnable magnitude to the convolution output.
308
+ order (str, optional, default='CNA'): Order of operations.
309
+ ``'C'``: convolution,
310
+ ``'N'``: normalization,
311
+ ``'A'``: nonlinear activation.
312
+ For example, a block initialized with ``order='CNA'`` will
313
+ do convolution first, then normalization, then nonlinearity.
314
+ """
315
+
316
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
317
+ padding=0, dilation=1, groups=1, bias=True,
318
+ padding_mode='zeros',
319
+ weight_norm_type='none', weight_norm_params=None,
320
+ activation_norm_type='none', activation_norm_params=None,
321
+ nonlinearity='none', inplace_nonlinearity=False,
322
+ apply_noise=False, order='CNA'):
323
+ super().__init__(in_channels, out_channels, kernel_size, stride,
324
+ padding, dilation, groups, bias, padding_mode,
325
+ weight_norm_type, weight_norm_params,
326
+ activation_norm_type, activation_norm_params,
327
+ nonlinearity, inplace_nonlinearity,
328
+ apply_noise, order, 2)
329
+
330
+
331
+ class Conv3dBlock(_BaseConvBlock):
332
+ r"""A Wrapper class that wraps ``torch.nn.Conv3d`` with normalization and
333
+ nonlinearity.
334
+
335
+ Args:
336
+ in_channels (int): Number of channels in the input tensor.
337
+ out_channels (int): Number of channels in the output tensor.
338
+ kernel_size (int or tuple): Size of the convolving kernel.
339
+ stride (int or tuple, optional, default=1):
340
+ Stride of the convolution.
341
+ padding (int or tuple, optional, default=0):
342
+ Zero-padding added to both sides of the input.
343
+ dilation (int or tuple, optional, default=1):
344
+ Spacing between kernel elements.
345
+ groups (int, optional, default=1): Number of blocked connections
346
+ from input channels to output channels.
347
+ bias (bool, optional, default=True):
348
+ If ``True``, adds a learnable bias to the output.
349
+ padding_mode (string, optional, default='zeros'): Type of padding:
350
+ ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
351
+ weight_norm_type (str, optional, default='none'):
352
+ Type of weight normalization.
353
+ ``'none'``, ``'spectral'``, ``'weight'``
354
+ or ``'weight_demod'``.
355
+ weight_norm_params (obj, optional, default=None):
356
+ Parameters of weight normalization.
357
+ If not ``None``, ``weight_norm_params.__dict__`` will be used as
358
+ keyword arguments when initializing weight normalization.
359
+ activation_norm_type (str, optional, default='none'):
360
+ Type of activation normalization.
361
+ ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
362
+ ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
363
+ ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
364
+ activation_norm_params (obj, optional, default=None):
365
+ Parameters of activation normalization.
366
+ If not ``None``, ``activation_norm_params.__dict__`` will be used as
367
+ keyword arguments when initializing activation normalization.
368
+ nonlinearity (str, optional, default='none'):
369
+ Type of nonlinear activation function.
370
+ ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
371
+ ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
372
+ inplace_nonlinearity (bool, optional, default=False): If ``True``,
373
+ set ``inplace=True`` when initializing the nonlinearity layer.
374
+ apply_noise (bool, optional, default=False): If ``True``, adds
375
+ Gaussian noise with learnable magnitude to the convolution output.
376
+ order (str, optional, default='CNA'): Order of operations.
377
+ ``'C'``: convolution,
378
+ ``'N'``: normalization,
379
+ ``'A'``: nonlinear activation.
380
+ For example, a block initialized with ``order='CNA'`` will
381
+ do convolution first, then normalization, then nonlinearity.
382
+ """
383
+
384
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
385
+ padding=0, dilation=1, groups=1, bias=True,
386
+ padding_mode='zeros',
387
+ weight_norm_type='none', weight_norm_params=None,
388
+ activation_norm_type='none', activation_norm_params=None,
389
+ nonlinearity='none', inplace_nonlinearity=False,
390
+ apply_noise=False,
391
+ order='CNA'):
392
+ super().__init__(in_channels, out_channels, kernel_size, stride,
393
+ padding, dilation, groups, bias, padding_mode,
394
+ weight_norm_type, weight_norm_params,
395
+ activation_norm_type, activation_norm_params,
396
+ nonlinearity, inplace_nonlinearity,
397
+ apply_noise, order, 3)
398
+
399
+
400
+ class _BaseHyperConvBlock(_BaseConvBlock):
401
+ r"""An abstract wrapper class that wraps a hyper convolutional layer
402
+ with normalization and nonlinearity.
403
+ """
404
+
405
+ def __init__(self, in_channels, out_channels, kernel_size, stride,
406
+ padding, dilation, groups, bias,
407
+ padding_mode,
408
+ weight_norm_type, weight_norm_params,
409
+ activation_norm_type, activation_norm_params,
410
+ nonlinearity, inplace_nonlinearity, apply_noise,
411
+ is_hyper_conv, is_hyper_norm,
412
+ order, input_dim):
413
+ self.is_hyper_conv = is_hyper_conv
414
+ if is_hyper_conv:
415
+ weight_norm_type = 'none'
416
+ if is_hyper_norm:
417
+ activation_norm_type = 'hyper_' + activation_norm_type
418
+ super().__init__(in_channels, out_channels, kernel_size, stride,
419
+ padding, dilation, groups, bias, padding_mode,
420
+ weight_norm_type, weight_norm_params,
421
+ activation_norm_type, activation_norm_params,
422
+ nonlinearity, inplace_nonlinearity, apply_noise,
423
+ order, input_dim)
424
+
425
+ def _get_conv_layer(self, in_channels, out_channels, kernel_size, stride,
426
+ padding, dilation, groups, bias, padding_mode,
427
+ input_dim):
428
+ if input_dim == 0:
429
+ raise ValueError('HyperLinearBlock is not supported.')
430
+ else:
431
+ name = 'HyperConv' if self.is_hyper_conv else 'nn.Conv'
432
+ layer_type = eval(name + '%dd' % input_dim)
433
+ layer = layer_type(
434
+ in_channels, out_channels, kernel_size, stride, padding,
435
+ dilation, groups, bias, padding_mode)
436
+ return layer
437
+
438
+
439
+ class HyperConv2dBlock(_BaseHyperConvBlock):
440
+ r"""A Wrapper class that wraps ``HyperConv2d`` with normalization and
441
+ nonlinearity.
442
+
443
+ Args:
444
+ in_channels (int): Number of channels in the input tensor.
445
+ out_channels (int): Number of channels in the output tensor.
446
+ kernel_size (int or tuple): Size of the convolving kernel.
447
+ stride (int or tuple, optional, default=1):
448
+ Stride of the convolution.
449
+ padding (int or tuple, optional, default=0):
450
+ Zero-padding added to both sides of the input.
451
+ dilation (int or tuple, optional, default=1):
452
+ Spacing between kernel elements.
453
+ groups (int, optional, default=1): Number of blocked connections
454
+ from input channels to output channels.
455
+ bias (bool, optional, default=True):
456
+ If ``True``, adds a learnable bias to the output.
457
+ padding_mode (string, optional, default='zeros'): Type of padding:
458
+ ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
459
+ weight_norm_type (str, optional, default='none'):
460
+ Type of weight normalization.
461
+ ``'none'``, ``'spectral'``, ``'weight'``
462
+ or ``'weight_demod'``.
463
+ weight_norm_params (obj, optional, default=None):
464
+ Parameters of weight normalization.
465
+ If not ``None``, ``weight_norm_params.__dict__`` will be used as
466
+ keyword arguments when initializing weight normalization.
467
+ activation_norm_type (str, optional, default='none'):
468
+ Type of activation normalization.
469
+ ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
470
+ ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
471
+ ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
472
+ activation_norm_params (obj, optional, default=None):
473
+ Parameters of activation normalization.
474
+ If not ``None``, ``activation_norm_params.__dict__`` will be used as
475
+ keyword arguments when initializing activation normalization.
476
+ is_hyper_conv (bool, optional, default=False): If ``True``, use
477
+ ``HyperConv2d``, otherwise use ``torch.nn.Conv2d``.
478
+ is_hyper_norm (bool, optional, default=False): If ``True``, use
479
+ hyper normalizations.
480
+ nonlinearity (str, optional, default='none'):
481
+ Type of nonlinear activation function.
482
+ ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
483
+ ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
484
+ inplace_nonlinearity (bool, optional, default=False): If ``True``,
485
+ set ``inplace=True`` when initializing the nonlinearity layer.
486
+ apply_noise (bool, optional, default=False): If ``True``, adds
487
+ Gaussian noise with learnable magnitude to the convolution output.
488
+ order (str, optional, default='CNA'): Order of operations.
489
+ ``'C'``: convolution,
490
+ ``'N'``: normalization,
491
+ ``'A'``: nonlinear activation.
492
+ For example, a block initialized with ``order='CNA'`` will
493
+ do convolution first, then normalization, then nonlinearity.
494
+ """
495
+
496
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
497
+ padding=0, dilation=1, groups=1, bias=True,
498
+ padding_mode='zeros',
499
+ weight_norm_type='none', weight_norm_params=None,
500
+ activation_norm_type='none', activation_norm_params=None,
501
+ is_hyper_conv=False, is_hyper_norm=False,
502
+ nonlinearity='none', inplace_nonlinearity=False,
503
+ apply_noise=False, order='CNA'):
504
+ super().__init__(in_channels, out_channels, kernel_size, stride,
505
+ padding, dilation, groups, bias, padding_mode,
506
+ weight_norm_type, weight_norm_params,
507
+ activation_norm_type, activation_norm_params,
508
+ nonlinearity, inplace_nonlinearity, apply_noise,
509
+ is_hyper_conv, is_hyper_norm, order, 2)
510
+
511
+
512
+ class HyperConv2d(nn.Module):
513
+ r"""Hyper Conv2d initialization.
514
+
515
+ Args:
516
+ in_channels (int): Dummy parameter.
517
+ out_channels (int): Dummy parameter.
518
+ kernel_size (int or tuple): Dummy parameter.
519
+ stride (int or tuple, optional, default=1):
520
+ Stride of the convolution. Default: 1
521
+ padding (int or tuple, optional, default=0):
522
+ Zero-padding added to both sides of the input.
523
+ padding_mode (string, optional, default='zeros'):
524
+ ``'zeros'``, ``'reflect'``, ``'replicate'``
525
+ or ``'circular'``.
526
+ dilation (int or tuple, optional, default=1):
527
+ Spacing between kernel elements.
528
+ groups (int, optional, default=1): Number of blocked connections
529
+ from input channels to output channels.
530
+ bias (bool, optional, default=True): If ``True``,
531
+ adds a learnable bias to the output.
532
+ """
533
+
534
+ def __init__(self, in_channels=0, out_channels=0, kernel_size=3,
535
+ stride=1, padding=1, dilation=1, groups=1, bias=True,
536
+ padding_mode='zeros'):
537
+ super().__init__()
538
+ self.stride = stride
539
+ self.padding = padding
540
+ self.dilation = dilation
541
+ self.groups = groups
542
+ self.use_bias = bias
543
+ self.padding_mode = padding_mode
544
+ self.conditional = True
545
+
546
+ def forward(self, x, *args, conv_weights=(None, None), **kwargs):
547
+ r"""Hyper Conv2d forward. Convolve x using the provided weight and bias.
548
+
549
+ Args:
550
+ x (N x C x H x W tensor): Input tensor.
551
+ conv_weights (N x C2 x C1 x k x k tensor or list of tensors):
552
+ Convolution weights or [weight, bias].
553
+ Returns:
554
+ y (N x C2 x H x W tensor): Output tensor.
555
+ """
556
+ if conv_weights is None:
557
+ conv_weight, conv_bias = None, None
558
+ elif isinstance(conv_weights, torch.Tensor):
559
+ conv_weight, conv_bias = conv_weights, None
560
+ else:
561
+ conv_weight, conv_bias = conv_weights
562
+
563
+ if conv_weight is None:
564
+ return x
565
+ if conv_bias is None:
566
+ if self.use_bias:
567
+ raise ValueError('bias not provided but set to true during '
568
+ 'initialization')
569
+ conv_bias = [None] * x.size(0)
570
+ if self.padding_mode != 'zeros':
571
+ x = F.pad(x, [self.padding] * 4, mode=self.padding_mode)
572
+ padding = 0
573
+ else:
574
+ padding = self.padding
575
+
576
+ y = None
577
+ for i in range(x.size(0)):
578
+ if self.stride >= 1:
579
+ yi = F.conv2d(x[i: i + 1],
580
+ weight=conv_weight[i], bias=conv_bias[i],
581
+ stride=self.stride, padding=padding,
582
+ dilation=self.dilation, groups=self.groups)
583
+ else:
584
+ yi = F.conv_transpose2d(x[i: i + 1], weight=conv_weight[i],
585
+ bias=conv_bias[i], padding=self.padding,
586
+ stride=int(1 / self.stride),
587
+ dilation=self.dilation,
588
+ output_padding=self.padding,
589
+ groups=self.groups)
590
+ y = torch.cat([y, yi]) if y is not None else yi
591
+ return y
592
+
593
+
594
+ class _BasePartialConvBlock(_BaseConvBlock):
595
+ r"""An abstract wrapper class that wraps a partial convolutional layer
596
+ with normalization and nonlinearity.
597
+ """
598
+
599
+ def __init__(self, in_channels, out_channels, kernel_size, stride,
600
+ padding, dilation, groups, bias, padding_mode,
601
+ weight_norm_type, weight_norm_params,
602
+ activation_norm_type, activation_norm_params,
603
+ nonlinearity, inplace_nonlinearity,
604
+ multi_channel, return_mask,
605
+ apply_noise, order, input_dim):
606
+ self.multi_channel = multi_channel
607
+ self.return_mask = return_mask
608
+ self.partial_conv = True
609
+ super().__init__(in_channels, out_channels, kernel_size, stride,
610
+ padding, dilation, groups, bias, padding_mode,
611
+ weight_norm_type, weight_norm_params,
612
+ activation_norm_type, activation_norm_params,
613
+ nonlinearity, inplace_nonlinearity, apply_noise,
614
+ order, input_dim)
615
+
616
+ def _get_conv_layer(self, in_channels, out_channels, kernel_size, stride,
617
+ padding, dilation, groups, bias, padding_mode,
618
+ input_dim):
619
+ if input_dim == 2:
620
+ layer_type = PartialConv2d
621
+ elif input_dim == 3:
622
+ layer_type = PartialConv3d
623
+ else:
624
+ raise ValueError('Partial conv only supports 2D and 3D conv now.')
625
+ layer = layer_type(
626
+ in_channels, out_channels, kernel_size, stride, padding,
627
+ dilation, groups, bias, padding_mode,
628
+ multi_channel=self.multi_channel, return_mask=self.return_mask)
629
+ return layer
630
+
631
+ def forward(self, x, *cond_inputs, mask_in=None, **kw_cond_inputs):
632
+ r"""
633
+
634
+ Args:
635
+ x (tensor): Input tensor.
636
+ cond_inputs (list of tensors) : Conditional input tensors.
637
+ mask_in (tensor, optional, default=``None``) If not ``None``,
638
+ it masks the valid input region.
639
+ kw_cond_inputs (dict) : Keyword conditional inputs.
640
+ Returns:
641
+ (tuple):
642
+ - x (tensor): Output tensor.
643
+ - mask_out (tensor, optional): Masks the valid output region.
644
+ """
645
+ mask_out = None
646
+ for layer in self.layers.values():
647
+ if getattr(layer, 'conditional', False):
648
+ x = layer(x, *cond_inputs, **kw_cond_inputs)
649
+ elif getattr(layer, 'partial_conv', False):
650
+ x = layer(x, mask_in=mask_in, **kw_cond_inputs)
651
+ if type(x) == tuple:
652
+ x, mask_out = x
653
+ else:
654
+ x = layer(x)
655
+
656
+ if mask_out is not None:
657
+ return x, mask_out
658
+ return x
659
+
660
+
661
+ class PartialConv2dBlock(_BasePartialConvBlock):
662
+ r"""A Wrapper class that wraps ``PartialConv2d`` with normalization and
663
+ nonlinearity.
664
+
665
+ Args:
666
+ in_channels (int): Number of channels in the input tensor.
667
+ out_channels (int): Number of channels in the output tensor.
668
+ kernel_size (int or tuple): Size of the convolving kernel.
669
+ stride (int or tuple, optional, default=1):
670
+ Stride of the convolution.
671
+ padding (int or tuple, optional, default=0):
672
+ Zero-padding added to both sides of the input.
673
+ dilation (int or tuple, optional, default=1):
674
+ Spacing between kernel elements.
675
+ groups (int, optional, default=1): Number of blocked connections
676
+ from input channels to output channels.
677
+ bias (bool, optional, default=True):
678
+ If ``True``, adds a learnable bias to the output.
679
+ padding_mode (string, optional, default='zeros'): Type of padding:
680
+ ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
681
+ weight_norm_type (str, optional, default='none'):
682
+ Type of weight normalization.
683
+ ``'none'``, ``'spectral'``, ``'weight'``
684
+ or ``'weight_demod'``.
685
+ weight_norm_params (obj, optional, default=None):
686
+ Parameters of weight normalization.
687
+ If not ``None``, ``weight_norm_params.__dict__`` will be used as
688
+ keyword arguments when initializing weight normalization.
689
+ activation_norm_type (str, optional, default='none'):
690
+ Type of activation normalization.
691
+ ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
692
+ ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
693
+ ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
694
+ activation_norm_params (obj, optional, default=None):
695
+ Parameters of activation normalization.
696
+ If not ``None``, ``activation_norm_params.__dict__`` will be used as
697
+ keyword arguments when initializing activation normalization.
698
+ nonlinearity (str, optional, default='none'):
699
+ Type of nonlinear activation function.
700
+ ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
701
+ ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
702
+ inplace_nonlinearity (bool, optional, default=False): If ``True``,
703
+ set ``inplace=True`` when initializing the nonlinearity layer.
704
+ apply_noise (bool, optional, default=False): If ``True``, adds
705
+ Gaussian noise with learnable magnitude to the convolution output.
706
+ order (str, optional, default='CNA'): Order of operations.
707
+ ``'C'``: convolution,
708
+ ``'N'``: normalization,
709
+ ``'A'``: nonlinear activation.
710
+ For example, a block initialized with ``order='CNA'`` will
711
+ do convolution first, then normalization, then nonlinearity.
712
+ multi_channel (bool, optional, default=False): If ``True``, use
713
+ different masks for different channels.
714
+ return_mask (bool, optional, default=True): If ``True``, the
715
+ forward call also returns a new mask.
716
+ """
717
+
718
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
719
+ padding=0, dilation=1, groups=1, bias=True,
720
+ padding_mode='zeros',
721
+ weight_norm_type='none', weight_norm_params=None,
722
+ activation_norm_type='none', activation_norm_params=None,
723
+ nonlinearity='none', inplace_nonlinearity=False,
724
+ multi_channel=False, return_mask=True,
725
+ apply_noise=False, order='CNA'):
726
+ super().__init__(in_channels, out_channels, kernel_size, stride,
727
+ padding, dilation, groups, bias, padding_mode,
728
+ weight_norm_type, weight_norm_params,
729
+ activation_norm_type, activation_norm_params,
730
+ nonlinearity, inplace_nonlinearity,
731
+ multi_channel, return_mask, apply_noise, order, 2)
732
+
733
+
734
+ class PartialConv3dBlock(_BasePartialConvBlock):
735
+ r"""A Wrapper class that wraps ``PartialConv3d`` with normalization and
736
+ nonlinearity.
737
+
738
+ Args:
739
+ in_channels (int): Number of channels in the input tensor.
740
+ out_channels (int): Number of channels in the output tensor.
741
+ kernel_size (int or tuple): Size of the convolving kernel.
742
+ stride (int or tuple, optional, default=1):
743
+ Stride of the convolution.
744
+ padding (int or tuple, optional, default=0):
745
+ Zero-padding added to both sides of the input.
746
+ dilation (int or tuple, optional, default=1):
747
+ Spacing between kernel elements.
748
+ groups (int, optional, default=1): Number of blocked connections
749
+ from input channels to output channels.
750
+ bias (bool, optional, default=True):
751
+ If ``True``, adds a learnable bias to the output.
752
+ padding_mode (string, optional, default='zeros'): Type of padding:
753
+ ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
754
+ weight_norm_type (str, optional, default='none'):
755
+ Type of weight normalization.
756
+ ``'none'``, ``'spectral'``, ``'weight'``
757
+ or ``'weight_demod'``.
758
+ weight_norm_params (obj, optional, default=None):
759
+ Parameters of weight normalization.
760
+ If not ``None``, ``weight_norm_params.__dict__`` will be used as
761
+ keyword arguments when initializing weight normalization.
762
+ activation_norm_type (str, optional, default='none'):
763
+ Type of activation normalization.
764
+ ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
765
+ ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
766
+ ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
767
+ activation_norm_params (obj, optional, default=None):
768
+ Parameters of activation normalization.
769
+ If not ``None``, ``activation_norm_params.__dict__`` will be used as
770
+ keyword arguments when initializing activation normalization.
771
+ nonlinearity (str, optional, default='none'):
772
+ Type of nonlinear activation function.
773
+ ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
774
+ ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
775
+ inplace_nonlinearity (bool, optional, default=False): If ``True``,
776
+ set ``inplace=True`` when initializing the nonlinearity layer.
777
+ apply_noise (bool, optional, default=False): If ``True``, adds
778
+ Gaussian noise with learnable magnitude to the convolution output.
779
+ order (str, optional, default='CNA'): Order of operations.
780
+ ``'C'``: convolution,
781
+ ``'N'``: normalization,
782
+ ``'A'``: nonlinear activation.
783
+ For example, a block initialized with ``order='CNA'`` will
784
+ do convolution first, then normalization, then nonlinearity.
785
+ multi_channel (bool, optional, default=False): If ``True``, use
786
+ different masks for different channels.
787
+ return_mask (bool, optional, default=True): If ``True``, the
788
+ forward call also returns a new mask.
789
+ """
790
+
791
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
792
+ padding=0, dilation=1, groups=1, bias=True,
793
+ padding_mode='zeros',
794
+ weight_norm_type='none', weight_norm_params=None,
795
+ activation_norm_type='none', activation_norm_params=None,
796
+ nonlinearity='none', inplace_nonlinearity=False,
797
+ multi_channel=False, return_mask=True,
798
+ apply_noise=False, order='CNA'):
799
+ super().__init__(in_channels, out_channels, kernel_size, stride,
800
+ padding, dilation, groups, bias, padding_mode,
801
+ weight_norm_type, weight_norm_params,
802
+ activation_norm_type, activation_norm_params,
803
+ nonlinearity, inplace_nonlinearity,
804
+ multi_channel, return_mask, apply_noise, order, 3)
805
+
806
+
807
+ class _MultiOutBaseConvBlock(_BaseConvBlock):
808
+ r"""An abstract wrapper class that wraps a hyper convolutional layer with
809
+ normalization and nonlinearity. It can return multiple outputs, if some
810
+ layers in the block return more than one output.
811
+ """
812
+
813
+ def __init__(self, in_channels, out_channels, kernel_size, stride,
814
+ padding, dilation, groups, bias,
815
+ padding_mode,
816
+ weight_norm_type, weight_norm_params,
817
+ activation_norm_type, activation_norm_params,
818
+ nonlinearity, inplace_nonlinearity,
819
+ apply_noise, order, input_dim):
820
+ super().__init__(in_channels, out_channels, kernel_size, stride,
821
+ padding, dilation, groups, bias, padding_mode,
822
+ weight_norm_type, weight_norm_params,
823
+ activation_norm_type, activation_norm_params,
824
+ nonlinearity, inplace_nonlinearity,
825
+ apply_noise, order, input_dim)
826
+ self.multiple_outputs = True
827
+
828
+ def forward(self, x, *cond_inputs, **kw_cond_inputs):
829
+ r"""
830
+
831
+ Args:
832
+ x (tensor): Input tensor.
833
+ cond_inputs (list of tensors) : Conditional input tensors.
834
+ kw_cond_inputs (dict) : Keyword conditional inputs.
835
+ Returns:
836
+ (tuple):
837
+ - x (tensor): Main output tensor.
838
+ - other_outputs (list of tensors): Other output tensors.
839
+ """
840
+ other_outputs = []
841
+ for layer in self.layers.values():
842
+ if getattr(layer, 'conditional', False):
843
+ x = layer(x, *cond_inputs, **kw_cond_inputs)
844
+ if getattr(layer, 'multiple_outputs', False):
845
+ x, other_output = layer(x)
846
+ other_outputs.append(other_output)
847
+ else:
848
+ x = layer(x)
849
+ return (x, *other_outputs)
850
+
851
+
852
+ class MultiOutConv2dBlock(_MultiOutBaseConvBlock):
853
+ r"""A Wrapper class that wraps ``torch.nn.Conv2d`` with normalization and
854
+ nonlinearity. It can return multiple outputs, if some layers in the block
855
+ return more than one output.
856
+
857
+ Args:
858
+ in_channels (int): Number of channels in the input tensor.
859
+ out_channels (int): Number of channels in the output tensor.
860
+ kernel_size (int or tuple): Size of the convolving kernel.
861
+ stride (int or tuple, optional, default=1):
862
+ Stride of the convolution.
863
+ padding (int or tuple, optional, default=0):
864
+ Zero-padding added to both sides of the input.
865
+ dilation (int or tuple, optional, default=1):
866
+ Spacing between kernel elements.
867
+ groups (int, optional, default=1): Number of blocked connections
868
+ from input channels to output channels.
869
+ bias (bool, optional, default=True):
870
+ If ``True``, adds a learnable bias to the output.
871
+ padding_mode (string, optional, default='zeros'): Type of padding:
872
+ ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
873
+ weight_norm_type (str, optional, default='none'):
874
+ Type of weight normalization.
875
+ ``'none'``, ``'spectral'``, ``'weight'``
876
+ or ``'weight_demod'``.
877
+ weight_norm_params (obj, optional, default=None):
878
+ Parameters of weight normalization.
879
+ If not ``None``, ``weight_norm_params.__dict__`` will be used as
880
+ keyword arguments when initializing weight normalization.
881
+ activation_norm_type (str, optional, default='none'):
882
+ Type of activation normalization.
883
+ ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
884
+ ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
885
+ ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
886
+ activation_norm_params (obj, optional, default=None):
887
+ Parameters of activation normalization.
888
+ If not ``None``, ``activation_norm_params.__dict__`` will be used as
889
+ keyword arguments when initializing activation normalization.
890
+ nonlinearity (str, optional, default='none'):
891
+ Type of nonlinear activation function.
892
+ ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
893
+ ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
894
+ inplace_nonlinearity (bool, optional, default=False): If ``True``,
895
+ set ``inplace=True`` when initializing the nonlinearity layer.
896
+ apply_noise (bool, optional, default=False): If ``True``, adds
897
+ Gaussian noise with learnable magnitude to the convolution output.
898
+ order (str, optional, default='CNA'): Order of operations.
899
+ ``'C'``: convolution,
900
+ ``'N'``: normalization,
901
+ ``'A'``: nonlinear activation.
902
+ For example, a block initialized with ``order='CNA'`` will
903
+ do convolution first, then normalization, then nonlinearity.
904
+ """
905
+
906
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
907
+ padding=0, dilation=1, groups=1, bias=True,
908
+ padding_mode='zeros',
909
+ weight_norm_type='none', weight_norm_params=None,
910
+ activation_norm_type='none', activation_norm_params=None,
911
+ nonlinearity='none', inplace_nonlinearity=False,
912
+ apply_noise=False, order='CNA'):
913
+ super().__init__(in_channels, out_channels, kernel_size, stride,
914
+ padding, dilation, groups, bias, padding_mode,
915
+ weight_norm_type, weight_norm_params,
916
+ activation_norm_type, activation_norm_params,
917
+ nonlinearity, inplace_nonlinearity,
918
+ apply_noise, order, 2)
919
+
920
+
921
+ ###############################################################################
922
+ # BSD 3-Clause License
923
+ #
924
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
925
+ #
926
+ # Author & Contact: Guilin Liu ([email protected])
927
+ ###############################################################################
928
+ class PartialConv2d(nn.Conv2d):
929
+ r"""Partial 2D convolution in
930
+ "Image inpainting for irregular holes using partial convolutions."
931
+ Liu et al., ECCV 2018
932
+ """
933
+
934
+ def __init__(self, *args, multi_channel=False, return_mask=True, **kwargs):
935
+ # whether the mask is multi-channel or not
936
+ self.multi_channel = multi_channel
937
+ self.return_mask = return_mask
938
+ super(PartialConv2d, self).__init__(*args, **kwargs)
939
+
940
+ if self.multi_channel:
941
+ self.weight_maskUpdater = torch.ones(self.out_channels,
942
+ self.in_channels,
943
+ self.kernel_size[0],
944
+ self.kernel_size[1])
945
+ else:
946
+ self.weight_maskUpdater = torch.ones(1, 1, self.kernel_size[0],
947
+ self.kernel_size[1])
948
+
949
+ shape = self.weight_maskUpdater.shape
950
+ self.slide_winsize = shape[1] * shape[2] * shape[3]
951
+
952
+ self.last_size = (None, None, None, None)
953
+ self.update_mask = None
954
+ self.mask_ratio = None
955
+ self.partial_conv = True
956
+
957
+ def forward(self, x, mask_in=None):
958
+ r"""
959
+
960
+ Args:
961
+ x (tensor): Input tensor.
962
+ mask_in (tensor, optional, default=``None``) If not ``None``,
963
+ it masks the valid input region.
964
+ """
965
+ assert len(x.shape) == 4
966
+ if mask_in is not None or self.last_size != tuple(x.shape):
967
+ self.last_size = tuple(x.shape)
968
+
969
+ with torch.no_grad():
970
+ if self.weight_maskUpdater.type() != x.type():
971
+ self.weight_maskUpdater = self.weight_maskUpdater.to(x)
972
+
973
+ if mask_in is None:
974
+ # If mask is not provided, create a mask.
975
+ if self.multi_channel:
976
+ mask = torch.ones(x.data.shape[0],
977
+ x.data.shape[1],
978
+ x.data.shape[2],
979
+ x.data.shape[3]).to(x)
980
+ else:
981
+ mask = torch.ones(1, 1, x.data.shape[2],
982
+ x.data.shape[3]).to(x)
983
+ else:
984
+ mask = mask_in
985
+
986
+ self.update_mask = F.conv2d(mask, self.weight_maskUpdater,
987
+ bias=None, stride=self.stride,
988
+ padding=self.padding,
989
+ dilation=self.dilation, groups=1)
990
+
991
+ # For mixed precision training, eps from 1e-8 to 1e-6.
992
+ eps = 1e-6
993
+ self.mask_ratio = self.slide_winsize / (self.update_mask + eps)
994
+ self.update_mask = torch.clamp(self.update_mask, 0, 1)
995
+ self.mask_ratio = torch.mul(self.mask_ratio, self.update_mask)
996
+
997
+ raw_out = super(PartialConv2d, self).forward(
998
+ torch.mul(x, mask) if mask_in is not None else x)
999
+
1000
+ if self.bias is not None:
1001
+ bias_view = self.bias.view(1, self.out_channels, 1, 1)
1002
+ output = torch.mul(raw_out - bias_view, self.mask_ratio) + bias_view
1003
+ output = torch.mul(output, self.update_mask)
1004
+ else:
1005
+ output = torch.mul(raw_out, self.mask_ratio)
1006
+
1007
+ if self.return_mask:
1008
+ return output, self.update_mask
1009
+ else:
1010
+ return output
1011
+
1012
+
1013
+ class PartialConv3d(nn.Conv3d):
1014
+ r"""Partial 3D convolution in
1015
+ "Image inpainting for irregular holes using partial convolutions."
1016
+ Liu et al., ECCV 2018
1017
+ """
1018
+
1019
+ def __init__(self, *args, multi_channel=False, return_mask=True, **kwargs):
1020
+ # whether the mask is multi-channel or not
1021
+ self.multi_channel = multi_channel
1022
+ self.return_mask = return_mask
1023
+ super(PartialConv3d, self).__init__(*args, **kwargs)
1024
+
1025
+ if self.multi_channel:
1026
+ self.weight_maskUpdater = \
1027
+ torch.ones(self.out_channels, self.in_channels,
1028
+ self.kernel_size[0], self.kernel_size[1],
1029
+ self.kernel_size[2])
1030
+ else:
1031
+ self.weight_maskUpdater = torch.ones(1, 1, self.kernel_size[0],
1032
+ self.kernel_size[1],
1033
+ self.kernel_size[2])
1034
+ self.weight_maskUpdater = self.weight_maskUpdater.to('cuda')
1035
+
1036
+ shape = self.weight_maskUpdater.shape
1037
+ self.slide_winsize = shape[1] * shape[2] * shape[3] * shape[4]
1038
+ self.partial_conv = True
1039
+
1040
+ def forward(self, x, mask_in=None):
1041
+ r"""
1042
+
1043
+ Args:
1044
+ x (tensor): Input tensor.
1045
+ mask_in (tensor, optional, default=``None``) If not ``None``, it
1046
+ masks the valid input region.
1047
+ """
1048
+ assert len(x.shape) == 5
1049
+
1050
+ with torch.no_grad():
1051
+ mask = mask_in
1052
+ update_mask = F.conv3d(mask, self.weight_maskUpdater, bias=None,
1053
+ stride=self.stride, padding=self.padding,
1054
+ dilation=self.dilation, groups=1)
1055
+
1056
+ mask_ratio = self.slide_winsize / (update_mask + 1e-8)
1057
+ update_mask = torch.clamp(update_mask, 0, 1)
1058
+ mask_ratio = torch.mul(mask_ratio, update_mask)
1059
+
1060
+ raw_out = super(PartialConv3d, self).forward(torch.mul(x, mask_in))
1061
+
1062
+ if self.bias is not None:
1063
+ bias_view = self.bias.view(1, self.out_channels, 1, 1, 1)
1064
+ output = torch.mul(raw_out - bias_view, mask_ratio) + bias_view
1065
+ if mask_in is not None:
1066
+ output = torch.mul(output, update_mask)
1067
+ else:
1068
+ output = torch.mul(raw_out, mask_ratio)
1069
+
1070
+ if self.return_mask:
1071
+ return output, update_mask
1072
+ else:
1073
+ return output
models/layers/misc.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2020 NVIDIA Corporation. All rights reserved.
2
+ #
3
+ # This work is made available under the Nvidia Source Code License-NC.
4
+ # To view a copy of this license, check out LICENSE.md
5
+ import torch
6
+ from torch import nn
7
+
8
+
9
+ class ApplyNoise(nn.Module):
10
+ r"""Add Gaussian noise to the input tensor."""
11
+
12
+ def __init__(self):
13
+ super().__init__()
14
+ # scale of the noise
15
+ self.weight = nn.Parameter(torch.zeros(1))
16
+
17
+ def forward(self, x, noise=None):
18
+ r"""
19
+
20
+ Args:
21
+ x (tensor): Input tensor.
22
+ noise (tensor, optional, default=``None``) : Noise tensor to be
23
+ added to the input.
24
+ """
25
+ if noise is None:
26
+ sz = x.size()
27
+ noise = x.new_empty(sz[0], 1, *sz[2:]).normal_()
28
+
29
+ return x + self.weight * noise
30
+
31
+
32
+ class PartialSequential(nn.Sequential):
33
+ r"""Sequential block for partial convolutions."""
34
+ def __init__(self, *modules):
35
+ super(PartialSequential, self).__init__(*modules)
36
+
37
+ def forward(self, x):
38
+ r"""
39
+
40
+ Args:
41
+ x (tensor): Input tensor.
42
+ """
43
+ act = x[:, :-1]
44
+ mask = x[:, -1].unsqueeze(1)
45
+ for module in self:
46
+ act, mask = module(act, mask_in=mask)
47
+ return act
models/layers/non_local.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2020 NVIDIA Corporation. All rights reserved.
2
+ #
3
+ # This work is made available under the Nvidia Source Code License-NC.
4
+ # To view a copy of this license, check out LICENSE.md
5
+ from functools import partial
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ from imaginaire.layers import Conv2dBlock
11
+
12
+
13
+ class NonLocal2dBlock(nn.Module):
14
+ r"""Self attention Layer
15
+
16
+ Args:
17
+ in_channels (int): Number of channels in the input tensor.
18
+ scale (bool, optional, default=True): If ``True``, scale the
19
+ output by a learnable parameter.
20
+ clamp (bool, optional, default=``False``): If ``True``, clamp the
21
+ scaling parameter to (-1, 1).
22
+ weight_norm_type (str, optional, default='none'):
23
+ Type of weight normalization.
24
+ ``'none'``, ``'spectral'``, ``'weight'``
25
+ or ``'weight_demod'``.
26
+ """
27
+
28
+ def __init__(self,
29
+ in_channels,
30
+ scale=True,
31
+ clamp=False,
32
+ weight_norm_type='none'):
33
+ super(NonLocal2dBlock, self).__init__()
34
+ self.clamp = clamp
35
+ self.gamma = nn.Parameter(torch.zeros(1)) if scale else 1.0
36
+ self.in_channels = in_channels
37
+ base_conv2d_block = partial(Conv2dBlock,
38
+ kernel_size=1,
39
+ stride=1,
40
+ padding=0,
41
+ weight_norm_type=weight_norm_type)
42
+ self.theta = base_conv2d_block(in_channels, in_channels // 8)
43
+ self.phi = base_conv2d_block(in_channels, in_channels // 8)
44
+ self.g = base_conv2d_block(in_channels, in_channels // 2)
45
+ self.out_conv = base_conv2d_block(in_channels // 2, in_channels)
46
+ self.softmax = nn.Softmax(dim=-1)
47
+ self.max_pool = nn.MaxPool2d(2)
48
+
49
+ def forward(self, x):
50
+ r"""
51
+
52
+ Args:
53
+ x (tensor) : input feature maps (B X C X W X H)
54
+ Returns:
55
+ (tuple):
56
+ - out (tensor) : self attention value + input feature
57
+ - attention (tensor): B x N x N (N is Width*Height)
58
+ """
59
+ n, c, h, w = x.size()
60
+ theta = self.theta(x).view(n, -1, h * w).permute(0, 2, 1)
61
+
62
+ phi = self.phi(x)
63
+ phi = self.max_pool(phi).view(n, -1, h * w // 4)
64
+
65
+ energy = torch.bmm(theta, phi)
66
+ attention = self.softmax(energy)
67
+
68
+ g = self.g(x)
69
+ g = self.max_pool(g).view(n, -1, h * w // 4)
70
+
71
+ out = torch.bmm(g, attention.permute(0, 2, 1))
72
+ out = out.view(n, c // 2, h, w)
73
+ out = self.out_conv(out)
74
+
75
+ if self.clamp:
76
+ out = self.gamma.clamp(-1, 1) * out + x
77
+ else:
78
+ out = self.gamma * out + x
79
+ return out
models/layers/nonlinearity.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2020 NVIDIA Corporation. All rights reserved.
2
+ #
3
+ # This work is made available under the Nvidia Source Code License-NC.
4
+ # To view a copy of this license, check out LICENSE.md
5
+ from torch import nn
6
+
7
+
8
+ def get_nonlinearity_layer(nonlinearity_type, inplace):
9
+ r"""Return a nonlinearity layer.
10
+
11
+ Args:
12
+ nonlinearity_type (str):
13
+ Type of nonlinear activation function.
14
+ ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
15
+ ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
16
+ inplace (bool): If ``True``, set ``inplace=True`` when initializing
17
+ the nonlinearity layer.
18
+ """
19
+ if nonlinearity_type == 'relu':
20
+ nonlinearity = nn.ReLU(inplace=inplace)
21
+ elif nonlinearity_type == 'leakyrelu':
22
+ nonlinearity = nn.LeakyReLU(0.2, inplace=inplace)
23
+ elif nonlinearity_type == 'prelu':
24
+ nonlinearity = nn.PReLU()
25
+ elif nonlinearity_type == 'tanh':
26
+ nonlinearity = nn.Tanh()
27
+ elif nonlinearity_type == 'sigmoid':
28
+ nonlinearity = nn.Sigmoid()
29
+ elif nonlinearity_type.startswith('softmax'):
30
+ dim = nonlinearity_type.split(',')[1] if ',' in nonlinearity_type else 1
31
+ nonlinearity = nn.Softmax(dim=int(dim))
32
+ elif nonlinearity_type == 'none' or nonlinearity_type == '':
33
+ nonlinearity = None
34
+ else:
35
+ raise ValueError('Nonlinearity %s is not recognized' %
36
+ nonlinearity_type)
37
+ return nonlinearity
models/layers/residual.py ADDED
@@ -0,0 +1,1235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2020 NVIDIA Corporation. All rights reserved.
2
+ #
3
+ # This work is made available under the Nvidia Source Code License-NC.
4
+ # To view a copy of this license, check out LICENSE.md
5
+ import functools
6
+
7
+ from torch import nn
8
+ from torch.nn import Upsample as NearestUpsample
9
+ from torch.utils.checkpoint import checkpoint
10
+
11
+ from .conv import (Conv1dBlock, Conv2dBlock, Conv3dBlock, HyperConv2dBlock,
12
+ LinearBlock, MultiOutConv2dBlock, PartialConv2dBlock,
13
+ PartialConv3dBlock)
14
+
15
+
16
+ class _BaseResBlock(nn.Module):
17
+ r"""An abstract class for residual blocks.
18
+ """
19
+
20
+ def __init__(self, in_channels, out_channels, kernel_size,
21
+ padding, dilation, groups, bias, padding_mode,
22
+ weight_norm_type, weight_norm_params,
23
+ activation_norm_type, activation_norm_params,
24
+ skip_activation_norm, skip_nonlinearity,
25
+ nonlinearity, inplace_nonlinearity, apply_noise,
26
+ hidden_channels_equal_out_channels,
27
+ order, block, learn_shortcut):
28
+ super().__init__()
29
+ if order == 'pre_act':
30
+ order = 'NACNAC'
31
+ if isinstance(bias, bool):
32
+ # The bias for conv_block_0, conv_block_1, and conv_block_s.
33
+ biases = [bias, bias, bias]
34
+ elif isinstance(bias, list):
35
+ if len(bias) == 3:
36
+ biases = bias
37
+ else:
38
+ raise ValueError('Bias list must be 3.')
39
+ else:
40
+ raise ValueError('Bias must be either an integer or s list.')
41
+ self.learn_shortcut = (in_channels != out_channels) or learn_shortcut
42
+ if len(order) > 6 or len(order) < 5:
43
+ raise ValueError('order must be either 5 or 6 characters')
44
+ if hidden_channels_equal_out_channels:
45
+ hidden_channels = out_channels
46
+ else:
47
+ hidden_channels = min(in_channels, out_channels)
48
+
49
+ # Parameters that are specific for convolutions.
50
+ conv_main_params = {}
51
+ conv_skip_params = {}
52
+ if block != LinearBlock:
53
+ conv_base_params = dict(stride=1, dilation=dilation,
54
+ groups=groups, padding_mode=padding_mode)
55
+ conv_main_params.update(conv_base_params)
56
+ conv_main_params.update(
57
+ dict(kernel_size=kernel_size,
58
+ activation_norm_type=activation_norm_type,
59
+ activation_norm_params=activation_norm_params,
60
+ padding=padding))
61
+ conv_skip_params.update(conv_base_params)
62
+ conv_skip_params.update(dict(kernel_size=1))
63
+ if skip_activation_norm:
64
+ conv_skip_params.update(
65
+ dict(activation_norm_type=activation_norm_type,
66
+ activation_norm_params=activation_norm_params))
67
+
68
+ # Other parameters.
69
+ other_params = dict(weight_norm_type=weight_norm_type,
70
+ weight_norm_params=weight_norm_params,
71
+ apply_noise=apply_noise)
72
+
73
+ # Residual branch.
74
+ if order.find('A') < order.find('C') and \
75
+ (activation_norm_type == '' or activation_norm_type == 'none'):
76
+ # Nonlinearity is the first operation in the residual path.
77
+ # In-place nonlinearity will modify the input variable and cause
78
+ # backward error.
79
+ first_inplace = False
80
+ else:
81
+ first_inplace = inplace_nonlinearity
82
+ self.conv_block_0 = block(in_channels, hidden_channels,
83
+ bias=biases[0],
84
+ nonlinearity=nonlinearity,
85
+ order=order[0:3],
86
+ inplace_nonlinearity=first_inplace,
87
+ **conv_main_params,
88
+ **other_params)
89
+ self.conv_block_1 = block(hidden_channels, out_channels,
90
+ bias=biases[1],
91
+ nonlinearity=nonlinearity,
92
+ order=order[3:],
93
+ inplace_nonlinearity=inplace_nonlinearity,
94
+ **conv_main_params,
95
+ **other_params)
96
+
97
+ # Shortcut branch.
98
+ if self.learn_shortcut:
99
+ if skip_nonlinearity:
100
+ skip_nonlinearity_type = nonlinearity
101
+ else:
102
+ skip_nonlinearity_type = ''
103
+ self.conv_block_s = block(in_channels, out_channels,
104
+ bias=biases[2],
105
+ nonlinearity=skip_nonlinearity_type,
106
+ order=order[0:3],
107
+ **conv_skip_params,
108
+ **other_params)
109
+
110
+ # Whether this block expects conditional inputs.
111
+ self.conditional = \
112
+ getattr(self.conv_block_0, 'conditional', False) or \
113
+ getattr(self.conv_block_1, 'conditional', False)
114
+
115
+ def conv_blocks(self, x, *cond_inputs, **kw_cond_inputs):
116
+ r"""Returns the output of the residual branch.
117
+
118
+ Args:
119
+ x (tensor): Input tensor.
120
+ cond_inputs (list of tensors) : Conditional input tensors.
121
+ kw_cond_inputs (dict) : Keyword conditional inputs.
122
+ Returns:
123
+ dx (tensor): Output tensor.
124
+ """
125
+ dx = self.conv_block_0(x, *cond_inputs, **kw_cond_inputs)
126
+ dx = self.conv_block_1(dx, *cond_inputs, **kw_cond_inputs)
127
+ return dx
128
+
129
+ def forward(self, x, *cond_inputs, do_checkpoint=False, **kw_cond_inputs):
130
+ r"""
131
+
132
+ Args:
133
+ x (tensor): Input tensor.
134
+ cond_inputs (list of tensors) : Conditional input tensors.
135
+ do_checkpoint (bool, optional, default=``False``) If ``True``,
136
+ trade compute for memory by checkpointing the model.
137
+ kw_cond_inputs (dict) : Keyword conditional inputs.
138
+ Returns:
139
+ output (tensor): Output tensor.
140
+ """
141
+ if do_checkpoint:
142
+ dx = checkpoint(self.conv_blocks, x, *cond_inputs, **kw_cond_inputs)
143
+ else:
144
+ dx = self.conv_blocks(x, *cond_inputs, **kw_cond_inputs)
145
+
146
+ if self.learn_shortcut:
147
+ x_shortcut = self.conv_block_s(x, *cond_inputs, **kw_cond_inputs)
148
+ else:
149
+ x_shortcut = x
150
+ output = x_shortcut + dx
151
+ return output
152
+
153
+
154
+ class ResLinearBlock(_BaseResBlock):
155
+ r"""Residual block with full-connected layers.
156
+
157
+ Args:
158
+ in_channels (int) : Number of channels in the input tensor.
159
+ out_channels (int) : Number of channels in the output tensor.
160
+ weight_norm_type (str, optional, default='none'):
161
+ Type of weight normalization.
162
+ ``'none'``, ``'spectral'``, ``'weight'``
163
+ or ``'weight_demod'``.
164
+ weight_norm_params (obj, optional, default=None):
165
+ Parameters of weight normalization.
166
+ If not ``None``, ``weight_norm_params.__dict__`` will be used as
167
+ keyword arguments when initializing weight normalization.
168
+ activation_norm_type (str, optional, default='none'):
169
+ Type of activation normalization.
170
+ ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
171
+ ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
172
+ ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
173
+ activation_norm_params (obj, optional, default=None):
174
+ Parameters of activation normalization.
175
+ If not ``None``, ``activation_norm_params.__dict__`` will be used as
176
+ keyword arguments when initializing activation normalization.
177
+ skip_activation_norm (bool, optional, default=True): If ``True`` and
178
+ ``learn_shortcut`` is also ``True``, applies activation norm to the
179
+ learned shortcut connection.
180
+ skip_nonlinearity (bool, optional, default=True): If ``True`` and
181
+ ``learn_shortcut`` is also ``True``, applies nonlinearity to the
182
+ learned shortcut connection.
183
+ nonlinearity (str, optional, default='none'):
184
+ Type of nonlinear activation function in the residual link.
185
+ ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
186
+ ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
187
+ inplace_nonlinearity (bool, optional, default=False): If ``True``,
188
+ set ``inplace=True`` when initializing the nonlinearity layers.
189
+ apply_noise (bool, optional, default=False): If ``True``, add
190
+ Gaussian noise with learnable magnitude after the
191
+ fully-connected layer.
192
+ hidden_channels_equal_out_channels (bool, optional, default=False):
193
+ If ``True``, set the hidden channel number to be equal to the
194
+ output channel number. If ``False``, the hidden channel number
195
+ equals to the smaller of the input channel number and the
196
+ output channel number.
197
+ order (str, optional, default='CNACNA'): Order of operations
198
+ in the residual link.
199
+ ``'C'``: fully-connected,
200
+ ``'N'``: normalization,
201
+ ``'A'``: nonlinear activation.
202
+ learn_shortcut (bool, optional, default=False): If ``True``, always use
203
+ a convolutional shortcut instead of an identity one, otherwise only
204
+ use a convolutional one if input and output have different number of
205
+ channels.
206
+ """
207
+
208
+ def __init__(self, in_channels, out_channels, bias=True,
209
+ weight_norm_type='none', weight_norm_params=None,
210
+ activation_norm_type='none', activation_norm_params=None,
211
+ skip_activation_norm=True, skip_nonlinearity=False,
212
+ nonlinearity='leakyrelu', inplace_nonlinearity=False,
213
+ apply_noise=False, hidden_channels_equal_out_channels=False,
214
+ order='CNACNA', learn_shortcut=False):
215
+ super().__init__(in_channels, out_channels, None, None,
216
+ None, None, bias, None,
217
+ weight_norm_type, weight_norm_params,
218
+ activation_norm_type, activation_norm_params,
219
+ skip_activation_norm, skip_nonlinearity,
220
+ nonlinearity, inplace_nonlinearity,
221
+ apply_noise, hidden_channels_equal_out_channels,
222
+ order, LinearBlock, learn_shortcut)
223
+
224
+
225
+ class Res1dBlock(_BaseResBlock):
226
+ r"""Residual block for 1D input.
227
+
228
+ Args:
229
+ in_channels (int) : Number of channels in the input tensor.
230
+ out_channels (int) : Number of channels in the output tensor.
231
+ kernel_size (int, optional, default=3): Kernel size for the
232
+ convolutional filters in the residual link.
233
+ padding (int, optional, default=1): Padding size.
234
+ dilation (int, optional, default=1): Dilation factor.
235
+ groups (int, optional, default=1): Number of convolutional/linear
236
+ groups.
237
+ padding_mode (string, optional, default='zeros'): Type of padding:
238
+ ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
239
+ weight_norm_type (str, optional, default='none'):
240
+ Type of weight normalization.
241
+ ``'none'``, ``'spectral'``, ``'weight'``
242
+ or ``'weight_demod'``.
243
+ weight_norm_params (obj, optional, default=None):
244
+ Parameters of weight normalization.
245
+ If not ``None``, ``weight_norm_params.__dict__`` will be used as
246
+ keyword arguments when initializing weight normalization.
247
+ activation_norm_type (str, optional, default='none'):
248
+ Type of activation normalization.
249
+ ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
250
+ ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
251
+ ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
252
+ activation_norm_params (obj, optional, default=None):
253
+ Parameters of activation normalization.
254
+ If not ``None``, ``activation_norm_params.__dict__`` will be used as
255
+ keyword arguments when initializing activation normalization.
256
+ skip_activation_norm (bool, optional, default=True): If ``True`` and
257
+ ``learn_shortcut`` is also ``True``, applies activation norm to the
258
+ learned shortcut connection.
259
+ skip_nonlinearity (bool, optional, default=True): If ``True`` and
260
+ ``learn_shortcut`` is also ``True``, applies nonlinearity to the
261
+ learned shortcut connection.
262
+ nonlinearity (str, optional, default='none'):
263
+ Type of nonlinear activation function in the residual link.
264
+ ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
265
+ ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
266
+ inplace_nonlinearity (bool, optional, default=False): If ``True``,
267
+ set ``inplace=True`` when initializing the nonlinearity layers.
268
+ apply_noise (bool, optional, default=False): If ``True``, adds
269
+ Gaussian noise with learnable magnitude to the convolution output.
270
+ hidden_channels_equal_out_channels (bool, optional, default=False):
271
+ If ``True``, set the hidden channel number to be equal to the
272
+ output channel number. If ``False``, the hidden channel number
273
+ equals to the smaller of the input channel number and the
274
+ output channel number.
275
+ order (str, optional, default='CNACNA'): Order of operations
276
+ in the residual link.
277
+ ``'C'``: convolution,
278
+ ``'N'``: normalization,
279
+ ``'A'``: nonlinear activation.
280
+ learn_shortcut (bool, optional, default=False): If ``True``, always use
281
+ a convolutional shortcut instead of an identity one, otherwise only
282
+ use a convolutional one if input and output have different number of
283
+ channels.
284
+ """
285
+
286
+ def __init__(self, in_channels, out_channels, kernel_size=3,
287
+ padding=1, dilation=1, groups=1, bias=True,
288
+ padding_mode='zeros',
289
+ weight_norm_type='none', weight_norm_params=None,
290
+ activation_norm_type='none', activation_norm_params=None,
291
+ skip_activation_norm=True, skip_nonlinearity=False,
292
+ nonlinearity='leakyrelu', inplace_nonlinearity=False,
293
+ apply_noise=False, hidden_channels_equal_out_channels=False,
294
+ order='CNACNA', learn_shortcut=False):
295
+ super().__init__(in_channels, out_channels, kernel_size, padding,
296
+ dilation, groups, bias, padding_mode,
297
+ weight_norm_type, weight_norm_params,
298
+ activation_norm_type, activation_norm_params,
299
+ skip_activation_norm, skip_nonlinearity,
300
+ nonlinearity, inplace_nonlinearity, apply_noise,
301
+ hidden_channels_equal_out_channels,
302
+ order, Conv1dBlock, learn_shortcut)
303
+
304
+
305
+ class Res2dBlock(_BaseResBlock):
306
+ r"""Residual block for 2D input.
307
+
308
+ Args:
309
+ in_channels (int) : Number of channels in the input tensor.
310
+ out_channels (int) : Number of channels in the output tensor.
311
+ kernel_size (int, optional, default=3): Kernel size for the
312
+ convolutional filters in the residual link.
313
+ padding (int, optional, default=1): Padding size.
314
+ dilation (int, optional, default=1): Dilation factor.
315
+ groups (int, optional, default=1): Number of convolutional/linear
316
+ groups.
317
+ padding_mode (string, optional, default='zeros'): Type of padding:
318
+ ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
319
+ weight_norm_type (str, optional, default='none'):
320
+ Type of weight normalization.
321
+ ``'none'``, ``'spectral'``, ``'weight'``
322
+ or ``'weight_demod'``.
323
+ weight_norm_params (obj, optional, default=None):
324
+ Parameters of weight normalization.
325
+ If not ``None``, ``weight_norm_params.__dict__`` will be used as
326
+ keyword arguments when initializing weight normalization.
327
+ activation_norm_type (str, optional, default='none'):
328
+ Type of activation normalization.
329
+ ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
330
+ ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
331
+ ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
332
+ activation_norm_params (obj, optional, default=None):
333
+ Parameters of activation normalization.
334
+ If not ``None``, ``activation_norm_params.__dict__`` will be used as
335
+ keyword arguments when initializing activation normalization.
336
+ skip_activation_norm (bool, optional, default=True): If ``True`` and
337
+ ``learn_shortcut`` is also ``True``, applies activation norm to the
338
+ learned shortcut connection.
339
+ skip_nonlinearity (bool, optional, default=True): If ``True`` and
340
+ ``learn_shortcut`` is also ``True``, applies nonlinearity to the
341
+ learned shortcut connection.
342
+ nonlinearity (str, optional, default='none'):
343
+ Type of nonlinear activation function in the residual link.
344
+ ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
345
+ ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
346
+ inplace_nonlinearity (bool, optional, default=False): If ``True``,
347
+ set ``inplace=True`` when initializing the nonlinearity layers.
348
+ apply_noise (bool, optional, default=False): If ``True``, adds
349
+ Gaussian noise with learnable magnitude to the convolution output.
350
+ hidden_channels_equal_out_channels (bool, optional, default=False):
351
+ If ``True``, set the hidden channel number to be equal to the
352
+ output channel number. If ``False``, the hidden channel number
353
+ equals to the smaller of the input channel number and the
354
+ output channel number.
355
+ order (str, optional, default='CNACNA'): Order of operations
356
+ in the residual link.
357
+ ``'C'``: convolution,
358
+ ``'N'``: normalization,
359
+ ``'A'``: nonlinear activation.
360
+ learn_shortcut (bool, optional, default=False): If ``True``, always use
361
+ a convolutional shortcut instead of an identity one, otherwise only
362
+ use a convolutional one if input and output have different number of
363
+ channels.
364
+ """
365
+
366
+ def __init__(self, in_channels, out_channels, kernel_size=3,
367
+ padding=1, dilation=1, groups=1, bias=True,
368
+ padding_mode='zeros',
369
+ weight_norm_type='none', weight_norm_params=None,
370
+ activation_norm_type='none', activation_norm_params=None,
371
+ skip_activation_norm=True, skip_nonlinearity=False,
372
+ nonlinearity='leakyrelu', inplace_nonlinearity=False,
373
+ apply_noise=False, hidden_channels_equal_out_channels=False,
374
+ order='CNACNA', learn_shortcut=False):
375
+ super().__init__(in_channels, out_channels, kernel_size, padding,
376
+ dilation, groups, bias, padding_mode,
377
+ weight_norm_type, weight_norm_params,
378
+ activation_norm_type, activation_norm_params,
379
+ skip_activation_norm, skip_nonlinearity,
380
+ nonlinearity, inplace_nonlinearity, apply_noise,
381
+ hidden_channels_equal_out_channels,
382
+ order, Conv2dBlock, learn_shortcut)
383
+
384
+
385
+ class Res3dBlock(_BaseResBlock):
386
+ r"""Residual block for 3D input.
387
+
388
+ Args:
389
+ in_channels (int) : Number of channels in the input tensor.
390
+ out_channels (int) : Number of channels in the output tensor.
391
+ kernel_size (int, optional, default=3): Kernel size for the
392
+ convolutional filters in the residual link.
393
+ padding (int, optional, default=1): Padding size.
394
+ dilation (int, optional, default=1): Dilation factor.
395
+ groups (int, optional, default=1): Number of convolutional/linear
396
+ groups.
397
+ padding_mode (string, optional, default='zeros'): Type of padding:
398
+ ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
399
+ weight_norm_type (str, optional, default='none'):
400
+ Type of weight normalization.
401
+ ``'none'``, ``'spectral'``, ``'weight'``
402
+ or ``'weight_demod'``.
403
+ weight_norm_params (obj, optional, default=None):
404
+ Parameters of weight normalization.
405
+ If not ``None``, ``weight_norm_params.__dict__`` will be used as
406
+ keyword arguments when initializing weight normalization.
407
+ activation_norm_type (str, optional, default='none'):
408
+ Type of activation normalization.
409
+ ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
410
+ ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
411
+ ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
412
+ activation_norm_params (obj, optional, default=None):
413
+ Parameters of activation normalization.
414
+ If not ``None``, ``activation_norm_params.__dict__`` will be used as
415
+ keyword arguments when initializing activation normalization.
416
+ skip_activation_norm (bool, optional, default=True): If ``True`` and
417
+ ``learn_shortcut`` is also ``True``, applies activation norm to the
418
+ learned shortcut connection.
419
+ skip_nonlinearity (bool, optional, default=True): If ``True`` and
420
+ ``learn_shortcut`` is also ``True``, applies nonlinearity to the
421
+ learned shortcut connection.
422
+ nonlinearity (str, optional, default='none'):
423
+ Type of nonlinear activation function in the residual link.
424
+ ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
425
+ ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
426
+ inplace_nonlinearity (bool, optional, default=False): If ``True``,
427
+ set ``inplace=True`` when initializing the nonlinearity layers.
428
+ apply_noise (bool, optional, default=False): If ``True``, adds
429
+ Gaussian noise with learnable magnitude to the convolution output.
430
+ hidden_channels_equal_out_channels (bool, optional, default=False):
431
+ If ``True``, set the hidden channel number to be equal to the
432
+ output channel number. If ``False``, the hidden channel number
433
+ equals to the smaller of the input channel number and the
434
+ output channel number.
435
+ order (str, optional, default='CNACNA'): Order of operations
436
+ in the residual link.
437
+ ``'C'``: convolution,
438
+ ``'N'``: normalization,
439
+ ``'A'``: nonlinear activation.
440
+ learn_shortcut (bool, optional, default=False): If ``True``, always use
441
+ a convolutional shortcut instead of an identity one, otherwise only
442
+ use a convolutional one if input and output have different number of
443
+ channels.
444
+ """
445
+
446
+ def __init__(self, in_channels, out_channels, kernel_size=3,
447
+ padding=1, dilation=1, groups=1, bias=True,
448
+ padding_mode='zeros',
449
+ weight_norm_type='none', weight_norm_params=None,
450
+ activation_norm_type='none', activation_norm_params=None,
451
+ skip_activation_norm=True, skip_nonlinearity=False,
452
+ nonlinearity='leakyrelu', inplace_nonlinearity=False,
453
+ apply_noise=False, hidden_channels_equal_out_channels=False,
454
+ order='CNACNA', learn_shortcut=False):
455
+ super().__init__(in_channels, out_channels, kernel_size, padding,
456
+ dilation, groups, bias, padding_mode,
457
+ weight_norm_type, weight_norm_params,
458
+ activation_norm_type, activation_norm_params,
459
+ skip_activation_norm, skip_nonlinearity,
460
+ nonlinearity, inplace_nonlinearity, apply_noise,
461
+ hidden_channels_equal_out_channels,
462
+ order, Conv3dBlock, learn_shortcut)
463
+
464
+
465
+ class _BaseHyperResBlock(_BaseResBlock):
466
+ r"""An abstract class for hyper residual blocks.
467
+ """
468
+
469
+ def __init__(self, in_channels, out_channels, kernel_size,
470
+ padding, dilation, groups, bias, padding_mode,
471
+ weight_norm_type, weight_norm_params,
472
+ activation_norm_type, activation_norm_params,
473
+ skip_activation_norm, skip_nonlinearity,
474
+ nonlinearity, inplace_nonlinearity, apply_noise,
475
+ hidden_channels_equal_out_channels,
476
+ order,
477
+ is_hyper_conv, is_hyper_norm, block, learn_shortcut):
478
+ block = functools.partial(block,
479
+ is_hyper_conv=is_hyper_conv,
480
+ is_hyper_norm=is_hyper_norm)
481
+ super().__init__(in_channels, out_channels, kernel_size, padding,
482
+ dilation, groups, bias, padding_mode,
483
+ weight_norm_type, weight_norm_params,
484
+ activation_norm_type, activation_norm_params,
485
+ skip_activation_norm, skip_nonlinearity,
486
+ nonlinearity, inplace_nonlinearity, apply_noise,
487
+ hidden_channels_equal_out_channels,
488
+ order, block, learn_shortcut)
489
+
490
+ def forward(self, x, *cond_inputs, conv_weights=(None,) * 3,
491
+ norm_weights=(None,) * 3, **kw_cond_inputs):
492
+ r"""
493
+
494
+ Args:
495
+ x (tensor): Input tensor.
496
+ cond_inputs (list of tensors) : Conditional input tensors.
497
+ conv_weights (list of tensors): Convolution weights for
498
+ three convolutional layers respectively.
499
+ norm_weights (list of tensors): Normalization weights for
500
+ three convolutional layers respectively.
501
+ kw_cond_inputs (dict) : Keyword conditional inputs.
502
+ Returns:
503
+ output (tensor): Output tensor.
504
+ """
505
+ dx = self.conv_block_0(x, *cond_inputs, conv_weights=conv_weights[0],
506
+ norm_weights=norm_weights[0])
507
+ dx = self.conv_block_1(dx, *cond_inputs, conv_weights=conv_weights[1],
508
+ norm_weights=norm_weights[1])
509
+ if self.learn_shortcut:
510
+ x_shortcut = self.conv_block_s(x, *cond_inputs,
511
+ conv_weights=conv_weights[2],
512
+ norm_weights=norm_weights[2])
513
+ else:
514
+ x_shortcut = x
515
+ output = x_shortcut + dx
516
+ return output
517
+
518
+
519
+ class HyperRes2dBlock(_BaseHyperResBlock):
520
+ r"""Hyper residual block for 2D input.
521
+
522
+ Args:
523
+ in_channels (int) : Number of channels in the input tensor.
524
+ out_channels (int) : Number of channels in the output tensor.
525
+ kernel_size (int, optional, default=3): Kernel size for the
526
+ convolutional filters in the residual link.
527
+ padding (int, optional, default=1): Padding size.
528
+ dilation (int, optional, default=1): Dilation factor.
529
+ groups (int, optional, default=1): Number of convolutional/linear
530
+ groups.
531
+ padding_mode (string, optional, default='zeros'): Type of padding:
532
+ ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
533
+ weight_norm_type (str, optional, default='none'):
534
+ Type of weight normalization.
535
+ ``'none'``, ``'spectral'``, ``'weight'``
536
+ or ``'weight_demod'``.
537
+ weight_norm_params (obj, optional, default=None):
538
+ Parameters of weight normalization.
539
+ If not ``None``, ``weight_norm_params.__dict__`` will be used as
540
+ keyword arguments when initializing weight normalization.
541
+ activation_norm_type (str, optional, default='none'):
542
+ Type of activation normalization.
543
+ ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
544
+ ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
545
+ ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
546
+ activation_norm_params (obj, optional, default=None):
547
+ Parameters of activation normalization.
548
+ If not ``None``, ``activation_norm_params.__dict__`` will be used as
549
+ keyword arguments when initializing activation normalization.
550
+ skip_activation_norm (bool, optional, default=True): If ``True`` and
551
+ ``learn_shortcut`` is also ``True``, applies activation norm to the
552
+ learned shortcut connection.
553
+ skip_nonlinearity (bool, optional, default=True): If ``True`` and
554
+ ``learn_shortcut`` is also ``True``, applies nonlinearity to the
555
+ learned shortcut connection.
556
+ nonlinearity (str, optional, default='none'):
557
+ Type of nonlinear activation function in the residual link.
558
+ ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
559
+ ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
560
+ inplace_nonlinearity (bool, optional, default=False): If ``True``,
561
+ set ``inplace=True`` when initializing the nonlinearity layers.
562
+ apply_noise (bool, optional, default=False): If ``True``, adds
563
+ Gaussian noise with learnable magnitude to the convolution output.
564
+ hidden_channels_equal_out_channels (bool, optional, default=False):
565
+ If ``True``, set the hidden channel number to be equal to the
566
+ output channel number. If ``False``, the hidden channel number
567
+ equals to the smaller of the input channel number and the
568
+ output channel number.
569
+ order (str, optional, default='CNACNA'): Order of operations
570
+ in the residual link.
571
+ ``'C'``: convolution,
572
+ ``'N'``: normalization,
573
+ ``'A'``: nonlinear activation.
574
+ is_hyper_conv (bool, optional, default=False): If ``True``, use
575
+ ``HyperConv2d``, otherwise use ``torch.nn.Conv2d``.
576
+ is_hyper_norm (bool, optional, default=False): If ``True``, use
577
+ hyper normalizations.
578
+ learn_shortcut (bool, optional, default=False): If ``True``, always use
579
+ a convolutional shortcut instead of an identity one, otherwise only
580
+ use a convolutional one if input and output have different number of
581
+ channels.
582
+ """
583
+
584
+ def __init__(self, in_channels, out_channels, kernel_size=3,
585
+ padding=1, dilation=1, groups=1, bias=True,
586
+ padding_mode='zeros',
587
+ weight_norm_type='', weight_norm_params=None,
588
+ activation_norm_type='', activation_norm_params=None,
589
+ skip_activation_norm=True, skip_nonlinearity=False,
590
+ nonlinearity='leakyrelu', inplace_nonlinearity=False,
591
+ apply_noise=False, hidden_channels_equal_out_channels=False,
592
+ order='CNACNA', is_hyper_conv=False, is_hyper_norm=False,
593
+ learn_shortcut=False):
594
+ super().__init__(in_channels, out_channels, kernel_size, padding,
595
+ dilation, groups, bias, padding_mode,
596
+ weight_norm_type, weight_norm_params,
597
+ activation_norm_type, activation_norm_params,
598
+ skip_activation_norm, skip_nonlinearity,
599
+ nonlinearity, inplace_nonlinearity, apply_noise,
600
+ hidden_channels_equal_out_channels,
601
+ order, is_hyper_conv, is_hyper_norm,
602
+ HyperConv2dBlock, learn_shortcut)
603
+
604
+
605
+ class _BaseDownResBlock(_BaseResBlock):
606
+ r"""An abstract class for residual blocks with downsampling.
607
+ """
608
+
609
+ def __init__(self, in_channels, out_channels, kernel_size,
610
+ padding, dilation, groups, bias, padding_mode,
611
+ weight_norm_type, weight_norm_params,
612
+ activation_norm_type, activation_norm_params,
613
+ skip_activation_norm, skip_nonlinearity,
614
+ nonlinearity, inplace_nonlinearity,
615
+ apply_noise, hidden_channels_equal_out_channels,
616
+ order, block, pooling, down_factor, learn_shortcut):
617
+ super().__init__(in_channels, out_channels, kernel_size, padding,
618
+ dilation, groups, bias, padding_mode,
619
+ weight_norm_type, weight_norm_params,
620
+ activation_norm_type, activation_norm_params,
621
+ skip_activation_norm, skip_nonlinearity,
622
+ nonlinearity, inplace_nonlinearity,
623
+ apply_noise, hidden_channels_equal_out_channels,
624
+ order, block, learn_shortcut)
625
+ self.pooling = pooling(down_factor)
626
+
627
+ def forward(self, x, *cond_inputs):
628
+ r"""
629
+
630
+ Args:
631
+ x (tensor) : Input tensor.
632
+ cond_inputs (list of tensors) : conditional input.
633
+ Returns:
634
+ output (tensor) : Output tensor.
635
+ """
636
+ dx = self.conv_block_0(x, *cond_inputs)
637
+ dx = self.conv_block_1(dx, *cond_inputs)
638
+ dx = self.pooling(dx)
639
+ if self.learn_shortcut:
640
+ x_shortcut = self.conv_block_s(x, *cond_inputs)
641
+ else:
642
+ x_shortcut = x
643
+ x_shortcut = self.pooling(x_shortcut)
644
+ output = x_shortcut + dx
645
+ return output
646
+
647
+
648
+ class DownRes2dBlock(_BaseDownResBlock):
649
+ r"""Residual block for 2D input with downsampling.
650
+
651
+ Args:
652
+ in_channels (int) : Number of channels in the input tensor.
653
+ out_channels (int) : Number of channels in the output tensor.
654
+ kernel_size (int, optional, default=3): Kernel size for the
655
+ convolutional filters in the residual link.
656
+ padding (int, optional, default=1): Padding size.
657
+ dilation (int, optional, default=1): Dilation factor.
658
+ groups (int, optional, default=1): Number of convolutional/linear
659
+ groups.
660
+ padding_mode (string, optional, default='zeros'): Type of padding:
661
+ ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
662
+ weight_norm_type (str, optional, default='none'):
663
+ Type of weight normalization.
664
+ ``'none'``, ``'spectral'``, ``'weight'``
665
+ or ``'weight_demod'``.
666
+ weight_norm_params (obj, optional, default=None):
667
+ Parameters of weight normalization.
668
+ If not ``None``, ``weight_norm_params.__dict__`` will be used as
669
+ keyword arguments when initializing weight normalization.
670
+ activation_norm_type (str, optional, default='none'):
671
+ Type of activation normalization.
672
+ ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
673
+ ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
674
+ ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
675
+ activation_norm_params (obj, optional, default=None):
676
+ Parameters of activation normalization.
677
+ If not ``None``, ``activation_norm_params.__dict__`` will be used as
678
+ keyword arguments when initializing activation normalization.
679
+ skip_activation_norm (bool, optional, default=True): If ``True`` and
680
+ ``learn_shortcut`` is also ``True``, applies activation norm to the
681
+ learned shortcut connection.
682
+ skip_nonlinearity (bool, optional, default=True): If ``True`` and
683
+ ``learn_shortcut`` is also ``True``, applies nonlinearity to the
684
+ learned shortcut connection.
685
+ nonlinearity (str, optional, default='none'):
686
+ Type of nonlinear activation function in the residual link.
687
+ ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
688
+ ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
689
+ inplace_nonlinearity (bool, optional, default=False): If ``True``,
690
+ set ``inplace=True`` when initializing the nonlinearity layers.
691
+ apply_noise (bool, optional, default=False): If ``True``, adds
692
+ Gaussian noise with learnable magnitude to the convolution output.
693
+ hidden_channels_equal_out_channels (bool, optional, default=False):
694
+ If ``True``, set the hidden channel number to be equal to the
695
+ output channel number. If ``False``, the hidden channel number
696
+ equals to the smaller of the input channel number and the
697
+ output channel number.
698
+ order (str, optional, default='CNACNA'): Order of operations
699
+ in the residual link.
700
+ ``'C'``: convolution,
701
+ ``'N'``: normalization,
702
+ ``'A'``: nonlinear activation.
703
+ pooling (class, optional, default=nn.AvgPool2d): Pytorch pooling
704
+ layer to be used.
705
+ down_factor (int, optional, default=2): Downsampling factor.
706
+ learn_shortcut (bool, optional, default=False): If ``True``, always use
707
+ a convolutional shortcut instead of an identity one, otherwise only
708
+ use a convolutional one if input and output have different number of
709
+ channels.
710
+ """
711
+
712
+ def __init__(self, in_channels, out_channels, kernel_size=3,
713
+ padding=1, dilation=1, groups=1, bias=True,
714
+ padding_mode='zeros',
715
+ weight_norm_type='none', weight_norm_params=None,
716
+ activation_norm_type='none', activation_norm_params=None,
717
+ skip_activation_norm=True, skip_nonlinearity=False,
718
+ nonlinearity='leakyrelu', inplace_nonlinearity=False,
719
+ apply_noise=False, hidden_channels_equal_out_channels=False,
720
+ order='CNACNA', pooling=nn.AvgPool2d, down_factor=2,
721
+ learn_shortcut=False):
722
+ super().__init__(in_channels, out_channels, kernel_size, padding,
723
+ dilation, groups, bias, padding_mode,
724
+ weight_norm_type, weight_norm_params,
725
+ activation_norm_type, activation_norm_params,
726
+ skip_activation_norm, skip_nonlinearity,
727
+ nonlinearity, inplace_nonlinearity, apply_noise,
728
+ hidden_channels_equal_out_channels,
729
+ order, Conv2dBlock, pooling,
730
+ down_factor, learn_shortcut)
731
+
732
+
733
+ class _BaseUpResBlock(_BaseResBlock):
734
+ r"""An abstract class for residual blocks with upsampling.
735
+ """
736
+
737
+ def __init__(self, in_channels, out_channels, kernel_size,
738
+ padding, dilation, groups, bias, padding_mode,
739
+ weight_norm_type, weight_norm_params,
740
+ activation_norm_type, activation_norm_params,
741
+ skip_activation_norm, skip_nonlinearity,
742
+ nonlinearity, inplace_nonlinearity,
743
+ apply_noise, hidden_channels_equal_out_channels,
744
+ order, block, upsample, up_factor, learn_shortcut):
745
+ super().__init__(in_channels, out_channels, kernel_size, padding,
746
+ dilation, groups, bias, padding_mode,
747
+ weight_norm_type, weight_norm_params,
748
+ activation_norm_type, activation_norm_params,
749
+ skip_activation_norm, skip_nonlinearity,
750
+ nonlinearity, inplace_nonlinearity,
751
+ apply_noise, hidden_channels_equal_out_channels,
752
+ order, block, learn_shortcut)
753
+ self.order = order
754
+ self.upsample = upsample(scale_factor=up_factor)
755
+
756
+ def forward(self, x, *cond_inputs):
757
+ r"""Implementation of the up residual block forward function.
758
+ If the order is 'NAC' for the first residual block, we will first
759
+ do the activation norm and nonlinearity, in the original resolution.
760
+ We will then upsample the activation map to a higher resolution. We
761
+ then do the convolution.
762
+ It is is other orders, then we first do the whole processing and
763
+ then upsample.
764
+
765
+ Args:
766
+ x (tensor) : Input tensor.
767
+ cond_inputs (list of tensors) : Conditional input.
768
+ Returns:
769
+ output (tensor) : Output tensor.
770
+ """
771
+ # In this particular upsample residual block operation, we first
772
+ # upsample the skip connection.
773
+ if self.learn_shortcut:
774
+ x_shortcut = self.upsample(x)
775
+ x_shortcut = self.conv_block_s(x_shortcut, *cond_inputs)
776
+ else:
777
+ x_shortcut = self.upsample(x)
778
+
779
+ if self.order[0:3] == 'NAC':
780
+ for ix, layer in enumerate(self.conv_block_0.layers.values()):
781
+ if getattr(layer, 'conditional', False):
782
+ x = layer(x, *cond_inputs)
783
+ else:
784
+ x = layer(x)
785
+ if ix == 1:
786
+ x = self.upsample(x)
787
+ else:
788
+ x = self.conv_block_0(x, *cond_inputs)
789
+ x = self.upsample(x)
790
+ x = self.conv_block_1(x, *cond_inputs)
791
+
792
+ output = x_shortcut + x
793
+ return output
794
+
795
+
796
+ class UpRes2dBlock(_BaseUpResBlock):
797
+ r"""Residual block for 2D input with downsampling.
798
+
799
+ Args:
800
+ in_channels (int) : Number of channels in the input tensor.
801
+ out_channels (int) : Number of channels in the output tensor.
802
+ kernel_size (int, optional, default=3): Kernel size for the
803
+ convolutional filters in the residual link.
804
+ padding (int, optional, default=1): Padding size.
805
+ dilation (int, optional, default=1): Dilation factor.
806
+ groups (int, optional, default=1): Number of convolutional/linear
807
+ groups.
808
+ padding_mode (string, optional, default='zeros'): Type of padding:
809
+ ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
810
+ weight_norm_type (str, optional, default='none'):
811
+ Type of weight normalization.
812
+ ``'none'``, ``'spectral'``, ``'weight'``
813
+ or ``'weight_demod'``.
814
+ weight_norm_params (obj, optional, default=None):
815
+ Parameters of weight normalization.
816
+ If not ``None``, ``weight_norm_params.__dict__`` will be used as
817
+ keyword arguments when initializing weight normalization.
818
+ activation_norm_type (str, optional, default='none'):
819
+ Type of activation normalization.
820
+ ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
821
+ ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
822
+ ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
823
+ activation_norm_params (obj, optional, default=None):
824
+ Parameters of activation normalization.
825
+ If not ``None``, ``activation_norm_params.__dict__`` will be used as
826
+ keyword arguments when initializing activation normalization.
827
+ skip_activation_norm (bool, optional, default=True): If ``True`` and
828
+ ``learn_shortcut`` is also ``True``, applies activation norm to the
829
+ learned shortcut connection.
830
+ skip_nonlinearity (bool, optional, default=True): If ``True`` and
831
+ ``learn_shortcut`` is also ``True``, applies nonlinearity to the
832
+ learned shortcut connection.
833
+ nonlinearity (str, optional, default='none'):
834
+ Type of nonlinear activation function in the residual link.
835
+ ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
836
+ ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
837
+ inplace_nonlinearity (bool, optional, default=False): If ``True``,
838
+ set ``inplace=True`` when initializing the nonlinearity layers.
839
+ apply_noise (bool, optional, default=False): If ``True``, adds
840
+ Gaussian noise with learnable magnitude to the convolution output.
841
+ hidden_channels_equal_out_channels (bool, optional, default=False):
842
+ If ``True``, set the hidden channel number to be equal to the
843
+ output channel number. If ``False``, the hidden channel number
844
+ equals to the smaller of the input channel number and the
845
+ output channel number.
846
+ order (str, optional, default='CNACNA'): Order of operations
847
+ in the residual link.
848
+ ``'C'``: convolution,
849
+ ``'N'``: normalization,
850
+ ``'A'``: nonlinear activation.
851
+ upsample (class, optional, default=NearestUpsample): PPytorch
852
+ upsampling layer to be used.
853
+ up_factor (int, optional, default=2): Upsampling factor.
854
+ learn_shortcut (bool, optional, default=False): If ``True``, always use
855
+ a convolutional shortcut instead of an identity one, otherwise only
856
+ use a convolutional one if input and output have different number of
857
+ channels.
858
+ """
859
+
860
+ def __init__(self, in_channels, out_channels, kernel_size=3,
861
+ padding=1, dilation=1, groups=1, bias=True,
862
+ padding_mode='zeros',
863
+ weight_norm_type='none', weight_norm_params=None,
864
+ activation_norm_type='none', activation_norm_params=None,
865
+ skip_activation_norm=True, skip_nonlinearity=False,
866
+ nonlinearity='leakyrelu', inplace_nonlinearity=False,
867
+ apply_noise=False, hidden_channels_equal_out_channels=False,
868
+ order='CNACNA', upsample=NearestUpsample, up_factor=2,
869
+ learn_shortcut=False):
870
+ super().__init__(in_channels, out_channels, kernel_size, padding,
871
+ dilation, groups, bias, padding_mode,
872
+ weight_norm_type, weight_norm_params,
873
+ activation_norm_type, activation_norm_params,
874
+ skip_activation_norm, skip_nonlinearity,
875
+ nonlinearity, inplace_nonlinearity,
876
+ apply_noise, hidden_channels_equal_out_channels,
877
+ order, Conv2dBlock,
878
+ upsample, up_factor, learn_shortcut)
879
+
880
+
881
+ class _BasePartialResBlock(_BaseResBlock):
882
+ r"""An abstract class for residual blocks with partial convolution.
883
+ """
884
+
885
+ def __init__(self, in_channels, out_channels, kernel_size,
886
+ padding, dilation, groups, bias, padding_mode,
887
+ weight_norm_type, weight_norm_params,
888
+ activation_norm_type, activation_norm_params,
889
+ skip_activation_norm, skip_nonlinearity,
890
+ nonlinearity, inplace_nonlinearity,
891
+ multi_channel, return_mask,
892
+ apply_noise, hidden_channels_equal_out_channels,
893
+ order, block, learn_shortcut):
894
+ block = functools.partial(block,
895
+ multi_channel=multi_channel,
896
+ return_mask=return_mask)
897
+ self.partial_conv = True
898
+ super().__init__(in_channels, out_channels, kernel_size, padding,
899
+ dilation, groups, bias, padding_mode,
900
+ weight_norm_type, weight_norm_params,
901
+ activation_norm_type, activation_norm_params,
902
+ skip_activation_norm, skip_nonlinearity,
903
+ nonlinearity, inplace_nonlinearity,
904
+ apply_noise, hidden_channels_equal_out_channels,
905
+ order, block, learn_shortcut)
906
+
907
+ def forward(self, x, *cond_inputs, mask_in=None, **kw_cond_inputs):
908
+ r"""
909
+
910
+ Args:
911
+ x (tensor): Input tensor.
912
+ cond_inputs (list of tensors) : Conditional input tensors.
913
+ mask_in (tensor, optional, default=``None``) If not ``None``,
914
+ it masks the valid input region.
915
+ kw_cond_inputs (dict) : Keyword conditional inputs.
916
+ Returns:
917
+ (tuple):
918
+ - output (tensor): Output tensor.
919
+ - mask_out (tensor, optional): Masks the valid output region.
920
+ """
921
+ if self.conv_block_0.layers.conv.return_mask:
922
+ dx, mask_out = self.conv_block_0(x, *cond_inputs,
923
+ mask_in=mask_in, **kw_cond_inputs)
924
+ dx, mask_out = self.conv_block_1(dx, *cond_inputs,
925
+ mask_in=mask_out, **kw_cond_inputs)
926
+ else:
927
+ dx = self.conv_block_0(x, *cond_inputs,
928
+ mask_in=mask_in, **kw_cond_inputs)
929
+ dx = self.conv_block_1(dx, *cond_inputs,
930
+ mask_in=mask_in, **kw_cond_inputs)
931
+ mask_out = None
932
+
933
+ if self.learn_shortcut:
934
+ x_shortcut = self.conv_block_s(x, mask_in=mask_in, *cond_inputs,
935
+ **kw_cond_inputs)
936
+ if type(x_shortcut) == tuple:
937
+ x_shortcut, _ = x_shortcut
938
+ else:
939
+ x_shortcut = x
940
+ output = x_shortcut + dx
941
+
942
+ if mask_out is not None:
943
+ return output, mask_out
944
+ return output
945
+
946
+
947
+ class PartialRes2dBlock(_BasePartialResBlock):
948
+ r"""Residual block for 2D input with partial convolution.
949
+
950
+ Args:
951
+ in_channels (int) : Number of channels in the input tensor.
952
+ out_channels (int) : Number of channels in the output tensor.
953
+ kernel_size (int, optional, default=3): Kernel size for the
954
+ convolutional filters in the residual link.
955
+ padding (int, optional, default=1): Padding size.
956
+ dilation (int, optional, default=1): Dilation factor.
957
+ groups (int, optional, default=1): Number of convolutional/linear
958
+ groups.
959
+ padding_mode (string, optional, default='zeros'): Type of padding:
960
+ ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
961
+ weight_norm_type (str, optional, default='none'):
962
+ Type of weight normalization.
963
+ ``'none'``, ``'spectral'``, ``'weight'``
964
+ or ``'weight_demod'``.
965
+ weight_norm_params (obj, optional, default=None):
966
+ Parameters of weight normalization.
967
+ If not ``None``, ``weight_norm_params.__dict__`` will be used as
968
+ keyword arguments when initializing weight normalization.
969
+ activation_norm_type (str, optional, default='none'):
970
+ Type of activation normalization.
971
+ ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
972
+ ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
973
+ ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
974
+ activation_norm_params (obj, optional, default=None):
975
+ Parameters of activation normalization.
976
+ If not ``None``, ``activation_norm_params.__dict__`` will be used as
977
+ keyword arguments when initializing activation normalization.
978
+ skip_activation_norm (bool, optional, default=True): If ``True`` and
979
+ ``learn_shortcut`` is also ``True``, applies activation norm to the
980
+ learned shortcut connection.
981
+ skip_nonlinearity (bool, optional, default=True): If ``True`` and
982
+ ``learn_shortcut`` is also ``True``, applies nonlinearity to the
983
+ learned shortcut connection.
984
+ nonlinearity (str, optional, default='none'):
985
+ Type of nonlinear activation function in the residual link.
986
+ ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
987
+ ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
988
+ inplace_nonlinearity (bool, optional, default=False): If ``True``,
989
+ set ``inplace=True`` when initializing the nonlinearity layers.
990
+ apply_noise (bool, optional, default=False): If ``True``, adds
991
+ Gaussian noise with learnable magnitude to the convolution output.
992
+ hidden_channels_equal_out_channels (bool, optional, default=False):
993
+ If ``True``, set the hidden channel number to be equal to the
994
+ output channel number. If ``False``, the hidden channel number
995
+ equals to the smaller of the input channel number and the
996
+ output channel number.
997
+ order (str, optional, default='CNACNA'): Order of operations
998
+ in the residual link.
999
+ ``'C'``: convolution,
1000
+ ``'N'``: normalization,
1001
+ ``'A'``: nonlinear activation.
1002
+ learn_shortcut (bool, optional, default=False): If ``True``, always use
1003
+ a convolutional shortcut instead of an identity one, otherwise only
1004
+ use a convolutional one if input and output have different number of
1005
+ channels.
1006
+ """
1007
+
1008
+ def __init__(self, in_channels, out_channels, kernel_size=3,
1009
+ padding=1, dilation=1, groups=1, bias=True,
1010
+ padding_mode='zeros',
1011
+ weight_norm_type='none', weight_norm_params=None,
1012
+ activation_norm_type='none', activation_norm_params=None,
1013
+ skip_activation_norm=True, skip_nonlinearity=False,
1014
+ nonlinearity='leakyrelu', inplace_nonlinearity=False,
1015
+ multi_channel=False, return_mask=True,
1016
+ apply_noise=False,
1017
+ hidden_channels_equal_out_channels=False,
1018
+ order='CNACNA', learn_shortcut=False):
1019
+ super().__init__(in_channels, out_channels, kernel_size, padding,
1020
+ dilation, groups, bias, padding_mode,
1021
+ weight_norm_type, weight_norm_params,
1022
+ activation_norm_type, activation_norm_params,
1023
+ skip_activation_norm, skip_nonlinearity,
1024
+ nonlinearity, inplace_nonlinearity,
1025
+ multi_channel, return_mask,
1026
+ apply_noise, hidden_channels_equal_out_channels,
1027
+ order, PartialConv2dBlock, learn_shortcut)
1028
+
1029
+
1030
+ class PartialRes3dBlock(_BasePartialResBlock):
1031
+ r"""Residual block for 3D input with partial convolution.
1032
+
1033
+ Args:
1034
+ in_channels (int) : Number of channels in the input tensor.
1035
+ out_channels (int) : Number of channels in the output tensor.
1036
+ kernel_size (int, optional, default=3): Kernel size for the
1037
+ convolutional filters in the residual link.
1038
+ padding (int, optional, default=1): Padding size.
1039
+ dilation (int, optional, default=1): Dilation factor.
1040
+ groups (int, optional, default=1): Number of convolutional/linear
1041
+ groups.
1042
+ padding_mode (string, optional, default='zeros'): Type of padding:
1043
+ ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
1044
+ weight_norm_type (str, optional, default='none'):
1045
+ Type of weight normalization.
1046
+ ``'none'``, ``'spectral'``, ``'weight'``
1047
+ or ``'weight_demod'``.
1048
+ weight_norm_params (obj, optional, default=None):
1049
+ Parameters of weight normalization.
1050
+ If not ``None``, ``weight_norm_params.__dict__`` will be used as
1051
+ keyword arguments when initializing weight normalization.
1052
+ activation_norm_type (str, optional, default='none'):
1053
+ Type of activation normalization.
1054
+ ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
1055
+ ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
1056
+ ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
1057
+ activation_norm_params (obj, optional, default=None):
1058
+ Parameters of activation normalization.
1059
+ If not ``None``, ``activation_norm_params.__dict__`` will be used as
1060
+ keyword arguments when initializing activation normalization.
1061
+ skip_activation_norm (bool, optional, default=True): If ``True`` and
1062
+ ``learn_shortcut`` is also ``True``, applies activation norm to the
1063
+ learned shortcut connection.
1064
+ skip_nonlinearity (bool, optional, default=True): If ``True`` and
1065
+ ``learn_shortcut`` is also ``True``, applies nonlinearity to the
1066
+ learned shortcut connection.
1067
+ nonlinearity (str, optional, default='none'):
1068
+ Type of nonlinear activation function in the residual link.
1069
+ ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
1070
+ ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
1071
+ inplace_nonlinearity (bool, optional, default=False): If ``True``,
1072
+ set ``inplace=True`` when initializing the nonlinearity layers.
1073
+ apply_noise (bool, optional, default=False): If ``True``, adds
1074
+ Gaussian noise with learnable magnitude to the convolution output.
1075
+ hidden_channels_equal_out_channels (bool, optional, default=False):
1076
+ If ``True``, set the hidden channel number to be equal to the
1077
+ output channel number. If ``False``, the hidden channel number
1078
+ equals to the smaller of the input channel number and the
1079
+ output channel number.
1080
+ order (str, optional, default='CNACNA'): Order of operations
1081
+ in the residual link.
1082
+ ``'C'``: convolution,
1083
+ ``'N'``: normalization,
1084
+ ``'A'``: nonlinear activation.
1085
+ learn_shortcut (bool, optional, default=False): If ``True``, always use
1086
+ a convolutional shortcut instead of an identity one, otherwise only
1087
+ use a convolutional one if input and output have different number of
1088
+ channels.
1089
+ """
1090
+
1091
+ def __init__(self, in_channels, out_channels, kernel_size=3,
1092
+ padding=1, dilation=1, groups=1, bias=True,
1093
+ padding_mode='zeros',
1094
+ weight_norm_type='none', weight_norm_params=None,
1095
+ activation_norm_type='none', activation_norm_params=None,
1096
+ skip_activation_norm=True, skip_nonlinearity=False,
1097
+ nonlinearity='leakyrelu', inplace_nonlinearity=False,
1098
+ multi_channel=False, return_mask=True,
1099
+ apply_noise=False, hidden_channels_equal_out_channels=False,
1100
+ order='CNACNA', learn_shortcut=False):
1101
+ super().__init__(in_channels, out_channels, kernel_size, padding,
1102
+ dilation, groups, bias, padding_mode,
1103
+ weight_norm_type, weight_norm_params,
1104
+ activation_norm_type, activation_norm_params,
1105
+ skip_activation_norm, skip_nonlinearity,
1106
+ nonlinearity, inplace_nonlinearity,
1107
+ multi_channel, return_mask,
1108
+ apply_noise, hidden_channels_equal_out_channels,
1109
+ order, PartialConv3dBlock, learn_shortcut)
1110
+
1111
+
1112
+ class _BaseMultiOutResBlock(_BaseResBlock):
1113
+ r"""An abstract class for residual blocks that can returns multiple outputs.
1114
+ """
1115
+
1116
+ def __init__(self, in_channels, out_channels, kernel_size,
1117
+ padding, dilation, groups, bias, padding_mode,
1118
+ weight_norm_type, weight_norm_params,
1119
+ activation_norm_type, activation_norm_params,
1120
+ skip_activation_norm, skip_nonlinearity,
1121
+ nonlinearity, inplace_nonlinearity,
1122
+ apply_noise, hidden_channels_equal_out_channels,
1123
+ order, block, learn_shortcut):
1124
+ self.multiple_outputs = True
1125
+ super().__init__(in_channels, out_channels, kernel_size, padding,
1126
+ dilation, groups, bias, padding_mode,
1127
+ weight_norm_type, weight_norm_params,
1128
+ activation_norm_type, activation_norm_params,
1129
+ skip_activation_norm, skip_nonlinearity,
1130
+ nonlinearity, inplace_nonlinearity, apply_noise,
1131
+ hidden_channels_equal_out_channels,
1132
+ order, block, learn_shortcut)
1133
+
1134
+ def forward(self, x, *cond_inputs):
1135
+ r"""
1136
+
1137
+ Args:
1138
+ x (tensor): Input tensor.
1139
+ cond_inputs (list of tensors) : Conditional input tensors.
1140
+ Returns:
1141
+ (tuple):
1142
+ - output (tensor): Output tensor.
1143
+ - aux_outputs_0 (tensor): Auxiliary output of the first block.
1144
+ - aux_outputs_1 (tensor): Auxiliary output of the second block.
1145
+ """
1146
+ dx, aux_outputs_0 = self.conv_block_0(x, *cond_inputs)
1147
+ dx, aux_outputs_1 = self.conv_block_1(dx, *cond_inputs)
1148
+ if self.learn_shortcut:
1149
+ # We are not using the auxiliary outputs of self.conv_block_s.
1150
+ x_shortcut, _ = self.conv_block_s(x, *cond_inputs)
1151
+ else:
1152
+ x_shortcut = x
1153
+ output = x_shortcut + dx
1154
+ return output, aux_outputs_0, aux_outputs_1
1155
+
1156
+
1157
+ class MultiOutRes2dBlock(_BaseMultiOutResBlock):
1158
+ r"""Residual block for 2D input. It can return multiple outputs, if some
1159
+ layers in the block return more than one output.
1160
+
1161
+ Args:
1162
+ in_channels (int) : Number of channels in the input tensor.
1163
+ out_channels (int) : Number of channels in the output tensor.
1164
+ kernel_size (int, optional, default=3): Kernel size for the
1165
+ convolutional filters in the residual link.
1166
+ padding (int, optional, default=1): Padding size.
1167
+ dilation (int, optional, default=1): Dilation factor.
1168
+ groups (int, optional, default=1): Number of convolutional/linear
1169
+ groups.
1170
+ padding_mode (string, optional, default='zeros'): Type of padding:
1171
+ ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
1172
+ weight_norm_type (str, optional, default='none'):
1173
+ Type of weight normalization.
1174
+ ``'none'``, ``'spectral'``, ``'weight'``
1175
+ or ``'weight_demod'``.
1176
+ weight_norm_params (obj, optional, default=None):
1177
+ Parameters of weight normalization.
1178
+ If not ``None``, ``weight_norm_params.__dict__`` will be used as
1179
+ keyword arguments when initializing weight normalization.
1180
+ activation_norm_type (str, optional, default='none'):
1181
+ Type of activation normalization.
1182
+ ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
1183
+ ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
1184
+ ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
1185
+ activation_norm_params (obj, optional, default=None):
1186
+ Parameters of activation normalization.
1187
+ If not ``None``, ``activation_norm_params.__dict__`` will be used as
1188
+ keyword arguments when initializing activation normalization.
1189
+ skip_activation_norm (bool, optional, default=True): If ``True`` and
1190
+ ``learn_shortcut`` is also ``True``, applies activation norm to the
1191
+ learned shortcut connection.
1192
+ skip_nonlinearity (bool, optional, default=True): If ``True`` and
1193
+ ``learn_shortcut`` is also ``True``, applies nonlinearity to the
1194
+ learned shortcut connection.
1195
+ nonlinearity (str, optional, default='none'):
1196
+ Type of nonlinear activation function in the residual link.
1197
+ ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
1198
+ ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
1199
+ inplace_nonlinearity (bool, optional, default=False): If ``True``,
1200
+ set ``inplace=True`` when initializing the nonlinearity layers.
1201
+ apply_noise (bool, optional, default=False): If ``True``, adds
1202
+ Gaussian noise with learnable magnitude to the convolution output.
1203
+ hidden_channels_equal_out_channels (bool, optional, default=False):
1204
+ If ``True``, set the hidden channel number to be equal to the
1205
+ output channel number. If ``False``, the hidden channel number
1206
+ equals to the smaller of the input channel number and the
1207
+ output channel number.
1208
+ order (str, optional, default='CNACNA'): Order of operations
1209
+ in the residual link.
1210
+ ``'C'``: convolution,
1211
+ ``'N'``: normalization,
1212
+ ``'A'``: nonlinear activation.
1213
+ learn_shortcut (bool, optional, default=False): If ``True``, always use
1214
+ a convolutional shortcut instead of an identity one, otherwise only
1215
+ use a convolutional one if input and output have different number of
1216
+ channels.
1217
+ """
1218
+
1219
+ def __init__(self, in_channels, out_channels, kernel_size=3,
1220
+ padding=1, dilation=1, groups=1, bias=True,
1221
+ padding_mode='zeros',
1222
+ weight_norm_type='none', weight_norm_params=None,
1223
+ activation_norm_type='none', activation_norm_params=None,
1224
+ skip_activation_norm=True, skip_nonlinearity=False,
1225
+ nonlinearity='leakyrelu', inplace_nonlinearity=False,
1226
+ apply_noise=False, hidden_channels_equal_out_channels=False,
1227
+ order='CNACNA', learn_shortcut=False):
1228
+ super().__init__(in_channels, out_channels, kernel_size, padding,
1229
+ dilation, groups, bias, padding_mode,
1230
+ weight_norm_type, weight_norm_params,
1231
+ activation_norm_type, activation_norm_params,
1232
+ skip_activation_norm, skip_nonlinearity,
1233
+ nonlinearity, inplace_nonlinearity,
1234
+ apply_noise, hidden_channels_equal_out_channels,
1235
+ order, MultiOutConv2dBlock, learn_shortcut)
models/layers/weight_norm.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2020 NVIDIA Corporation. All rights reserved.
2
+ #
3
+ # This work is made available under the Nvidia Source Code License-NC.
4
+ # To view a copy of this license, check out LICENSE.md
5
+ import functools
6
+
7
+ import torch
8
+ from torch import nn
9
+ from torch.nn.utils import spectral_norm, weight_norm
10
+
11
+ from .conv import LinearBlock
12
+
13
+
14
+ class WeightDemodulation(nn.Module):
15
+ r"""Weight demodulation in
16
+ "Analyzing and Improving the Image Quality of StyleGAN", Karras et al.
17
+
18
+ Args:
19
+ conv (torch.nn.Modules): Convolutional layer.
20
+ cond_dims (int): The number of channels in the conditional input.
21
+ eps (float, optional, default=1e-8): a value added to the
22
+ denominator for numerical stability.
23
+ adaptive_bias (bool, optional, default=False): If ``True``, adaptively
24
+ predicts bias from the conditional input.
25
+ demod (bool, optional, default=False): If ``True``, performs
26
+ weight demodulation.
27
+ """
28
+
29
+ def __init__(self, conv, cond_dims, eps=1e-8,
30
+ adaptive_bias=False, demod=True):
31
+ super().__init__()
32
+ self.conv = conv
33
+ self.adaptive_bias = adaptive_bias
34
+ if adaptive_bias:
35
+ self.conv.register_parameter('bias', None)
36
+ self.fc_beta = LinearBlock(cond_dims, self.conv.out_channels)
37
+ self.fc_gamma = LinearBlock(cond_dims, self.conv.in_channels)
38
+ self.eps = eps
39
+ self.demod = demod
40
+ self.conditional = True
41
+
42
+ def forward(self, x, y):
43
+ r"""Weight demodulation forward"""
44
+ b, c, h, w = x.size()
45
+ self.conv.groups = b
46
+ gamma = self.fc_gamma(y)
47
+ gamma = gamma[:, None, :, None, None]
48
+ weight = self.conv.weight[None, :, :, :, :] * (gamma + 1)
49
+
50
+ if self.demod:
51
+ d = torch.rsqrt(
52
+ (weight ** 2).sum(dim=(2, 3, 4), keepdim=True) + self.eps)
53
+ weight = weight * d
54
+
55
+ x = x.reshape(1, -1, h, w)
56
+ _, _, *ws = weight.shape
57
+ weight = weight.reshape(b * self.conv.out_channels, *ws)
58
+ x = self.conv.conv2d_forward(x, weight)
59
+
60
+ x = x.reshape(-1, self.conv.out_channels, h, w)
61
+ if self.adaptive_bias:
62
+ x += self.fc_beta(y)[:, :, None, None]
63
+ return x
64
+
65
+
66
+ def weight_demod(conv, cond_dims=256, eps=1e-8, demod=True):
67
+ r"""Weight demodulation."""
68
+ return WeightDemodulation(conv, cond_dims, eps, demod)
69
+
70
+
71
+ def get_weight_norm_layer(norm_type, **norm_params):
72
+ r"""Return weight normalization.
73
+
74
+ Args:
75
+ norm_type (str):
76
+ Type of weight normalization.
77
+ ``'none'``, ``'spectral'``, ``'weight'``
78
+ or ``'weight_demod'``.
79
+ norm_params: Arbitrary keyword arguments that will be used to
80
+ initialize the weight normalization.
81
+ """
82
+ if norm_type == 'none' or norm_type == '': # no normalization
83
+ return lambda x: x
84
+ elif norm_type == 'spectral': # spectral normalization
85
+ return functools.partial(spectral_norm, **norm_params)
86
+ elif norm_type == 'weight': # weight normalization
87
+ return functools.partial(weight_norm, **norm_params)
88
+ elif norm_type == 'weight_demod': # weight demodulation
89
+ return functools.partial(weight_demod, **norm_params)
90
+ else:
91
+ raise ValueError(
92
+ 'Weight norm layer %s is not recognized' % norm_type)