# Copyright (c) OpenMMLab. All rights reserved. import torch.nn as nn class GANLoss(nn.Module): """Define GAN loss. Args: gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'. real_label_val (float): The value for real label. Default: 1.0. fake_label_val (float): The value for fake label. Default: 0.0. loss_weight (float): Loss weight. Default: 1.0. Note that loss_weight is only for generators; and it is always 1.0 for discriminators. """ def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0): super().__init__() self.gan_type = gan_type self.loss_weight = loss_weight self.real_label_val = real_label_val self.fake_label_val = fake_label_val if self.gan_type == 'vanilla': self.loss = nn.BCEWithLogitsLoss() elif self.gan_type == 'lsgan': self.loss = nn.MSELoss() elif self.gan_type == 'wgan': self.loss = self._wgan_loss elif self.gan_type == 'hinge': self.loss = nn.ReLU() else: raise NotImplementedError( f'GAN type {self.gan_type} is not implemented.') @staticmethod def _wgan_loss(input, target): """wgan loss. Args: input (Tensor): Input tensor. target (bool): Target label. Returns: Tensor: wgan loss. """ return -input.mean() if target else input.mean() def get_target_label(self, input, target_is_real): """Get target label. Args: input (Tensor): Input tensor. target_is_real (bool): Whether the target is real or fake. Returns: (bool | Tensor): Target tensor. Return bool for wgan, otherwise, return Tensor. """ if self.gan_type == 'wgan': return target_is_real target_val = (self.real_label_val if target_is_real else self.fake_label_val) return input.new_ones(input.size()) * target_val def forward(self, input, target_is_real, is_disc=False): """ Args: input (Tensor): The input for the loss module, i.e., the network prediction. target_is_real (bool): Whether the targe is real or fake. is_disc (bool): Whether the loss for discriminators or not. Default: False. Returns: Tensor: GAN loss value. """ target_label = self.get_target_label(input, target_is_real) if self.gan_type == 'hinge': if is_disc: # for discriminators in hinge-gan input = -input if target_is_real else input loss = self.loss(1 + input).mean() else: # for generators in hinge-gan loss = -input.mean() else: # other gan types loss = self.loss(input, target_label) # loss_weight is always 1.0 for discriminators return loss if is_disc else loss * self.loss_weight