cloneofsimo commited on
Commit
c4ff943
1 Parent(s): 084a8fa

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +88 -3
README.md CHANGED
@@ -1,3 +1,88 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ ---
4
+
5
+ # Equivarient 16ch, f8 VAE
6
+
7
+ <video controls autoplay src="https://cdn-uploads.huggingface.co/production/uploads/6311151c64939fabc00c8436/6DQGRWvQvDXp2xQlvwvwU.mp4"></video>
8
+
9
+ AuraEquiVAE is novel autoencoder that fixes multiple problem of existing conventional VAE. First, unlike traditional VAE that has significantly small log-variance, this model admits large noise to the latent.
10
+ Next, unlike traditional VAE the latent space is equivariant under `Z_2 X Z_2` group operation (Horizonal / Vertical flip).
11
+
12
+ To understand the equivariance, we give suitable group action to both latent globally but also locally. Meaning, latent represented as `Z = (z_1, \cdots, z_n)` and performing the permutation group action `g_global` to the tuples such that `g_global` is isomorphic to `Z_2 x Z_2` group.
13
+ But also `g_local` to individual `z_i` themselves such that `g_local` is also isomorphic to `Z_2 x Z_2`.
14
+
15
+ In our case specifically, `g_global` corresponds to flips, `g_local` corresponds to sign flip on specific latent dimension. changing 2 channel for sign flip for both horizonal, vertical was chosen empirically.
16
+
17
+ The model has been trained on [Mastering VAE Training](https://github.com/cloneofsimo/vqgan-training), and detailed explanation for training could be found there.
18
+
19
+ ## How to use
20
+
21
+ To use the weights, copy paste the [VAE](https://github.com/cloneofsimo/vqgan-training/blob/03e04401cf49fe55be612d1f568be0110aa0fad1/ae.py) implementation.
22
+
23
+ ```python
24
+ from ae import VAE
25
+ import torch
26
+ from PIL import Image
27
+
28
+ vae = VAE(
29
+ resolution=256,
30
+ in_channels=3,
31
+ ch=256,
32
+ out_ch=3,
33
+ ch_mult=[1, 2, 4, 4],
34
+ num_res_blocks=2,
35
+ z_ch
36
+ ).cuda().bfloat16()
37
+
38
+ from safetensors.torch import load_file
39
+ state_dict = load_file("./vae_epoch_3_step_49501_bf16.pt")
40
+ vae.load_state_dict(state_dict)
41
+
42
+ imgpath = 'contents/lavender.jpg'
43
+
44
+ img_orig = Image.open(imgpath).convert("RGB")
45
+ offset = 128
46
+ W = 768
47
+ img_orig = img_orig.crop((offset, offset, W + offset, W + offset))
48
+ img = transforms.ToTensor()(img_orig).unsqueeze(0).cuda()
49
+ img = (img - 0.5) / 0.5
50
+
51
+ with torch.no_grad():
52
+ z = vae.encoder(img)
53
+ z = z.clamp(-8.0, 8.0) # this is latent!!
54
+
55
+ # flip horizontal
56
+ z = torch.flip(z, [-1]) # this corresponds to g_global
57
+ z[:, -4:-2] = -z[:, -4:-2] # this corresponds to g_local
58
+
59
+ # flip vertical
60
+ z = torch.flip(z, [-2])
61
+ z[:, -2:] = -z[:, -2:]
62
+
63
+
64
+ with torch.no_grad():
65
+ decz = vae.decoder(z) # this is image!
66
+
67
+ decimg = ((decz + 1) / 2).clamp(0, 1).squeeze(0).cpu().float().numpy().transpose(1, 2, 0)
68
+ decimg = (decimg * 255).astype('uint8')
69
+ decimg = Image.fromarray(decimg) # PIL image.
70
+
71
+ ```
72
+
73
+ ## Citation
74
+
75
+ If you find this material useful, please cite:
76
+
77
+ ```
78
+ @misc{Training VQGAN and VAE, with detailed explanation,
79
+ author = {Simo Ryu},
80
+ title = {Training VQGAN and VAE, with detailed explanation},
81
+ year = {2024},
82
+ publisher = {GitHub},
83
+ journal = {GitHub repository},
84
+ howpublished = {\url{https://github.com/cloneofsimo/vqgan-training}},
85
+ }
86
+ ```
87
+
88
+