qninhdt commited on
Commit
7c61758
1 Parent(s): df1eac6
Files changed (1) hide show
  1. swim/models/autoencoder.py +17 -3
swim/models/autoencoder.py CHANGED
@@ -10,7 +10,9 @@ from torchmetrics import (
10
  PeakSignalNoiseRatio,
11
  StructuralSimilarityIndexMeasure,
12
  MeanSquaredError,
 
13
  )
 
14
 
15
 
16
  class Autoencoder(LightningModule):
@@ -63,12 +65,15 @@ class Autoencoder(LightningModule):
63
  # embedding space
64
  self.post_quant_conv = nn.Conv2d(emb_channels, z_channels, 1)
65
 
 
 
66
  self.train_psnr = PeakSignalNoiseRatio()
67
  self.train_ssim = StructuralSimilarityIndexMeasure()
68
 
69
  self.val_psnr = PeakSignalNoiseRatio()
70
  self.val_ssim = StructuralSimilarityIndexMeasure()
71
  self.val_mse = MeanSquaredError()
 
72
 
73
  def encode(self, img: torch.Tensor) -> GaussianDistribution:
74
  """
@@ -114,17 +119,20 @@ class Autoencoder(LightningModule):
114
  img = batch["images"]
115
  recon = self.forward(img)
116
  # Calculate the loss
117
- loss = torch.abs(img - recon).sum() # L1 loss
 
 
118
 
119
  self.train_psnr(recon, img)
120
  self.train_ssim(recon, img)
121
 
122
  # Log the loss
123
- self.log("train/l1_loss", loss.item(), on_step=True, prog_bar=True)
 
124
  self.log("train/psnr", self.train_psnr, on_step=True, prog_bar=True)
125
  self.log("train/ssim", self.train_ssim, on_step=True, prog_bar=True)
126
 
127
- return loss
128
 
129
  def validation_step(self, batch, batch_idx):
130
  """
@@ -138,14 +146,20 @@ class Autoencoder(LightningModule):
138
  # Get the distribution
139
  recon = self.forward(img)
140
 
 
 
141
  self.val_psnr(recon, img)
142
  self.val_ssim(recon, img)
143
  self.val_mse(recon, img)
 
144
 
145
  # Log the loss
146
  self.log("val/psnr", self.val_psnr, on_epoch=True, on_step=False, prog_bar=True)
147
  self.log("val/ssim", self.val_ssim, on_epoch=True, on_step=False, prog_bar=True)
148
  self.log("val/mse", self.val_mse, on_epoch=True, on_step=False, prog_bar=True)
 
 
 
149
 
150
  if batch_idx == 0:
151
  self.log_images(img, recon)
 
10
  PeakSignalNoiseRatio,
11
  StructuralSimilarityIndexMeasure,
12
  MeanSquaredError,
13
+ MeanMetric,
14
  )
15
+ from lpips import LPIPS
16
 
17
 
18
  class Autoencoder(LightningModule):
 
65
  # embedding space
66
  self.post_quant_conv = nn.Conv2d(emb_channels, z_channels, 1)
67
 
68
+ self.lpips = LPIPS(net="vgg").eval()
69
+
70
  self.train_psnr = PeakSignalNoiseRatio()
71
  self.train_ssim = StructuralSimilarityIndexMeasure()
72
 
73
  self.val_psnr = PeakSignalNoiseRatio()
74
  self.val_ssim = StructuralSimilarityIndexMeasure()
75
  self.val_mse = MeanSquaredError()
76
+ self.val_lpips = MeanMetric()
77
 
78
  def encode(self, img: torch.Tensor) -> GaussianDistribution:
79
  """
 
119
  img = batch["images"]
120
  recon = self.forward(img)
121
  # Calculate the loss
122
+ l1_loss = torch.abs(img - recon).sum() # L1 loss
123
+ lpips_loss = self.lpips.forward(recon, img).sum() # LPIPS loss
124
+ total_loss = l1_loss + lpips_loss
125
 
126
  self.train_psnr(recon, img)
127
  self.train_ssim(recon, img)
128
 
129
  # Log the loss
130
+ self.log("train/l1_loss", l1_loss.item(), on_step=True, prog_bar=True)
131
+ self.log("train/lpips_loss", lpips_loss.item(), on_step=True, prog_bar=True)
132
  self.log("train/psnr", self.train_psnr, on_step=True, prog_bar=True)
133
  self.log("train/ssim", self.train_ssim, on_step=True, prog_bar=True)
134
 
135
+ return total_loss
136
 
137
  def validation_step(self, batch, batch_idx):
138
  """
 
146
  # Get the distribution
147
  recon = self.forward(img)
148
 
149
+ lpips_loss = self.lpips.forward(recon, img) # LPIPS loss
150
+
151
  self.val_psnr(recon, img)
152
  self.val_ssim(recon, img)
153
  self.val_mse(recon, img)
154
+ self.val_lpips(lpips_loss)
155
 
156
  # Log the loss
157
  self.log("val/psnr", self.val_psnr, on_epoch=True, on_step=False, prog_bar=True)
158
  self.log("val/ssim", self.val_ssim, on_epoch=True, on_step=False, prog_bar=True)
159
  self.log("val/mse", self.val_mse, on_epoch=True, on_step=False, prog_bar=True)
160
+ self.log(
161
+ "val/lpips", self.val_lpips, on_epoch=True, on_step=False, prog_bar=True
162
+ )
163
 
164
  if batch_idx == 0:
165
  self.log_images(img, recon)