File size: 1,895 Bytes
9b66f69
 
 
798fdd3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b66f69
 
798fdd3
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import torch
from torchinfo import summary

from swim.autoencoder import Autoencoder
from diffusers import AutoencoderKL, UNet2DModel

# vae = Autoencoder(
#     z_channels=4,
#     in_channels=3,
#     channels=128,
#     channel_multipliers=[1, 2, 4, 4],
#     n_resnet_blocks=2,
#     emb_channels=4,
# ).to("meta")
# lol_vae = AutoencoderKL.from_pretrained(
#     "stabilityai/stable-diffusion-2-1", subfolder="vae"
# ).to("meta")

# # copy weights from lol_vae to vae
# import json

# with open("lolvae.json", "w") as f:
#     json.dump(list(lol_vae.state_dict().keys()), f, indent=4)

# with open("vae.json", "w") as f:
#     json.dump(list(vae.state_dict().keys()), f, indent=4)

# sample = torch.randn(1, 3, 512, 512).to("meta")
# # lantent = vae.encoder(sample)

from diffusers import UNet2DModel

model = UNet2DModel(
    sample_size=512,  # the target image resolution
    in_channels=3,  # the number of input channels, 3 for RGB images
    out_channels=3,  # the number of output channels
    layers_per_block=2,  # how many ResNet layers to use per UNet block
    block_out_channels=(
        128,
        128,
        256,
        256,
        512,
        512,
    ),  # the number of output channels for each UNet block
    down_block_types=(
        "DownBlock2D",  # a regular ResNet downsampling block
        "DownBlock2D",
        "DownBlock2D",
        "DownBlock2D",
        "AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention
        "DownBlock2D",
    ),
    up_block_types=(
        "UpBlock2D",  # a regular ResNet upsampling block
        "AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention
        "UpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
    ),
).to("meta")

sample = torch.randn(1, 3, 512, 512).to("meta")

summary(
    model,
    input_data=(
        sample,
        0,
    ),
)