|
from typing import List |
|
|
|
import torch |
|
import torch.nn as nn |
|
from lightning import LightningModule |
|
|
|
import wandb |
|
from .blocks import Encoder, Decoder, GaussianDistribution |
|
from torchmetrics import ( |
|
MeanSquaredError, |
|
MeanMetric, |
|
) |
|
from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure |
|
from lpips import LPIPS |
|
|
|
from swim.models.discriminators import NLayerDiscriminator, weights_init, hinge_d_loss |
|
|
|
|
|
class Autoencoder(LightningModule): |
|
|
|
def __init__( |
|
self, |
|
channels: int, |
|
channel_multipliers: List[int], |
|
n_resnet_blocks: int, |
|
in_channels: int, |
|
out_channels: int, |
|
z_channels: int, |
|
emb_channels: int, |
|
base_learning_rate: float, |
|
kl_weight: float, |
|
): |
|
super().__init__() |
|
|
|
self.save_hyperparameters(logger=False) |
|
|
|
self.automatic_optimization = False |
|
|
|
self.kl_weight = kl_weight |
|
|
|
self.encoder = Encoder( |
|
channels=channels, |
|
channel_multipliers=channel_multipliers, |
|
n_resnet_blocks=n_resnet_blocks, |
|
in_channels=in_channels, |
|
z_channels=z_channels, |
|
) |
|
|
|
self.decoder = Decoder( |
|
channels=channels, |
|
channel_multipliers=channel_multipliers, |
|
n_resnet_blocks=n_resnet_blocks, |
|
out_channels=out_channels, |
|
z_channels=z_channels, |
|
) |
|
|
|
self.quant_conv = nn.Conv2d(2 * z_channels, 2 * emb_channels, 1) |
|
|
|
|
|
self.post_quant_conv = nn.Conv2d(emb_channels, z_channels, 1) |
|
|
|
self.lpips = LPIPS(net="vgg") |
|
|
|
self.discriminator = NLayerDiscriminator( |
|
input_nc=3, n_layers=3, use_actnorm=False |
|
).apply(weights_init) |
|
|
|
self.train_psnr = PeakSignalNoiseRatio() |
|
self.train_ssim = StructuralSimilarityIndexMeasure() |
|
|
|
self.val_psnr = PeakSignalNoiseRatio() |
|
self.val_ssim = StructuralSimilarityIndexMeasure() |
|
self.val_mse = MeanSquaredError() |
|
self.val_lpips = MeanMetric() |
|
|
|
def encode(self, img: torch.Tensor) -> GaussianDistribution: |
|
z = self.encoder(img) |
|
|
|
moments = self.quant_conv(z) |
|
|
|
return GaussianDistribution(moments) |
|
|
|
def decode(self, z: torch.Tensor): |
|
|
|
z = self.post_quant_conv(z) |
|
|
|
return self.decoder(z) |
|
|
|
def forward(self, img: torch.Tensor, sample: bool = False): |
|
|
|
|
|
z_dis = self.encode(img) |
|
if sample: |
|
z = z_dis.sample() |
|
else: |
|
z = z_dis.mean |
|
|
|
return self.decode(z) |
|
|
|
def training_step(self, batch, batch_idx): |
|
ae_opt, d_opt = self.optimizers() |
|
|
|
|
|
|
|
|
|
img = batch["images"] |
|
z_dis = self.encode(img) |
|
z = z_dis.sample() |
|
recon = self.decode(z) |
|
|
|
|
|
B = img.shape[0] |
|
l1_loss = torch.abs(img - recon).sum() / B |
|
lpips_loss = self.lpips.forward(recon, img).sum() / B |
|
kl_loss = z_dis.kl().mean() |
|
|
|
logit_fake = self.discriminator(recon.contiguous()) |
|
g_loss = -logit_fake.mean() |
|
|
|
d_weight = self.calculate_adaptive_weight(l1_loss, g_loss) |
|
|
|
total_loss = l1_loss + lpips_loss + self.kl_weight * kl_loss + d_weight * g_loss |
|
|
|
ae_opt.zero_grad() |
|
self.manual_backward(total_loss) |
|
ae_opt.step() |
|
|
|
self.train_psnr(recon, img) |
|
self.train_ssim(recon, img) |
|
|
|
|
|
self.log("train/l1_loss", l1_loss.item(), on_step=True, prog_bar=True) |
|
self.log("train/kl_loss", kl_loss.item(), on_step=True, prog_bar=True) |
|
self.log("train/lpips_loss", lpips_loss.item(), on_step=True, prog_bar=True) |
|
self.log("train/psnr", self.train_psnr, on_step=True, prog_bar=True) |
|
self.log("train/ssim", self.train_ssim, on_step=True, prog_bar=True) |
|
self.log("train/g_loss", g_loss.item(), on_step=True, prog_bar=True) |
|
self.log("train/d_weight", d_weight.item(), on_step=True, prog_bar=True) |
|
|
|
|
|
logit_real = self.discriminator(img.contiguous()) |
|
logit_fake = self.discriminator(recon.detach().contiguous()) |
|
|
|
d_loss = hinge_d_loss(logit_real, logit_fake) |
|
|
|
d_opt.zero_grad() |
|
self.manual_backward(d_loss) |
|
d_opt.step() |
|
|
|
self.log("train/d_loss", d_loss.item(), on_step=True, prog_bar=True) |
|
|
|
def validation_step(self, batch, batch_idx): |
|
|
|
img = batch["images"] |
|
|
|
recon = self.forward(img) |
|
|
|
lpips_loss = self.lpips.forward(recon, img) |
|
|
|
self.val_psnr(recon, img) |
|
self.val_ssim(recon, img) |
|
self.val_mse(recon, img) |
|
self.val_lpips(lpips_loss) |
|
|
|
|
|
self.log("val/psnr", self.val_psnr, on_epoch=True, on_step=False, prog_bar=True) |
|
self.log("val/ssim", self.val_ssim, on_epoch=True, on_step=False, prog_bar=True) |
|
self.log("val/mse", self.val_mse, on_epoch=True, on_step=False, prog_bar=True) |
|
self.log( |
|
"val/lpips", self.val_lpips, on_epoch=True, on_step=False, prog_bar=True |
|
) |
|
|
|
if batch_idx == 0: |
|
self.log_images(img, recon) |
|
|
|
def compile(self): |
|
for attr in [ |
|
"encoder", |
|
"decoder", |
|
"quant_conv", |
|
"post_quant_conv", |
|
"lpips", |
|
"discriminator", |
|
]: |
|
setattr(self, attr, torch.compile(getattr(self, attr), "max-autotune")) |
|
|
|
def get_last_layer(self): |
|
return self.decoder.conv_out.weight |
|
|
|
def calculate_adaptive_weight(self, rec_loss, g_loss): |
|
last_layer = self.decoder.conv_out.weight |
|
|
|
rec_grads = torch.autograd.grad(rec_loss, last_layer, retain_graph=True)[0] |
|
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] |
|
|
|
d_weight = torch.norm(rec_grads) / (torch.norm(g_grads) + 1e-4) |
|
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() |
|
d_weight = d_weight * 0.5 |
|
return d_weight |
|
|
|
def log_images(self, ori_images, recon_images): |
|
for ori_image, recon_image in zip(ori_images, recon_images): |
|
ori_image = ((ori_image + 1) * 127.5).cpu().type(torch.uint8) |
|
recon_image = ((recon_image + 1) * 127.5).cpu().type(torch.uint8) |
|
|
|
wandb.log( |
|
{ |
|
"val/ori_image": [ |
|
wandb.Image(ori_image), |
|
], |
|
"val/recon_image": [ |
|
wandb.Image(recon_image), |
|
], |
|
} |
|
) |
|
|
|
def configure_optimizers(self): |
|
encoder_params = list(self.encoder.parameters()) |
|
decoder_params = list(self.decoder.parameters()) |
|
other_params = list(self.quant_conv.parameters()) + list( |
|
self.post_quant_conv.parameters() |
|
) |
|
discriminator_params = list(self.discriminator.parameters()) |
|
|
|
ae_opt = torch.optim.AdamW( |
|
encoder_params + decoder_params + other_params, |
|
lr=self.learning_rate, |
|
betas=(0.5, 0.9), |
|
) |
|
|
|
d_opt = torch.optim.AdamW( |
|
discriminator_params, lr=self.learning_rate, betas=(0.5, 0.9) |
|
) |
|
|
|
return [ae_opt, d_opt] |
|
|