|
from typing import List |
|
|
|
import torch |
|
import torch.nn as nn |
|
from lightning import LightningModule |
|
|
|
import wandb |
|
from .blocks import Encoder, Decoder, GaussianDistribution |
|
from torchmetrics import ( |
|
PeakSignalNoiseRatio, |
|
StructuralSimilarityIndexMeasure, |
|
MeanSquaredError, |
|
) |
|
|
|
|
|
class Autoencoder(LightningModule): |
|
""" |
|
## Autoencoder |
|
|
|
This consists of the encoder and decoder modules. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
channels: int, |
|
channel_multipliers: List[int], |
|
n_resnet_blocks: int, |
|
in_channels: int, |
|
out_channels: int, |
|
z_channels: int, |
|
emb_channels: int, |
|
learning_rate: float = 1e-4, |
|
): |
|
""" |
|
:param encoder: is the encoder |
|
:param decoder: is the decoder |
|
:param emb_channels: is the number of dimensions in the quantized embedding space |
|
:param z_channels: is the number of channels in the embedding space |
|
""" |
|
super().__init__() |
|
|
|
self.save_hyperparameters(logger=False) |
|
|
|
self.encoder = Encoder( |
|
channels=channels, |
|
channel_multipliers=channel_multipliers, |
|
n_resnet_blocks=n_resnet_blocks, |
|
in_channels=in_channels, |
|
z_channels=z_channels, |
|
) |
|
|
|
self.decoder = Decoder( |
|
channels=channels, |
|
channel_multipliers=channel_multipliers, |
|
n_resnet_blocks=n_resnet_blocks, |
|
out_channels=out_channels, |
|
z_channels=z_channels, |
|
) |
|
|
|
|
|
self.quant_conv = nn.Conv2d(z_channels, emb_channels, 1) |
|
|
|
|
|
self.post_quant_conv = nn.Conv2d(emb_channels, z_channels, 1) |
|
|
|
self.train_psnr = PeakSignalNoiseRatio() |
|
self.train_ssim = StructuralSimilarityIndexMeasure() |
|
|
|
self.val_psnr = PeakSignalNoiseRatio() |
|
self.val_ssim = StructuralSimilarityIndexMeasure() |
|
self.val_mse = MeanSquaredError() |
|
|
|
def encode(self, img: torch.Tensor) -> GaussianDistribution: |
|
""" |
|
### Encode images to latent representation |
|
|
|
:param img: is the image tensor with shape `[batch_size, img_channels, img_height, img_width]` |
|
""" |
|
|
|
z = self.encoder(img) |
|
|
|
z = self.quant_conv(z) |
|
|
|
|
|
return z |
|
|
|
def decode(self, z: torch.Tensor): |
|
""" |
|
### Decode images from latent representation |
|
|
|
:param z: is the latent representation with shape `[batch_size, emb_channels, z_height, z_height]` |
|
""" |
|
|
|
z = self.post_quant_conv(z) |
|
|
|
return self.decoder(z) |
|
|
|
def forward(self, img: torch.Tensor): |
|
""" |
|
:param img: is the image tensor with shape `[batch_size, img_channels, img_height, img_width]` |
|
""" |
|
|
|
z = self.encode(img) |
|
return self.decode(z) |
|
|
|
def training_step(self, batch, batch_idx): |
|
""" |
|
### Training step |
|
|
|
:param batch: is the batch of data |
|
:param batch_idx: is the index of the batch |
|
""" |
|
|
|
img = batch["images"] |
|
recon = self.forward(img) |
|
|
|
loss = torch.abs(img - recon).sum() |
|
|
|
self.train_psnr(recon, img) |
|
self.train_ssim(recon, img) |
|
|
|
|
|
self.log( |
|
"train/l1_loss", loss.item(), on_epoch=False, on_step=True, prog_bar=True |
|
) |
|
self.log( |
|
"train/psnr", self.train_psnr, on_epoch=True, on_step=False, prog_bar=True |
|
) |
|
self.log( |
|
"train/ssim", self.train_ssim, on_epoch=True, on_step=False, prog_bar=True |
|
) |
|
|
|
return loss |
|
|
|
def validation_step(self, batch, batch_idx): |
|
""" |
|
### Validation step |
|
|
|
:param batch: is the batch of data |
|
:param batch_idx: is the index of the batch |
|
""" |
|
|
|
img = batch["images"] |
|
|
|
recon = self.forward(img) |
|
|
|
self.val_psnr(recon, img) |
|
self.val_ssim(recon, img) |
|
self.val_mse(recon, img) |
|
|
|
|
|
self.log("val/psnr", self.val_psnr, on_epoch=True, on_step=False, prog_bar=True) |
|
self.log("val/ssim", self.val_ssim, on_epoch=True, on_step=False, prog_bar=True) |
|
self.log("val/mse", self.val_mse, on_epoch=True, on_step=False, prog_bar=True) |
|
|
|
if batch_idx == 0: |
|
self.log_images(img, recon) |
|
|
|
|
|
|
|
|
|
|
|
def log_images(self, ori_images, recon_images): |
|
""" |
|
### Log images |
|
""" |
|
for ori_image, recon_image in zip(ori_images, recon_images): |
|
ori_image = ((ori_image + 1) * 127.5).cpu().type(torch.uint8) |
|
recon_image = ((recon_image + 1) * 127.5).cpu().type(torch.uint8) |
|
|
|
wandb.log( |
|
{ |
|
"val/ori_image": [ |
|
wandb.Image(ori_image), |
|
], |
|
"val/recon_image": [ |
|
wandb.Image(recon_image), |
|
], |
|
} |
|
) |
|
|
|
def configure_optimizers(self): |
|
""" |
|
### Configure optimizers |
|
""" |
|
encoder_params = list(self.encoder.parameters()) |
|
decoder_params = list(self.decoder.parameters()) |
|
other_params = list(self.quant_conv.parameters()) + list( |
|
self.post_quant_conv.parameters() |
|
) |
|
|
|
return torch.optim.AdamW( |
|
encoder_params + decoder_params + other_params, |
|
lr=self.hparams.learning_rate, |
|
) |
|
|