Spaces:
Sleeping
Sleeping
# Copyright (c) OpenMMLab. All rights reserved. | |
import torch.nn as nn | |
class GANLoss(nn.Module): | |
"""Define GAN loss. | |
Args: | |
gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'. | |
real_label_val (float): The value for real label. Default: 1.0. | |
fake_label_val (float): The value for fake label. Default: 0.0. | |
loss_weight (float): Loss weight. Default: 1.0. | |
Note that loss_weight is only for generators; and it is always 1.0 | |
for discriminators. | |
""" | |
def __init__(self, | |
gan_type, | |
real_label_val=1.0, | |
fake_label_val=0.0, | |
loss_weight=1.0): | |
super().__init__() | |
self.gan_type = gan_type | |
self.loss_weight = loss_weight | |
self.real_label_val = real_label_val | |
self.fake_label_val = fake_label_val | |
if self.gan_type == 'vanilla': | |
self.loss = nn.BCEWithLogitsLoss() | |
elif self.gan_type == 'lsgan': | |
self.loss = nn.MSELoss() | |
elif self.gan_type == 'wgan': | |
self.loss = self._wgan_loss | |
elif self.gan_type == 'hinge': | |
self.loss = nn.ReLU() | |
else: | |
raise NotImplementedError( | |
f'GAN type {self.gan_type} is not implemented.') | |
def _wgan_loss(input, target): | |
"""wgan loss. | |
Args: | |
input (Tensor): Input tensor. | |
target (bool): Target label. | |
Returns: | |
Tensor: wgan loss. | |
""" | |
return -input.mean() if target else input.mean() | |
def get_target_label(self, input, target_is_real): | |
"""Get target label. | |
Args: | |
input (Tensor): Input tensor. | |
target_is_real (bool): Whether the target is real or fake. | |
Returns: | |
(bool | Tensor): Target tensor. Return bool for wgan, otherwise, | |
return Tensor. | |
""" | |
if self.gan_type == 'wgan': | |
return target_is_real | |
target_val = (self.real_label_val | |
if target_is_real else self.fake_label_val) | |
return input.new_ones(input.size()) * target_val | |
def forward(self, input, target_is_real, is_disc=False): | |
""" | |
Args: | |
input (Tensor): The input for the loss module, i.e., the network | |
prediction. | |
target_is_real (bool): Whether the targe is real or fake. | |
is_disc (bool): Whether the loss for discriminators or not. | |
Default: False. | |
Returns: | |
Tensor: GAN loss value. | |
""" | |
target_label = self.get_target_label(input, target_is_real) | |
if self.gan_type == 'hinge': | |
if is_disc: # for discriminators in hinge-gan | |
input = -input if target_is_real else input | |
loss = self.loss(1 + input).mean() | |
else: # for generators in hinge-gan | |
loss = -input.mean() | |
else: # other gan types | |
loss = self.loss(input, target_label) | |
# loss_weight is always 1.0 for discriminators | |
return loss if is_disc else loss * self.loss_weight | |