from typing import List import torch import torch.nn as nn from lightning import LightningModule import wandb from .blocks import Encoder, Decoder, GaussianDistribution from torchmetrics import ( MeanSquaredError, MeanMetric, ) from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure from lpips import LPIPS from swim.models.discriminators import NLayerDiscriminator, weights_init, hinge_d_loss class Autoencoder(LightningModule): def __init__( self, channels: int, channel_multipliers: List[int], n_resnet_blocks: int, in_channels: int, out_channels: int, z_channels: int, emb_channels: int, base_learning_rate: float, kl_weight: float, gan_start: int, ): super().__init__() self.save_hyperparameters(logger=False) self.automatic_optimization = False self.kl_weight = kl_weight self.encoder = Encoder( channels=channels, channel_multipliers=channel_multipliers, n_resnet_blocks=n_resnet_blocks, in_channels=in_channels, z_channels=z_channels, ) self.decoder = Decoder( channels=channels, channel_multipliers=channel_multipliers, n_resnet_blocks=n_resnet_blocks, out_channels=out_channels, z_channels=z_channels, ) self.quant_conv = nn.Conv2d(2 * z_channels, 2 * emb_channels, 1) # Convolution to map from quantized embedding space back to # embedding space self.post_quant_conv = nn.Conv2d(emb_channels, z_channels, 1) self.lpips = LPIPS(net="vgg") self.discriminator = NLayerDiscriminator( input_nc=3, n_layers=3, use_actnorm=False ).apply(weights_init) self.val_psnr = PeakSignalNoiseRatio() self.val_ssim = StructuralSimilarityIndexMeasure() self.val_mse = MeanSquaredError() self.val_lpips = MeanMetric() def encode(self, img: torch.Tensor) -> GaussianDistribution: z = self.encoder(img) # Get the moments in the quantized embedding space moments = self.quant_conv(z) # Return the distribution return GaussianDistribution(moments) def decode(self, z: torch.Tensor): # Map to embedding space from the quantized representation z = self.post_quant_conv(z) # Decode the image of shape `[batch_size, channels, height, width]` return self.decoder(z) def forward(self, img: torch.Tensor, sample: bool = False): # Encode the image z_dis = self.encode(img) if sample: z = z_dis.sample() else: z = z_dis.mean return self.decode(z) def training_step(self, batch, batch_idx): ae_opt, d_opt = self.optimizers() opt_gan = self.global_step >= self.hparams.gan_start # optimize the autoencoder # Get the image img = batch["images"] z_dis = self.encode(img) z = z_dis.sample() recon = self.decode(z) # Calculate the loss B = img.shape[0] l1_loss = torch.abs(img - recon).sum() / B # L1 loss lpips_loss = self.lpips.forward(recon, img).sum() / B # LPIPS loss kl_loss = z_dis.kl().mean() # KL loss if opt_gan: logit_fake = self.discriminator(recon.contiguous()) g_loss = -logit_fake.mean() d_weight = self.calculate_adaptive_weight(l1_loss, g_loss) else: g_loss = torch.tensor(0.5, device=self.device) d_weight = torch.tensor(0, device=self.device) total_loss = l1_loss + lpips_loss + self.kl_weight * kl_loss + d_weight * g_loss ae_opt.zero_grad() self.manual_backward(total_loss) ae_opt.step() # Log the loss self.log("train/l1_loss", l1_loss.item(), on_step=True, prog_bar=True) self.log("train/kl_loss", kl_loss.item(), on_step=True, prog_bar=True) self.log("train/lpips_loss", lpips_loss.item(), on_step=True, prog_bar=True) self.log("train/g_loss", g_loss.item(), on_step=True, prog_bar=True) self.log("train/d_weight", d_weight.item(), on_step=True, prog_bar=True) # optimize the discriminator if opt_gan: logit_real = self.discriminator(img.contiguous()) logit_fake = self.discriminator(recon.detach().contiguous()) d_loss = hinge_d_loss(logit_real, logit_fake) d_opt.zero_grad() self.manual_backward(d_loss) d_opt.step() else: d_loss = torch.tensor(0.5, device=self.device) self.log("train/d_loss", d_loss.item(), on_step=True, prog_bar=True) def validation_step(self, batch, batch_idx): # Get the image img = batch["images"] # Get the distribution recon = self.forward(img) lpips_loss = self.lpips.forward(recon, img) # LPIPS loss self.val_psnr(recon, img) self.val_ssim(recon, img) self.val_mse(recon, img) self.val_lpips(lpips_loss) # Log the loss self.log("val/psnr", self.val_psnr, on_epoch=True, on_step=False, prog_bar=True) self.log("val/ssim", self.val_ssim, on_epoch=True, on_step=False, prog_bar=True) self.log("val/mse", self.val_mse, on_epoch=True, on_step=False, prog_bar=True) self.log( "val/lpips", self.val_lpips, on_epoch=True, on_step=False, prog_bar=True ) if batch_idx == 0 and self.trainer.is_global_zero: num_imgs = min(img.shape[0], 16) self.log_images(img[:num_imgs], recon) def compile(self): for attr in [ "encoder", "decoder", "quant_conv", "post_quant_conv", "lpips", "discriminator", ]: setattr(self, attr, torch.compile(getattr(self, attr), "max-autotune")) def get_last_layer(self): return self.decoder.conv_out.weight def calculate_adaptive_weight(self, rec_loss, g_loss): last_layer = self.decoder.conv_out.weight rec_grads = torch.autograd.grad(rec_loss, last_layer, retain_graph=True)[0] g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] d_weight = torch.norm(rec_grads) / (torch.norm(g_grads) + 1e-4) d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() d_weight = d_weight * 0.5 return d_weight def log_images(self, ori_images, recon_images): for ori_image, recon_image in zip(ori_images, recon_images): ori_image = ((ori_image + 1) * 127.5).cpu().type(torch.uint8) recon_image = ((recon_image + 1) * 127.5).cpu().type(torch.uint8) wandb.log( { "val/ori_image": [ wandb.Image(ori_image), ], "val/recon_image": [ wandb.Image(recon_image), ], } ) def configure_optimizers(self): encoder_params = list(self.encoder.parameters()) decoder_params = list(self.decoder.parameters()) other_params = list(self.quant_conv.parameters()) + list( self.post_quant_conv.parameters() ) discriminator_params = list(self.discriminator.parameters()) ae_opt = torch.optim.AdamW( encoder_params + decoder_params + other_params, lr=self.learning_rate, betas=(0.5, 0.9), ) d_opt = torch.optim.AdamW( discriminator_params, lr=self.learning_rate, betas=(0.5, 0.9) ) return [ae_opt, d_opt]