Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import numpy as np | |
import torch.nn.functional as F | |
class VanillaConv2d(nn.Module): | |
def __init__( | |
self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, | |
groups=1, bias=True, norm="SN", activation=nn.LeakyReLU(0.2, inplace=True) | |
): | |
super().__init__() | |
if padding == -1: | |
if isinstance(kernel_size, int): | |
kernel_size = (kernel_size, kernel_size) | |
if isinstance(dilation, int): | |
dilation = (dilation, dilation) | |
self.padding = tuple(((np.array(kernel_size) - 1) * np.array(dilation)) // 2) if padding == -1 else padding | |
self.featureConv = nn.Conv2d( | |
in_channels, out_channels, kernel_size, | |
stride, self.padding, dilation, groups, bias) | |
self.norm = norm | |
if norm == "BN": | |
self.norm_layer = nn.BatchNorm2d(out_channels) | |
elif norm == "IN": | |
self.norm_layer = nn.InstanceNorm2d(out_channels, track_running_stats=True) | |
elif norm == "SN": | |
self.norm = None | |
self.featureConv = nn.utils.spectral_norm(self.featureConv) | |
else: | |
self.norm = None | |
self.activation = activation | |
def forward(self, xs): | |
out = self.featureConv(xs) | |
if self.activation: | |
out = self.activation(out) | |
if self.norm is not None: | |
out = self.norm_layer(out) | |
return out | |
class VanillaDeconv2d(nn.Module): | |
def __init__( | |
self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, | |
groups=1, bias=True, norm="SN", activation=nn.LeakyReLU(0.2, inplace=True), | |
scale_factor=2 | |
): | |
super().__init__() | |
self.conv = VanillaConv2d( | |
in_channels, out_channels, kernel_size, stride, padding, dilation, | |
groups, bias, norm, activation) | |
self.scale_factor = scale_factor | |
def forward(self, xs): | |
xs_resized = F.interpolate(xs, scale_factor=self.scale_factor) | |
return self.conv(xs_resized) | |
class GatedConv2d(VanillaConv2d): | |
def __init__( | |
self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, | |
groups=1, bias=True, norm="SN", activation=nn.LeakyReLU(0.2, inplace=True) | |
): | |
super().__init__( | |
in_channels, out_channels, kernel_size, stride, padding, dilation, | |
groups, bias, norm, activation | |
) | |
self.gatingConv = nn.Conv2d( | |
in_channels, out_channels, kernel_size, | |
stride, self.padding, dilation, groups, bias) | |
if norm == 'SN': | |
self.gatingConv = nn.utils.spectral_norm(self.gatingConv) | |
self.sigmoid = nn.Sigmoid() | |
self.store_gated_values = False | |
def gated(self, mask): | |
# return torch.clamp(mask, -1, 1) | |
out = self.sigmoid(mask) | |
if self.store_gated_values: | |
self.gated_values = out.detach().cpu() | |
return out | |
def forward(self, xs): | |
gating = self.gatingConv(xs) | |
feature = self.featureConv(xs) | |
if self.activation: | |
feature = self.activation(feature) | |
out = self.gated(gating) * feature | |
if self.norm is not None: | |
out = self.norm_layer(out) | |
return out | |
class GatedDeconv2d(VanillaDeconv2d): | |
def __init__( | |
self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, | |
groups=1, bias=True, norm="SN", activation=nn.LeakyReLU(0.2, inplace=True), | |
scale_factor=2 | |
): | |
super().__init__( | |
in_channels, out_channels, kernel_size, stride, padding, dilation, | |
groups, bias, norm, activation, scale_factor | |
) | |
self.conv = GatedConv2d( | |
in_channels, out_channels, kernel_size, stride, padding, dilation, | |
groups, bias, norm, activation) | |
class PartialConv2d(VanillaConv2d): | |
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, | |
groups=1, bias=True, norm="SN", activation=nn.LeakyReLU(0.2, inplace=True)): | |
super().__init__( | |
in_channels, out_channels, kernel_size, stride, padding, dilation, | |
groups, bias, norm, activation | |
) | |
self.mask_sum_conv = nn.Conv2d(1, 1, kernel_size, | |
stride, self.padding, dilation, groups, False) | |
nn.init.constant_(self.mask_sum_conv.weight, 1.0) | |
# mask conv needs not update | |
for param in self.mask_sum_conv.parameters(): | |
param.requires_grad = False | |
def forward(self, input_tuple): | |
# http://masc.cs.gmu.edu/wiki/partialconv | |
# C(X) = W^T * X + b, C(0) = b, D(M) = 1 * M + 0 = sum(M) | |
# output = W^T* (M .* X) / sum(M) + b = [C(M .* X) – C(0)] / D(M) + C(0), if sum(M) != 0 | |
# = 0, if sum(M) == 0 | |
inp, mask = input_tuple | |
# print(inp.shape, mask.shape) | |
# C(M .* X) | |
output = self.featureConv(mask * inp) | |
# C(0) = b | |
if self.featureConv.bias is not None: | |
output_bias = self.featureConv.bias.view(1, -1, 1, 1) | |
else: | |
output_bias = torch.zeros([1, 1, 1, 1]).to(inp.device) | |
# D(M) = sum(M) | |
with torch.no_grad(): | |
mask_sum = self.mask_sum_conv(mask) | |
# find those sum(M) == 0 | |
no_update_holes = (mask_sum == 0) | |
# Just to prevent devided by 0 | |
mask_sum_no_zero = mask_sum.masked_fill_(no_update_holes, 1.0) | |
# output = [C(M .* X) – C(0)] / D(M) + C(0), if sum(M) != 0 | |
# = 0, if sum (M) == 0 | |
output = (output - output_bias) / mask_sum_no_zero + output_bias | |
output = output.masked_fill_(no_update_holes, 0.0) | |
# create a new mask with only 1 or 0 | |
new_mask = torch.ones_like(mask_sum) | |
new_mask = new_mask.masked_fill_(no_update_holes, 0.0) | |
if self.activation is not None: | |
output = self.activation(output) | |
if self.norm is not None: | |
output = self.norm_layer(output) | |
return output, new_mask | |
class PartialDeconv2d(nn.Module): | |
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, | |
groups=1, bias=True, norm="SN", activation=nn.LeakyReLU(0.2, inplace=True), | |
scale_factor=2): | |
super().__init__() | |
self.conv = PartialConv2d( | |
in_channels, out_channels, kernel_size, stride, padding, dilation, | |
groups, bias, norm, activation) | |
self.scale_factor = scale_factor | |
def forward(self, input_tuple): | |
inp, mask = input_tuple | |
inp_resized = F.interpolate(inp, scale_factor=self.scale_factor) | |
with torch.no_grad(): | |
mask_resized = F.interpolate(mask, scale_factor=self.scale_factor) | |
return self.conv((inp_resized, mask_resized)) | |