qninhdt commited on
Commit
9b66f69
β€’
1 Parent(s): c367a8c
scripts/build_cyclegan_dataset.py CHANGED
@@ -8,7 +8,8 @@ from tqdm import tqdm
8
  @click.option("--swim_dir", type=str, default="datasets/swim_data")
9
  @click.option("--output_dir", type=str, default="datasets/swim_data_cyclegan")
10
  @click.option("--type", type=str, help="fog|rain|snow|night", required=True)
11
- def build_cyclegan_dataset(swim_dir: str, output_dir: str, type: str):
 
12
  # build the dataset with format
13
  # swim_data_cyclegan
14
  # β”œβ”€β”€ trainA
@@ -42,25 +43,52 @@ def build_cyclegan_dataset(swim_dir: str, output_dir: str, type: str):
42
  with open(os.path.join(swim_dir, "val", "labels.json"), "r") as f:
43
  val_labels = json.load(f)
44
 
45
- for label in tqdm(train_labels, desc="train"):
46
- if label["weather"] == type:
47
- os.system(
48
- f"cp {os.path.join(swim_dir, 'train', 'images', label['name'])} {os.path.join(output_dir, 'trainB', label['name'])}"
49
- )
50
- elif label["weather"] == "clear":
51
- os.system(
52
- f"cp {os.path.join(swim_dir, 'train', 'images', label['name'])} {os.path.join(output_dir, 'trainA', label['name'])}"
53
- )
54
 
