File size: 7,825 Bytes
8cc0674
 
 
 
 
 
 
 
 
 
7c61758
8cc0674
0846051
7c61758
8cc0674
567d0f7
8cc0674
 
567d0f7
8cc0674
 
 
 
 
 
 
 
 
 
0846051
d4e3955
dcbfb4d
8cc0674
 
 
 
 
567d0f7
 
d4e3955
 
8cc0674
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0846051
 
8cc0674
 
 
 
0846051
7c61758
567d0f7
 
 
 
8cc0674
 
 
7c61758
8cc0674
 
 
 
0846051
8cc0674
0846051
8cc0674
 
 
 
 
 
 
0846051
567d0f7
8cc0674
0846051
 
 
 
 
d4e3955
8cc0674
 
 
567d0f7
 
dcbfb4d
 
567d0f7
8cc0674
 
 
d4e3955
 
 
 
8cc0674
0846051
 
 
d4e3955
 
dcbfb4d
 
 
567d0f7
dcbfb4d
 
 
 
567d0f7
 
 
 
 
 
8cc0674
 
7c61758
d4e3955
7c61758
567d0f7
 
8cc0674
567d0f7
dcbfb4d
 
 
8cc0674
dcbfb4d
567d0f7
dcbfb4d
 
 
 
 
567d0f7
 
8cc0674
567d0f7
8cc0674
 
 
 
 
7c61758
 
8cc0674
 
 
7c61758
8cc0674
 
 
 
 
7c61758
 
 
8cc0674
dcbfb4d
 
 
8cc0674
95aa666
567d0f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8cc0674
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
567d0f7
8cc0674
567d0f7
8cc0674
0846051
567d0f7
 
 
 
 
8cc0674
567d0f7
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
from typing import List

import torch
import torch.nn as nn
from lightning import LightningModule

import wandb
from .blocks import Encoder, Decoder, GaussianDistribution
from torchmetrics import (
    MeanSquaredError,
    MeanMetric,
)
from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
from lpips import LPIPS

from swim.models.discriminators import NLayerDiscriminator, weights_init, hinge_d_loss


