|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import functools |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.autograd import grad |
|
from torch.nn import init |
|
|
|
|
|
def gradient(inputs, outputs): |
|
d_points = torch.ones_like(outputs, requires_grad=False, device=outputs.device) |
|
points_grad = grad( |
|
outputs=outputs, |
|
inputs=inputs, |
|
grad_outputs=d_points, |
|
create_graph=True, |
|
retain_graph=True, |
|
only_inputs=True, |
|
allow_unused=True, |
|
)[0] |
|
return points_grad |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def conv3x3(in_planes, out_planes, kernel=3, strd=1, dilation=1, padding=1, bias=False): |
|
"3x3 convolution with padding" |
|
return nn.Conv2d( |
|
in_planes, |
|
out_planes, |
|
kernel_size=kernel, |
|
dilation=dilation, |
|
stride=strd, |
|
padding=padding, |
|
bias=bias, |
|
) |
|
|
|
|
|
def conv1x1(in_planes, out_planes, stride=1): |
|
"""1x1 convolution""" |
|
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) |
|
|
|
|
|
def init_weights(net, init_type="normal", init_gain=0.02): |
|
"""Initialize network weights. |
|
|
|
Parameters: |
|
net (network) -- network to be initialized |
|
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal |
|
init_gain (float) -- scaling factor for normal, xavier and orthogonal. |
|
|
|
We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might |
|
work better for some applications. Feel free to try yourself. |
|
""" |
|
def init_func(m): |
|
classname = m.__class__.__name__ |
|
if hasattr(m, |
|
"weight") and (classname.find("Conv") != -1 or classname.find("Linear") != -1): |
|
if init_type == "normal": |
|
init.normal_(m.weight.data, 0.0, init_gain) |
|
elif init_type == "xavier": |
|
init.xavier_normal_(m.weight.data, gain=init_gain) |
|
elif init_type == "kaiming": |
|
init.kaiming_normal_(m.weight.data, a=0, mode="fan_in") |
|
elif init_type == "orthogonal": |
|
init.orthogonal_(m.weight.data, gain=init_gain) |
|
else: |
|
raise NotImplementedError( |
|
"initialization method [%s] is not implemented" % init_type |
|
) |
|
if hasattr(m, "bias") and m.bias is not None: |
|
init.constant_(m.bias.data, 0.0) |
|
elif ( |
|
classname.find("BatchNorm2d") != -1 |
|
): |
|
init.normal_(m.weight.data, 1.0, init_gain) |
|
init.constant_(m.bias.data, 0.0) |
|
|
|
|
|
net.apply(init_func) |
|
|
|
|
|
def init_net(net, init_type="xavier", init_gain=0.02, gpu_ids=[]): |
|
"""Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights |
|
Parameters: |
|
net (network) -- the network to be initialized |
|
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal |
|
gain (float) -- scaling factor for normal, xavier and orthogonal. |
|
gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 |
|
|
|
Return an initialized network. |
|
""" |
|
if len(gpu_ids) > 0: |
|
assert torch.cuda.is_available() |
|
net = torch.nn.DataParallel(net) |
|
init_weights(net, init_type, init_gain=init_gain) |
|
return net |
|
|
|
|
|
def imageSpaceRotation(xy, rot): |
|
""" |
|
args: |
|
xy: (B, 2, N) input |
|
rot: (B, 2) x,y axis rotation angles |
|
|
|
rotation center will be always image center (other rotation center can be represented by additional z translation) |
|
""" |
|
disp = rot.unsqueeze(2).sin().expand_as(xy) |
|
return (disp * xy).sum(dim=1) |
|
|
|
|
|
def cal_gradient_penalty( |
|
netD, real_data, fake_data, device, type="mixed", constant=1.0, lambda_gp=10.0 |
|
): |
|
"""Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028 |
|
|
|
Arguments: |
|
netD (network) -- discriminator network |
|
real_data (tensor array) -- real images |
|
fake_data (tensor array) -- generated images from the generator |
|
device (str) -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') |
|
type (str) -- if we mix real and fake data or not [real | fake | mixed]. |
|
constant (float) -- the constant used in formula ( | |gradient||_2 - constant)^2 |
|
lambda_gp (float) -- weight for this loss |
|
|
|
Returns the gradient penalty loss |
|
""" |
|
if lambda_gp > 0.0: |
|
|
|
if type == "real": |
|
interpolatesv = real_data |
|
elif type == "fake": |
|
interpolatesv = fake_data |
|
elif type == "mixed": |
|
alpha = torch.rand(real_data.shape[0], 1) |
|
alpha = ( |
|
alpha.expand(real_data.shape[0], |
|
real_data.nelement() // |
|
real_data.shape[0]).contiguous().view(*real_data.shape) |
|
) |
|
alpha = alpha.to(device) |
|
interpolatesv = alpha * real_data + ((1 - alpha) * fake_data) |
|
else: |
|
raise NotImplementedError("{} not implemented".format(type)) |
|
interpolatesv.requires_grad_(True) |
|
disc_interpolates = netD(interpolatesv) |
|
gradients = torch.autograd.grad( |
|
outputs=disc_interpolates, |
|
inputs=interpolatesv, |
|
grad_outputs=torch.ones(disc_interpolates.size()).to(device), |
|
create_graph=True, |
|
retain_graph=True, |
|
only_inputs=True, |
|
) |
|
gradients = gradients[0].view(real_data.size(0), -1) |
|
gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant)** |
|
2).mean() * lambda_gp |
|
return gradient_penalty, gradients |
|
else: |
|
return 0.0, None |
|
|
|
|
|
def get_norm_layer(norm_type="instance"): |
|
"""Return a normalization layer |
|
Parameters: |
|
norm_type (str) -- the name of the normalization layer: batch | instance | none |
|
For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev). |
|
For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics. |
|
""" |
|
if norm_type == "batch": |
|
norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True) |
|
elif norm_type == "instance": |
|
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False) |
|
elif norm_type == "group": |
|
norm_layer = functools.partial(nn.GroupNorm, 32) |
|
elif norm_type == "none": |
|
norm_layer = None |
|
else: |
|
raise NotImplementedError("normalization layer [%s] is not found" % norm_type) |
|
return norm_layer |
|
|
|
|
|
class Flatten(nn.Module): |
|
def forward(self, input): |
|
return input.view(input.size(0), -1) |
|
|
|
|
|
class ConvBlock(nn.Module): |
|
def __init__(self, in_planes, out_planes, opt): |
|
super(ConvBlock, self).__init__() |
|
[k, s, d, p] = opt.conv3x3 |
|
self.conv1 = conv3x3(in_planes, int(out_planes / 2), k, s, d, p) |
|
self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4), k, s, d, p) |
|
self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4), k, s, d, p) |
|
|
|
if opt.norm == "batch": |
|
self.bn1 = nn.BatchNorm2d(in_planes) |
|
self.bn2 = nn.BatchNorm2d(int(out_planes / 2)) |
|
self.bn3 = nn.BatchNorm2d(int(out_planes / 4)) |
|
self.bn4 = nn.BatchNorm2d(in_planes) |
|
elif opt.norm == "group": |
|
self.bn1 = nn.GroupNorm(32, in_planes) |
|
self.bn2 = nn.GroupNorm(32, int(out_planes / 2)) |
|
self.bn3 = nn.GroupNorm(32, int(out_planes / 4)) |
|
self.bn4 = nn.GroupNorm(32, in_planes) |
|
|
|
if in_planes != out_planes: |
|
self.downsample = nn.Sequential( |
|
self.bn4, |
|
nn.ReLU(True), |
|
nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, bias=False), |
|
) |
|
else: |
|
self.downsample = None |
|
|
|
def forward(self, x): |
|
residual = x |
|
|
|
out1 = self.bn1(x) |
|
out1 = F.relu(out1, True) |
|
out1 = self.conv1(out1) |
|
|
|
out2 = self.bn2(out1) |
|
out2 = F.relu(out2, True) |
|
out2 = self.conv2(out2) |
|
|
|
out3 = self.bn3(out2) |
|
out3 = F.relu(out3, True) |
|
out3 = self.conv3(out3) |
|
|
|
out3 = torch.cat((out1, out2, out3), 1) |
|
|
|
if self.downsample is not None: |
|
residual = self.downsample(residual) |
|
|
|
out3 += residual |
|
|
|
return out3 |
|
|