swim_new / swim /models /autoencoder.py
qninhdt's picture
cc
567d0f7
from typing import List
import torch
import torch.nn as nn
from lightning import LightningModule
import wandb
from .blocks import Encoder, Decoder, GaussianDistribution
from torchmetrics import (
MeanSquaredError,
MeanMetric,
)
from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
from lpips import LPIPS
from swim.models.discriminators import NLayerDiscriminator, weights_init, hinge_d_loss
class Autoencoder(LightningModule):
def __init__(
self,
channels: int,
channel_multipliers: List[int],
n_resnet_blocks: int,
in_channels: int,
out_channels: int,
z_channels: int,
emb_channels: int,
base_learning_rate: float,
kl_weight: float,
):
super().__init__()
self.save_hyperparameters(logger=False)
self.automatic_optimization = False
self.kl_weight = kl_weight
self.encoder = Encoder(
channels=channels,
channel_multipliers=channel_multipliers,
n_resnet_blocks=n_resnet_blocks,
in_channels=in_channels,
z_channels=z_channels,
)
self.decoder = Decoder(
channels=channels,
channel_multipliers=channel_multipliers,
n_resnet_blocks=n_resnet_blocks,
out_channels=out_channels,
z_channels=z_channels,
)
self.quant_conv = nn.Conv2d(2 * z_channels, 2 * emb_channels, 1)
# Convolution to map from quantized embedding space back to
# embedding space
self.post_quant_conv = nn.Conv2d(emb_channels, z_channels, 1)
self.lpips = LPIPS(net="vgg")
self.discriminator = NLayerDiscriminator(
input_nc=3, n_layers=3, use_actnorm=False
).apply(weights_init)
self.train_psnr = PeakSignalNoiseRatio()
self.train_ssim = StructuralSimilarityIndexMeasure()
self.val_psnr = PeakSignalNoiseRatio()
self.val_ssim = StructuralSimilarityIndexMeasure()
self.val_mse = MeanSquaredError()
self.val_lpips = MeanMetric()
def encode(self, img: torch.Tensor) -> GaussianDistribution:
z = self.encoder(img)
# Get the moments in the quantized embedding space
moments = self.quant_conv(z)
# Return the distribution
return GaussianDistribution(moments)
def decode(self, z: torch.Tensor):
# Map to embedding space from the quantized representation
z = self.post_quant_conv(z)
# Decode the image of shape `[batch_size, channels, height, width]`
return self.decoder(z)
def forward(self, img: torch.Tensor, sample: bool = False):
# Encode the image
z_dis = self.encode(img)
if sample:
z = z_dis.sample()
else:
z = z_dis.mean
return self.decode(z)
def training_step(self, batch, batch_idx):
ae_opt, d_opt = self.optimizers()
# optimize the autoencoder
# Get the image
img = batch["images"]
z_dis = self.encode(img)
z = z_dis.sample()
recon = self.decode(z)
# Calculate the loss
B = img.shape[0]
l1_loss = torch.abs(img - recon).sum() / B # L1 loss
lpips_loss = self.lpips.forward(recon, img).sum() / B # LPIPS loss
kl_loss = z_dis.kl().mean() # KL loss
logit_fake = self.discriminator(recon.contiguous())
g_loss = -logit_fake.mean()
d_weight = self.calculate_adaptive_weight(l1_loss, g_loss)
total_loss = l1_loss + lpips_loss + self.kl_weight * kl_loss + d_weight * g_loss
ae_opt.zero_grad()
self.manual_backward(total_loss)
ae_opt.step()
self.train_psnr(recon, img)
self.train_ssim(recon, img)
# Log the loss
self.log("train/l1_loss", l1_loss.item(), on_step=True, prog_bar=True)
self.log("train/kl_loss", kl_loss.item(), on_step=True, prog_bar=True)
self.log("train/lpips_loss", lpips_loss.item(), on_step=True, prog_bar=True)
self.log("train/psnr", self.train_psnr, on_step=True, prog_bar=True)
self.log("train/ssim", self.train_ssim, on_step=True, prog_bar=True)
self.log("train/g_loss", g_loss.item(), on_step=True, prog_bar=True)
self.log("train/d_weight", d_weight.item(), on_step=True, prog_bar=True)
# optimize the discriminator
logit_real = self.discriminator(img.contiguous())
logit_fake = self.discriminator(recon.detach().contiguous())
d_loss = hinge_d_loss(logit_real, logit_fake)
d_opt.zero_grad()
self.manual_backward(d_loss)
d_opt.step()
self.log("train/d_loss", d_loss.item(), on_step=True, prog_bar=True)
def validation_step(self, batch, batch_idx):
# Get the image
img = batch["images"]
# Get the distribution
recon = self.forward(img)
lpips_loss = self.lpips.forward(recon, img) # LPIPS loss
self.val_psnr(recon, img)
self.val_ssim(recon, img)
self.val_mse(recon, img)
self.val_lpips(lpips_loss)
# Log the loss
self.log("val/psnr", self.val_psnr, on_epoch=True, on_step=False, prog_bar=True)
self.log("val/ssim", self.val_ssim, on_epoch=True, on_step=False, prog_bar=True)
self.log("val/mse", self.val_mse, on_epoch=True, on_step=False, prog_bar=True)
self.log(
"val/lpips", self.val_lpips, on_epoch=True, on_step=False, prog_bar=True
)
if batch_idx == 0:
self.log_images(img, recon)
def compile(self):
for attr in [
"encoder",
"decoder",
"quant_conv",
"post_quant_conv",
"lpips",
"discriminator",
]:
setattr(self, attr, torch.compile(getattr(self, attr), "max-autotune"))
def get_last_layer(self):
return self.decoder.conv_out.weight
def calculate_adaptive_weight(self, rec_loss, g_loss):
last_layer = self.decoder.conv_out.weight
rec_grads = torch.autograd.grad(rec_loss, last_layer, retain_graph=True)[0]
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
d_weight = torch.norm(rec_grads) / (torch.norm(g_grads) + 1e-4)
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
d_weight = d_weight * 0.5
return d_weight
def log_images(self, ori_images, recon_images):
for ori_image, recon_image in zip(ori_images, recon_images):
ori_image = ((ori_image + 1) * 127.5).cpu().type(torch.uint8)
recon_image = ((recon_image + 1) * 127.5).cpu().type(torch.uint8)
wandb.log(
{
"val/ori_image": [
wandb.Image(ori_image),
],
"val/recon_image": [
wandb.Image(recon_image),
],
}
)
def configure_optimizers(self):
encoder_params = list(self.encoder.parameters())
decoder_params = list(self.decoder.parameters())
other_params = list(self.quant_conv.parameters()) + list(
self.post_quant_conv.parameters()
)
discriminator_params = list(self.discriminator.parameters())
ae_opt = torch.optim.AdamW(
encoder_params + decoder_params + other_params,
lr=self.learning_rate,
betas=(0.5, 0.9),
)
d_opt = torch.optim.AdamW(
discriminator_params, lr=self.learning_rate, betas=(0.5, 0.9)
)
return [ae_opt, d_opt]