import torch | |
from torchinfo import summary | |
from swim.autoencoder import Autoencoder | |
from diffusers import AutoencoderKL, UNet2DModel | |
# vae = Autoencoder( | |
# z_channels=4, | |
# in_channels=3, | |
# channels=128, | |
# channel_multipliers=[1, 2, 4, 4], | |
# n_resnet_blocks=2, | |
# emb_channels=4, | |
# ).to("meta") | |
# lol_vae = AutoencoderKL.from_pretrained( | |
# "stabilityai/stable-diffusion-2-1", subfolder="vae" | |
# ).to("meta") | |
# # copy weights from lol_vae to vae | |
# import json | |
# with open("lolvae.json", "w") as f: | |
# json.dump(list(lol_vae.state_dict().keys()), f, indent=4) | |
# with open("vae.json", "w") as f: | |
# json.dump(list(vae.state_dict().keys()), f, indent=4) | |
# sample = torch.randn(1, 3, 512, 512).to("meta") | |
# # lantent = vae.encoder(sample) | |
from diffusers import UNet2DModel | |
model = UNet2DModel( | |
sample_size=512, # the target image resolution | |
in_channels=3, # the number of input channels, 3 for RGB images | |
out_channels=3, # the number of output channels | |
layers_per_block=2, # how many ResNet layers to use per UNet block | |
block_out_channels=( | |
128, | |
128, | |
256, | |
256, | |
512, | |
512, | |
), # the number of output channels for each UNet block | |
down_block_types=( | |
"DownBlock2D", # a regular ResNet downsampling block | |
"DownBlock2D", | |
"DownBlock2D", | |
"DownBlock2D", | |
"AttnDownBlock2D", # a ResNet downsampling block with spatial self-attention | |
"DownBlock2D", | |
), | |
up_block_types=( | |
"UpBlock2D", # a regular ResNet upsampling block | |
"AttnUpBlock2D", # a ResNet upsampling block with spatial self-attention | |
"UpBlock2D", | |
"UpBlock2D", | |
"UpBlock2D", | |
"UpBlock2D", | |
), | |
).to("meta") | |
sample = torch.randn(1, 3, 512, 512).to("meta") | |
summary( | |
model, | |
input_data=( | |
sample, | |
0, | |
), | |
) | |