cc
Browse files- scripts/build_cyclegan_dataset.py +47 -19
- swim/__init__.py +0 -0
- swim/attention_blocks.py +315 -0
- swim/autoencoder.py +0 -0
- swim/blocks.py +185 -0
- swim/codeblock.py +74 -0
- swim/discriminator.py +45 -0
- swim/encoder.py +90 -0
- swim/unet.py +169 -0
- train.py +8 -0
scripts/build_cyclegan_dataset.py
CHANGED
@@ -8,7 +8,8 @@ from tqdm import tqdm
|
|
8 |
@click.option("--swim_dir", type=str, default="datasets/swim_data")
|
9 |
@click.option("--output_dir", type=str, default="datasets/swim_data_cyclegan")
|
10 |
@click.option("--type", type=str, help="fog|rain|snow|night", required=True)
|
11 |
-
|
|
|
12 |
# build the dataset with format
|
13 |
# swim_data_cyclegan
|
14 |
# βββ trainA
|
@@ -42,25 +43,52 @@ def build_cyclegan_dataset(swim_dir: str, output_dir: str, type: str):
|
|
42 |
with open(os.path.join(swim_dir, "val", "labels.json"), "r") as f:
|
43 |
val_labels = json.load(f)
|
44 |
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
)
|
50 |
-
elif label["weather"] == "clear":
|
51 |
-
os.system(
|
52 |
-
f"cp {os.path.join(swim_dir, 'train', 'images', label['name'])} {os.path.join(output_dir, 'trainA', label['name'])}"
|
53 |
-
)
|
54 |
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
|
65 |
|
66 |
if __name__ == "__main__":
|
|
|
8 |
@click.option("--swim_dir", type=str, default="datasets/swim_data")
|
9 |
@click.option("--output_dir", type=str, default="datasets/swim_data_cyclegan")
|
10 |
@click.option("--type", type=str, help="fog|rain|snow|night", required=True)
|
11 |
+
@click.option("--no_night", is_flag=True)
|
12 |
+
def build_cyclegan_dataset(swim_dir: str, output_dir: str, type: str, no_night: bool):
|
13 |
# build the dataset with format
|
14 |
# swim_data_cyclegan
|
15 |
# βββ trainA
|
|
|
43 |
with open(os.path.join(swim_dir, "val", "labels.json"), "r") as f:
|
44 |
val_labels = json.load(f)
|
45 |
|
46 |
+
if type != "night":
|
47 |
+
for label in tqdm(train_labels, desc="train"):
|
48 |
+
if no_night and label["timeofdata"] == "night":
|
49 |
+
continue
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
+
if label["weather"] == type:
|
52 |
+
os.system(
|
53 |
+
f"cp {os.path.join(swim_dir, 'train', 'images', label['name'])} {os.path.join(output_dir, 'trainB', label['name'])}"
|
54 |
+
)
|
55 |
+
elif label["weather"] == "clear":
|
56 |
+
os.system(
|
57 |
+
f"cp {os.path.join(swim_dir, 'train', 'images', label['name'])} {os.path.join(output_dir, 'trainA', label['name'])}"
|
58 |
+
)
|
59 |
+
|
60 |
+
for label in tqdm(val_labels, desc="val"):
|
61 |
+
if no_night and label["timeofdata"] == "night":
|
62 |
+
continue
|
63 |
+
|
64 |
+
if label["weather"] == type:
|
65 |
+
os.system(
|
66 |
+
f"cp {os.path.join(swim_dir, 'val', 'images', label['name'])} {os.path.join(output_dir, 'testB', label['name'])}"
|
67 |
+
)
|
68 |
+
elif label["weather"] == "clear":
|
69 |
+
os.system(
|
70 |
+
f"cp {os.path.join(swim_dir, 'val', 'images', label['name'])} {os.path.join(output_dir, 'testA', label['name'])}"
|
71 |
+
)
|
72 |
+
else:
|
73 |
+
for label in tqdm(train_labels, desc="train"):
|
74 |
+
if label["timeofdata"] == "night":
|
75 |
+
os.system(
|
76 |
+
f"cp {os.path.join(swim_dir, 'train', 'images', label['name'])} {os.path.join(output_dir, 'trainB', label['name'])}"
|
77 |
+
)
|
78 |
+
elif label["timeofdata"] == "daytime":
|
79 |
+
os.system(
|
80 |
+
f"cp {os.path.join(swim_dir, 'train', 'images', label['name'])} {os.path.join(output_dir, 'trainA', label['name'])}"
|
81 |
+
)
|
82 |
+
|
83 |
+
for label in tqdm(val_labels, desc="val"):
|
84 |
+
if label["timeofdata"] == "night":
|
85 |
+
os.system(
|
86 |
+
f"cp {os.path.join(swim_dir, 'val', 'images', label['name'])} {os.path.join(output_dir, 'testB', label['name'])}"
|
87 |
+
)
|
88 |
+
elif label["timeofdata"] == "daytime":
|
89 |
+
os.system(
|
90 |
+
f"cp {os.path.join(swim_dir, 'val', 'images', label['name'])} {os.path.join(output_dir, 'testA', label['name'])}"
|
91 |
+
)
|
92 |
|
93 |
|
94 |
if __name__ == "__main__":
|
swim/__init__.py
ADDED
File without changes
|
swim/attention_blocks.py
ADDED
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch import nn
|
6 |
+
|
7 |
+
|
8 |
+
class SpatialTransformer(nn.Module):
|
9 |
+
"""
|
10 |
+
## Spatial Transformer
|
11 |
+
"""
|
12 |
+
|
13 |
+
def __init__(self, channels: int, n_heads: int, n_layers: int, d_cond: int):
|
14 |
+
"""
|
15 |
+
:param channels: is the number of channels in the feature map
|
16 |
+
:param n_heads: is the number of attention heads
|
17 |
+
:param n_layers: is the number of transformer layers
|
18 |
+
:param d_cond: is the size of the conditional embedding
|
19 |
+
"""
|
20 |
+
super().__init__()
|
21 |
+
# Initial group normalization
|
22 |
+
self.norm = torch.nn.GroupNorm(
|
23 |
+
num_groups=32, num_channels=channels, eps=1e-6, affine=True
|
24 |
+
)
|
25 |
+
# Initial $1 \times 1$ convolution
|
26 |
+
self.proj_in = nn.Conv2d(channels, channels, kernel_size=1, stride=1, padding=0)
|
27 |
+
|
28 |
+
# Transformer layers
|
29 |
+
self.transformer_blocks = nn.ModuleList(
|
30 |
+
[
|
31 |
+
BasicTransformerBlock(
|
32 |
+
channels, n_heads, channels // n_heads, d_cond=d_cond
|
33 |
+
)
|
34 |
+
for _ in range(n_layers)
|
35 |
+
]
|
36 |
+
)
|
37 |
+
|
38 |
+
# Final $1 \times 1$ convolution
|
39 |
+
self.proj_out = nn.Conv2d(
|
40 |
+
channels, channels, kernel_size=1, stride=1, padding=0
|
41 |
+
)
|
42 |
+
|
43 |
+
def forward(self, x: torch.Tensor, cond: torch.Tensor):
|
44 |
+
"""
|
45 |
+
:param x: is the feature map of shape `[batch_size, channels, height, width]`
|
46 |
+
:param cond: is the conditional embeddings of shape `[batch_size, n_cond, d_cond]`
|
47 |
+
"""
|
48 |
+
# Get shape `[batch_size, channels, height, width]`
|
49 |
+
b, c, h, w = x.shape
|
50 |
+
# For residual connection
|
51 |
+
x_in = x
|
52 |
+
# Normalize
|
53 |
+
x = self.norm(x)
|
54 |
+
# Initial $1 \times 1$ convolution
|
55 |
+
x = self.proj_in(x)
|
56 |
+
# Transpose and reshape from `[batch_size, channels, height, width]`
|
57 |
+
# to `[batch_size, height * width, channels]`
|
58 |
+
x = x.permute(0, 2, 3, 1).view(b, h * w, c)
|
59 |
+
# Apply the transformer layers
|
60 |
+
for block in self.transformer_blocks:
|
61 |
+
x = block(x, cond)
|
62 |
+
# Reshape and transpose from `[batch_size, height * width, channels]`
|
63 |
+
# to `[batch_size, channels, height, width]`
|
64 |
+
x = x.view(b, h, w, c).permute(0, 3, 1, 2)
|
65 |
+
# Final $1 \times 1$ convolution
|
66 |
+
x = self.proj_out(x)
|
67 |
+
# Add residual
|
68 |
+
return x + x_in
|
69 |
+
|
70 |
+
|
71 |
+
class BasicTransformerBlock(nn.Module):
|
72 |
+
"""
|
73 |
+
### Transformer Layer
|
74 |
+
"""
|
75 |
+
|
76 |
+
def __init__(self, d_model: int, n_heads: int, d_head: int, d_cond: int):
|
77 |
+
"""
|
78 |
+
:param d_model: is the input embedding size
|
79 |
+
:param n_heads: is the number of attention heads
|
80 |
+
:param d_head: is the size of a attention head
|
81 |
+
:param d_cond: is the size of the conditional embeddings
|
82 |
+
"""
|
83 |
+
super().__init__()
|
84 |
+
# Self-attention layer and pre-norm layer
|
85 |
+
self.attn1 = CrossAttention(d_model, d_model, n_heads, d_head)
|
86 |
+
self.norm1 = nn.LayerNorm(d_model)
|
87 |
+
# Cross attention layer and pre-norm layer
|
88 |
+
self.attn2 = CrossAttention(d_model, d_cond, n_heads, d_head)
|
89 |
+
self.norm2 = nn.LayerNorm(d_model)
|
90 |
+
# Feed-forward network and pre-norm layer
|
91 |
+
self.ff = FeedForward(d_model)
|
92 |
+
self.norm3 = nn.LayerNorm(d_model)
|
93 |
+
|
94 |
+
def forward(self, x: torch.Tensor, cond: torch.Tensor):
|
95 |
+
"""
|
96 |
+
:param x: are the input embeddings of shape `[batch_size, height * width, d_model]`
|
97 |
+
:param cond: is the conditional embeddings of shape `[batch_size, n_cond, d_cond]`
|
98 |
+
"""
|
99 |
+
# Self attention
|
100 |
+
x = self.attn1(self.norm1(x)) + x
|
101 |
+
# Cross-attention with conditioning
|
102 |
+
x = self.attn2(self.norm2(x), cond=cond) + x
|
103 |
+
# Feed-forward network
|
104 |
+
x = self.ff(self.norm3(x)) + x
|
105 |
+
#
|
106 |
+
return x
|
107 |
+
|
108 |
+
|
109 |
+
class CrossAttention(nn.Module):
|
110 |
+
"""
|
111 |
+
### Cross Attention Layer
|
112 |
+
|
113 |
+
This falls-back to self-attention when conditional embeddings are not specified.
|
114 |
+
"""
|
115 |
+
|
116 |
+
use_flash_attention: bool = False
|
117 |
+
|
118 |
+
def __init__(
|
119 |
+
self,
|
120 |
+
d_model: int,
|
121 |
+
d_cond: int,
|
122 |
+
n_heads: int,
|
123 |
+
d_head: int,
|
124 |
+
is_inplace: bool = True,
|
125 |
+
):
|
126 |
+
"""
|
127 |
+
:param d_model: is the input embedding size
|
128 |
+
:param n_heads: is the number of attention heads
|
129 |
+
:param d_head: is the size of a attention head
|
130 |
+
:param d_cond: is the size of the conditional embeddings
|
131 |
+
:param is_inplace: specifies whether to perform the attention softmax computation inplace to
|
132 |
+
save memory
|
133 |
+
"""
|
134 |
+
super().__init__()
|
135 |
+
|
136 |
+
self.is_inplace = is_inplace
|
137 |
+
self.n_heads = n_heads
|
138 |
+
self.d_head = d_head
|
139 |
+
|
140 |
+
# Attention scaling factor
|
141 |
+
self.scale = d_head**-0.5
|
142 |
+
|
143 |
+
# Query, key and value mappings
|
144 |
+
d_attn = d_head * n_heads
|
145 |
+
self.to_q = nn.Linear(d_model, d_attn, bias=False)
|
146 |
+
self.to_k = nn.Linear(d_cond, d_attn, bias=False)
|
147 |
+
self.to_v = nn.Linear(d_cond, d_attn, bias=False)
|
148 |
+
|
149 |
+
# Final linear layer
|
150 |
+
self.to_out = nn.Sequential(nn.Linear(d_attn, d_model))
|
151 |
+
|
152 |
+
# Setup [flash attention](https://github.com/HazyResearch/flash-attention).
|
153 |
+
# Flash attention is only used if it's installed
|
154 |
+
# and `CrossAttention.use_flash_attention` is set to `True`.
|
155 |
+
# try:
|
156 |
+
# # You can install flash attention by cloning their Github repo,
|
157 |
+
# # [https://github.com/HazyResearch/flash-attention](https://github.com/HazyResearch/flash-attention)
|
158 |
+
# # and then running `python setup.py install`
|
159 |
+
# from flash_attn.flash_attention import FlashAttention
|
160 |
+
|
161 |
+
# self.flash = FlashAttention()
|
162 |
+
# # Set the scale for scaled dot-product attention.
|
163 |
+
# self.flash.softmax_scale = self.scale
|
164 |
+
# # Set to `None` if it's not installed
|
165 |
+
# except ImportError:
|
166 |
+
# self.flash = None
|
167 |
+
|
168 |
+
def forward(self, x: torch.Tensor, cond: Optional[torch.Tensor] = None):
|
169 |
+
"""
|
170 |
+
:param x: are the input embeddings of shape `[batch_size, height * width, d_model]`
|
171 |
+
:param cond: is the conditional embeddings of shape `[batch_size, n_cond, d_cond]`
|
172 |
+
"""
|
173 |
+
|
174 |
+
# If `cond` is `None` we perform self attention
|
175 |
+
has_cond = cond is not None
|
176 |
+
if not has_cond:
|
177 |
+
cond = x
|
178 |
+
|
179 |
+
# Get query, key and value vectors
|
180 |
+
q = self.to_q(x)
|
181 |
+
k = self.to_k(cond)
|
182 |
+
v = self.to_v(cond)
|
183 |
+
|
184 |
+
# Use flash attention if it's available and the head size is less than or equal to `128`
|
185 |
+
if (
|
186 |
+
CrossAttention.use_flash_attention
|
187 |
+
and self.flash is not None
|
188 |
+
and not has_cond
|
189 |
+
and self.d_head <= 128
|
190 |
+
):
|
191 |
+
return self.flash_attention(q, k, v)
|
192 |
+
# Otherwise, fallback to normal attention
|
193 |
+
else:
|
194 |
+
return self.normal_attention(q, k, v)
|
195 |
+
|
196 |
+
def flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
|
197 |
+
"""
|
198 |
+
#### Flash Attention
|
199 |
+
|
200 |
+
:param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
|
201 |
+
:param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
|
202 |
+
:param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
|
203 |
+
"""
|
204 |
+
|
205 |
+
# Get batch size and number of elements along sequence axis (`width * height`)
|
206 |
+
batch_size, seq_len, _ = q.shape
|
207 |
+
|
208 |
+
# Stack `q`, `k`, `v` vectors for flash attention, to get a single tensor of
|
209 |
+
# shape `[batch_size, seq_len, 3, n_heads * d_head]`
|
210 |
+
qkv = torch.stack((q, k, v), dim=2)
|
211 |
+
# Split the heads
|
212 |
+
qkv = qkv.view(batch_size, seq_len, 3, self.n_heads, self.d_head)
|
213 |
+
|
214 |
+
# Flash attention works for head sizes `32`, `64` and `128`, so we have to pad the heads to
|
215 |
+
# fit this size.
|
216 |
+
if self.d_head <= 32:
|
217 |
+
pad = 32 - self.d_head
|
218 |
+
elif self.d_head <= 64:
|
219 |
+
pad = 64 - self.d_head
|
220 |
+
elif self.d_head <= 128:
|
221 |
+
pad = 128 - self.d_head
|
222 |
+
else:
|
223 |
+
raise ValueError(f"Head size ${self.d_head} too large for Flash Attention")
|
224 |
+
|
225 |
+
# Pad the heads
|
226 |
+
if pad:
|
227 |
+
qkv = torch.cat(
|
228 |
+
(qkv, qkv.new_zeros(batch_size, seq_len, 3, self.n_heads, pad)), dim=-1
|
229 |
+
)
|
230 |
+
|
231 |
+
# Compute attention
|
232 |
+
# $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$
|
233 |
+
# This gives a tensor of shape `[batch_size, seq_len, n_heads, d_padded]`
|
234 |
+
out, _ = self.flash(qkv)
|
235 |
+
# Truncate the extra head size
|
236 |
+
out = out[:, :, :, : self.d_head]
|
237 |
+
# Reshape to `[batch_size, seq_len, n_heads * d_head]`
|
238 |
+
out = out.reshape(batch_size, seq_len, self.n_heads * self.d_head)
|
239 |
+
|
240 |
+
# Map to `[batch_size, height * width, d_model]` with a linear layer
|
241 |
+
return self.to_out(out)
|
242 |
+
|
243 |
+
def normal_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
|
244 |
+
"""
|
245 |
+
#### Normal Attention
|
246 |
+
|
247 |
+
:param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
|
248 |
+
:param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
|
249 |
+
:param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
|
250 |
+
"""
|
251 |
+
|
252 |
+
# Split them to heads of shape `[batch_size, seq_len, n_heads, d_head]`
|
253 |
+
q = q.view(*q.shape[:2], self.n_heads, -1)
|
254 |
+
k = k.view(*k.shape[:2], self.n_heads, -1)
|
255 |
+
v = v.view(*v.shape[:2], self.n_heads, -1)
|
256 |
+
|
257 |
+
# Calculate attention $\frac{Q K^\top}{\sqrt{d_{key}}}$
|
258 |
+
attn = torch.einsum("bihd,bjhd->bhij", q, k) * self.scale
|
259 |
+
|
260 |
+
# Compute softmax
|
261 |
+
# $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)$$
|
262 |
+
if self.is_inplace:
|
263 |
+
half = attn.shape[0] // 2
|
264 |
+
attn[half:] = attn[half:].softmax(dim=-1)
|
265 |
+
attn[:half] = attn[:half].softmax(dim=-1)
|
266 |
+
else:
|
267 |
+
attn = attn.softmax(dim=-1)
|
268 |
+
|
269 |
+
# Compute attention output
|
270 |
+
# $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$
|
271 |
+
out = torch.einsum("bhij,bjhd->bihd", attn, v)
|
272 |
+
# Reshape to `[batch_size, height * width, n_heads * d_head]`
|
273 |
+
out = out.reshape(*out.shape[:2], -1)
|
274 |
+
# Map to `[batch_size, height * width, d_model]` with a linear layer
|
275 |
+
return self.to_out(out)
|
276 |
+
|
277 |
+
|
278 |
+
class FeedForward(nn.Module):
|
279 |
+
"""
|
280 |
+
### Feed-Forward Network
|
281 |
+
"""
|
282 |
+
|
283 |
+
def __init__(self, d_model: int, d_mult: int = 4):
|
284 |
+
"""
|
285 |
+
:param d_model: is the input embedding size
|
286 |
+
:param d_mult: is multiplicative factor for the hidden layer size
|
287 |
+
"""
|
288 |
+
super().__init__()
|
289 |
+
self.net = nn.Sequential(
|
290 |
+
GeGLU(d_model, d_model * d_mult),
|
291 |
+
nn.Dropout(0.0),
|
292 |
+
nn.Linear(d_model * d_mult, d_model),
|
293 |
+
)
|
294 |
+
|
295 |
+
def forward(self, x: torch.Tensor):
|
296 |
+
return self.net(x)
|
297 |
+
|
298 |
+
|
299 |
+
class GeGLU(nn.Module):
|
300 |
+
"""
|
301 |
+
### GeGLU Activation
|
302 |
+
|
303 |
+
$$\text{GeGLU}(x) = (xW + b) * \text{GELU}(xV + c)$$
|
304 |
+
"""
|
305 |
+
|
306 |
+
def __init__(self, d_in: int, d_out: int):
|
307 |
+
super().__init__()
|
308 |
+
# Combined linear projections $xW + b$ and $xV + c$
|
309 |
+
self.proj = nn.Linear(d_in, d_out * 2)
|
310 |
+
|
311 |
+
def forward(self, x: torch.Tensor):
|
312 |
+
# Get $xW + b$ and $xV + c$
|
313 |
+
x, gate = self.proj(x).chunk(2, dim=-1)
|
314 |
+
# $\text{GeGLU}(x) = (xW + b) * \text{GELU}(xV + c)$
|
315 |
+
return x * F.gelu(gate)
|
swim/autoencoder.py
ADDED
File without changes
|
swim/blocks.py
ADDED
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import abstractmethod
|
2 |
+
|
3 |
+
import math
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
from torch.nn import functional as F
|
7 |
+
|
8 |
+
|
9 |
+
def get_timestep_embedding(
|
10 |
+
timesteps: torch.Tensor, emb_dim: int, max_period: int = 10000
|
11 |
+
) -> torch.Tensor:
|
12 |
+
half_dim = emb_dim // 2
|
13 |
+
|
14 |
+
emb = math.log(max_period) / (half_dim - 1)
|
15 |
+
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
|
16 |
+
emb = emb.to(device=timesteps.device)
|
17 |
+
|
18 |
+
emb = timesteps.float()[:, None] * emb[None, :]
|
19 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
20 |
+
|
21 |
+
if emb_dim % 2 == 1:
|
22 |
+
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
23 |
+
|
24 |
+
return emb
|
25 |
+
|
26 |
+
|
27 |
+
class GroupNorm(nn.Module):
|
28 |
+
def __init__(self, in_channels: int) -> None:
|
29 |
+
super().__init__()
|
30 |
+
|
31 |
+
self.group_norm = nn.GroupNorm(
|
32 |
+
num_groups=32, num_channels=in_channels, eps=1e-06, affine=True
|
33 |
+
)
|
34 |
+
|
35 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
36 |
+
return self.group_norm(x)
|
37 |
+
|
38 |
+
|
39 |
+
class UpsampleBlock(nn.Module):
|
40 |
+
def __init__(self, channels: int):
|
41 |
+
super().__init__()
|
42 |
+
self.conv = nn.Conv2d(channels, channels, 3, padding=1)
|
43 |
+
|
44 |
+
def forward(self, x: torch.Tensor):
|
45 |
+
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
46 |
+
return self.conv(x)
|
47 |
+
|
48 |
+
|
49 |
+
class DownsampleBlock(nn.Module):
|
50 |
+
def __init__(self, channels: int):
|
51 |
+
super().__init__()
|
52 |
+
self.op = nn.Conv2d(channels, channels, 3, stride=2, padding=1)
|
53 |
+
|
54 |
+
def forward(self, x: torch.Tensor):
|
55 |
+
return self.op(x)
|
56 |
+
|
57 |
+
|
58 |
+
class TimestepBlock(nn.Module):
|
59 |
+
@abstractmethod
|
60 |
+
def forward(self, x: torch.Tensor, t_emb: torch.Tensor) -> torch.Tensor:
|
61 |
+
pass
|
62 |
+
|
63 |
+
|
64 |
+
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
65 |
+
def forward(self, x: torch.Tensor, t_emb: torch.Tensor) -> torch.Tensor:
|
66 |
+
for layer in self:
|
67 |
+
if isinstance(layer, TimestepBlock):
|
68 |
+
x = layer(x, t_emb)
|
69 |
+
else:
|
70 |
+
x = layer(x)
|
71 |
+
return x
|
72 |
+
|
73 |
+
|
74 |
+
class ResnetBlock(nn.Module):
|
75 |
+
|
76 |
+
def __init__(
|
77 |
+
self,
|
78 |
+
in_channels: int,
|
79 |
+
out_channels: int = None,
|
80 |
+
t_emb_dim: int = None,
|
81 |
+
dropout: float = 0.0,
|
82 |
+
):
|
83 |
+
super().__init__()
|
84 |
+
|
85 |
+
if out_channels is None:
|
86 |
+
out_channels = in_channels
|
87 |
+
|
88 |
+
self.input_layers = nn.Sequential(
|
89 |
+
GroupNorm(in_channels),
|
90 |
+
nn.SiLU(),
|
91 |
+
nn.Conv2d(in_channels, out_channels, 3, padding=1),
|
92 |
+
)
|
93 |
+
|
94 |
+
if t_emb_dim is not None:
|
95 |
+
self.t_emb_layers = nn.Sequential(
|
96 |
+
nn.SiLU(),
|
97 |
+
nn.Linear(t_emb_dim, out_channels),
|
98 |
+
)
|
99 |
+
else:
|
100 |
+
self.t_emb_layers = None
|
101 |
+
|
102 |
+
self.output_layers = nn.Sequential(
|
103 |
+
GroupNorm(out_channels),
|
104 |
+
nn.SiLU(),
|
105 |
+
nn.Dropout(dropout),
|
106 |
+
nn.Conv2d(out_channels, out_channels, 3, padding=1),
|
107 |
+
)
|
108 |
+
|
109 |
+
if in_channels != out_channels:
|
110 |
+
self.skip = nn.Conv2d(in_channels, out_channels, 1)
|
111 |
+
else:
|
112 |
+
self.skip = nn.Identity()
|
113 |
+
|
114 |
+
def forward(self, x: torch.Tensor, t: torch.Tensor = None) -> torch.Tensor:
|
115 |
+
assert t is not None or self.t_emb_layers is None
|
116 |
+
|
117 |
+
h = self.input_layers(x)
|
118 |
+
|
119 |
+
if self.t_emb_layers is not None:
|
120 |
+
t_emb = self.t_emb_layers(t)
|
121 |
+
h = h + t_emb[:, :, None, None]
|
122 |
+
|
123 |
+
h = self.output_layers(h)
|
124 |
+
|
125 |
+
h = h + self.skip(x)
|
126 |
+
|
127 |
+
return h
|
128 |
+
|
129 |
+
|
130 |
+
class AttentionBlock(nn.Module):
|
131 |
+
"""Attention mechanism similar to transformers but for CNNs, paper https://arxiv.org/abs/1805.08318
|
132 |
+
|
133 |
+
Args:
|
134 |
+
in_channels (int): Number of channels in the input tensor.
|
135 |
+
"""
|
136 |
+
|
137 |
+
def __init__(self, in_channels: int) -> None:
|
138 |
+
super().__init__()
|
139 |
+
|
140 |
+
self.in_channels = in_channels
|
141 |
+
|
142 |
+
# normalization layer
|
143 |
+
self.norm = GroupNorm(in_channels)
|
144 |
+
|
145 |
+
# query, key and value layers
|
146 |
+
self.q = nn.Conv2d(in_channels, in_channels, 1, 1, 0)
|
147 |
+
self.k = nn.Conv2d(in_channels, in_channels, 1, 1, 0)
|
148 |
+
self.v = nn.Conv2d(in_channels, in_channels, 1, 1, 0)
|
149 |
+
|
150 |
+
self.project_out = nn.Conv2d(in_channels, in_channels, 1, 1, 0)
|
151 |
+
|
152 |
+
self.softmax = nn.Softmax(dim=2)
|
153 |
+
|
154 |
+
def forward(self, x):
|
155 |
+
|
156 |
+
batch, _, height, width = x.size()
|
157 |
+
|
158 |
+
x = self.norm(x)
|
159 |
+
|
160 |
+
# query, key and value layers
|
161 |
+
q = self.q(x)
|
162 |
+
k = self.k(x)
|
163 |
+
v = self.v(x)
|
164 |
+
|
165 |
+
# resizing the output from 4D to 3D to generate attention map
|
166 |
+
q = q.reshape(batch, self.in_channels, height * width)
|
167 |
+
k = k.reshape(batch, self.in_channels, height * width)
|
168 |
+
v = v.reshape(batch, self.in_channels, height * width)
|
169 |
+
|
170 |
+
# transpose the query tensor for dot product
|
171 |
+
q = q.permute(0, 2, 1)
|
172 |
+
|
173 |
+
# main attention formula
|
174 |
+
scores = torch.bmm(q, k) * (self.in_channels**-0.5)
|
175 |
+
weights = self.softmax(scores)
|
176 |
+
weights = weights.permute(0, 2, 1)
|
177 |
+
|
178 |
+
attention = torch.bmm(v, weights)
|
179 |
+
|
180 |
+
# resizing the output from 3D to 4D to match the input
|
181 |
+
attention = attention.reshape(batch, self.in_channels, height, width)
|
182 |
+
attention = self.project_out(attention)
|
183 |
+
|
184 |
+
# adding the identity to the output
|
185 |
+
return x + attention
|
swim/codeblock.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
class CodeBook(nn.Module):
|
6 |
+
def __init__(
|
7 |
+
self, num_codebook_vectors: int = 1024, latent_dim: int = 256, beta: int = 0.25
|
8 |
+
):
|
9 |
+
super().__init__()
|
10 |
+
|
11 |
+
self.num_codebook_vectors = num_codebook_vectors
|
12 |
+
self.latent_dim = latent_dim
|
13 |
+
self.beta = beta
|
14 |
+
|
15 |
+
# creating the codebook, nn.Embedding here is simply a 2D array mainly for storing our embeddings, it's also learnable
|
16 |
+
self.codebook = nn.Embedding(num_codebook_vectors, latent_dim)
|
17 |
+
|
18 |
+
# Initializing the weights in codebook in uniform distribution
|
19 |
+
self.codebook.weight.data.uniform_(
|
20 |
+
-1 / num_codebook_vectors, 1 / num_codebook_vectors
|
21 |
+
)
|
22 |
+
|
23 |
+
def forward(self, z: torch.Tensor) -> torch.Tensor:
|
24 |
+
# Channel to last dimension and copying the tensor to store it in a contiguous ( in a sequence ) way
|
25 |
+
z = z.permute(0, 2, 3, 1).contiguous()
|
26 |
+
|
27 |
+
z_flattened = z.view(
|
28 |
+
-1, self.latent_dim
|
29 |
+
) # b*h*w * latent_dim, will look similar to codebook in fig 2 of the paper
|
30 |
+
|
31 |
+
# calculating the distance between the z to the vectors in flattened codebook, from eq. 2
|
32 |
+
# (a - b)^2 = a^2 + b^2 - 2ab
|
33 |
+
distance = (
|
34 |
+
torch.sum(
|
35 |
+
z_flattened**2, dim=1, keepdim=True
|
36 |
+
) # keepdim = True to keep the same original shape after the sum
|
37 |
+
+ torch.sum(self.codebook.weight**2, dim=1)
|
38 |
+
- 2
|
39 |
+
* torch.matmul(
|
40 |
+
z_flattened, self.codebook.weight.t()
|
41 |
+
) # 2*dot(z, codebook.T)
|
42 |
+
)
|
43 |
+
|
44 |
+
# getting indices of vectors with minimum distance from the codebook
|
45 |
+
min_distance_indices = torch.argmin(distance, dim=1)
|
46 |
+
|
47 |
+
# getting the corresponding vector from the codebook
|
48 |
+
z_q = self.codebook(min_distance_indices).view(z.shape)
|
49 |
+
|
50 |
+
"""
|
51 |
+
this represent the equation 4 from the paper ( except the reconstruction loss ) . Thia loss will then be added
|
52 |
+
to GAN loss to create the final loss function for VQGAN, eq. 6 in the paper.
|
53 |
+
|
54 |
+
|
55 |
+
Note : In the first para of A. Changlog section of the paper,
|
56 |
+
they found a bug which resulted in beta equal to 1. here https://github.com/CompVis/taming-transformers/issues/57
|
57 |
+
just a note :)
|
58 |
+
"""
|
59 |
+
loss = torch.mean(
|
60 |
+
(z_q.detach() - z) ** 2
|
61 |
+
# detach() to avoid calculating gradient while backpropagating
|
62 |
+
+ self.beta
|
63 |
+
* torch.mean(
|
64 |
+
(z_q - z.detach()) ** 2
|
65 |
+
) # commitment loss, detach() to avoid calculating gradient while backpropagating
|
66 |
+
)
|
67 |
+
|
68 |
+
# Not sure why we need this, but it's in the original implementation and mentions for "preserving gradients"
|
69 |
+
z_q = z + (z_q - z).detach()
|
70 |
+
|
71 |
+
# reshapring to the original shape
|
72 |
+
z_q = z_q.permute(0, 3, 1, 2)
|
73 |
+
|
74 |
+
return z_q, min_distance_indices, loss
|
swim/discriminator.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
|
4 |
+
class Discriminator(nn.Module):
|
5 |
+
"""PatchGAN Discriminator
|
6 |
+
|
7 |
+
|
8 |
+
Args:
|
9 |
+
image_channels (int): Number of channels in the input image.
|
10 |
+
num_filters_last (int): Number of filters in the last layer of the discriminator.
|
11 |
+
n_layers (int): Number of layers in the discriminator.
|
12 |
+
|
13 |
+
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self, image_channels: int = 3, num_filters_last=64, n_layers=3):
|
17 |
+
super(Discriminator, self).__init__()
|
18 |
+
|
19 |
+
layers = [
|
20 |
+
nn.Conv2d(image_channels, num_filters_last, 4, 2, 1),
|
21 |
+
nn.LeakyReLU(0.2),
|
22 |
+
]
|
23 |
+
num_filters_mult = 1
|
24 |
+
|
25 |
+
for i in range(1, n_layers + 1):
|
26 |
+
num_filters_mult_last = num_filters_mult
|
27 |
+
num_filters_mult = min(2**i, 8)
|
28 |
+
layers += [
|
29 |
+
nn.Conv2d(
|
30 |
+
num_filters_last * num_filters_mult_last,
|
31 |
+
num_filters_last * num_filters_mult,
|
32 |
+
4,
|
33 |
+
2 if i < n_layers else 1,
|
34 |
+
1,
|
35 |
+
bias=False,
|
36 |
+
),
|
37 |
+
nn.BatchNorm2d(num_filters_last * num_filters_mult),
|
38 |
+
nn.LeakyReLU(0.2, True),
|
39 |
+
]
|
40 |
+
|
41 |
+
layers.append(nn.Conv2d(num_filters_last * num_filters_mult, 1, 4, 1, 1))
|
42 |
+
self.model = nn.Sequential(*layers)
|
43 |
+
|
44 |
+
def forward(self, x):
|
45 |
+
return self.model(x)
|
swim/encoder.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from .blocks import DownsampleBlock, GroupNorm, AttentionBlock, ResnetBlock
|
5 |
+
|
6 |
+
|
7 |
+
class SwimEncoder(nn.Module):
|
8 |
+
"""
|
9 |
+
The encoder part of the VQGAN.
|
10 |
+
|
11 |
+
Args:
|
12 |
+
img_channels (int): Number of channels in the input image.
|
13 |
+
image_size (int): Size of the input image, only used in encoder (height or width ).
|
14 |
+
latent_channels (int): Number of channels in the latent vector.
|
15 |
+
intermediate_channels (list): List of channels in the intermediate layers.
|
16 |
+
num_residual_blocks (int): Number of residual blocks b/w each downsample block.
|
17 |
+
dropout (float): Dropout probability for residual blocks.
|
18 |
+
attention_resolution (list): tensor size ( height or width ) at which to add attention blocks
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(
|
22 |
+
self,
|
23 |
+
img_channels: int = 3,
|
24 |
+
image_size: int = 256,
|
25 |
+
latent_channels: int = 256,
|
26 |
+
intermediate_channels: list = [128, 128, 256, 256, 512],
|
27 |
+
num_residual_blocks: int = 2,
|
28 |
+
dropout: float = 0.0,
|
29 |
+
attention_resolution: list = [16],
|
30 |
+
):
|
31 |
+
super().__init__()
|
32 |
+
|
33 |
+
# Inserting first intermediate channel to index 0
|
34 |
+
intermediate_channels.insert(0, intermediate_channels[0])
|
35 |
+
|
36 |
+
# Appends all the layers to this list
|
37 |
+
layers = []
|
38 |
+
|
39 |
+
# Addingt the first conv layer increase input channels to the first intermediate channels
|
40 |
+
layers.append(
|
41 |
+
nn.Conv2d(
|
42 |
+
img_channels,
|
43 |
+
intermediate_channels[0],
|
44 |
+
kernel_size=3,
|
45 |
+
stride=1,
|
46 |
+
padding=1,
|
47 |
+
)
|
48 |
+
)
|
49 |
+
|
50 |
+
# Loop over the intermediate channels except the last one
|
51 |
+
for n in range(len(intermediate_channels) - 1):
|
52 |
+
in_channels = intermediate_channels[n]
|
53 |
+
out_channels = intermediate_channels[n + 1]
|
54 |
+
|
55 |
+
# Adding the residual blocks for each channel
|
56 |
+
for _ in range(num_residual_blocks):
|
57 |
+
layers.append(ResnetBlock(in_channels, out_channels, dropout=dropout))
|
58 |
+
in_channels = out_channels
|
59 |
+
|
60 |
+
# Once we have downsampled the image to the size in attention resolution, we add attention blocks
|
61 |
+
if image_size in attention_resolution:
|
62 |
+
layers.append(AttentionBlock(in_channels))
|
63 |
+
|
64 |
+
# only downsample for the first n-2 layers, and decrease the input size by a factor of 2
|
65 |
+
if n != len(intermediate_channels) - 2:
|
66 |
+
layers.append(DownsampleBlock(intermediate_channels[n + 1]))
|
67 |
+
image_size = image_size // 2 # Downsample by a factor of 2
|
68 |
+
|
69 |
+
in_channels = intermediate_channels[-1]
|
70 |
+
layers.extend(
|
71 |
+
[
|
72 |
+
ResnetBlock(
|
73 |
+
in_channels=in_channels, out_channels=in_channels, dropout=dropout
|
74 |
+
),
|
75 |
+
AttentionBlock(in_channels=in_channels),
|
76 |
+
ResnetBlock(
|
77 |
+
in_channels=in_channels, out_channels=in_channels, dropout=dropout
|
78 |
+
),
|
79 |
+
GroupNorm(in_channels=in_channels),
|
80 |
+
nn.SiLU(),
|
81 |
+
# increase the channels upto the latent vector channels
|
82 |
+
nn.Conv2d(
|
83 |
+
in_channels, latent_channels, kernel_size=3, stride=1, padding=1
|
84 |
+
),
|
85 |
+
]
|
86 |
+
)
|
87 |
+
self.model = nn.Sequential(*layers)
|
88 |
+
|
89 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
90 |
+
return self.model(x)
|
swim/unet.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
from .attention_blocks import SpatialTransformer
|
9 |
+
from .blocks import (
|
10 |
+
DownSample,
|
11 |
+
ResnetBlock,
|
12 |
+
TimestepEmbedSequential,
|
13 |
+
UpSample,
|
14 |
+
Normalization,
|
15 |
+
get_timestep_embedding,
|
16 |
+
)
|
17 |
+
|
18 |
+
|
19 |
+
class UNet(nn.Module):
|
20 |
+
"""
|
21 |
+
## U-Net model
|
22 |
+
"""
|
23 |
+
|
24 |
+
def __init__(
|
25 |
+
self,
|
26 |
+
*,
|
27 |
+
in_channels: int,
|
28 |
+
out_channels: int,
|
29 |
+
channels: int,
|
30 |
+
n_res_blocks: int,
|
31 |
+
attention_levels: List[int],
|
32 |
+
channel_multipliers: List[int],
|
33 |
+
n_heads: int,
|
34 |
+
tf_layers: int = 1,
|
35 |
+
d_cond: int = 768
|
36 |
+
):
|
37 |
+
"""
|
38 |
+
:param in_channels: is the number of channels in the input feature map
|
39 |
+
:param out_channels: is the number of channels in the output feature map
|
40 |
+
:param channels: is the base channel count for the model
|
41 |
+
:param n_res_blocks: number of residual blocks at each level
|
42 |
+
:param attention_levels: are the levels at which attention should be performed
|
43 |
+
:param channel_multipliers: are the multiplicative factors for number of channels for each level
|
44 |
+
:param n_heads: is the number of attention heads in the transformers
|
45 |
+
:param tf_layers: is the number of transformer layers in the transformers
|
46 |
+
:param d_cond: is the size of the conditional embedding in the transformers
|
47 |
+
"""
|
48 |
+
super().__init__()
|
49 |
+
self.channels = channels
|
50 |
+
|
51 |
+
# Number of levels
|
52 |
+
levels = len(channel_multipliers)
|
53 |
+
# Size time embeddings
|
54 |
+
d_time_emb = channels * 4
|
55 |
+
self.time_embed = nn.Sequential(
|
56 |
+
nn.Linear(channels, d_time_emb),
|
57 |
+
nn.SiLU(),
|
58 |
+
nn.Linear(d_time_emb, d_time_emb),
|
59 |
+
)
|
60 |
+
|
61 |
+
# Input half of the U-Net
|
62 |
+
self.input_blocks = nn.ModuleList()
|
63 |
+
# Initial $3 \times 3$ convolution that maps the input to `channels`.
|
64 |
+
# The blocks are wrapped in `TimestepEmbedSequential` module because
|
65 |
+
# different modules have different forward function signatures;
|
66 |
+
# for example, convolution only accepts the feature map and
|
67 |
+
# residual blocks accept the feature map and time embedding.
|
68 |
+
# `TimestepEmbedSequential` calls them accordingly.
|
69 |
+
self.input_blocks.append(
|
70 |
+
TimestepEmbedSequential(nn.Conv2d(in_channels, channels, 3, padding=1))
|
71 |
+
)
|
72 |
+
# Number of channels at each block in the input half of U-Net
|
73 |
+
input_block_channels = [channels]
|
74 |
+
# Number of channels at each level
|
75 |
+
channels_list = [channels * m for m in channel_multipliers]
|
76 |
+
# Prepare levels
|
77 |
+
for i in range(levels):
|
78 |
+
# Add the residual blocks and attentions
|
79 |
+
for _ in range(n_res_blocks):
|
80 |
+
# Residual block maps from previous number of channels to the number of
|
81 |
+
# channels in the current level
|
82 |
+
layers = [
|
83 |
+
ResnetBlock(channels, d_time_emb, out_channels=channels_list[i])
|
84 |
+
]
|
85 |
+
channels = channels_list[i]
|
86 |
+
# Add transformer
|
87 |
+
if i in attention_levels:
|
88 |
+
layers.append(
|
89 |
+
SpatialTransformer(channels, n_heads, tf_layers, d_cond)
|
90 |
+
)
|
91 |
+
# Add them to the input half of the U-Net and keep track of the number of channels of
|
92 |
+
# its output
|
93 |
+
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
94 |
+
input_block_channels.append(channels)
|
95 |
+
# Down sample at all levels except last
|
96 |
+
if i != levels - 1:
|
97 |
+
self.input_blocks.append(TimestepEmbedSequential(DownSample(channels)))
|
98 |
+
input_block_channels.append(channels)
|
99 |
+
|
100 |
+
# The middle of the U-Net
|
101 |
+
self.middle_block = TimestepEmbedSequential(
|
102 |
+
ResnetBlock(channels, d_time_emb),
|
103 |
+
SpatialTransformer(channels, n_heads, tf_layers, d_cond),
|
104 |
+
ResnetBlock(channels, d_time_emb),
|
105 |
+
)
|
106 |
+
|
107 |
+
# Second half of the U-Net
|
108 |
+
self.output_blocks = nn.ModuleList([])
|
109 |
+
# Prepare levels in reverse order
|
110 |
+
for i in reversed(range(levels)):
|
111 |
+
# Add the residual blocks and attentions
|
112 |
+
for j in range(n_res_blocks + 1):
|
113 |
+
# Residual block maps from previous number of channels plus the
|
114 |
+
# skip connections from the input half of U-Net to the number of
|
115 |
+
# channels in the current level.
|
116 |
+
layers = [
|
117 |
+
ResnetBlock(
|
118 |
+
channels + input_block_channels.pop(),
|
119 |
+
d_time_emb,
|
120 |
+
out_channels=channels_list[i],
|
121 |
+
)
|
122 |
+
]
|
123 |
+
channels = channels_list[i]
|
124 |
+
# Add transformer
|
125 |
+
if i in attention_levels:
|
126 |
+
layers.append(
|
127 |
+
SpatialTransformer(channels, n_heads, tf_layers, d_cond)
|
128 |
+
)
|
129 |
+
# Up-sample at every level after last residual block
|
130 |
+
# except the last one.
|
131 |
+
# Note that we are iterating in reverse; i.e. `i == 0` is the last.
|
132 |
+
if i != 0 and j == n_res_blocks:
|
133 |
+
layers.append(UpSample(channels))
|
134 |
+
# Add to the output half of the U-Net
|
135 |
+
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
136 |
+
|
137 |
+
# Final normalization and $3 \times 3$ convolution
|
138 |
+
self.out = nn.Sequential(
|
139 |
+
Normalization(channels),
|
140 |
+
nn.SiLU(),
|
141 |
+
nn.Conv2d(channels, out_channels, 3, padding=1),
|
142 |
+
)
|
143 |
+
|
144 |
+
def forward(self, x: torch.Tensor, timesteps: torch.Tensor, cond: torch.Tensor):
|
145 |
+
"""
|
146 |
+
:param x: is the input feature map of shape `[batch_size, channels, width, height]`
|
147 |
+
:param timesteps: are the time steps of shape `[batch_size]`
|
148 |
+
:param cond: conditioning of shape `[batch_size, n_cond, d_cond]`
|
149 |
+
"""
|
150 |
+
# To store the input half outputs for skip connections
|
151 |
+
x_input_block = []
|
152 |
+
|
153 |
+
# Get time step embeddings
|
154 |
+
t_emb = get_timestep_embedding(timesteps, self.channels * 2)
|
155 |
+
t_emb = self.time_embed(t_emb)
|
156 |
+
|
157 |
+
# Input half of the U-Net
|
158 |
+
for module in self.input_blocks:
|
159 |
+
x = module(x, t_emb, cond)
|
160 |
+
x_input_block.append(x)
|
161 |
+
# Middle of the U-Net
|
162 |
+
x = self.middle_block(x, t_emb, cond)
|
163 |
+
# Output half of the U-Net
|
164 |
+
for module in self.output_blocks:
|
165 |
+
x = torch.cat([x, x_input_block.pop()], dim=1)
|
166 |
+
x = module(x, t_emb, cond)
|
167 |
+
|
168 |
+
# Final normalization and $3 \times 3$ convolution
|
169 |
+
return self.out(x)
|
train.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torchinfo import summary
|
3 |
+
from swim.encoder import SwimEncoder
|
4 |
+
|
5 |
+
encoder = SwimEncoder().to("meta")
|
6 |
+
sample = torch.randn(1, 3, 512, 512).to("meta")
|
7 |
+
|
8 |
+
summary(encoder, input_data=(sample,))
|