swim_new / swim /models /autoencoder.py
qninhdt's picture
cc
8cc0674
raw
history blame
6.22 kB
from typing import List
import torch
import torch.nn as nn
from lightning import LightningModule
import wandb
from .blocks import Encoder, Decoder, GaussianDistribution
from torchmetrics import (
PeakSignalNoiseRatio,
StructuralSimilarityIndexMeasure,
MeanSquaredError,
)
class Autoencoder(LightningModule):
"""
## Autoencoder
This consists of the encoder and decoder modules.
"""
def __init__(
self,
channels: int,
channel_multipliers: List[int],
n_resnet_blocks: int,
in_channels: int,
out_channels: int,
z_channels: int,
emb_channels: int,
learning_rate: float = 1e-4,
):
"""
:param encoder: is the encoder
:param decoder: is the decoder
:param emb_channels: is the number of dimensions in the quantized embedding space
:param z_channels: is the number of channels in the embedding space
"""
super().__init__()
self.save_hyperparameters(logger=False)
self.encoder = Encoder(
channels=channels,
channel_multipliers=channel_multipliers,
n_resnet_blocks=n_resnet_blocks,
in_channels=in_channels,
z_channels=z_channels,
)
self.decoder = Decoder(
channels=channels,
channel_multipliers=channel_multipliers,
n_resnet_blocks=n_resnet_blocks,
out_channels=out_channels,
z_channels=z_channels,
)
# Convolution to map from embedding space to
# quantized embedding space moments (mean and log variance)
self.quant_conv = nn.Conv2d(z_channels, emb_channels, 1)
# Convolution to map from quantized embedding space back to
# embedding space
self.post_quant_conv = nn.Conv2d(emb_channels, z_channels, 1)
self.train_psnr = PeakSignalNoiseRatio()
self.train_ssim = StructuralSimilarityIndexMeasure()
self.val_psnr = PeakSignalNoiseRatio()
self.val_ssim = StructuralSimilarityIndexMeasure()
self.val_mse = MeanSquaredError()
def encode(self, img: torch.Tensor) -> GaussianDistribution:
"""
### Encode images to latent representation
:param img: is the image tensor with shape `[batch_size, img_channels, img_height, img_width]`
"""
# Get embeddings with shape `[batch_size, z_channels * 2, z_height, z_height]`
z = self.encoder(img)
# Get the moments in the quantized embedding space
z = self.quant_conv(z)
# Return the distribution
# return GaussianDistribution(moments)
return z
def decode(self, z: torch.Tensor):
"""
### Decode images from latent representation
:param z: is the latent representation with shape `[batch_size, emb_channels, z_height, z_height]`
"""
# Map to embedding space from the quantized representation
z = self.post_quant_conv(z)
# Decode the image of shape `[batch_size, channels, height, width]`
return self.decoder(z)
def forward(self, img: torch.Tensor):
"""
:param img: is the image tensor with shape `[batch_size, img_channels, img_height, img_width]`
"""
# Encode the image
z = self.encode(img)
return self.decode(z)
def training_step(self, batch, batch_idx):
"""
### Training step
:param batch: is the batch of data
:param batch_idx: is the index of the batch
"""
# Get the image
img = batch["images"]
recon = self.forward(img)
# Calculate the loss
loss = torch.abs(img - recon).sum() # L1 loss
self.train_psnr(recon, img)
self.train_ssim(recon, img)
# Log the loss
self.log(
"train/l1_loss", loss.item(), on_epoch=False, on_step=True, prog_bar=True
)
self.log(
"train/psnr", self.train_psnr, on_epoch=True, on_step=False, prog_bar=True
)
self.log(
"train/ssim", self.train_ssim, on_epoch=True, on_step=False, prog_bar=True
)
return loss
def validation_step(self, batch, batch_idx):
"""
### Validation step
:param batch: is the batch of data
:param batch_idx: is the index of the batch
"""
# Get the image
img = batch["images"]
# Get the distribution
recon = self.forward(img)
self.val_psnr(recon, img)
self.val_ssim(recon, img)
self.val_mse(recon, img)
# Log the loss
self.log("val/psnr", self.val_psnr, on_epoch=True, on_step=False, prog_bar=True)
self.log("val/ssim", self.val_ssim, on_epoch=True, on_step=False, prog_bar=True)
self.log("val/mse", self.val_mse, on_epoch=True, on_step=False, prog_bar=True)
if batch_idx == 0:
self.log_images(img, recon)
# def setup(self, stage: str) -> None:
# if self.hparams.compile and stage == "fit":
# self.net = torch.compile(self.net)
def log_images(self, ori_images, recon_images):
"""
### Log images
"""
for ori_image, recon_image in zip(ori_images, recon_images):
ori_image = ((ori_image + 1) * 127.5).cpu().type(torch.uint8)
recon_image = ((recon_image + 1) * 127.5).cpu().type(torch.uint8)
wandb.log(
{
"val/ori_image": [
wandb.Image(ori_image),
],
"val/recon_image": [
wandb.Image(recon_image),
],
}
)
def configure_optimizers(self):
"""
### Configure optimizers
"""
encoder_params = list(self.encoder.parameters())
decoder_params = list(self.decoder.parameters())
other_params = list(self.quant_conv.parameters()) + list(
self.post_quant_conv.parameters()
)
return torch.optim.AdamW(
encoder_params + decoder_params + other_params,
lr=self.hparams.learning_rate,
)