File size: 8,593 Bytes
18dd6ad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
from typing import Tuple, Dict, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
class BaseAdversarialLoss:
def pre_generator_step(self, real_batch: torch.Tensor, fake_batch: torch.Tensor,
generator: nn.Module, discriminator: nn.Module):
"""
Prepare for generator step
:param real_batch: Tensor, a batch of real samples
:param fake_batch: Tensor, a batch of samples produced by generator
:param generator:
:param discriminator:
:return: None
"""
def pre_discriminator_step(self, real_batch: torch.Tensor, fake_batch: torch.Tensor,
generator: nn.Module, discriminator: nn.Module):
"""
Prepare for discriminator step
:param real_batch: Tensor, a batch of real samples
:param fake_batch: Tensor, a batch of samples produced by generator
:param generator:
:param discriminator:
:return: None
"""
def generator_loss(self, real_batch: torch.Tensor, fake_batch: torch.Tensor,
discr_real_pred: torch.Tensor, discr_fake_pred: torch.Tensor,
mask: Optional[torch.Tensor] = None) \
-> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""
Calculate generator loss
:param real_batch: Tensor, a batch of real samples
:param fake_batch: Tensor, a batch of samples produced by generator
:param discr_real_pred: Tensor, discriminator output for real_batch
:param discr_fake_pred: Tensor, discriminator output for fake_batch
:param mask: Tensor, actual mask, which was at input of generator when making fake_batch
:return: total generator loss along with some values that might be interesting to log
"""
raise NotImplemented()
def discriminator_loss(self, real_batch: torch.Tensor, fake_batch: torch.Tensor,
discr_real_pred: torch.Tensor, discr_fake_pred: torch.Tensor,
mask: Optional[torch.Tensor] = None) \
-> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""
Calculate discriminator loss and call .backward() on it
:param real_batch: Tensor, a batch of real samples
:param fake_batch: Tensor, a batch of samples produced by generator
:param discr_real_pred: Tensor, discriminator output for real_batch
:param discr_fake_pred: Tensor, discriminator output for fake_batch
:param mask: Tensor, actual mask, which was at input of generator when making fake_batch
:return: total discriminator loss along with some values that might be interesting to log
"""
raise NotImplemented()
def interpolate_mask(self, mask, shape):
assert mask is not None
assert self.allow_scale_mask or shape == mask.shape[-2:]
if shape != mask.shape[-2:] and self.allow_scale_mask:
if self.mask_scale_mode == 'maxpool':
mask = F.adaptive_max_pool2d(mask, shape)
else:
mask = F.interpolate(mask, size=shape, mode=self.mask_scale_mode)
return mask
def make_r1_gp(discr_real_pred, real_batch):
if torch.is_grad_enabled():
grad_real = torch.autograd.grad(outputs=discr_real_pred.sum(), inputs=real_batch, create_graph=True)[0]
grad_penalty = (grad_real.view(grad_real.shape[0], -1).norm(2, dim=1) ** 2).mean()
else:
grad_penalty = 0
real_batch.requires_grad = False
return grad_penalty
class NonSaturatingWithR1(BaseAdversarialLoss):
def __init__(self, gp_coef=5, weight=1, mask_as_fake_target=False, allow_scale_mask=False,
mask_scale_mode='nearest', extra_mask_weight_for_gen=0,
use_unmasked_for_gen=True, use_unmasked_for_discr=True):
self.gp_coef = gp_coef
self.weight = weight
# use for discr => use for gen;
# otherwise we teach only the discr to pay attention to very small difference
assert use_unmasked_for_gen or (not use_unmasked_for_discr)
# mask as target => use unmasked for discr:
# if we don't care about unmasked regions at all
# then it doesn't matter if the value of mask_as_fake_target is true or false
assert use_unmasked_for_discr or (not mask_as_fake_target)
self.use_unmasked_for_gen = use_unmasked_for_gen
self.use_unmasked_for_discr = use_unmasked_for_discr
self.mask_as_fake_target = mask_as_fake_target
self.allow_scale_mask = allow_scale_mask
self.mask_scale_mode = mask_scale_mode
self.extra_mask_weight_for_gen = extra_mask_weight_for_gen
def generator_loss(self, real_batch: torch.Tensor, fake_batch: torch.Tensor,
discr_real_pred: torch.Tensor, discr_fake_pred: torch.Tensor,
mask=None) \
-> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
fake_loss = F.softplus(-discr_fake_pred)
if (self.mask_as_fake_target and self.extra_mask_weight_for_gen > 0) or \
not self.use_unmasked_for_gen: # == if masked region should be treated differently
mask = self.interpolate_mask(mask, discr_fake_pred.shape[-2:])
if not self.use_unmasked_for_gen:
fake_loss = fake_loss * mask
else:
pixel_weights = 1 + mask * self.extra_mask_weight_for_gen
fake_loss = fake_loss * pixel_weights
return fake_loss.mean() * self.weight, dict()
def pre_discriminator_step(self, real_batch: torch.Tensor, fake_batch: torch.Tensor,
generator: nn.Module, discriminator: nn.Module):
real_batch.requires_grad = True
def discriminator_loss(self, real_batch: torch.Tensor, fake_batch: torch.Tensor,
discr_real_pred: torch.Tensor, discr_fake_pred: torch.Tensor,
mask=None) \
-> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
real_loss = F.softplus(-discr_real_pred)
grad_penalty = make_r1_gp(discr_real_pred, real_batch) * self.gp_coef
fake_loss = F.softplus(discr_fake_pred)
if not self.use_unmasked_for_discr or self.mask_as_fake_target:
# == if masked region should be treated differently
mask = self.interpolate_mask(mask, discr_fake_pred.shape[-2:])
# use_unmasked_for_discr=False only makes sense for fakes;
# for reals there is no difference beetween two regions
fake_loss = fake_loss * mask
if self.mask_as_fake_target:
fake_loss = fake_loss + (1 - mask) * F.softplus(-discr_fake_pred)
sum_discr_loss = real_loss + grad_penalty + fake_loss
metrics = dict(discr_real_out=discr_real_pred.mean(),
discr_fake_out=discr_fake_pred.mean(),
discr_real_gp=grad_penalty)
return sum_discr_loss.mean(), metrics
class BCELoss(BaseAdversarialLoss):
def __init__(self, weight):
self.weight = weight
self.bce_loss = nn.BCEWithLogitsLoss()
def generator_loss(self, discr_fake_pred: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
real_mask_gt = torch.zeros(discr_fake_pred.shape).to(discr_fake_pred.device)
fake_loss = self.bce_loss(discr_fake_pred, real_mask_gt) * self.weight
return fake_loss, dict()
def pre_discriminator_step(self, real_batch: torch.Tensor, fake_batch: torch.Tensor,
generator: nn.Module, discriminator: nn.Module):
real_batch.requires_grad = True
def discriminator_loss(self,
mask: torch.Tensor,
discr_real_pred: torch.Tensor,
discr_fake_pred: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
real_mask_gt = torch.zeros(discr_real_pred.shape).to(discr_real_pred.device)
sum_discr_loss = (self.bce_loss(discr_real_pred, real_mask_gt) + self.bce_loss(discr_fake_pred, mask)) / 2
metrics = dict(discr_real_out=discr_real_pred.mean(),
discr_fake_out=discr_fake_pred.mean(),
discr_real_gp=0)
return sum_discr_loss, metrics
def make_discrim_loss(kind, **kwargs):
if kind == 'r1':
return NonSaturatingWithR1(**kwargs)
elif kind == 'bce':
return BCELoss(**kwargs)
raise ValueError(f'Unknown adversarial loss kind {kind}')
|