|
""" The code is based on https://github.com/apple/ml-gsn/ with adaption. """ |
|
|
|
import math |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from lib.torch_utils.ops.native_ops import ( |
|
FusedLeakyReLU, |
|
fused_leaky_relu, |
|
upfirdn2d, |
|
) |
|
|
|
|
|
class DiscriminatorHead(nn.Module): |
|
def __init__(self, in_channel, disc_stddev=False): |
|
super().__init__() |
|
|
|
self.disc_stddev = disc_stddev |
|
stddev_dim = 1 if disc_stddev else 0 |
|
|
|
self.conv_stddev = ConvLayer2d( |
|
in_channel=in_channel + stddev_dim, |
|
out_channel=in_channel, |
|
kernel_size=3, |
|
activate=True |
|
) |
|
|
|
self.final_linear = nn.Sequential( |
|
nn.Flatten(), |
|
EqualLinear(in_channel=in_channel * 4 * 4, out_channel=in_channel, activate=True), |
|
EqualLinear(in_channel=in_channel, out_channel=1), |
|
) |
|
|
|
def cat_stddev(self, x, stddev_group=4, stddev_feat=1): |
|
perm = torch.randperm(len(x)) |
|
inv_perm = torch.argsort(perm) |
|
|
|
batch, channel, height, width = x.shape |
|
x = x[perm |
|
] |
|
|
|
group = min(batch, stddev_group) |
|
stddev = x.view(group, -1, stddev_feat, channel // stddev_feat, height, width) |
|
stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8) |
|
stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2) |
|
stddev = stddev.repeat(group, 1, height, width) |
|
|
|
stddev = stddev[inv_perm] |
|
x = x[inv_perm] |
|
|
|
out = torch.cat([x, stddev], 1) |
|
return out |
|
|
|
def forward(self, x): |
|
if self.disc_stddev: |
|
x = self.cat_stddev(x) |
|
x = self.conv_stddev(x) |
|
out = self.final_linear(x) |
|
return out |
|
|
|
|
|
class ConvDecoder(nn.Module): |
|
def __init__(self, in_channel, out_channel, in_res, out_res): |
|
super().__init__() |
|
|
|
log_size_in = int(math.log(in_res, 2)) |
|
log_size_out = int(math.log(out_res, 2)) |
|
|
|
self.layers = [] |
|
in_ch = in_channel |
|
for i in range(log_size_in, log_size_out): |
|
out_ch = in_ch // 2 |
|
self.layers.append( |
|
ConvLayer2d( |
|
in_channel=in_ch, |
|
out_channel=out_ch, |
|
kernel_size=3, |
|
upsample=True, |
|
bias=True, |
|
activate=True |
|
) |
|
) |
|
in_ch = out_ch |
|
|
|
self.layers.append( |
|
ConvLayer2d( |
|
in_channel=in_ch, out_channel=out_channel, kernel_size=3, bias=True, activate=False |
|
) |
|
) |
|
self.layers = nn.Sequential(*self.layers) |
|
|
|
def forward(self, x): |
|
return self.layers(x) |
|
|
|
|
|
class StyleDiscriminator(nn.Module): |
|
def __init__(self, in_channel, in_res, ch_mul=64, ch_max=512, **kwargs): |
|
super().__init__() |
|
|
|
log_size_in = int(math.log(in_res, 2)) |
|
log_size_out = int(math.log(4, 2)) |
|
|
|
self.conv_in = ConvLayer2d(in_channel=in_channel, out_channel=ch_mul, kernel_size=3) |
|
|
|
|
|
self.layers = [] |
|
in_channels = ch_mul |
|
for i in range(log_size_in, log_size_out, -1): |
|
out_channels = int(min(in_channels * 2, ch_max)) |
|
self.layers.append( |
|
ConvResBlock2d(in_channel=in_channels, out_channel=out_channels, downsample=True) |
|
) |
|
in_channels = out_channels |
|
self.layers = nn.Sequential(*self.layers) |
|
|
|
self.disc_out = DiscriminatorHead(in_channel=in_channels, disc_stddev=True) |
|
|
|
def forward(self, x): |
|
x = self.conv_in(x) |
|
x = self.layers(x) |
|
out = self.disc_out(x) |
|
return out |
|
|
|
|
|
def make_kernel(k): |
|
k = torch.tensor(k, dtype=torch.float32) |
|
|
|
if k.ndim == 1: |
|
k = k[None, :] * k[:, None] |
|
|
|
k /= k.sum() |
|
|
|
return k |
|
|
|
|
|
class Blur(nn.Module): |
|
"""Blur layer. |
|
|
|
Applies a blur kernel to input image using finite impulse response filter. Blurring feature maps after |
|
convolutional upsampling or before convolutional downsampling helps produces models that are more robust to |
|
shifting inputs (https://richzhang.github.io/antialiased-cnns/). In the context of GANs, this can provide |
|
cleaner gradients, and therefore more stable training. |
|
|
|
Args: |
|
---- |
|
kernel: list, int |
|
A list of integers representing a blur kernel. For exmaple: [1, 3, 3, 1]. |
|
pad: tuple, int |
|
A tuple of integers representing the number of rows/columns of padding to be added to the top/left and |
|
the bottom/right respectively. |
|
upsample_factor: int |
|
Upsample factor. |
|
|
|
""" |
|
def __init__(self, kernel, pad, upsample_factor=1): |
|
super().__init__() |
|
|
|
kernel = make_kernel(kernel) |
|
|
|
if upsample_factor > 1: |
|
kernel = kernel * (upsample_factor**2) |
|
|
|
self.register_buffer("kernel", kernel) |
|
self.pad = pad |
|
|
|
def forward(self, input): |
|
out = upfirdn2d(input, self.kernel, pad=self.pad) |
|
return out |
|
|
|
|
|
class Upsample(nn.Module): |
|
"""Upsampling layer. |
|
|
|
Perform upsampling using a blur kernel. |
|
|
|
Args: |
|
---- |
|
kernel: list, int |
|
A list of integers representing a blur kernel. For exmaple: [1, 3, 3, 1]. |
|
factor: int |
|
Upsampling factor. |
|
|
|
""" |
|
def __init__(self, kernel=[1, 3, 3, 1], factor=2): |
|
super().__init__() |
|
|
|
self.factor = factor |
|
kernel = make_kernel(kernel) * (factor**2) |
|
self.register_buffer("kernel", kernel) |
|
|
|
p = kernel.shape[0] - factor |
|
pad0 = (p + 1) // 2 + factor - 1 |
|
pad1 = p // 2 |
|
self.pad = (pad0, pad1) |
|
|
|
def forward(self, input): |
|
out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad) |
|
return out |
|
|
|
|
|
class Downsample(nn.Module): |
|
"""Downsampling layer. |
|
|
|
Perform downsampling using a blur kernel. |
|
|
|
Args: |
|
---- |
|
kernel: list, int |
|
A list of integers representing a blur kernel. For exmaple: [1, 3, 3, 1]. |
|
factor: int |
|
Downsampling factor. |
|
|
|
""" |
|
def __init__(self, kernel=[1, 3, 3, 1], factor=2): |
|
super().__init__() |
|
|
|
self.factor = factor |
|
kernel = make_kernel(kernel) |
|
self.register_buffer("kernel", kernel) |
|
|
|
p = kernel.shape[0] - factor |
|
pad0 = (p + 1) // 2 |
|
pad1 = p // 2 |
|
self.pad = (pad0, pad1) |
|
|
|
def forward(self, input): |
|
out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad) |
|
return out |
|
|
|
|
|
class EqualLinear(nn.Module): |
|
"""Linear layer with equalized learning rate. |
|
|
|
During the forward pass the weights are scaled by the inverse of the He constant (i.e. sqrt(in_dim)) to |
|
prevent vanishing gradients and accelerate training. This constant only works for ReLU or LeakyReLU |
|
activation functions. |
|
|
|
Args: |
|
---- |
|
in_channel: int |
|
Input channels. |
|
out_channel: int |
|
Output channels. |
|
bias: bool |
|
Use bias term. |
|
bias_init: float |
|
Initial value for the bias. |
|
lr_mul: float |
|
Learning rate multiplier. By scaling weights and the bias we can proportionally scale the magnitude of |
|
the gradients, effectively increasing/decreasing the learning rate for this layer. |
|
activate: bool |
|
Apply leakyReLU activation. |
|
|
|
""" |
|
def __init__(self, in_channel, out_channel, bias=True, bias_init=0, lr_mul=1, activate=False): |
|
super().__init__() |
|
|
|
self.weight = nn.Parameter(torch.randn(out_channel, in_channel).div_(lr_mul)) |
|
|
|
if bias: |
|
self.bias = nn.Parameter(torch.zeros(out_channel).fill_(bias_init)) |
|
else: |
|
self.bias = None |
|
|
|
self.activate = activate |
|
self.scale = (1 / math.sqrt(in_channel)) * lr_mul |
|
self.lr_mul = lr_mul |
|
|
|
def forward(self, input): |
|
if self.activate: |
|
out = F.linear(input, self.weight * self.scale) |
|
out = fused_leaky_relu(out, self.bias * self.lr_mul) |
|
else: |
|
out = F.linear(input, self.weight * self.scale, bias=self.bias * self.lr_mul) |
|
return out |
|
|
|
def __repr__(self): |
|
return f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})" |
|
|
|
|
|
class EqualConv2d(nn.Module): |
|
"""2D convolution layer with equalized learning rate. |
|
|
|
During the forward pass the weights are scaled by the inverse of the He constant (i.e. sqrt(in_dim)) to |
|
prevent vanishing gradients and accelerate training. This constant only works for ReLU or LeakyReLU |
|
activation functions. |
|
|
|
Args: |
|
---- |
|
in_channel: int |
|
Input channels. |
|
out_channel: int |
|
Output channels. |
|
kernel_size: int |
|
Kernel size. |
|
stride: int |
|
Stride of convolutional kernel across the input. |
|
padding: int |
|
Amount of zero padding applied to both sides of the input. |
|
bias: bool |
|
Use bias term. |
|
|
|
""" |
|
def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True): |
|
super().__init__() |
|
|
|
self.weight = nn.Parameter(torch.randn(out_channel, in_channel, kernel_size, kernel_size)) |
|
self.scale = 1 / math.sqrt(in_channel * kernel_size**2) |
|
|
|
self.stride = stride |
|
self.padding = padding |
|
|
|
if bias: |
|
self.bias = nn.Parameter(torch.zeros(out_channel)) |
|
else: |
|
self.bias = None |
|
|
|
def forward(self, input): |
|
out = F.conv2d( |
|
input, |
|
self.weight * self.scale, |
|
bias=self.bias, |
|
stride=self.stride, |
|
padding=self.padding |
|
) |
|
return out |
|
|
|
def __repr__(self): |
|
return ( |
|
f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]}," |
|
f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})" |
|
) |
|
|
|
|
|
class EqualConvTranspose2d(nn.Module): |
|
"""2D transpose convolution layer with equalized learning rate. |
|
|
|
During the forward pass the weights are scaled by the inverse of the He constant (i.e. sqrt(in_dim)) to |
|
prevent vanishing gradients and accelerate training. This constant only works for ReLU or LeakyReLU |
|
activation functions. |
|
|
|
Args: |
|
---- |
|
in_channel: int |
|
Input channels. |
|
out_channel: int |
|
Output channels. |
|
kernel_size: int |
|
Kernel size. |
|
stride: int |
|
Stride of convolutional kernel across the input. |
|
padding: int |
|
Amount of zero padding applied to both sides of the input. |
|
output_padding: int |
|
Extra padding added to input to achieve the desired output size. |
|
bias: bool |
|
Use bias term. |
|
|
|
""" |
|
def __init__( |
|
self, |
|
in_channel, |
|
out_channel, |
|
kernel_size, |
|
stride=1, |
|
padding=0, |
|
output_padding=0, |
|
bias=True |
|
): |
|
super().__init__() |
|
|
|
self.weight = nn.Parameter(torch.randn(in_channel, out_channel, kernel_size, kernel_size)) |
|
self.scale = 1 / math.sqrt(in_channel * kernel_size**2) |
|
|
|
self.stride = stride |
|
self.padding = padding |
|
self.output_padding = output_padding |
|
|
|
if bias: |
|
self.bias = nn.Parameter(torch.zeros(out_channel)) |
|
else: |
|
self.bias = None |
|
|
|
def forward(self, input): |
|
out = F.conv_transpose2d( |
|
input, |
|
self.weight * self.scale, |
|
bias=self.bias, |
|
stride=self.stride, |
|
padding=self.padding, |
|
output_padding=self.output_padding, |
|
) |
|
return out |
|
|
|
def __repr__(self): |
|
return ( |
|
f'{self.__class__.__name__}({self.weight.shape[0]}, {self.weight.shape[1]},' |
|
f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})' |
|
) |
|
|
|
|
|
class ConvLayer2d(nn.Sequential): |
|
def __init__( |
|
self, |
|
in_channel, |
|
out_channel, |
|
kernel_size=3, |
|
upsample=False, |
|
downsample=False, |
|
blur_kernel=[1, 3, 3, 1], |
|
bias=True, |
|
activate=True, |
|
): |
|
assert not (upsample and downsample), 'Cannot upsample and downsample simultaneously' |
|
layers = [] |
|
|
|
if upsample: |
|
factor = 2 |
|
p = (len(blur_kernel) - factor) - (kernel_size - 1) |
|
pad0 = (p + 1) // 2 + factor - 1 |
|
pad1 = p // 2 + 1 |
|
|
|
layers.append( |
|
EqualConvTranspose2d( |
|
in_channel, |
|
out_channel, |
|
kernel_size, |
|
padding=0, |
|
stride=2, |
|
bias=bias and not activate |
|
) |
|
) |
|
layers.append(Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)) |
|
|
|
if downsample: |
|
factor = 2 |
|
p = (len(blur_kernel) - factor) + (kernel_size - 1) |
|
pad0 = (p + 1) // 2 |
|
pad1 = p // 2 |
|
|
|
layers.append(Blur(blur_kernel, pad=(pad0, pad1))) |
|
layers.append( |
|
EqualConv2d( |
|
in_channel, |
|
out_channel, |
|
kernel_size, |
|
padding=0, |
|
stride=2, |
|
bias=bias and not activate |
|
) |
|
) |
|
|
|
if (not downsample) and (not upsample): |
|
padding = kernel_size // 2 |
|
|
|
layers.append( |
|
EqualConv2d( |
|
in_channel, |
|
out_channel, |
|
kernel_size, |
|
padding=padding, |
|
stride=1, |
|
bias=bias and not activate |
|
) |
|
) |
|
|
|
if activate: |
|
layers.append(FusedLeakyReLU(out_channel, bias=bias)) |
|
|
|
super().__init__(*layers) |
|
|
|
|
|
class ConvResBlock2d(nn.Module): |
|
"""2D convolutional residual block with equalized learning rate. |
|
|
|
Residual block composed of 3x3 convolutions and leaky ReLUs. |
|
|
|
Args: |
|
---- |
|
in_channel: int |
|
Input channels. |
|
out_channel: int |
|
Output channels. |
|
upsample: bool |
|
Apply upsampling via strided convolution in the first conv. |
|
downsample: bool |
|
Apply downsampling via strided convolution in the second conv. |
|
|
|
""" |
|
def __init__(self, in_channel, out_channel, upsample=False, downsample=False): |
|
super().__init__() |
|
|
|
assert not (upsample and downsample), 'Cannot upsample and downsample simultaneously' |
|
mid_ch = in_channel if downsample else out_channel |
|
|
|
self.conv1 = ConvLayer2d(in_channel, mid_ch, upsample=upsample, kernel_size=3) |
|
self.conv2 = ConvLayer2d(mid_ch, out_channel, downsample=downsample, kernel_size=3) |
|
|
|
if (in_channel != out_channel) or upsample or downsample: |
|
self.skip = ConvLayer2d( |
|
in_channel, |
|
out_channel, |
|
upsample=upsample, |
|
downsample=downsample, |
|
kernel_size=1, |
|
activate=False, |
|
bias=False, |
|
) |
|
|
|
def forward(self, input): |
|
out = self.conv1(input) |
|
out = self.conv2(out) |
|
|
|
if hasattr(self, 'skip'): |
|
skip = self.skip(input) |
|
out = (out + skip) / math.sqrt(2) |
|
else: |
|
out = (out + input) / math.sqrt(2) |
|
return out |
|
|