qninhdt commited on
Commit
798fdd3
1 Parent(s): be91bac
scripts/build_cyclegan_dataset.py CHANGED
@@ -71,6 +71,9 @@ def build_cyclegan_dataset(swim_dir: str, output_dir: str, type: str, no_night:
71
  )
72
  else:
73
  for label in tqdm(train_labels, desc="train"):
 
 
 
74
  if label["timeofday"] == "night":
75
  os.system(
76
  f"cp {os.path.join(swim_dir, 'train', 'images', label['name'])} {os.path.join(output_dir, 'trainB', label['name'])}"
@@ -81,6 +84,9 @@ def build_cyclegan_dataset(swim_dir: str, output_dir: str, type: str, no_night:
81
  )
82
 
83
  for label in tqdm(val_labels, desc="val"):
 
 
 
84
  if label["timeofday"] == "night":
85
  os.system(
86
  f"cp {os.path.join(swim_dir, 'val', 'images', label['name'])} {os.path.join(output_dir, 'testB', label['name'])}"
 
71
  )
72
  else:
73
  for label in tqdm(train_labels, desc="train"):
74
+ if label["weather"] != "clear":
75
+ continue
76
+
77
  if label["timeofday"] == "night":
78
  os.system(
79
  f"cp {os.path.join(swim_dir, 'train', 'images', label['name'])} {os.path.join(output_dir, 'trainB', label['name'])}"
 
84
  )
85
 
86
  for label in tqdm(val_labels, desc="val"):
87
+ if label["weather"] != "clear":
88
+ continue
89
+
90
  if label["timeofday"] == "night":
91
  os.system(
92
  f"cp {os.path.join(swim_dir, 'val', 'images', label['name'])} {os.path.join(output_dir, 'testB', label['name'])}"
swim/autoencoder.py CHANGED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn
6
+
7
+ from .blocks import (
8
+ ResnetBlock,
9
+ AttentionBlock,
10
+ GroupNorm,
11
+ UpSampleBlock,
12
+ DownSampleBlock,
13
+ )
14
+
15
+
16
+ class Autoencoder(nn.Module):
17
+
18
+ def __init__(
19
+ self,
20
+ channels: int,
21
+ channel_multipliers: List[int],
22
+ n_resnet_blocks: int,
23
+ in_channels: int,
24
+ z_channels: int,
25
+ emb_channels: int,
26
+ ):
27
+ super().__init__()
28
+ self.encoder = Encoder(
29
+ channels=channels,
30
+ channel_multipliers=channel_multipliers,
31
+ n_resnet_blocks=n_resnet_blocks,
32
+ in_channels=in_channels,
33
+ z_channels=z_channels,
34
+ )
35
+ self.decoder = Decoder(
36
+ channels=channels,
37
+ channel_multipliers=channel_multipliers,
38
+ n_resnet_blocks=n_resnet_blocks,
39
+ out_channels=in_channels,
40
+ z_channels=z_channels,
41
+ )
42
+ # Convolution to map from embedding space to
43
+ # quantized embedding space moments (mean and log variance)
44
+ self.quant_conv = nn.Conv2d(2 * z_channels, 2 * emb_channels, 1)
45
+ # Convolution to map from quantized embedding space back to
46
+ # embedding space
47
+ self.post_quant_conv = nn.Conv2d(emb_channels, z_channels, 1)
48
+
49
+ def encode(self, img: torch.Tensor) -> "GaussianDistribution":
50
+ # Get embeddings with shape `[batch_size, z_channels * 2, z_height, z_height]`
51
+ z = self.encoder(img)
52
+ # Get the moments in the quantized embedding space
53
+ moments = self.quant_conv(z)
54
+ # Return the distribution
55
+ return GaussianDistribution(moments)
56
+
57
+ def decode(self, z: torch.Tensor):
58
+ # Map to embedding space from the quantized representation
59
+ z = self.post_quant_conv(z)
60
+ # Decode the image of shape `[batch_size, channels, height, width]`
61
+ return self.decoder(z)
62
+
63
+ def forward(self, x: torch.Tensor, sample_posterior: bool = False):
64
+ posterior = self.encode(x)
65
+ if sample_posterior:
66
+ z = posterior.sample()
67
+ else:
68
+ z = posterior.mode()
69
+ decoded_x = self.decode(z)
70
+ return decoded_x, posterior
71
+
72
+
73
+ class Encoder(nn.Module):
74
+ def __init__(
75
+ self,
76
+ *,
77
+ channels: int,
78
+ channel_multipliers: List[int],
79
+ n_resnet_blocks: int,
80
+ in_channels: int,
81
+ z_channels: int
82
+ ):
83
+ super().__init__()
84
+
85
+ # Number of blocks of different resolutions.
86
+ # The resolution is halved at the end each top level block
87
+ n_resolutions = len(channel_multipliers)
88
+
89
+ # Initial $3 \times 3$ convolution layer that maps the image to `channels`
90
+ self.conv_in = nn.Conv2d(in_channels, channels, 3, stride=1, padding=1)
91
+
92
+ # Number of channels in each top level block
93
+ channels_list = [m * channels for m in [1] + channel_multipliers]
94
+
95
+ # List of top-level blocks
96
+ self.down = nn.ModuleList()
97
+ # Create top-level blocks
98
+ for i in range(n_resolutions):
99
+ # Each top level block consists of multiple ResNet Blocks and down-sampling
100
+ resnet_blocks = nn.ModuleList()
101
+ # Add ResNet Blocks
102
+ for _ in range(n_resnet_blocks):
103
+ resnet_blocks.append(ResnetBlock(channels, channels_list[i + 1]))
104
+ channels = channels_list[i + 1]
105
+ # Top-level block
106
+ down = nn.Module()
107
+ down.block = resnet_blocks
108
+ # Down-sampling at the end of each top level block except the last
109
+ if i != n_resolutions - 1:
110
+ down.downsample = DownSampleBlock(channels)
111
+ else:
112
+ down.downsample = nn.Identity()
113
+ #
114
+ self.down.append(down)
115
+
116
+ # Final ResNet blocks with attention
117
+ self.mid = nn.Module()
118
+ self.mid.block_1 = ResnetBlock(channels, channels)
119
+ self.mid.attn_1 = AttentionBlock(channels)
120
+ self.mid.block_2 = ResnetBlock(channels, channels)
121
+
122
+ # Map to embedding space with a $3 \times 3$ convolution
123
+ self.norm_out = GroupNorm(channels)
124
+ self.conv_out = nn.Conv2d(channels, 2 * z_channels, 3, stride=1, padding=1)
125
+
126
+ def forward(self, img: torch.Tensor):
127
+ # Map to `channels` with the initial convolution
128
+ x = self.conv_in(img)
129
+
130
+ # Top-level blocks
131
+ for down in self.down:
132
+ # ResNet Blocks
133
+ for block in down.block:
134
+ x = block(x)
135
+ # Down-sampling
136
+ x = down.downsample(x)
137
+
138
+ # Final ResNet blocks with attention
139
+ x = self.mid.block_1(x)
140
+ x = self.mid.attn_1(x)
141
+ x = self.mid.block_2(x)
142
+
143
+ # Normalize and map to embedding space
144
+ x = self.norm_out(x)
145
+ x = F.silu(x)
146
+ x = self.conv_out(x)
147
+
148
+ return x
149
+
150
+
151
+ class Decoder(nn.Module):
152
+
153
+ def __init__(
154
+ self,
155
+ *,
156
+ channels: int,
157
+ channel_multipliers: List[int],
158
+ n_resnet_blocks: int,
159
+ out_channels: int,
160
+ z_channels: int
161
+ ):
162
+ super().__init__()
163
+
164
+ # Number of blocks of different resolutions.
165
+ # The resolution is halved at the end each top level block
166
+ num_resolutions = len(channel_multipliers)
167
+
168
+ # Number of channels in each top level block, in the reverse order
169
+ channels_list = [m * channels for m in channel_multipliers]
170
+
171
+ # Number of channels in the top-level block
172
+ channels = channels_list[-1]
173
+
174
+ # Initial $3 \times 3$ convolution layer that maps the embedding space to `channels`
175
+ self.conv_in = nn.Conv2d(z_channels, channels, 3, stride=1, padding=1)
176
+
177
+ # ResNet blocks with attention
178
+ self.mid = nn.Module()
179
+ self.mid.block_1 = ResnetBlock(channels, channels)
180
+ self.mid.attn_1 = AttentionBlock(channels)
181
+ self.mid.block_2 = ResnetBlock(channels, channels)
182
+
183
+ # List of top-level blocks
184
+ self.up = nn.ModuleList()
185
+ # Create top-level blocks
186
+ for i in reversed(range(num_resolutions)):
187
+ # Each top level block consists of multiple ResNet Blocks and up-sampling
188
+ resnet_blocks = nn.ModuleList()
189
+ # Add ResNet Blocks
190
+ for _ in range(n_resnet_blocks + 1):
191
+ resnet_blocks.append(ResnetBlock(channels, channels_list[i]))
192
+ channels = channels_list[i]
193
+ # Top-level block
194
+ up = nn.Module()
195
+ up.block = resnet_blocks
196
+ # Up-sampling at the end of each top level block except the first
197
+ if i != 0:
198
+ up.upsample = UpSampleBlock(channels)
199
+ else:
200
+ up.upsample = nn.Identity()
201
+ # Prepend to be consistent with the checkpoint
202
+ self.up.insert(0, up)
203
+
204
+ # Map to image space with a $3 \times 3$ convolution
205
+ self.norm_out = GroupNorm(channels)
206
+ self.conv_out = nn.Conv2d(channels, out_channels, 3, stride=1, padding=1)
207
+
208
+ def forward(self, z: torch.Tensor):
209
+ # Map to `channels` with the initial convolution
210
+ h = self.conv_in(z)
211
+
212
+ # ResNet blocks with attention
213
+ h = self.mid.block_1(h)
214
+ h = self.mid.attn_1(h)
215
+ h = self.mid.block_2(h)
216
+
217
+ # Top-level blocks
218
+ for up in reversed(self.up):
219
+ # ResNet Blocks
220
+ for block in up.block:
221
+ h = block(h)
222
+ # Up-sampling
223
+ h = up.upsample(h)
224
+
225
+ # Normalize and map to image space
226
+ h = self.norm_out(h)
227
+ h = F.silu(h)
228
+ img = self.conv_out(h)
229
+
230
+ return img
231
+
232
+
233
+ class GaussianDistribution:
234
+ def __init__(self, parameters: torch.Tensor):
235
+ # Split mean and log of variance
236
+ self.mean, log_var = torch.chunk(parameters, 2, dim=1)
237
+ # Clamp the log of variances
238
+ self.log_var = torch.clamp(log_var, -30.0, 20.0)
239
+ # Calculate standard deviation
240
+ self.std = torch.exp(0.5 * self.log_var)
241
+
242
+ def sample(self):
243
+ # Sample from the distribution
244
+ return self.mean + self.std * torch.randn_like(self.std)
245
+
246
+ def mode(self):
247
+ return self.mean
swim/blocks.py CHANGED
@@ -36,23 +36,24 @@ class GroupNorm(nn.Module):
36
  return self.group_norm(x)
37
 
38
 
39
- class UpsampleBlock(nn.Module):
40
  def __init__(self, channels: int):
41
  super().__init__()
42
  self.conv = nn.Conv2d(channels, channels, 3, padding=1)
43
 
44
  def forward(self, x: torch.Tensor):
45
- x = F.interpolate(x, scale_factor=2, mode="nearest")
46
  return self.conv(x)
47
 
48
 
49
- class DownsampleBlock(nn.Module):
50
  def __init__(self, channels: int):
51
  super().__init__()
52
- self.op = nn.Conv2d(channels, channels, 3, stride=2, padding=1)
53
 
54
  def forward(self, x: torch.Tensor):
55
- return self.op(x)
 
56
 
57
 
58
  class TimestepBlock(nn.Module):
@@ -128,12 +129,6 @@ class ResnetBlock(nn.Module):
128
 
129
 
130
  class AttentionBlock(nn.Module):
131
- """Attention mechanism similar to transformers but for CNNs, paper https://arxiv.org/abs/1805.08318
132
-
133
- Args:
134
- in_channels (int): Number of channels in the input tensor.
135
- """
136
-
137
  def __init__(self, in_channels: int) -> None:
138
  super().__init__()
139
 
@@ -183,3 +178,50 @@ class AttentionBlock(nn.Module):
183
 
184
  # adding the identity to the output
185
  return x + attention
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  return self.group_norm(x)
37
 
38
 
39
+ class UpSampleBlock(nn.Module):
40
  def __init__(self, channels: int):
41
  super().__init__()
42
  self.conv = nn.Conv2d(channels, channels, 3, padding=1)
43
 
44
  def forward(self, x: torch.Tensor):
45
+ x = F.interpolate(x, scale_factor=2.0, mode="nearest")
46
  return self.conv(x)
47
 
48
 
49
+ class DownSampleBlock(nn.Module):
50
  def __init__(self, channels: int):
51
  super().__init__()
52
+ self.conv = nn.Conv2d(channels, channels, 3, stride=2, padding=0)
53
 
54
  def forward(self, x: torch.Tensor):
55
+ x = F.pad(x, (0, 1, 0, 1), mode="constant", value=0)
56
+ return self.conv(x)
57
 
58
 
59
  class TimestepBlock(nn.Module):
 
129
 
130
 
131
  class AttentionBlock(nn.Module):
 
 
 
 
 
 
132
  def __init__(self, in_channels: int) -> None:
133
  super().__init__()
134
 
 
178
 
179
  # adding the identity to the output
180
  return x + attention
181
+
182
+
183
+ class AttentionBlock(nn.Module):
184
+ def __init__(self, channels: int):
185
+ super().__init__()
186
+ # Group normalization
187
+ self.norm = GroupNorm(channels)
188
+ # Query, key and value mappings
189
+ self.q = nn.Conv2d(channels, channels, 1)
190
+ self.k = nn.Conv2d(channels, channels, 1)
191
+ self.v = nn.Conv2d(channels, channels, 1)
192
+
193
+ self.proj_out = nn.Conv2d(channels, channels, 1)
194
+
195
+ # Attention scaling factor
196
+ self.scale = channels**-0.5
197
+
198
+ def forward(self, x: torch.Tensor):
199
+ # Normalize `x`
200
+ x_norm = self.norm(x)
201
+ # Get query, key and vector embeddings
202
+ q = self.q(x_norm)
203
+ k = self.k(x_norm)
204
+ v = self.v(x_norm)
205
+
206
+ # Reshape to query, key and vector embeedings from
207
+ # `[batch_size, channels, height, width]` to
208
+ # `[batch_size, channels, height * width]`
209
+ b, c, h, w = q.shape
210
+ q = q.view(b, c, h * w)
211
+ k = k.view(b, c, h * w)
212
+ v = v.view(b, c, h * w)
213
+
214
+ # Compute $\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)$
215
+ attn = torch.einsum("bci,bcj->bij", q, k) * self.scale
216
+ attn = F.softmax(attn, dim=2)
217
+
218
+ # Compute $\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$
219
+ out = torch.einsum("bij,bcj->bci", attn, v)
220
+
221
+ # Reshape back to `[batch_size, channels, height, width]`
222
+ out = out.view(b, c, h, w)
223
+ # Final $1 \times 1$ convolution layer
224
+ out = self.proj_out(out)
225
+
226
+ # Add residual connection
227
+ return x + out
swim/codeblock.py CHANGED
@@ -2,7 +2,7 @@ import torch
2
  import torch.nn as nn
3
 
4
 
5
- class CodeBook(nn.Module):
6
  def __init__(
7
  self, num_codebook_vectors: int = 1024, latent_dim: int = 256, beta: int = 0.25
8
  ):
 
2
  import torch.nn as nn
3
 
4
 
5
+ class SwimCodeBook(nn.Module):
6
  def __init__(
7
  self, num_codebook_vectors: int = 1024, latent_dim: int = 256, beta: int = 0.25
8
  ):
swim/encoder.py DELETED
@@ -1,90 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
-
4
- from .blocks import DownsampleBlock, GroupNorm, AttentionBlock, ResnetBlock
5
-
6
-
7
- class SwimEncoder(nn.Module):
8
- """
9
- The encoder part of the VQGAN.
10
-
11
- Args:
12
- img_channels (int): Number of channels in the input image.
13
- image_size (int): Size of the input image, only used in encoder (height or width ).
14
- latent_channels (int): Number of channels in the latent vector.
15
- intermediate_channels (list): List of channels in the intermediate layers.
16
- num_residual_blocks (int): Number of residual blocks b/w each downsample block.
17
- dropout (float): Dropout probability for residual blocks.
18
- attention_resolution (list): tensor size ( height or width ) at which to add attention blocks
19
- """
20
-
21
- def __init__(
22
- self,
23
- img_channels: int = 3,
24
- image_size: int = 256,
25
- latent_channels: int = 256,
26
- intermediate_channels: list = [128, 128, 256, 256, 512],
27
- num_residual_blocks: int = 2,
28
- dropout: float = 0.0,
29
- attention_resolution: list = [16],
30
- ):
31
- super().__init__()
32
-
33
- # Inserting first intermediate channel to index 0
34
- intermediate_channels.insert(0, intermediate_channels[0])
35
-
36
- # Appends all the layers to this list
37
- layers = []
38
-
39
- # Addingt the first conv layer increase input channels to the first intermediate channels
40
- layers.append(
41
- nn.Conv2d(
42
- img_channels,
43
- intermediate_channels[0],
44
- kernel_size=3,
45
- stride=1,
46
- padding=1,
47
- )
48
- )
49
-
50
- # Loop over the intermediate channels except the last one
51
- for n in range(len(intermediate_channels) - 1):
52
- in_channels = intermediate_channels[n]
53
- out_channels = intermediate_channels[n + 1]
54
-
55
- # Adding the residual blocks for each channel
56
- for _ in range(num_residual_blocks):
57
- layers.append(ResnetBlock(in_channels, out_channels, dropout=dropout))
58
- in_channels = out_channels
59
-
60
- # Once we have downsampled the image to the size in attention resolution, we add attention blocks
61
- if image_size in attention_resolution:
62
- layers.append(AttentionBlock(in_channels))
63
-
64
- # only downsample for the first n-2 layers, and decrease the input size by a factor of 2
65
- if n != len(intermediate_channels) - 2:
66
- layers.append(DownsampleBlock(intermediate_channels[n + 1]))
67
- image_size = image_size // 2 # Downsample by a factor of 2
68
-
69
- in_channels = intermediate_channels[-1]
70
- layers.extend(
71
- [
72
- ResnetBlock(
73
- in_channels=in_channels, out_channels=in_channels, dropout=dropout
74
- ),
75
- AttentionBlock(in_channels=in_channels),
76
- ResnetBlock(
77
- in_channels=in_channels, out_channels=in_channels, dropout=dropout
78
- ),
79
- GroupNorm(in_channels=in_channels),
80
- nn.SiLU(),
81
- # increase the channels upto the latent vector channels
82
- nn.Conv2d(
83
- in_channels, latent_channels, kernel_size=3, stride=1, padding=1
84
- ),
85
- ]
86
- )
87
- self.model = nn.Sequential(*layers)
88
-
89
- def forward(self, x: torch.Tensor) -> torch.Tensor:
90
- return self.model(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
train.py CHANGED
@@ -1,8 +1,72 @@
1
  import torch
2
  from torchinfo import summary
3
- from swim.encoder import SwimEncoder
4
 
5
- encoder = SwimEncoder().to("meta")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  sample = torch.randn(1, 3, 512, 512).to("meta")
7
 
8
- summary(encoder, input_data=(sample,))
 
 
 
 
 
 
 
1
  import torch
2
  from torchinfo import summary
 
3
 
4
+ from swim.autoencoder import Autoencoder
5
+ from diffusers import AutoencoderKL, UNet2DModel
6
+
7
+ # vae = Autoencoder(
8
+ # z_channels=4,
9
+ # in_channels=3,
10
+ # channels=128,
11
+ # channel_multipliers=[1, 2, 4, 4],
12
+ # n_resnet_blocks=2,
13
+ # emb_channels=4,
14
+ # ).to("meta")
15
+ # lol_vae = AutoencoderKL.from_pretrained(
16
+ # "stabilityai/stable-diffusion-2-1", subfolder="vae"
17
+ # ).to("meta")
18
+
19
+ # # copy weights from lol_vae to vae
20
+ # import json
21
+
22
+ # with open("lolvae.json", "w") as f:
23
+ # json.dump(list(lol_vae.state_dict().keys()), f, indent=4)
24
+
25
+ # with open("vae.json", "w") as f:
26
+ # json.dump(list(vae.state_dict().keys()), f, indent=4)
27
+
28
+ # sample = torch.randn(1, 3, 512, 512).to("meta")
29
+ # # lantent = vae.encoder(sample)
30
+
31
+ from diffusers import UNet2DModel
32
+
33
+ model = UNet2DModel(
34
+ sample_size=512, # the target image resolution
35
+ in_channels=3, # the number of input channels, 3 for RGB images
36
+ out_channels=3, # the number of output channels
37
+ layers_per_block=2, # how many ResNet layers to use per UNet block
38
+ block_out_channels=(
39
+ 128,
40
+ 128,
41
+ 256,
42
+ 256,
43
+ 512,
44
+ 512,
45
+ ), # the number of output channels for each UNet block
46
+ down_block_types=(
47
+ "DownBlock2D", # a regular ResNet downsampling block
48
+ "DownBlock2D",
49
+ "DownBlock2D",
50
+ "DownBlock2D",
51
+ "AttnDownBlock2D", # a ResNet downsampling block with spatial self-attention
52
+ "DownBlock2D",
53
+ ),
54
+ up_block_types=(
55
+ "UpBlock2D", # a regular ResNet upsampling block
56
+ "AttnUpBlock2D", # a ResNet upsampling block with spatial self-attention
57
+ "UpBlock2D",
58
+ "UpBlock2D",
59
+ "UpBlock2D",
60
+ "UpBlock2D",
61
+ ),
62
+ ).to("meta")
63
+
64
  sample = torch.randn(1, 3, 512, 512).to("meta")
65
 
66
+ summary(
67
+ model,
68
+ input_data=(
69
+ sample,
70
+ 0,
71
+ ),
72
+ )