import torch from torch import nn # VAE Decoder but with half-size output. # The last upsample is not there. ### # Code from madebyollin/taesd class Recon(nn.Module): def __init__(self, ch_in, ch_out): super().__init__() self.long = nn.Sequential( nn.Conv2d(ch_in, ch_out, 3, padding=1), nn.ReLU(), nn.Conv2d(ch_out, ch_out, 3, padding=1), nn.ReLU(), nn.Conv2d(ch_out, ch_out, 3, padding=1) ) if ch_in != ch_out: self.short = nn.Conv2d(ch_in, ch_out, 1, bias=False) else: # The one without identity, a placeholder. self.short = nn.Identity() self.fuse = nn.ReLU() def forward(self, x): return self.fuse(self.long(x) + self.short(x)) class TeaDecoder(nn.Module): def __init__(self, ch_in): super().__init__() self.block_in = nn.Sequential( nn.Conv2d(ch_in, 64, 3, padding=1), nn.ReLU() ) self.middle = nn.Sequential( *[Recon(64, 64) for _ in range(3)], # Opposite of stride=2 nn.Upsample(scale_factor=2), # It leads to a simpler model with fewer parameters. # The output of the previous layers matches the number of channels specified in this line. # The input to this layer is already well-represented by the feature maps from the previous layers, # the bias may not add significant value. nn.Conv2d(64, 64, 3, padding=1, bias=False), # Final upscale to 1/2 size of the image. *[Recon(64, 64) for _ in range(3)], nn.Upsample(scale_factor=2), nn.Conv2d(64, 64, 3, padding=1, bias=False), ) self.block_out = nn.Sequential( Recon(64, 64), # Convert to RGB, regardless of the latent channels. nn.Conv2d(64, 3, 3, padding=1), ) def forward(self, x): # Clamp the input values to a specific range. clamped = torch.tanh(x / 1) cooked = self.middle(self.block_in(clamped)) return self.block_out(cooked)