55
- for label in tqdm(val_labels, desc="val"):
56
- if label["weather"] == type:
57
- os.system(
58
- f"cp {os.path.join(swim_dir, 'val', 'images', label['name'])} {os.path.join(output_dir, 'testB', label['name'])}"
59
- )
60
- elif label["weather"] == "clear":
61
- os.system(
62
- f"cp {os.path.join(swim_dir, 'val', 'images', label['name'])} {os.path.join(output_dir, 'testA', label['name'])}"
63
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
 
66
  if __name__ == "__main__":
 
8
  @click.option("--swim_dir", type=str, default="datasets/swim_data")
9
  @click.option("--output_dir", type=str, default="datasets/swim_data_cyclegan")
10
  @click.option("--type", type=str, help="fog|rain|snow|night", required=True)
11
+ @click.option("--no_night", is_flag=True)
12
+ def build_cyclegan_dataset(swim_dir: str, output_dir: str, type: str, no_night: bool):
13
  # build the dataset with format
14
  # swim_data_cyclegan
15
  # β”œβ”€β”€ trainA
 
43
  with open(os.path.join(swim_dir, "val", "labels.json"), "r") as f:
44
  val_labels = json.load(f)
45
 
46
+ if type != "night":
47
+ for label in tqdm(train_labels, desc="train"):
48
+ if no_night and label["timeofdata"] == "night":
49
+ continue
 
 
 
 
 
50
 
51
+ if label["weather"] == type:
52
+ os.system(
53
+ f"cp {os.path.join(swim_dir, 'train', 'images', label['name'])} {os.path.join(output_dir, 'trainB', label['name'])}"
54
+ )
55
+ elif label["weather"] == "clear":
56
+ os.system(
57
+ f"cp {os.path.join(swim_dir, 'train', 'images', label['name'])} {os.path.join(output_dir, 'trainA', label['name'])}"
58
+ )
59
+
60
+ for label in tqdm(val_labels, desc="val"):
61
+ if no_night and label["timeofdata"] == "night":
62
+ continue
63
+
64
+ if label["weather"] == type:
65
+ os.system(
66
+ f"cp {os.path.join(swim_dir, 'val', 'images', label['name'])} {os.path.join(output_dir, 'testB', label['name'])}"
67
+ )
68
+ elif label["weather"] == "clear":
69
+ os.system(
70
+ f"cp {os.path.join(swim_dir, 'val', 'images', label['name'])} {os.path.join(output_dir, 'testA', label['name'])}"
71
+ )
72
+ else:
73
+ for label in tqdm(train_labels, desc="train"):
74
+ if label["timeofdata"] == "night":
75
+ os.system(
76
+ f"cp {os.path.join(swim_dir, 'train', 'images', label['name'])} {os.path.join(output_dir, 'trainB', label['name'])}"
77
+ )
78
+ elif label["timeofdata"] == "daytime":
79
+ os.system(
80
+ f"cp {os.path.join(swim_dir, 'train', 'images', label['name'])} {os.path.join(output_dir, 'trainA', label['name'])}"
81
+ )
82
+
83
+ for label in tqdm(val_labels, desc="val"):
84
+ if label["timeofdata"] == "night":
85
+ os.system(
86
+ f"cp {os.path.join(swim_dir, 'val', 'images', label['name'])} {os.path.join(output_dir, 'testB', label['name'])}"
87
+ )
88
+ elif label["timeofdata"] == "daytime":
89
+ os.system(
90
+ f"cp {os.path.join(swim_dir, 'val', 'images', label['name'])} {os.path.join(output_dir, 'testA', label['name'])}"
91
+ )
92
 
93
 
94
  if __name__ == "__main__":
swim/__init__.py ADDED
File without changes
swim/attention_blocks.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn
6
+
7
+
8
+ class SpatialTransformer(nn.Module):
9
+ """
10
+ ## Spatial Transformer
11
+ """
12
+
13
+ def __init__(self, channels: int, n_heads: int, n_layers: int, d_cond: int):
14
+ """
15
+ :param channels: is the number of channels in the feature map
16
+ :param n_heads: is the number of attention heads
17
+ :param n_layers: is the number of transformer layers
18
+ :param d_cond: is the size of the conditional embedding
19
+ """
20
+ super().__init__()
21
+ # Initial group normalization
22
+ self.norm = torch.nn.GroupNorm(
23
+ num_groups=32, num_channels=channels, eps=1e-6, affine=True
24
+ )
25
+ # Initial $1 \times 1$ convolution
26
+ self.proj_in = nn.Conv2d(channels, channels, kernel_size=1, stride=1, padding=0)
27
+
28
+ # Transformer layers
29
+ self.transformer_blocks = nn.ModuleList(
30
+ [
31
+ BasicTransformerBlock(
32
+ channels, n_heads, channels // n_heads, d_cond=d_cond
33
+ )
34
+ for _ in range(n_layers)
35
+ ]
36
+ )
37
+
38
+ # Final $1 \times 1$ convolution
39
+ self.proj_out = nn.Conv2d(
40
+ channels, channels, kernel_size=1, stride=1, padding=0
41
+ )
42
+
43
+ def forward(self, x: torch.Tensor, cond: torch.Tensor):
44
+ """
45
+ :param x: is the feature map of shape `[batch_size, channels, height, width]`
46
+ :param cond: is the conditional embeddings of shape `[batch_size, n_cond, d_cond]`
47
+ """
48
+ # Get shape `[batch_size, channels, height, width]`
49
+ b, c, h, w = x.shape
50
+ # For residual connection
51
+ x_in = x
52
+ # Normalize
53
+ x = self.norm(x)
54
+ # Initial $1 \times 1$ convolution
55
+ x = self.proj_in(x)
56
+ # Transpose and reshape from `[batch_size, channels, height, width]`
57
+ # to `[batch_size, height * width, channels]`
58
+ x = x.permute(0, 2, 3, 1).view(b, h * w, c)
59
+ # Apply the transformer layers
60
+ for block in self.transformer_blocks:
61
+ x = block(x, cond)
62
+ # Reshape and transpose from `[batch_size, height * width, channels]`
63
+ # to `[batch_size, channels, height, width]`
64
+ x = x.view(b, h, w, c).permute(0, 3, 1, 2)
65
+ # Final $1 \times 1$ convolution
66
+ x = self.proj_out(x)
67
+ # Add residual
68
+ return x + x_in
69
+
70
+
71
+ class BasicTransformerBlock(nn.Module):
72
+ """
73
+ ### Transformer Layer
74
+ """
75
+
76
+ def __init__(self, d_model: int, n_heads: int, d_head: int, d_cond: int):
77
+ """
78
+ :param d_model: is the input embedding size
79
+ :param n_heads: is the number of attention heads
80
+ :param d_head: is the size of a attention head
81
+ :param d_cond: is the size of the conditional embeddings
82
+ """
83
+ super().__init__()
84
+ # Self-attention layer and pre-norm layer
85
+ self.attn1 = CrossAttention(d_model, d_model, n_heads, d_head)
86
+ self.norm1 = nn.LayerNorm(d_model)
87
+ # Cross attention layer and pre-norm layer
88
+ self.attn2 = CrossAttention(d_model, d_cond, n_heads, d_head)
89
+ self.norm2 = nn.LayerNorm(d_model)
90
+ # Feed-forward network and pre-norm layer
91
+ self.ff = FeedForward(d_model)
92
+ self.norm3 = nn.LayerNorm(d_model)
93
+
94
+ def forward(self, x: torch.Tensor, cond: torch.Tensor):
95
+ """
96
+ :param x: are the input embeddings of shape `[batch_size, height * width, d_model]`
97
+ :param cond: is the conditional embeddings of shape `[batch_size, n_cond, d_cond]`
98
+ """
99
+ # Self attention
100
+ x = self.attn1(self.norm1(x)) + x
101
+ # Cross-attention with conditioning
102
+ x = self.attn2(self.norm2(x), cond=cond) + x
103
+ # Feed-forward network
104
+ x = self.ff(self.norm3(x)) + x
105
+ #
106
+ return x
107
+
108
+
109
+ class CrossAttention(nn.Module):
110
+ """
111
+ ### Cross Attention Layer
112
+
113
+ This falls-back to self-attention when conditional embeddings are not specified.
114
+ """
115
+
116
+ use_flash_attention: bool = False
117
+
118
+ def __init__(
119
+ self,
120
+ d_model: int,
121
+ d_cond: int,
122
+ n_heads: int,
123
+ d_head: int,
124
+ is_inplace: bool = True,
125
+ ):
126
+ """
127
+ :param d_model: is the input embedding size
128
+ :param n_heads: is the number of attention heads
129
+ :param d_head: is the size of a attention head
130
+ :param d_cond: is the size of the conditional embeddings
131
+ :param is_inplace: specifies whether to perform the attention softmax computation inplace to
132
+ save memory
133
+ """
134
+ super().__init__()
135
+
136
+ self.is_inplace = is_inplace
137
+ self.n_heads = n_heads
138
+ self.d_head = d_head
139
+
140
+ # Attention scaling factor
141
+ self.scale = d_head**-0.5
142
+
143
+ # Query, key and value mappings
144
+ d_attn = d_head * n_heads
145
+ self.to_q = nn.Linear(d_model, d_attn, bias=False)
146
+ self.to_k = nn.Linear(d_cond, d_attn, bias=False)
147
+ self.to_v = nn.Linear(d_cond, d_attn, bias=False)
148
+
149
+ # Final linear layer
150
+ self.to_out = nn.Sequential(nn.Linear(d_attn, d_model))
151
+
152
+ # Setup [flash attention](https://github.com/HazyResearch/flash-attention).
153
+ # Flash attention is only used if it's installed
154
+ # and `CrossAttention.use_flash_attention` is set to `True`.
155
+ # try:
156
+ # # You can install flash attention by cloning their Github repo,
157
+ # # [https://github.com/HazyResearch/flash-attention](https://github.com/HazyResearch/flash-attention)
158
+ # # and then running `python setup.py install`
159
+ # from flash_attn.flash_attention import FlashAttention
160
+
161
+ # self.flash = FlashAttention()
162
+ # # Set the scale for scaled dot-product attention.
163
+ # self.flash.softmax_scale = self.scale
164
+ # # Set to `None` if it's not installed
165
+ # except ImportError:
166
+ # self.flash = None
167
+
168
+ def forward(self, x: torch.Tensor, cond: Optional[torch.Tensor] = None):
169
+ """
170
+ :param x: are the input embeddings of shape `[batch_size, height * width, d_model]`
171
+ :param cond: is the conditional embeddings of shape `[batch_size, n_cond, d_cond]`
172
+ """
173
+
174
+ # If `cond` is `None` we perform self attention
175
+ has_cond = cond is not None
176
+ if not has_cond:
177
+ cond = x
178
+
179
+ # Get query, key and value vectors
180
+ q = self.to_q(x)
181
+ k = self.to_k(cond)
182
+ v = self.to_v(cond)
183
+
184
+ # Use flash attention if it's available and the head size is less than or equal to `128`
185
+ if (
186
+ CrossAttention.use_flash_attention
187
+ and self.flash is not None
188
+ and not has_cond
189
+ and self.d_head <= 128
190
+ ):
191
+ return self.flash_attention(q, k, v)
192
+ # Otherwise, fallback to normal attention
193
+ else:
194
+ return self.normal_attention(q, k, v)
195
+
196
+ def flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
197
+ """
198
+ #### Flash Attention
199
+
200
+ :param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
201
+ :param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
202
+ :param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
203
+ """
204
+
205
+ # Get batch size and number of elements along sequence axis (`width * height`)
206
+ batch_size, seq_len, _ = q.shape
207
+
208
+ # Stack `q`, `k`, `v` vectors for flash attention, to get a single tensor of
209
+ # shape `[batch_size, seq_len, 3, n_heads * d_head]`
210
+ qkv = torch.stack((q, k, v), dim=2)
211
+ # Split the heads
212
+ qkv = qkv.view(batch_size, seq_len, 3, self.n_heads, self.d_head)
213
+
214
+ # Flash attention works for head sizes `32`, `64` and `128`, so we have to pad the heads to
215
+ # fit this size.
216
+ if self.d_head <= 32:
217
+ pad = 32 - self.d_head
218
+ elif self.d_head <= 64:
219
+ pad = 64 - self.d_head
220
+ elif self.d_head <= 128:
221
+ pad = 128 - self.d_head
222
+ else:
223
+ raise ValueError(f"Head size ${self.d_head} too large for Flash Attention")
224
+
225
+ # Pad the heads
226
+ if pad:
227
+ qkv = torch.cat(
228
+ (qkv, qkv.new_zeros(batch_size, seq_len, 3, self.n_heads, pad)), dim=-1
229
+ )
230
+
231
+ # Compute attention
232
+ # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$
233
+ # This gives a tensor of shape `[batch_size, seq_len, n_heads, d_padded]`
234
+ out, _ = self.flash(qkv)
235
+ # Truncate the extra head size
236
+ out = out[:, :, :, : self.d_head]
237
+ # Reshape to `[batch_size, seq_len, n_heads * d_head]`
238
+ out = out.reshape(batch_size, seq_len, self.n_heads * self.d_head)
239
+
240
+ # Map to `[batch_size, height * width, d_model]` with a linear layer
241
+ return self.to_out(out)
242
+
243
+ def normal_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
244
+ """
245
+ #### Normal Attention
246
+
247
+ :param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
248
+ :param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
249
+ :param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
250
+ """
251
+
252
+ # Split them to heads of shape `[batch_size, seq_len, n_heads, d_head]`
253
+ q = q.view(*q.shape[:2], self.n_heads, -1)
254
+ k = k.view(*k.shape[:2], self.n_heads, -1)
255
+ v = v.view(*v.shape[:2], self.n_heads, -1)
256
+
257
+ # Calculate attention $\frac{Q K^\top}{\sqrt{d_{key}}}$
258
+ attn = torch.einsum("bihd,bjhd->bhij", q, k) * self.scale
259
+
260
+ # Compute softmax
261
+ # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)$$
262
+ if self.is_inplace:
263
+ half = attn.shape[0] // 2
264
+ attn[half:] = attn[half:].softmax(dim=-1)
265
+ attn[:half] = attn[:half].softmax(dim=-1)
266
+ else:
267
+ attn = attn.softmax(dim=-1)
268
+
269
+ # Compute attention output
270
+ # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$
271
+ out = torch.einsum("bhij,bjhd->bihd", attn, v)
272
+ # Reshape to `[batch_size, height * width, n_heads * d_head]`
273
+ out = out.reshape(*out.shape[:2], -1)
274
+ # Map to `[batch_size, height * width, d_model]` with a linear layer
275
+ return self.to_out(out)
276
+
277
+
278
+ class FeedForward(nn.Module):
279
+ """
280
+ ### Feed-Forward Network
281
+ """
282
+
283
+ def __init__(self, d_model: int, d_mult: int = 4):
284
+ """
285
+ :param d_model: is the input embedding size
286
+ :param d_mult: is multiplicative factor for the hidden layer size
287
+ """
288
+ super().__init__()
289
+ self.net = nn.Sequential(
290
+ GeGLU(d_model, d_model * d_mult),
291
+ nn.Dropout(0.0),
292
+ nn.Linear(d_model * d_mult, d_model),
293
+ )
294
+
295
+ def forward(self, x: torch.Tensor):
296
+ return self.net(x)
297
+
298
+
299
+ class GeGLU(nn.Module):
300
+ """
301
+ ### GeGLU Activation
302
+
303
+ $$\text{GeGLU}(x) = (xW + b) * \text{GELU}(xV + c)$$
304
+ """
305
+
306
+ def __init__(self, d_in: int, d_out: int):
307
+ super().__init__()
308
+ # Combined linear projections $xW + b$ and $xV + c$
309
+ self.proj = nn.Linear(d_in, d_out * 2)
310
+
311
+ def forward(self, x: torch.Tensor):
312
+ # Get $xW + b$ and $xV + c$
313
+ x, gate = self.proj(x).chunk(2, dim=-1)
314
+ # $\text{GeGLU}(x) = (xW + b) * \text{GELU}(xV + c)$
315
+ return x * F.gelu(gate)
swim/autoencoder.py ADDED
File without changes
swim/blocks.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+
3
+ import math
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+
8
+
9
+ def get_timestep_embedding(
10
+ timesteps: torch.Tensor, emb_dim: int, max_period: int = 10000
11
+ ) -> torch.Tensor:
12
+ half_dim = emb_dim // 2
13
+
14
+ emb = math.log(max_period) / (half_dim - 1)
15
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
16
+ emb = emb.to(device=timesteps.device)
17
+
18
+ emb = timesteps.float()[:, None] * emb[None, :]
19
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
20
+
21
+ if emb_dim % 2 == 1:
22
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
23
+
24
+ return emb
25
+
26
+
27
+ class GroupNorm(nn.Module):
28
+ def __init__(self, in_channels: int) -> None:
29
+ super().__init__()
30
+
31
+ self.group_norm = nn.GroupNorm(
32
+ num_groups=32, num_channels=in_channels, eps=1e-06, affine=True
33
+ )
34
+
35
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
36
+ return self.group_norm(x)
37
+
38
+
39
+ class UpsampleBlock(nn.Module):
40
+ def __init__(self, channels: int):
41
+ super().__init__()
42
+ self.conv = nn.Conv2d(channels, channels, 3, padding=1)
43
+
44
+ def forward(self, x: torch.Tensor):
45
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
46
+ return self.conv(x)
47
+
48
+
49
+ class DownsampleBlock(nn.Module):
50
+ def __init__(self, channels: int):
51
+ super().__init__()
52
+ self.op = nn.Conv2d(channels, channels, 3, stride=2, padding=1)
53
+
54
+ def forward(self, x: torch.Tensor):
55
+ return self.op(x)
56
+
57
+
58
+ class TimestepBlock(nn.Module):
59
+ @abstractmethod
60
+ def forward(self, x: torch.Tensor, t_emb: torch.Tensor) -> torch.Tensor:
61
+ pass
62
+
63
+
64
+ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
65
+ def forward(self, x: torch.Tensor, t_emb: torch.Tensor) -> torch.Tensor:
66
+ for layer in self:
67
+ if isinstance(layer, TimestepBlock):
68
+ x = layer(x, t_emb)
69
+ else:
70
+ x = layer(x)
71
+ return x
72
+
73
+
74
+ class ResnetBlock(nn.Module):
75
+
76
+ def __init__(
77
+ self,
78
+ in_channels: int,
79
+ out_channels: int = None,
80
+ t_emb_dim: int = None,
81
+ dropout: float = 0.0,
82
+ ):
83
+ super().__init__()
84
+
85
+ if out_channels is None:
86
+ out_channels = in_channels
87
+
88
+ self.input_layers = nn.Sequential(
89
+ GroupNorm(in_channels),
90
+ nn.SiLU(),
91
+ nn.Conv2d(in_channels, out_channels, 3, padding=1),
92
+ )
93
+
94
+ if t_emb_dim is not None:
95
+ self.t_emb_layers = nn.Sequential(
96
+ nn.SiLU(),
97
+ nn.Linear(t_emb_dim, out_channels),
98
+ )
99
+ else:
100
+ self.t_emb_layers = None
101
+
102
+ self.output_layers = nn.Sequential(
103
+ GroupNorm(out_channels),
104
+ nn.SiLU(),
105
+ nn.Dropout(dropout),
106
+ nn.Conv2d(out_channels, out_channels, 3, padding=1),
107
+ )
108
+
109
+ if in_channels != out_channels:
110
+ self.skip = nn.Conv2d(in_channels, out_channels, 1)
111
+ else:
112
+ self.skip = nn.Identity()
113
+
114
+ def forward(self, x: torch.Tensor, t: torch.Tensor = None) -> torch.Tensor:
115
+ assert t is not None or self.t_emb_layers is None
116
+
117
+ h = self.input_layers(x)
118
+
119
+ if self.t_emb_layers is not None:
120
+ t_emb = self.t_emb_layers(t)
121
+ h = h + t_emb[:, :, None, None]
122
+
123
+ h = self.output_layers(h)
124
+
125
+ h = h + self.skip(x)
126
+
127
+ return h
128
+
129
+
130
+ class AttentionBlock(nn.Module):
131
+ """Attention mechanism similar to transformers but for CNNs, paper https://arxiv.org/abs/1805.08318
132
+
133
+ Args:
134
+ in_channels (int): Number of channels in the input tensor.
135
+ """
136
+
137
+ def __init__(self, in_channels: int) -> None:
138
+ super().__init__()
139
+
140
+ self.in_channels = in_channels
141
+
142
+ # normalization layer
143
+ self.norm = GroupNorm(in_channels)
144
+
145
+ # query, key and value layers
146
+ self.q = nn.Conv2d(in_channels, in_channels, 1, 1, 0)
147
+ self.k = nn.Conv2d(in_channels, in_channels, 1, 1, 0)
148
+ self.v = nn.Conv2d(in_channels, in_channels, 1, 1, 0)
149
+
150
+ self.project_out = nn.Conv2d(in_channels, in_channels, 1, 1, 0)
151
+
152
+ self.softmax = nn.Softmax(dim=2)
153
+
154
+ def forward(self, x):
155
+
156
+ batch, _, height, width = x.size()
157
+
158
+ x = self.norm(x)
159
+
160
+ # query, key and value layers
161
+ q = self.q(x)
162
+ k = self.k(x)
163
+ v = self.v(x)
164
+
165
+ # resizing the output from 4D to 3D to generate attention map
166
+ q = q.reshape(batch, self.in_channels, height * width)
167
+ k = k.reshape(batch, self.in_channels, height * width)
168
+ v = v.reshape(batch, self.in_channels, height * width)
169
+
170
+ # transpose the query tensor for dot product
171
+ q = q.permute(0, 2, 1)
172
+
173
+ # main attention formula
174
+ scores = torch.bmm(q, k) * (self.in_channels**-0.5)
175
+ weights = self.softmax(scores)
176
+ weights = weights.permute(0, 2, 1)
177
+
178
+ attention = torch.bmm(v, weights)
179
+
180
+ # resizing the output from 3D to 4D to match the input
181
+ attention = attention.reshape(batch, self.in_channels, height, width)
182
+ attention = self.project_out(attention)
183
+
184
+ # adding the identity to the output
185
+ return x + attention
swim/codeblock.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class CodeBook(nn.Module):
6
+ def __init__(
7
+ self, num_codebook_vectors: int = 1024, latent_dim: int = 256, beta: int = 0.25
8
+ ):
9
+ super().__init__()
10
+
11
+ self.num_codebook_vectors = num_codebook_vectors
12
+ self.latent_dim = latent_dim
13
+ self.beta = beta
14
+
15
+ # creating the codebook, nn.Embedding here is simply a 2D array mainly for storing our embeddings, it's also learnable
16
+ self.codebook = nn.Embedding(num_codebook_vectors, latent_dim)
17
+
18
+ # Initializing the weights in codebook in uniform distribution
19
+ self.codebook.weight.data.uniform_(
20
+ -1 / num_codebook_vectors, 1 / num_codebook_vectors
21
+ )
22
+
23
+ def forward(self, z: torch.Tensor) -> torch.Tensor:
24
+ # Channel to last dimension and copying the tensor to store it in a contiguous ( in a sequence ) way
25
+ z = z.permute(0, 2, 3, 1).contiguous()
26
+
27
+ z_flattened = z.view(
28
+ -1, self.latent_dim
29
+ ) # b*h*w * latent_dim, will look similar to codebook in fig 2 of the paper
30
+
31
+ # calculating the distance between the z to the vectors in flattened codebook, from eq. 2
32
+ # (a - b)^2 = a^2 + b^2 - 2ab
33
+ distance = (
34
+ torch.sum(
35
+ z_flattened**2, dim=1, keepdim=True
36
+ ) # keepdim = True to keep the same original shape after the sum
37
+ + torch.sum(self.codebook.weight**2, dim=1)
38
+ - 2
39
+ * torch.matmul(
40
+ z_flattened, self.codebook.weight.t()
41
+ ) # 2*dot(z, codebook.T)
42
+ )
43
+
44
+ # getting indices of vectors with minimum distance from the codebook
45
+ min_distance_indices = torch.argmin(distance, dim=1)
46
+
47
+ # getting the corresponding vector from the codebook
48
+ z_q = self.codebook(min_distance_indices).view(z.shape)
49
+
50
+ """
51
+ this represent the equation 4 from the paper ( except the reconstruction loss ) . Thia loss will then be added
52
+ to GAN loss to create the final loss function for VQGAN, eq. 6 in the paper.
53
+
54
+
55
+ Note : In the first para of A. Changlog section of the paper,
56
+ they found a bug which resulted in beta equal to 1. here https://github.com/CompVis/taming-transformers/issues/57
57
+ just a note :)
58
+ """
59
+ loss = torch.mean(
60
+ (z_q.detach() - z) ** 2
61
+ # detach() to avoid calculating gradient while backpropagating
62
+ + self.beta
63
+ * torch.mean(
64
+ (z_q - z.detach()) ** 2
65
+ ) # commitment loss, detach() to avoid calculating gradient while backpropagating
66
+ )
67
+
68
+ # Not sure why we need this, but it's in the original implementation and mentions for "preserving gradients"
69
+ z_q = z + (z_q - z).detach()
70
+
71
+ # reshapring to the original shape
72
+ z_q = z_q.permute(0, 3, 1, 2)
73
+
74
+ return z_q, min_distance_indices, loss
swim/discriminator.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+
4
+ class Discriminator(nn.Module):
5
+ """PatchGAN Discriminator
6
+
7
+
8
+ Args:
9
+ image_channels (int): Number of channels in the input image.
10
+ num_filters_last (int): Number of filters in the last layer of the discriminator.
11
+ n_layers (int): Number of layers in the discriminator.
12
+
13
+
14
+ """
15
+
16
+ def __init__(self, image_channels: int = 3, num_filters_last=64, n_layers=3):
17
+ super(Discriminator, self).__init__()
18
+
19
+ layers = [
20
+ nn.Conv2d(image_channels, num_filters_last, 4, 2, 1),
21
+ nn.LeakyReLU(0.2),
22
+ ]
23
+ num_filters_mult = 1
24
+
25
+ for i in range(1, n_layers + 1):
26
+ num_filters_mult_last = num_filters_mult
27
+ num_filters_mult = min(2**i, 8)
28
+ layers += [
29
+ nn.Conv2d(
30
+ num_filters_last * num_filters_mult_last,
31
+ num_filters_last * num_filters_mult,
32
+ 4,
33
+ 2 if i < n_layers else 1,
34
+ 1,
35
+ bias=False,
36
+ ),
37
+ nn.BatchNorm2d(num_filters_last * num_filters_mult),
38
+ nn.LeakyReLU(0.2, True),
39
+ ]
40
+
41
+ layers.append(nn.Conv2d(num_filters_last * num_filters_mult, 1, 4, 1, 1))
42
+ self.model = nn.Sequential(*layers)
43
+
44
+ def forward(self, x):
45
+ return self.model(x)
swim/encoder.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from .blocks import DownsampleBlock, GroupNorm, AttentionBlock, ResnetBlock
5
+
6
+
7
+ class SwimEncoder(nn.Module):
8
+ """
9
+ The encoder part of the VQGAN.
10
+
11
+ Args:
12
+ img_channels (int): Number of channels in the input image.
13
+ image_size (int): Size of the input image, only used in encoder (height or width ).
14
+ latent_channels (int): Number of channels in the latent vector.
15
+ intermediate_channels (list): List of channels in the intermediate layers.
16
+ num_residual_blocks (int): Number of residual blocks b/w each downsample block.
17
+ dropout (float): Dropout probability for residual blocks.
18
+ attention_resolution (list): tensor size ( height or width ) at which to add attention blocks
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ img_channels: int = 3,
24
+ image_size: int = 256,
25
+ latent_channels: int = 256,
26
+ intermediate_channels: list = [128, 128, 256, 256, 512],
27
+ num_residual_blocks: int = 2,
28
+ dropout: float = 0.0,
29
+ attention_resolution: list = [16],
30
+ ):
31
+ super().__init__()
32
+
33
+ # Inserting first intermediate channel to index 0
34
+ intermediate_channels.insert(0, intermediate_channels[0])
35
+
36
+ # Appends all the layers to this list
37
+ layers = []
38
+
39
+ # Addingt the first conv layer increase input channels to the first intermediate channels
40
+ layers.append(
41
+ nn.Conv2d(
42
+ img_channels,
43
+ intermediate_channels[0],
44
+ kernel_size=3,
45
+ stride=1,
46
+ padding=1,
47
+ )
48
+ )
49
+
50
+ # Loop over the intermediate channels except the last one
51
+ for n in range(len(intermediate_channels) - 1):
52
+ in_channels = intermediate_channels[n]
53
+ out_channels = intermediate_channels[n + 1]
54
+
55
+ # Adding the residual blocks for each channel
56
+ for _ in range(num_residual_blocks):
57
+ layers.append(ResnetBlock(in_channels, out_channels, dropout=dropout))
58
+ in_channels = out_channels
59
+
60
+ # Once we have downsampled the image to the size in attention resolution, we add attention blocks
61
+ if image_size in attention_resolution:
62
+ layers.append(AttentionBlock(in_channels))
63
+
64
+ # only downsample for the first n-2 layers, and decrease the input size by a factor of 2
65
+ if n != len(intermediate_channels) - 2:
66
+ layers.append(DownsampleBlock(intermediate_channels[n + 1]))
67
+ image_size = image_size // 2 # Downsample by a factor of 2
68
+
69
+ in_channels = intermediate_channels[-1]
70
+ layers.extend(
71
+ [
72
+ ResnetBlock(
73
+ in_channels=in_channels, out_channels=in_channels, dropout=dropout
74
+ ),
75
+ AttentionBlock(in_channels=in_channels),
76
+ ResnetBlock(
77
+ in_channels=in_channels, out_channels=in_channels, dropout=dropout
78
+ ),
79
+ GroupNorm(in_channels=in_channels),
80
+ nn.SiLU(),
81
+ # increase the channels upto the latent vector channels
82
+ nn.Conv2d(
83
+ in_channels, latent_channels, kernel_size=3, stride=1, padding=1
84
+ ),
85
+ ]
86
+ )
87
+ self.model = nn.Sequential(*layers)
88
+
89
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
90
+ return self.model(x)
swim/unet.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from .attention_blocks import SpatialTransformer
9
+ from .blocks import (
10
+ DownSample,
11
+ ResnetBlock,
12
+ TimestepEmbedSequential,
13
+ UpSample,
14
+ Normalization,
15
+ get_timestep_embedding,
16
+ )
17
+
18
+
19
+ class UNet(nn.Module):
20
+ """
21
+ ## U-Net model
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ *,
27
+ in_channels: int,
28
+ out_channels: int,
29
+ channels: int,
30
+ n_res_blocks: int,
31
+ attention_levels: List[int],
32
+ channel_multipliers: List[int],
33
+ n_heads: int,
34
+ tf_layers: int = 1,
35
+ d_cond: int = 768
36
+ ):
37
+ """
38
+ :param in_channels: is the number of channels in the input feature map
39
+ :param out_channels: is the number of channels in the output feature map
40
+ :param channels: is the base channel count for the model
41
+ :param n_res_blocks: number of residual blocks at each level
42
+ :param attention_levels: are the levels at which attention should be performed
43
+ :param channel_multipliers: are the multiplicative factors for number of channels for each level
44
+ :param n_heads: is the number of attention heads in the transformers
45
+ :param tf_layers: is the number of transformer layers in the transformers
46
+ :param d_cond: is the size of the conditional embedding in the transformers
47
+ """
48
+ super().__init__()
49
+ self.channels = channels
50
+
51
+ # Number of levels
52
+ levels = len(channel_multipliers)
53
+ # Size time embeddings
54
+ d_time_emb = channels * 4
55
+ self.time_embed = nn.Sequential(
56
+ nn.Linear(channels, d_time_emb),
57
+ nn.SiLU(),
58
+ nn.Linear(d_time_emb, d_time_emb),
59
+ )
60
+
61
+ # Input half of the U-Net
62
+ self.input_blocks = nn.ModuleList()
63
+ # Initial $3 \times 3$ convolution that maps the input to `channels`.
64
+ # The blocks are wrapped in `TimestepEmbedSequential` module because
65
+ # different modules have different forward function signatures;
66
+ # for example, convolution only accepts the feature map and
67
+ # residual blocks accept the feature map and time embedding.
68
+ # `TimestepEmbedSequential` calls them accordingly.
69
+ self.input_blocks.append(
70
+ TimestepEmbedSequential(nn.Conv2d(in_channels, channels, 3, padding=1))
71
+ )
72
+ # Number of channels at each block in the input half of U-Net
73
+ input_block_channels = [channels]
74
+ # Number of channels at each level
75
+ channels_list = [channels * m for m in channel_multipliers]
76
+ # Prepare levels
77
+ for i in range(levels):
78
+ # Add the residual blocks and attentions
79
+ for _ in range(n_res_blocks):
80
+ # Residual block maps from previous number of channels to the number of
81
+ # channels in the current level
82
+ layers = [
83
+ ResnetBlock(channels, d_time_emb, out_channels=channels_list[i])
84
+ ]
85
+ channels = channels_list[i]
86
+ # Add transformer
87
+ if i in attention_levels:
88
+ layers.append(
89
+ SpatialTransformer(channels, n_heads, tf_layers, d_cond)
90
+ )
91
+ # Add them to the input half of the U-Net and keep track of the number of channels of
92
+ # its output
93
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
94
+ input_block_channels.append(channels)
95
+ # Down sample at all levels except last
96
+ if i != levels - 1:
97
+ self.input_blocks.append(TimestepEmbedSequential(DownSample(channels)))
98
+ input_block_channels.append(channels)
99
+
100
+ # The middle of the U-Net
101
+ self.middle_block = TimestepEmbedSequential(
102
+ ResnetBlock(channels, d_time_emb),
103
+ SpatialTransformer(channels, n_heads, tf_layers, d_cond),
104
+ ResnetBlock(channels, d_time_emb),
105
+ )
106
+
107
+ # Second half of the U-Net
108
+ self.output_blocks = nn.ModuleList([])
109
+ # Prepare levels in reverse order
110
+ for i in reversed(range(levels)):
111
+ # Add the residual blocks and attentions
112
+ for j in range(n_res_blocks + 1):
113
+ # Residual block maps from previous number of channels plus the
114
+ # skip connections from the input half of U-Net to the number of
115
+ # channels in the current level.
116
+ layers = [
117
+ ResnetBlock(
118
+ channels + input_block_channels.pop(),
119
+ d_time_emb,
120
+ out_channels=channels_list[i],
121
+ )
122
+ ]
123
+ channels = channels_list[i]
124
+ # Add transformer
125
+ if i in attention_levels:
126
+ layers.append(
127
+ SpatialTransformer(channels, n_heads, tf_layers, d_cond)
128
+ )
129
+ # Up-sample at every level after last residual block
130
+ # except the last one.
131
+ # Note that we are iterating in reverse; i.e. `i == 0` is the last.
132
+ if i != 0 and j == n_res_blocks:
133
+ layers.append(UpSample(channels))
134
+ # Add to the output half of the U-Net
135
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
136
+
137
+ # Final normalization and $3 \times 3$ convolution
138
+ self.out = nn.Sequential(
139
+ Normalization(channels),
140
+ nn.SiLU(),
141
+ nn.Conv2d(channels, out_channels, 3, padding=1),
142
+ )
143
+
144
+ def forward(self, x: torch.Tensor, timesteps: torch.Tensor, cond: torch.Tensor):
145
+ """
146
+ :param x: is the input feature map of shape `[batch_size, channels, width, height]`
147
+ :param timesteps: are the time steps of shape `[batch_size]`
148
+ :param cond: conditioning of shape `[batch_size, n_cond, d_cond]`
149
+ """
150
+ # To store the input half outputs for skip connections
151
+ x_input_block = []
152
+
153
+ # Get time step embeddings
154
+ t_emb = get_timestep_embedding(timesteps, self.channels * 2)
155
+ t_emb = self.time_embed(t_emb)
156
+
157
+ # Input half of the U-Net
158
+ for module in self.input_blocks:
159
+ x = module(x, t_emb, cond)
160
+ x_input_block.append(x)
161
+ # Middle of the U-Net
162
+ x = self.middle_block(x, t_emb, cond)
163
+ # Output half of the U-Net
164
+ for module in self.output_blocks:
165
+ x = torch.cat([x, x_input_block.pop()], dim=1)
166
+ x = module(x, t_emb, cond)
167
+
168
+ # Final normalization and $3 \times 3$ convolution
169
+ return self.out(x)
train.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchinfo import summary
3
+ from swim.encoder import SwimEncoder
4
+
5
+ encoder = SwimEncoder().to("meta")
6
+ sample = torch.randn(1, 3, 512, 512).to("meta")
7
+
8
+ summary(encoder, input_data=(sample,))