import torch from torchinfo import summary from swim.autoencoder import Autoencoder from diffusers import AutoencoderKL, UNet2DModel # vae = Autoencoder( # z_channels=4, # in_channels=3, # channels=128, # channel_multipliers=[1, 2, 4, 4], # n_resnet_blocks=2, # emb_channels=4, # ).to("meta") # lol_vae = AutoencoderKL.from_pretrained( # "stabilityai/stable-diffusion-2-1", subfolder="vae" # ).to("meta") # # copy weights from lol_vae to vae # import json # with open("lolvae.json", "w") as f: # json.dump(list(lol_vae.state_dict().keys()), f, indent=4) # with open("vae.json", "w") as f: # json.dump(list(vae.state_dict().keys()), f, indent=4) # sample = torch.randn(1, 3, 512, 512).to("meta") # # lantent = vae.encoder(sample) from diffusers import UNet2DModel model = UNet2DModel( sample_size=512, # the target image resolution in_channels=3, # the number of input channels, 3 for RGB images out_channels=3, # the number of output channels layers_per_block=2, # how many ResNet layers to use per UNet block block_out_channels=( 128, 128, 256, 256, 512, 512, ), # the number of output channels for each UNet block down_block_types=( "DownBlock2D", # a regular ResNet downsampling block "DownBlock2D", "DownBlock2D", "DownBlock2D", "AttnDownBlock2D", # a ResNet downsampling block with spatial self-attention "DownBlock2D", ), up_block_types=( "UpBlock2D", # a regular ResNet upsampling block "AttnUpBlock2D", # a ResNet upsampling block with spatial self-attention "UpBlock2D", "UpBlock2D", "UpBlock2D", "UpBlock2D", ), ).to("meta") sample = torch.randn(1, 3, 512, 512).to("meta") summary( model, input_data=( sample, 0, ), )