class Autoencoder(LightningModule):

    def __init__(
        self,
        channels: int,
        channel_multipliers: List[int],
        n_resnet_blocks: int,
        in_channels: int,
        out_channels: int,
        z_channels: int,
        emb_channels: int,
        base_learning_rate: float,
        kl_weight: float,
        gan_start: int,
    ):
        super().__init__()

        self.save_hyperparameters(logger=False)

        self.automatic_optimization = False

        self.kl_weight = kl_weight

        self.encoder = Encoder(
            channels=channels,
            channel_multipliers=channel_multipliers,
            n_resnet_blocks=n_resnet_blocks,
            in_channels=in_channels,
            z_channels=z_channels,
        )

        self.decoder = Decoder(
            channels=channels,
            channel_multipliers=channel_multipliers,
            n_resnet_blocks=n_resnet_blocks,
            out_channels=out_channels,
            z_channels=z_channels,
        )

        self.quant_conv = nn.Conv2d(2 * z_channels, 2 * emb_channels, 1)
        # Convolution to map from quantized embedding space back to
        # embedding space
        self.post_quant_conv = nn.Conv2d(emb_channels, z_channels, 1)

        self.lpips = LPIPS(net="vgg")

        self.discriminator = NLayerDiscriminator(
            input_nc=3, n_layers=3, use_actnorm=False
        ).apply(weights_init)

        self.val_psnr = PeakSignalNoiseRatio()
        self.val_ssim = StructuralSimilarityIndexMeasure()
        self.val_mse = MeanSquaredError()
        self.val_lpips = MeanMetric()

    def encode(self, img: torch.Tensor) -> GaussianDistribution:
        z = self.encoder(img)
        # Get the moments in the quantized embedding space
        moments = self.quant_conv(z)
        # Return the distribution
        return GaussianDistribution(moments)

    def decode(self, z: torch.Tensor):
        # Map to embedding space from the quantized representation
        z = self.post_quant_conv(z)
        # Decode the image of shape `[batch_size, channels, height, width]`
        return self.decoder(z)

    def forward(self, img: torch.Tensor, sample: bool = False):

        # Encode the image
        z_dis = self.encode(img)
        if sample:
            z = z_dis.sample()
        else:
            z = z_dis.mean

        return self.decode(z)

    def training_step(self, batch, batch_idx):
        ae_opt, d_opt = self.optimizers()

        opt_gan = self.global_step >= self.hparams.gan_start

        # optimize the autoencoder

        # Get the image
        img = batch["images"]
        z_dis = self.encode(img)
        z = z_dis.sample()
        recon = self.decode(z)

        # Calculate the loss
        B = img.shape[0]
        l1_loss = torch.abs(img - recon).sum() / B  # L1 loss
        lpips_loss = self.lpips.forward(recon, img).sum() / B  # LPIPS loss
        kl_loss = z_dis.kl().mean()  # KL loss

        if opt_gan:
            logit_fake = self.discriminator(recon.contiguous())
            g_loss = -logit_fake.mean()

            d_weight = self.calculate_adaptive_weight(l1_loss, g_loss)
        else:
            g_loss = torch.tensor(0.5, device=self.device)
            d_weight = torch.tensor(0, device=self.device)

        total_loss = l1_loss + lpips_loss + self.kl_weight * kl_loss + d_weight * g_loss

        ae_opt.zero_grad()
        self.manual_backward(total_loss)
        ae_opt.step()

        # Log the loss
        self.log("train/l1_loss", l1_loss.item(), on_step=True, prog_bar=True)
        self.log("train/kl_loss", kl_loss.item(), on_step=True, prog_bar=True)
        self.log("train/lpips_loss", lpips_loss.item(), on_step=True, prog_bar=True)
        self.log("train/g_loss", g_loss.item(), on_step=True, prog_bar=True)
        self.log("train/d_weight", d_weight.item(), on_step=True, prog_bar=True)

        # optimize the discriminator
        if opt_gan:
            logit_real = self.discriminator(img.contiguous())
            logit_fake = self.discriminator(recon.detach().contiguous())

            d_loss = hinge_d_loss(logit_real, logit_fake)

            d_opt.zero_grad()
            self.manual_backward(d_loss)
            d_opt.step()
        else:
            d_loss = torch.tensor(0.5, device=self.device)

        self.log("train/d_loss", d_loss.item(), on_step=True, prog_bar=True)

    def validation_step(self, batch, batch_idx):
        # Get the image
        img = batch["images"]
        # Get the distribution
        recon = self.forward(img)

        lpips_loss = self.lpips.forward(recon, img)  # LPIPS loss

        self.val_psnr(recon, img)
        self.val_ssim(recon, img)
        self.val_mse(recon, img)
        self.val_lpips(lpips_loss)

        # Log the loss
        self.log("val/psnr", self.val_psnr, on_epoch=True, on_step=False, prog_bar=True)
        self.log("val/ssim", self.val_ssim, on_epoch=True, on_step=False, prog_bar=True)
        self.log("val/mse", self.val_mse, on_epoch=True, on_step=False, prog_bar=True)
        self.log(
            "val/lpips", self.val_lpips, on_epoch=True, on_step=False, prog_bar=True
        )

        if batch_idx == 0 and self.trainer.is_global_zero:
            num_imgs = min(img.shape[0], 16)
            self.log_images(img[:num_imgs], recon)

    def compile(self):
        for attr in [
            "encoder",
            "decoder",
            "quant_conv",
            "post_quant_conv",
            "lpips",
            "discriminator",
        ]:
            setattr(self, attr, torch.compile(getattr(self, attr), "max-autotune"))

    def get_last_layer(self):
        return self.decoder.conv_out.weight

    def calculate_adaptive_weight(self, rec_loss, g_loss):
        last_layer = self.decoder.conv_out.weight

        rec_grads = torch.autograd.grad(rec_loss, last_layer, retain_graph=True)[0]
        g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]

        d_weight = torch.norm(rec_grads) / (torch.norm(g_grads) + 1e-4)
        d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
        d_weight = d_weight * 0.5
        return d_weight

    def log_images(self, ori_images, recon_images):
        for ori_image, recon_image in zip(ori_images, recon_images):
            ori_image = ((ori_image + 1) * 127.5).cpu().type(torch.uint8)
            recon_image = ((recon_image + 1) * 127.5).cpu().type(torch.uint8)

            wandb.log(
                {
                    "val/ori_image": [
                        wandb.Image(ori_image),
                    ],
                    "val/recon_image": [
                        wandb.Image(recon_image),
                    ],
                }
            )

    def configure_optimizers(self):
        encoder_params = list(self.encoder.parameters())
        decoder_params = list(self.decoder.parameters())
        other_params = list(self.quant_conv.parameters()) + list(
            self.post_quant_conv.parameters()
        )
        discriminator_params = list(self.discriminator.parameters())

        ae_opt = torch.optim.AdamW(
            encoder_params + decoder_params + other_params,
            lr=self.learning_rate,
            betas=(0.5, 0.9),
        )

        d_opt = torch.optim.AdamW(
            discriminator_params, lr=self.learning_rate, betas=(0.5, 0.9)
        )

        return [ae_opt, d_opt]