swim / train.py
qninhdt's picture
cc
798fdd3
raw
history blame
1.9 kB
import torch
from torchinfo import summary
from swim.autoencoder import Autoencoder
from diffusers import AutoencoderKL, UNet2DModel
# vae = Autoencoder(
# z_channels=4,
# in_channels=3,
# channels=128,
# channel_multipliers=[1, 2, 4, 4],
# n_resnet_blocks=2,
# emb_channels=4,
# ).to("meta")
# lol_vae = AutoencoderKL.from_pretrained(
# "stabilityai/stable-diffusion-2-1", subfolder="vae"
# ).to("meta")
# # copy weights from lol_vae to vae
# import json
# with open("lolvae.json", "w") as f:
# json.dump(list(lol_vae.state_dict().keys()), f, indent=4)
# with open("vae.json", "w") as f:
# json.dump(list(vae.state_dict().keys()), f, indent=4)
# sample = torch.randn(1, 3, 512, 512).to("meta")
# # lantent = vae.encoder(sample)
from diffusers import UNet2DModel
model = UNet2DModel(
sample_size=512, # the target image resolution
in_channels=3, # the number of input channels, 3 for RGB images
out_channels=3, # the number of output channels
layers_per_block=2, # how many ResNet layers to use per UNet block
block_out_channels=(
128,
128,
256,
256,
512,
512,
), # the number of output channels for each UNet block
down_block_types=(
"DownBlock2D", # a regular ResNet downsampling block
"DownBlock2D",
"DownBlock2D",
"DownBlock2D",
"AttnDownBlock2D", # a ResNet downsampling block with spatial self-attention
"DownBlock2D",
),
up_block_types=(
"UpBlock2D", # a regular ResNet upsampling block
"AttnUpBlock2D", # a ResNet upsampling block with spatial self-attention
"UpBlock2D",
"UpBlock2D",
"UpBlock2D",
"UpBlock2D",
),
).to("meta")
sample = torch.randn(1, 3, 512, 512).to("meta")
summary(
model,
input_data=(
sample,
0,
),
)