LiteVAE: Lightweight and Efficient Variational Autoencoders for Latent Diffusion Models
Abstract
Advances in latent diffusion models (LDMs) have revolutionized high-resolution image generation, but the design space of the autoencoder that is central to these systems remains underexplored. In this paper, we introduce LiteVAE, a family of autoencoders for LDMs that leverage the 2D discrete wavelet transform to enhance scalability and computational efficiency over standard variational autoencoders (VAEs) with no sacrifice in output quality. We also investigate the training methodologies and the decoder architecture of LiteVAE and propose several enhancements that improve the training dynamics and reconstruction quality. Our base LiteVAE model matches the quality of the established VAEs in current LDMs with a six-fold reduction in encoder parameters, leading to faster training and lower GPU memory requirements, while our larger model outperforms VAEs of comparable complexity across all evaluated metrics (rFID, LPIPS, PSNR, and SSIM).
Community
Plain english rewrite of the paper - feedback from the authors is much appreciated! https://www.aimodels.fyi/papers/arxiv/litevae-lightweight-efficient-variational-autoencoders-latent-diffusion
I really enjoyed reading your paper!
Your technique seems very logical and straight forward - I hope to use it on other encoding problems.
For me personally, these parts stood out:
I haven't used a wavelet transform before and was surprised at how easy they are to generate and how well they naturally fit into the early stages of an encoder.
I hadn't seen the adaptive weighting scheme for
lambda_adv
mentioned on page 3. Seemed interesting, but looks like you didn't end up using it (according to the appendix).I was a little sad the decoder was the original SD decoder. I was so blown away by your small and efficient encoder that it was sad to see the big nasty SD decoder used. :-)
It was nice to see you doing pertaining at a lower res. This used to be in fashion but you see less and less of it these days. I did find it interesting that you pertained only at 1/2 res. I'm curious if you tried 1/4 res or if that was too destructive.
Your self modulated convolution is very interesting. One thing that I don't think the paper made clear was where exactly in the decoder they are used. The paper says "using SMC in the decoder balances feature maps." But you didn't say which layers exactly you replaced/augmented with it. For example, did you do a 1:1 replacement of GroupNorm layers only? Or do your convolutions replace the decoder convolutions followed by the group norm. Is SMC also used on the residual skip connections when there is a channel count change? I would have liked to have seen more about this layer and how it's best used but I guess that would have been a second paper.
I have never used Charbonnier loss before and can't wait to try it out.
Thanks for the paper! It's a great encoder and I learned a lot of new tricks. Keep up the good work.
-Frank
Dear Frank,
Thank you for your interest in our work, and I'm glad you liked the results. I'll address your questions below, and please feel free to contact us if you have additional comments.
Adaptive weight was introduced in the VQGAN paper, and some autoencoders use it during training. However, with our U-Net discriminator, we found it largely unnecessary and removed it for more stable training of the VAE. For us, it created large floating-point operations that caused stability issues, especially in mixed-precision training.
Please note that the decoder can be distilled independently of the autoencoder after training. Since it does not change the latent space, one can replace the decoder with whatever network works best for the underlying problem. This is why we focused on the encoder, as it cannot be changed afterwards and determines the efficiency of training in both VAEs and LDMs.
We have not tried pre-training at 1/4 resolution because the encoder has a downsampling factor of 8x, and we felt it might be too destructive. However, I'll be happy to try it and see whether it affects the final quality. Given that using 1/2 resolution worked very well for us, I assume that later fine-tuning at 1/2 and 1x resolution will address the compression issues at 1/4.
We used SMC whenever there was a combination of GN + Conv in the decoder. Thus, SMC replaces both GN and Conv (If I remember correctly, this even helped the efficiency of the decoder because SMC was faster than GN + Conv). Please find a pseudocode for this below. The skip connections have normal convolution, as the original decoder also doesn't have any normalization in the skip connection path. In our experiments, we found that SMC and GN can be even applied side-by-side (e.g., using GN + Conv for the low-resolution blocks and only replacing the GN in the higher resolution blocks since we noted blob artifacts are more present in higher resolutions). We hope our observation leads to further investigation into the use of GN in decoder networks and mitigating its side effects.
# GN operation
h = self.group_norm(x)
h = self.conv(h, w)
h = self.act(h)
# SMC operation
h = self.smc(h, w)
h = self.act(h)
Once again, thank you for your feedback and your interest in our work!
Best regards,
Morteza
Hello, in Figure 10 of the paper, the pseudocode for SMC is published, which involves the modulated_conv2d function. Is there a PyTorch implementation of this function available?thanks
you!!
Hello. Thank you for your message. The modulated_conv2d
function is imported from the StyleGAN implementation:
Please let me know if you have additional questions.
Hi, This paper is a lot promising and I am interested to apply the model idea in my research work. Is it possible if you can put some rough ideas about the feature extraction and aggregation blocks in the encoder part.
Thank you for your interest in our paper. The feature-extraction and feature-aggregation modules are based on UNets, derived from the ADM implementation (https://github.com/openai/guided-diffusion). The key difference is that we removed the time-step embedding input and the down/up-sampling blocks from the model. As a result, each network is essentially a series of residual blocks with identical input and output resolutions.
If you have any further questions, please feel free to contact us.
Models citing this paper 0
No model linking this paper
Datasets citing this paper 0
No dataset linking this paper
Spaces citing this paper 0
No Space linking this paper