from typing import List import torch import torch.nn as nn from lightning import LightningModule import wandb from .blocks import Encoder, Decoder, GaussianDistribution from torchmetrics import ( PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure, MeanSquaredError, ) class Autoencoder(LightningModule): """ ## Autoencoder This consists of the encoder and decoder modules. """ def __init__( self, channels: int, channel_multipliers: List[int], n_resnet_blocks: int, in_channels: int, out_channels: int, z_channels: int, emb_channels: int, learning_rate: float = 1e-4, ): """ :param encoder: is the encoder :param decoder: is the decoder :param emb_channels: is the number of dimensions in the quantized embedding space :param z_channels: is the number of channels in the embedding space """ super().__init__() self.save_hyperparameters(logger=False) self.encoder = Encoder( channels=channels, channel_multipliers=channel_multipliers, n_resnet_blocks=n_resnet_blocks, in_channels=in_channels, z_channels=z_channels, ) self.decoder = Decoder( channels=channels, channel_multipliers=channel_multipliers, n_resnet_blocks=n_resnet_blocks, out_channels=out_channels, z_channels=z_channels, ) # Convolution to map from embedding space to # quantized embedding space moments (mean and log variance) self.quant_conv = nn.Conv2d(z_channels, emb_channels, 1) # Convolution to map from quantized embedding space back to # embedding space self.post_quant_conv = nn.Conv2d(emb_channels, z_channels, 1) self.train_psnr = PeakSignalNoiseRatio() self.train_ssim = StructuralSimilarityIndexMeasure() self.val_psnr = PeakSignalNoiseRatio() self.val_ssim = StructuralSimilarityIndexMeasure() self.val_mse = MeanSquaredError() def encode(self, img: torch.Tensor) -> GaussianDistribution: """ ### Encode images to latent representation :param img: is the image tensor with shape `[batch_size, img_channels, img_height, img_width]` """ # Get embeddings with shape `[batch_size, z_channels * 2, z_height, z_height]` z = self.encoder(img) # Get the moments in the quantized embedding space z = self.quant_conv(z) # Return the distribution # return GaussianDistribution(moments) return z def decode(self, z: torch.Tensor): """ ### Decode images from latent representation :param z: is the latent representation with shape `[batch_size, emb_channels, z_height, z_height]` """ # Map to embedding space from the quantized representation z = self.post_quant_conv(z) # Decode the image of shape `[batch_size, channels, height, width]` return self.decoder(z) def forward(self, img: torch.Tensor): """ :param img: is the image tensor with shape `[batch_size, img_channels, img_height, img_width]` """ # Encode the image z = self.encode(img) return self.decode(z) def training_step(self, batch, batch_idx): """ ### Training step :param batch: is the batch of data :param batch_idx: is the index of the batch """ # Get the image img = batch["images"] recon = self.forward(img) # Calculate the loss loss = torch.abs(img - recon).sum() # L1 loss self.train_psnr(recon, img) self.train_ssim(recon, img) # Log the loss self.log( "train/l1_loss", loss.item(), on_epoch=False, on_step=True, prog_bar=True ) self.log( "train/psnr", self.train_psnr, on_epoch=True, on_step=False, prog_bar=True ) self.log( "train/ssim", self.train_ssim, on_epoch=True, on_step=False, prog_bar=True ) return loss def validation_step(self, batch, batch_idx): """ ### Validation step :param batch: is the batch of data :param batch_idx: is the index of the batch """ # Get the image img = batch["images"] # Get the distribution recon = self.forward(img) self.val_psnr(recon, img) self.val_ssim(recon, img) self.val_mse(recon, img) # Log the loss self.log("val/psnr", self.val_psnr, on_epoch=True, on_step=False, prog_bar=True) self.log("val/ssim", self.val_ssim, on_epoch=True, on_step=False, prog_bar=True) self.log("val/mse", self.val_mse, on_epoch=True, on_step=False, prog_bar=True) if batch_idx == 0: self.log_images(img, recon) # def setup(self, stage: str) -> None: # if self.hparams.compile and stage == "fit": # self.net = torch.compile(self.net) def log_images(self, ori_images, recon_images): """ ### Log images """ for ori_image, recon_image in zip(ori_images, recon_images): ori_image = ((ori_image + 1) * 127.5).cpu().type(torch.uint8) recon_image = ((recon_image + 1) * 127.5).cpu().type(torch.uint8) wandb.log( { "val/ori_image": [ wandb.Image(ori_image), ], "val/recon_image": [ wandb.Image(recon_image), ], } ) def configure_optimizers(self): """ ### Configure optimizers """ encoder_params = list(self.encoder.parameters()) decoder_params = list(self.decoder.parameters()) other_params = list(self.quant_conv.parameters()) + list( self.post_quant_conv.parameters() ) return torch.optim.AdamW( encoder_params + decoder_params + other_params, lr=self.hparams.learning_rate, )