|
import torch |
|
from .cut_model import CUTModel |
|
|
|
|
|
class SinCUTModel(CUTModel): |
|
""" This class implements the single image translation model (Fig 9) of |
|
Contrastive Learning for Unpaired Image-to-Image Translation |
|
Taesung Park, Alexei A. Efros, Richard Zhang, Jun-Yan Zhu |
|
ECCV, 2020 |
|
""" |
|
|
|
@staticmethod |
|
def modify_commandline_options(parser, is_train=True): |
|
parser = CUTModel.modify_commandline_options(parser, is_train) |
|
parser.add_argument('--lambda_R1', type=float, default=1.0, |
|
help='weight for the R1 gradient penalty') |
|
parser.add_argument('--lambda_identity', type=float, default=1.0, |
|
help='the "identity preservation loss"') |
|
|
|
parser.set_defaults(nce_includes_all_negatives_from_minibatch=True, |
|
dataset_mode="singleimage", |
|
netG="stylegan2", |
|
stylegan2_G_num_downsampling=1, |
|
netD="stylegan2", |
|
gan_mode="nonsaturating", |
|
num_patches=1, |
|
nce_layers="0,2,4", |
|
lambda_NCE=4.0, |
|
ngf=10, |
|
ndf=8, |
|
lr=0.002, |
|
beta1=0.0, |
|
beta2=0.99, |
|
load_size=1024, |
|
crop_size=64, |
|
preprocess="zoom_and_patch", |
|
) |
|
|
|
if is_train: |
|
parser.set_defaults(preprocess="zoom_and_patch", |
|
batch_size=16, |
|
save_epoch_freq=1, |
|
save_latest_freq=20000, |
|
n_epochs=8, |
|
n_epochs_decay=8, |
|
|
|
) |
|
else: |
|
parser.set_defaults(preprocess="none", |
|
batch_size=1, |
|
num_test=1, |
|
) |
|
|
|
return parser |
|
|
|
def __init__(self, opt): |
|
super().__init__(opt) |
|
if self.isTrain: |
|
if opt.lambda_R1 > 0.0: |
|
self.loss_names += ['D_R1'] |
|
if opt.lambda_identity > 0.0: |
|
self.loss_names += ['idt'] |
|
|
|
def compute_D_loss(self): |
|
self.real_B.requires_grad_() |
|
GAN_loss_D = super().compute_D_loss() |
|
self.loss_D_R1 = self.R1_loss(self.pred_real, self.real_B) |
|
self.loss_D = GAN_loss_D + self.loss_D_R1 |
|
return self.loss_D |
|
|
|
def compute_G_loss(self): |
|
CUT_loss_G = super().compute_G_loss() |
|
self.loss_idt = torch.nn.functional.l1_loss(self.idt_B, self.real_B) * self.opt.lambda_identity |
|
return CUT_loss_G + self.loss_idt |
|
|
|
def R1_loss(self, real_pred, real_img): |
|
grad_real, = torch.autograd.grad(outputs=real_pred.sum(), inputs=real_img, create_graph=True, retain_graph=True) |
|
grad_penalty = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean() |
|
return grad_penalty * (self.opt.lambda_R1 * 0.5) |
|
|